Skip to content

Commit

Permalink
Fix context propagation in Executor#execute() for non-capturing lambd…
Browse files Browse the repository at this point in the history
…as (#9179)

Co-authored-by: Trask Stalnaker <[email protected]>
  • Loading branch information
Mateusz Rzeszutek and trask authored Aug 11, 2023
1 parent 54a2a6a commit 32c5d4c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.bootstrap.executors;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;

public final class ContextPropagatingRunnable implements Runnable {

public static boolean shouldDecorateRunnable(Runnable task) {
// We wrap only lambdas' anonymous classes and if given object has not already been wrapped.
// Anonymous classes have '/' in class name which is not allowed in 'normal' classes.
// note: it is always safe to decorate lambdas since downstream code cannot be expecting a
// specific runnable implementation anyways
return task.getClass().getName().contains("/") && !(task instanceof ContextPropagatingRunnable);
}

public static Runnable propagateContext(Runnable task, Context context) {
return new ContextPropagatingRunnable(task, context);
}

private final Runnable delegate;
private final Context context;

private ContextPropagatingRunnable(Runnable delegate, Context context) {
this.delegate = delegate;
this.context = context;
}

@Override
public void run() {
try (Scope ignored = context.makeCurrent()) {
delegate.run();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.opentelemetry.context.Context;
import io.opentelemetry.instrumentation.api.util.VirtualField;
import io.opentelemetry.javaagent.bootstrap.Java8BytecodeBridge;
import io.opentelemetry.javaagent.bootstrap.executors.ContextPropagatingRunnable;
import io.opentelemetry.javaagent.bootstrap.executors.ExecutorAdviceHelper;
import io.opentelemetry.javaagent.bootstrap.executors.PropagatedContext;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
Expand Down Expand Up @@ -80,12 +81,16 @@ public static class SetExecuteRunnableStateAdvice {
public static PropagatedContext enterJobSubmit(
@Advice.Argument(value = 0, readOnly = false) Runnable task) {
Context context = Java8BytecodeBridge.currentContext();
if (ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
VirtualField<Runnable, PropagatedContext> virtualField =
VirtualField.find(Runnable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
if (!ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
return null;
}
return null;
if (ContextPropagatingRunnable.shouldDecorateRunnable(task)) {
task = ContextPropagatingRunnable.propagateContext(task, context);
return null;
}
VirtualField<Runnable, PropagatedContext> virtualField =
VirtualField.find(Runnable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
}

@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.executors;

import static org.assertj.core.api.Assertions.assertThat;

import io.opentelemetry.api.baggage.Baggage;
import io.opentelemetry.context.Scope;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;

// regression test for #9175
class LambdaContextPropagationTest {

// must be static! the lambda that uses that must be non-capturing
private static final AtomicInteger failureCounter = new AtomicInteger();

@Test
void shouldCorrectlyPropagateContextToRunnables() {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
// must text execute() -- other methods like submit() decorate the Runnable with a
// FutureTask
executor.execute(LambdaContextPropagationTest::assertBaggage);
}
}

assertThat(failureCounter).hasValue(0);
}

private static void assertBaggage() {
if (Baggage.current().getEntryValue("test") == null) {
failureCounter.incrementAndGet();
}
}
}

0 comments on commit 32c5d4c

Please sign in to comment.