Skip to content

Commit

Permalink
Skip Context Propagation in special operators
Browse files Browse the repository at this point in the history
This change should improve performance of Automatic Context Propagation
in certain cases when doOnDiscard, onErrorContinue, and onErrorStop are
used.

The context-propagation integration requires contextWrite and tap
operators to be barriers for restoring ThreadLocal values. Some internal
usage of contextWrite does not require us to treat the operators the
same way and we can skip the ceremony of restoring ThreadLocal state as
we know that no ThreadLocalAccessor can be registered for them.
Therefore, a private variant is introduced to avoid unnecessary overhead
when not required.

Related #3840
  • Loading branch information
chemicL committed Jul 12, 2024
1 parent 1b2b5e9 commit 9fead48
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 8 deletions.
16 changes: 12 additions & 4 deletions reactor-core/src/main/java/reactor/core/publisher/Flux.java
Original file line number Diff line number Diff line change
Expand Up @@ -4393,6 +4393,14 @@ public final Flux<T> contextWrite(Function<Context, Context> contextModifier) {
return onAssembly(new FluxContextWrite<>(this, contextModifier));
}

private final Flux<T> contextWriteSkippingContextPropagation(ContextView contextToAppend) {
return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend));
}

private final Flux<T> contextWriteSkippingContextPropagation(Function<Context, Context> contextModifier) {
return onAssembly(new FluxContextWrite<>(this, contextModifier));
}

/**
* Counts the number of values in this {@link Flux}.
* The count will be emitted when onComplete is observed.
Expand Down Expand Up @@ -4866,7 +4874,7 @@ public final Flux<T> doOnComplete(Runnable onComplete) {
* @return a {@link Flux} that cleans up matching elements that get discarded upstream of it.
*/
public final <R> Flux<T> doOnDiscard(final Class<R> type, final Consumer<? super R> discardHook) {
return contextWrite(Operators.discardLocalAdapter(type, discardHook));
return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook));
}

/**
Expand Down Expand Up @@ -7147,7 +7155,7 @@ public final Flux<T> onErrorComplete(Predicate<? super Throwable> predicate) {
*/
public final Flux<T> onErrorContinue(BiConsumer<Throwable, Object> errorConsumer) {
BiConsumer<Throwable, Object> genericConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resume(genericConsumer)
));
Expand Down Expand Up @@ -7231,7 +7239,7 @@ public final <E extends Throwable> Flux<T> onErrorContinue(Predicate<E> errorPre
@SuppressWarnings("unchecked")
Predicate<Throwable> genericPredicate = (Predicate<Throwable>) errorPredicate;
BiConsumer<Throwable, Object> genericErrorConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer)
));
Expand All @@ -7248,7 +7256,7 @@ public final <E extends Throwable> Flux<T> onErrorContinue(Predicate<E> errorPre
* was used downstream
*/
public final Flux<T> onErrorStop() {
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.stop()));
}
Expand Down
16 changes: 12 additions & 4 deletions reactor-core/src/main/java/reactor/core/publisher/Mono.java
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,14 @@ public final Mono<T> contextWrite(Function<Context, Context> contextModifier) {
return onAssembly(new MonoContextWrite<>(this, contextModifier));
}

private final Mono<T> contextWriteSkippingContextPropagation(ContextView contextToAppend) {
return contextWriteSkippingContextPropagation(c -> c.putAll(contextToAppend));
}

private final Mono<T> contextWriteSkippingContextPropagation(Function<Context, Context> contextModifier) {
return onAssembly(new MonoContextWrite<>(this, contextModifier));
}

/**
* Provide a default single value if this mono is completed without any data
*
Expand Down Expand Up @@ -2713,7 +2721,7 @@ public final Mono<T> doOnCancel(Runnable onCancel) {
* @return a {@link Mono} that cleans up matching elements that get discarded upstream of it.
*/
public final <R> Mono<T> doOnDiscard(final Class<R> type, final Consumer<? super R> discardHook) {
return contextWrite(Operators.discardLocalAdapter(type, discardHook));
return contextWriteSkippingContextPropagation(Operators.discardLocalAdapter(type, discardHook));
}

/**
Expand Down Expand Up @@ -3712,7 +3720,7 @@ public final Mono<T> onErrorComplete(Predicate<? super Throwable> predicate) {
*/
public final Mono<T> onErrorContinue(BiConsumer<Throwable, Object> errorConsumer) {
BiConsumer<Throwable, Object> genericConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resume(genericConsumer)
));
Expand Down Expand Up @@ -3802,7 +3810,7 @@ public final <E extends Throwable> Mono<T> onErrorContinue(Predicate<E> errorPre
@SuppressWarnings("unchecked")
Predicate<Throwable> genericPredicate = (Predicate<Throwable>) errorPredicate;
BiConsumer<Throwable, Object> genericErrorConsumer = errorConsumer;
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.resumeIf(genericPredicate, genericErrorConsumer)
));
Expand All @@ -3819,7 +3827,7 @@ public final <E extends Throwable> Mono<T> onErrorContinue(Predicate<E> errorPre
* was used downstream
*/
public final Mono<T> onErrorStop() {
return contextWrite(Context.of(
return contextWriteSkippingContextPropagation(Context.of(
OnNextFailureStrategy.KEY_ON_NEXT_ERROR_STRATEGY,
OnNextFailureStrategy.stop()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ThreadLocalAccessor;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -2453,4 +2455,124 @@ void fluxToIterable() {
assertThat(value.get()).isEqualTo("present");
}
}

@Nested
class SpecialContextAlteringOperators {

// The cases here consider operators like doOnDiscard(), which underneath
// utilize contextWrite() for its purpose. They are special in that we use them
// internally and do not anticipate the registered keys to be corresponding to
// any ThreadLocal values. That expectation is reasonable in user facing code
// as we don't know what keys are used and whether a ThreadLocalAccessor is
// registered for these keys. Therefore, in specific cases that are internal to
// reactor-core, we can skip ThreadLocal restoration in fragments of the chain.

// Explanation of GET/SET operations on TL in the scenarios here:
// When going UP, we use the value "present".
// When going DOWN, we clear the value or restore a captured empty value.
//
// 1 x GET in block() with implicit context capture
//
// 1 x GET going UP from contextWrite (read current to restore later)
// + 2 x SET going UP from contextWrite + SET restoring current later
//
// 1 x GET going DOWN from contextWrite with subscription (read current)
// + 2 x SET going DOWN from contextWrite + SET restoring current later
//
// 1 x GET going UP to request (read current)
// + 2 x SET going UP from contextWrite + SET restoring current later
//
// 1 x GET going DOWN to deliver onComplete (read current)
// + 2 x SET going DOWN from contextWrite + SET restoring current later

@Test
void discardFlux() {
CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor();
ContextRegistry.getInstance().registerThreadLocalAccessor(accessor);

AtomicInteger tlPresent = new AtomicInteger();
AtomicInteger discards = new AtomicInteger();

Flux.just("a")
.doOnEach(signal -> {
if (CountingThreadLocalAccessor.TL.get().equals("present")) {
tlPresent.incrementAndGet();
}
})
.filter(s -> false)
.doOnDiscard(String.class, s -> discards.incrementAndGet())
.count()
.contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present"))
.block();

assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete
assertThat(discards.get()).isEqualTo(1);
// 5 with doOnDiscard skipping TL restoration, 9 with restoring
assertThat(accessor.reads.get()).isEqualTo(5);
// 8 with doOnDiscard skipping TL restoration, 16 with restoring
assertThat(accessor.writes.get()).isEqualTo(8);

ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY);
}

@Test
void discardMono() {
CountingThreadLocalAccessor accessor = new CountingThreadLocalAccessor();
ContextRegistry.getInstance().registerThreadLocalAccessor(accessor);

AtomicInteger tlPresent = new AtomicInteger();
AtomicInteger discards = new AtomicInteger();

Mono.just("a")
.doOnEach(signal -> {
if (CountingThreadLocalAccessor.TL.get().equals("present")) {
tlPresent.incrementAndGet();
}
})
.filter(s -> false)
.doOnDiscard(String.class, s -> discards.incrementAndGet())
.contextWrite(ctx -> ctx.put(CountingThreadLocalAccessor.KEY, "present"))
.block();

assertThat(tlPresent.get()).isEqualTo(2); // 1 x onNext + 1 x onComplete
assertThat(discards.get()).isEqualTo(1);
// 5 with doOnDiscard skipping TL restoration, 9 with restoring
assertThat(accessor.reads.get()).isEqualTo(5);
// 8 with doOnDiscard skipping TL restoration, 16 with restoring
assertThat(accessor.writes.get()).isEqualTo(8);

ContextRegistry.getInstance().removeThreadLocalAccessor(CountingThreadLocalAccessor.KEY);
}
}

private static class CountingThreadLocalAccessor implements ThreadLocalAccessor<String> {
static final String KEY = "CTLA";
static final ThreadLocal<String> TL = new ThreadLocal<>();

AtomicInteger reads = new AtomicInteger();
AtomicInteger writes = new AtomicInteger();

@Override
public Object key() {
return KEY;
}

@Override
public String getValue() {
reads.incrementAndGet();
return TL.get();
}

@Override
public void setValue(String s) {
writes.incrementAndGet();
TL.set(s);
}

@Override
public void setValue() {
writes.incrementAndGet();
TL.remove();
}
}
}

0 comments on commit 9fead48

Please sign in to comment.