Skip to content

Commit

Permalink
Spring scheduling: run error handler with the same context as task (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
laurit authored Apr 6, 2023
1 parent fbdd611 commit 7c9cf7a
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.spring.scheduling.v3_1;

import static net.bytebuddy.matcher.ElementMatchers.isConstructor;
import static net.bytebuddy.matcher.ElementMatchers.isPublic;
import static net.bytebuddy.matcher.ElementMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.takesArgument;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.javaagent.bootstrap.Java8BytecodeBridge;
import io.opentelemetry.javaagent.extension.instrumentation.TypeInstrumentation;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;
import org.springframework.util.ErrorHandler;

public class DelegatingErrorHandlingRunnableInstrumentation implements TypeInstrumentation {
@Override
public ElementMatcher<TypeDescription> typeMatcher() {
return named("org.springframework.scheduling.support.DelegatingErrorHandlingRunnable");
}

@Override
public void transform(TypeTransformer transformer) {
transformer.applyAdviceToMethod(
isConstructor().and(takesArgument(1, named("org.springframework.util.ErrorHandler"))),
this.getClass().getName() + "$WrapErrorHandlerAdvice");

transformer.applyAdviceToMethod(
isPublic().and(named("run")), this.getClass().getName() + "$RunAdvice");
}

@SuppressWarnings("unused")
public static class WrapErrorHandlerAdvice {

@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 1, readOnly = false) ErrorHandler errorHandler) {
if (errorHandler != null) {
errorHandler = new ErrorHandlerWrapper(errorHandler);
}
}
}

@SuppressWarnings("unused")
public static class RunAdvice {

@Advice.OnMethodEnter(suppress = Throwable.class)
public static Scope onEnter() {
Context parentContext = Java8BytecodeBridge.currentContext();
return TaskContextHolder.init(parentContext).makeCurrent();
}

@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
public static void onExit(@Advice.Enter Scope scope) {
if (scope != null) {
scope.close();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.spring.scheduling.v3_1;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import org.springframework.util.ErrorHandler;

public final class ErrorHandlerWrapper implements ErrorHandler {
private final ErrorHandler errorHandler;

public ErrorHandlerWrapper(ErrorHandler errorHandler) {
this.errorHandler = errorHandler;
}

@Override
public void handleError(Throwable throwable) {
Context taskContext = TaskContextHolder.getTaskContext(Context.current());
// run the error handler with the same context as task execution
try (Scope ignore = taskContext != null ? taskContext.makeCurrent() : null) {
errorHandler.handleError(throwable);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package io.opentelemetry.javaagent.instrumentation.spring.scheduling.v3_1;

import static java.util.Collections.singletonList;
import static java.util.Arrays.asList;

import com.google.auto.service.AutoService;
import io.opentelemetry.javaagent.extension.instrumentation.InstrumentationModule;
Expand All @@ -21,6 +21,7 @@ public SpringSchedulingInstrumentationModule() {

@Override
public List<TypeInstrumentation> typeInstrumentations() {
return singletonList(new TaskSchedulerInstrumentation());
return asList(
new TaskSchedulerInstrumentation(), new DelegatingErrorHandlingRunnableInstrumentation());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public void run() {
}

Context context = instrumenter().start(parentContext, runnable);
// remember the context, so it could be reused in error handler
TaskContextHolder.set(context);
try (Scope ignored = context.makeCurrent()) {
runnable.run();
instrumenter().end(context, runnable, null, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.spring.scheduling.v3_1;

import static io.opentelemetry.context.ContextKey.named;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.ContextKey;
import io.opentelemetry.context.ImplicitContextKeyed;
import javax.annotation.Nullable;

public final class TaskContextHolder implements ImplicitContextKeyed {

private static final ContextKey<TaskContextHolder> KEY =
named("opentelemetry-spring-scheduling-task");

private Context taskContext;

private TaskContextHolder() {}

public static Context init(Context context) {
if (context.get(KEY) != null) {
return context;
}
return context.with(new TaskContextHolder());
}

public static void set(Context taskContext) {
TaskContextHolder holder = taskContext.get(KEY);
if (holder != null) {
holder.taskContext = taskContext;
}
}

@Nullable
public static Context getTaskContext(Context context) {
Context taskContext = null;
TaskContextHolder holder = context.get(KEY);
if (holder != null) {
taskContext = holder.taskContext;
}
return taskContext;
}

@Override
public Context storeInContext(Context context) {
return context.with(KEY, this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/

import io.opentelemetry.api.trace.StatusCode
import io.opentelemetry.instrumentation.test.AgentInstrumentationSpecification
import io.opentelemetry.semconv.trace.attributes.SemanticAttributes
import org.springframework.context.annotation.AnnotationConfigApplicationContext

import java.util.concurrent.CountDownLatch
Expand Down Expand Up @@ -121,4 +123,41 @@ class SpringSchedulingTest extends AgentInstrumentationSpecification {
}
}
}

def "task with error test"() {
setup:
def context = new AnnotationConfigApplicationContext(TaskWithErrorConfig)
def task = context.getBean(TaskWithError)

task.blockUntilExecute()

expect:
assert task != null
assertTraces(1) {
trace(0, 2) {
span(0) {
name "TaskWithError.run"
hasNoParent()
status StatusCode.ERROR
attributes {
"job.system" "spring_scheduling"
"code.namespace" "TaskWithError"
"code.function" "run"
}
event(0) {
eventName "$SemanticAttributes.EXCEPTION_EVENT_NAME"
attributes {
"$SemanticAttributes.EXCEPTION_TYPE" IllegalStateException.getName()
"$SemanticAttributes.EXCEPTION_MESSAGE" "failure"
"$SemanticAttributes.EXCEPTION_STACKTRACE" String
}
}
}
span(1) {
name "error-handler"
childOf(span(0))
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

@Component
public class TaskWithError implements Runnable {

private final CountDownLatch latch = new CountDownLatch(1);

@Scheduled(fixedRate = 5000)
@Override
public void run() {
latch.countDown();
throw new IllegalStateException("failure");
}

public void blockUntilExecute() throws InterruptedException {
latch.await(5, TimeUnit.SECONDS);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

import io.opentelemetry.instrumentation.testing.GlobalTraceUtil;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler;

@Configuration
@EnableScheduling
public class TaskWithErrorConfig {
@Bean
public TaskWithError task() {
return new TaskWithError();
}

@Bean
public TaskScheduler taskScheduler() {
ConcurrentTaskScheduler scheduler = new ConcurrentTaskScheduler();
scheduler.setErrorHandler(throwable -> GlobalTraceUtil.runWithSpan("error-handler", () -> {}));
return scheduler;
}
}

0 comments on commit 7c9cf7a

Please sign in to comment.