diff --git a/instrumentation/executors/bootstrap/src/main/java/io/opentelemetry/javaagent/bootstrap/executors/ContextPropagatingRunnable.java b/instrumentation/executors/bootstrap/src/main/java/io/opentelemetry/javaagent/bootstrap/executors/ContextPropagatingRunnable.java new file mode 100644 index 000000000000..f3e4846c9adb --- /dev/null +++ b/instrumentation/executors/bootstrap/src/main/java/io/opentelemetry/javaagent/bootstrap/executors/ContextPropagatingRunnable.java @@ -0,0 +1,39 @@ +/* + * Copyright The OpenTelemetry Authors + * SPDX-License-Identifier: Apache-2.0 + */ + +package io.opentelemetry.javaagent.bootstrap.executors; + +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; + +public final class ContextPropagatingRunnable implements Runnable { + + public static boolean shouldDecorateRunnable(Runnable task) { + // We wrap only lambdas' anonymous classes and if given object has not already been wrapped. + // Anonymous classes have '/' in class name which is not allowed in 'normal' classes. + // note: it is always safe to decorate lambdas since downstream code cannot be expecting a + // specific runnable implementation anyways + return task.getClass().getName().contains("/") && !(task instanceof ContextPropagatingRunnable); + } + + public static Runnable propagateContext(Runnable task, Context context) { + return new ContextPropagatingRunnable(task, context); + } + + private final Runnable delegate; + private final Context context; + + private ContextPropagatingRunnable(Runnable delegate, Context context) { + this.delegate = delegate; + this.context = context; + } + + @Override + public void run() { + try (Scope ignored = context.makeCurrent()) { + delegate.run(); + } + } +} diff --git a/instrumentation/executors/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/executors/JavaExecutorInstrumentation.java b/instrumentation/executors/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/executors/JavaExecutorInstrumentation.java index 07949cbe16b1..a1835d71948e 100644 --- a/instrumentation/executors/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/executors/JavaExecutorInstrumentation.java +++ b/instrumentation/executors/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/executors/JavaExecutorInstrumentation.java @@ -16,6 +16,7 @@ import io.opentelemetry.context.Context; import io.opentelemetry.instrumentation.api.util.VirtualField; import io.opentelemetry.javaagent.bootstrap.Java8BytecodeBridge; +import io.opentelemetry.javaagent.bootstrap.executors.ContextPropagatingRunnable; import io.opentelemetry.javaagent.bootstrap.executors.ExecutorAdviceHelper; import io.opentelemetry.javaagent.bootstrap.executors.PropagatedContext; import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer; @@ -80,12 +81,16 @@ public static class SetExecuteRunnableStateAdvice { public static PropagatedContext enterJobSubmit( @Advice.Argument(value = 0, readOnly = false) Runnable task) { Context context = Java8BytecodeBridge.currentContext(); - if (ExecutorAdviceHelper.shouldPropagateContext(context, task)) { - VirtualField virtualField = - VirtualField.find(Runnable.class, PropagatedContext.class); - return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task); + if (!ExecutorAdviceHelper.shouldPropagateContext(context, task)) { + return null; } - return null; + if (ContextPropagatingRunnable.shouldDecorateRunnable(task)) { + task = ContextPropagatingRunnable.propagateContext(task, context); + return null; + } + VirtualField virtualField = + VirtualField.find(Runnable.class, PropagatedContext.class); + return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task); } @Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class) diff --git a/instrumentation/executors/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/executors/LambdaContextPropagationTest.java b/instrumentation/executors/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/executors/LambdaContextPropagationTest.java new file mode 100644 index 000000000000..ad46803d30ab --- /dev/null +++ b/instrumentation/executors/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/executors/LambdaContextPropagationTest.java @@ -0,0 +1,44 @@ +/* + * Copyright The OpenTelemetry Authors + * SPDX-License-Identifier: Apache-2.0 + */ + +package io.opentelemetry.javaagent.instrumentation.executors; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.context.Scope; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +// regression test for #9175 +class LambdaContextPropagationTest { + + // must be static! the lambda that uses that must be non-capturing + private static final AtomicInteger failureCounter = new AtomicInteger(); + + @Test + void shouldCorrectlyPropagateContextToRunnables() { + ExecutorService executor = Executors.newSingleThreadExecutor(); + + Baggage baggage = Baggage.builder().put("test", "test").build(); + try (Scope ignored = baggage.makeCurrent()) { + for (int i = 0; i < 20; i++) { + // must text execute() -- other methods like submit() decorate the Runnable with a + // FutureTask + executor.execute(LambdaContextPropagationTest::assertBaggage); + } + } + + assertThat(failureCounter).hasValue(0); + } + + private static void assertBaggage() { + if (Baggage.current().getEntryValue("test") == null) { + failureCounter.incrementAndGet(); + } + } +}