Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scanning instrumented reactor publishers and only allow registe… #5755

Merged
merged 5 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.annotation.support.async.AsyncOperationEndStrategies;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.CoreSubscriber;
import reactor.core.Fuseable;
Expand All @@ -39,6 +41,8 @@
/** Based on Spring Sleuth's Reactor instrumentation. */
public final class ContextPropagationOperator {

private static final Object VALUE = new Object();

public static ContextPropagationOperator create() {
return builder().build();
}
Expand All @@ -57,7 +61,7 @@ public String toString() {
}
};

private static volatile boolean enabled = false;
private static final AtomicBoolean enabled = new AtomicBoolean();

/**
* Stores Trace {@link io.opentelemetry.context.Context} in Reactor {@link
Expand Down Expand Up @@ -99,16 +103,19 @@ public static Context getOpenTelemetryContext(
* application.
*/
public void registerOnEachOperator() {
Hooks.onEachOperator(TracingSubscriber.class.getName(), tracingLift(asyncOperationEndStrategy));
AsyncOperationEndStrategies.instance().registerStrategy(asyncOperationEndStrategy);
enabled = true;
if (enabled.compareAndSet(false, true)) {
Hooks.onEachOperator(
TracingSubscriber.class.getName(), tracingLift(asyncOperationEndStrategy));
AsyncOperationEndStrategies.instance().registerStrategy(asyncOperationEndStrategy);
}
}

/** Unregisters the hook registered by {@link #registerOnEachOperator()}. */
public void resetOnEachOperator() {
Hooks.resetOnEachOperator(TracingSubscriber.class.getName());
AsyncOperationEndStrategies.instance().unregisterStrategy(asyncOperationEndStrategy);
enabled = false;
if (enabled.compareAndSet(true, false)) {
Hooks.resetOnEachOperator(TracingSubscriber.class.getName());
AsyncOperationEndStrategies.instance().unregisterStrategy(asyncOperationEndStrategy);
}
}

private static <T> Function<? super Publisher<T>, ? extends Publisher<T>> tracingLift(
Expand All @@ -118,29 +125,27 @@ public void resetOnEachOperator() {

/** Forces Mono to run in traceContext scope. */
public static <T> Mono<T> runWithContext(Mono<T> publisher, Context tracingContext) {
if (!enabled) {
if (!enabled.get()) {
return publisher;
}

// this hack forces 'publisher' to run in the onNext callback of `TracingSubscriber`
// (created for this publisher) and with current() span that refers to span created here
// without the hack, publisher runs in the onAssembly stage, before traceContext is made current
return ScalarPropagatingMono.INSTANCE
.flatMap(i -> publisher)
return ScalarPropagatingMono.create(publisher)
.subscriberContext(ctx -> storeOpenTelemetryContext(ctx, tracingContext));
}

/** Forces Flux to run in traceContext scope. */
public static <T> Flux<T> runWithContext(Flux<T> publisher, Context tracingContext) {
if (!enabled) {
if (!enabled.get()) {
return publisher;
}

// this hack forces 'publisher' to run in the onNext callback of `TracingSubscriber`
// (created for this publisher) and with current() span that refers to span created here
// without the hack, publisher runs in the onAssembly stage, before traceContext is made current
return ScalarPropagatingFlux.INSTANCE
.flatMap(i -> publisher)
return ScalarPropagatingFlux.create(publisher)
.subscriberContext(ctx -> storeOpenTelemetryContext(ctx, tracingContext));
}

Expand Down Expand Up @@ -177,29 +182,61 @@ static void subscribeInActiveSpan(CoreSubscriber<? super Object> actual, Object
}
}

static class ScalarPropagatingMono extends Mono<Object> {
public static final Mono<Object> INSTANCE = new ScalarPropagatingMono();
static class ScalarPropagatingMono extends Mono<Object> implements Scannable {

private final Object value = new Object();
static <T> Mono<T> create(Mono<T> source) {
return new ScalarPropagatingMono(source).flatMap(unused -> source);
}

private ScalarPropagatingMono() {}
private final Mono<?> source;

private ScalarPropagatingMono(Mono<?> source) {
this.source = source;
}

@Override
public void subscribe(CoreSubscriber<? super Object> actual) {
subscribeInActiveSpan(actual, value);
subscribeInActiveSpan(actual, VALUE);
}

@Override
@Nullable
// Interface method doesn't have type parameter so we can't add it either.
@SuppressWarnings("rawtypes")
public Object scanUnsafe(Attr attr) {
if (attr == Attr.PARENT) {
return source;
}
return null;
}
}

static class ScalarPropagatingFlux extends Flux<Object> {
public static final Flux<Object> INSTANCE = new ScalarPropagatingFlux();
static class ScalarPropagatingFlux extends Flux<Object> implements Scannable {

static <T> Flux<T> create(Flux<T> source) {
return new ScalarPropagatingFlux(source).flatMap(unused -> source);
}

private final Object value = new Object();
private final Flux<?> source;

private ScalarPropagatingFlux() {}
private ScalarPropagatingFlux(Flux<?> source) {
this.source = source;
}

@Override
public void subscribe(CoreSubscriber<? super Object> actual) {
subscribeInActiveSpan(actual, value);
subscribeInActiveSpan(actual, VALUE);
}

@Override
@Nullable
// Interface method doesn't have type parameter so we can't add it either.
@SuppressWarnings("rawtypes")
public Object scanUnsafe(Scannable.Attr attr) {
if (attr == Scannable.Attr.PARENT) {
return source;
}
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import reactor.core.Scannable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.UnicastProcessor;

class ReactorCoreTest extends AbstractReactorCoreTest {

Expand Down Expand Up @@ -69,17 +71,16 @@ void monoInNonBlockingPublisherAssembly() {

@Test
void fluxInNonBlockingPublisherAssembly() {
Flux<Integer> source =
Flux.defer(
() -> {
Span.current().setAttribute("inner", "foo");
return Flux.just(5, 6);
});
testing.runWithSpan(
"parent",
() ->
ContextPropagationOperator.ScalarPropagatingFlux.INSTANCE
.flatMap(
unused ->
Flux.defer(
() -> {
Span.current().setAttribute("inner", "foo");
return Flux.just(5, 6);
}))
ContextPropagationOperator.ScalarPropagatingFlux.create(source)
.doOnEach(
signal -> {
if (signal.isOnError()) {
Expand Down Expand Up @@ -199,9 +200,37 @@ void noTracingBeforeRegistration() {
trace -> trace.hasSpansSatisfyingExactly(span -> span.hasName("after").hasNoParent()));
}

@Test
void monoParentsAccessible() {
UnicastProcessor<String> source = UnicastProcessor.create();
Mono<String> mono =
ContextPropagationOperator.runWithContext(source.singleOrEmpty(), Context.root());

source.onNext("foo");
source.onComplete();

assertThat(mono.block()).isEqualTo("foo");

assertThat(((Scannable) mono).parents().filter(UnicastProcessor.class::isInstance).findFirst())
.isPresent();
}

@Test
void fluxParentsAccessible() {
UnicastProcessor<String> source = UnicastProcessor.create();
Flux<String> mono = ContextPropagationOperator.runWithContext(source, Context.root());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
Flux<String> mono = ContextPropagationOperator.runWithContext(source, Context.root());
Flux<String> flux = ContextPropagationOperator.runWithContext(source, Context.root());


source.onNext("foo");
source.onComplete();

assertThat(mono.collectList().block()).containsExactly("foo");

assertThat(((Scannable) mono).parents().filter(UnicastProcessor.class::isInstance).findFirst())
.isPresent();
}

private <T> Mono<T> monoSpan(Mono<T> mono, String spanName) {
return ContextPropagationOperator.ScalarPropagatingMono.INSTANCE
.flatMap(unused -> mono)
return ContextPropagationOperator.ScalarPropagatingMono.create(mono)
.doOnEach(
signal -> {
if (signal.isOnError()) {
Expand Down