Skip to content

Commit

Permalink
Ensure single completion signal through SubscriptionCommand #1576
Browse files Browse the repository at this point in the history
We now ensure that CommandWrapper and its subclass SubscriptionCommand trigger only a single completion signal. Previously, a race could happen inside the completion notification which could cause a completion and error signal or error + completion because of improper guarding.
  • Loading branch information
mp911de committed Jan 12, 2021
1 parent d79fc24 commit 04de69d
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 79 deletions.
78 changes: 23 additions & 55 deletions src/main/java/io/lettuce/core/RedisPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import io.lettuce.core.internal.ExceptionFactory;
import reactor.core.CoreSubscriber;
import reactor.core.Exceptions;
import reactor.util.context.Context;
import io.lettuce.core.api.StatefulConnection;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.internal.ExceptionFactory;
import io.lettuce.core.internal.LettuceAssert;
import io.lettuce.core.output.StreamingOutput;
import io.lettuce.core.protocol.CommandWrapper;
Expand Down Expand Up @@ -494,7 +494,7 @@ void dispatch(RedisSubscription<?> redisSubscription) {
*
* Refer to the individual states for more information.
*/
private enum State {
enum State {

/**
* The initial unsubscribed state. Will respond to {@link #subscribe(RedisSubscription, Subscriber)} by changing state
Expand Down Expand Up @@ -724,14 +724,12 @@ void onError(RedisSubscription<?> subscription, Throwable t) {
* @param <V> value type
* @param <T> response type
*/
private static class SubscriptionCommand<K, V, T> extends CommandWrapper<K, V, T> implements DemandAware.Sink {
static class SubscriptionCommand<K, V, T> extends CommandWrapper<K, V, T> implements DemandAware.Sink {

private final boolean dissolve;

private final RedisSubscription<T> subscription;

private volatile boolean completed = false;

private volatile DemandAware.Source source;

public SubscriptionCommand(RedisCommand<K, V, T> command, RedisSubscription<T> subscription, boolean dissolve) {
Expand All @@ -744,50 +742,40 @@ public SubscriptionCommand(RedisCommand<K, V, T> command, RedisSubscription<T> s

@Override
public boolean hasDemand() {
return completed || subscription.state() == State.COMPLETED || subscription.data.isEmpty();
return isDone() || subscription.state() == State.COMPLETED || subscription.data.isEmpty();
}

@Override
@SuppressWarnings("unchecked")
public void complete() {
@SuppressWarnings({ "unchecked", "CastCanBeRemovedNarrowingVariableType" })
protected void doOnComplete() {

if (completed) {
return;
}
if (getOutput() != null) {

try {
super.complete();
Object result = getOutput().get();

if (getOutput() != null) {
Object result = getOutput().get();

if (getOutput().hasError()) {
onError(ExceptionFactory.createExecutionException(getOutput().getError()));
completed = true;
return;
}
if (getOutput().hasError()) {
onError(ExceptionFactory.createExecutionException(getOutput().getError()));
return;
}

if (!(getOutput() instanceof StreamingOutput<?>) && result != null) {
if (!(getOutput() instanceof StreamingOutput<?>) && result != null) {

if (dissolve && result instanceof Collection) {
if (dissolve && result instanceof Collection) {

Collection<T> collection = (Collection<T>) result;
Collection<T> collection = (Collection<T>) result;

for (T t : collection) {
if (t != null) {
subscription.onNext(t);
}
for (T t : collection) {
if (t != null) {
subscription.onNext(t);
}
} else {
subscription.onNext((T) result);
}
} else {
subscription.onNext((T) result);
}
}

subscription.onAllDataRead();
} finally {
completed = true;
}

subscription.onAllDataRead();
}

@Override
Expand All @@ -801,28 +789,8 @@ public void removeSource() {
}

@Override
public void cancel() {

if (completed) {
return;
}

super.cancel();

completed = true;
}

@Override
public boolean completeExceptionally(Throwable throwable) {

if (completed) {
return false;
}

boolean b = super.completeExceptionally(throwable);
protected void doOnError(Throwable throwable) {
onError(throwable);
completed = true;
return b;
}

private void onError(Throwable throwable) {
Expand Down
73 changes: 52 additions & 21 deletions src/main/java/io/lettuce/core/protocol/CommandWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
*/
public class CommandWrapper<K, V, T> implements RedisCommand<K, V, T>, CompleteableCommand<T>, DecoratedCommand<K, V, T> {

@SuppressWarnings({ "rawtypes", "unchecked" })
@SuppressWarnings({ "rawtypes" })
private static final AtomicReferenceFieldUpdater<CommandWrapper, Object[]> ONCOMPLETE = AtomicReferenceFieldUpdater
.newUpdater(CommandWrapper.class, Object[].class, "onComplete");

@SuppressWarnings({ "rawtypes", "unchecked" })
private static final Object[] EMPTY = new Object[0];

private static final Object[] COMPLETE = new Object[0];

protected final RedisCommand<K, V, T> command;

// accessed via AtomicReferenceFieldUpdater.
Expand All @@ -53,15 +54,30 @@ public CommandOutput<K, V, T> getOutput() {
}

@Override
@SuppressWarnings({ "rawtypes", "unchecked" })
public void complete() {

command.complete();

Object[] consumers = ONCOMPLETE.get(this);
if (!expireCallbacks(consumers)) {
return;

if (consumers != COMPLETE && ONCOMPLETE.compareAndSet(this, consumers, COMPLETE)) {

command.complete();

doOnComplete();
notifyConsumers(consumers);
}
}

/**
* Callback method called after successful completion and before notifying downstream consumers.
*
* @since 6.0.2
*/
protected void doOnComplete() {

}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void notifyConsumers(Object[] consumers) {

for (Object callback : consumers) {

Expand All @@ -88,27 +104,46 @@ public void complete() {
@Override
public void cancel() {

command.cancel();
notifyBiConsumer(new CancellationException());
Object[] consumers = ONCOMPLETE.get(this);

if (consumers != COMPLETE && ONCOMPLETE.compareAndSet(this, consumers, COMPLETE)) {

command.cancel();
CancellationException exception = new CancellationException();
doOnError(exception);
notifyBiConsumer(consumers, exception);
}

}

@Override
public boolean completeExceptionally(Throwable throwable) {

boolean result = command.completeExceptionally(throwable);
notifyBiConsumer(throwable);
Object[] consumers = ONCOMPLETE.get(this);

boolean result = false;
if (consumers != COMPLETE && ONCOMPLETE.compareAndSet(this, consumers, COMPLETE)) {

result = command.completeExceptionally(throwable);
doOnError(throwable);
notifyBiConsumer(consumers, throwable);
}

return result;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void notifyBiConsumer(Throwable exception) {
/**
* Callback method called after error completion and before notifying downstream consumers.
*
* @param throwable
* @since 6.0.2
*/
protected void doOnError(Throwable throwable) {

Object[] consumers = ONCOMPLETE.get(this);
}

if (!expireCallbacks(consumers)) {
return;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private void notifyBiConsumer(Object[] consumers, Throwable exception) {

for (Object callback : consumers) {

Expand All @@ -125,10 +160,6 @@ private void notifyBiConsumer(Throwable exception) {
}
}

private boolean expireCallbacks(Object[] consumers) {
return consumers != EMPTY && ONCOMPLETE.compareAndSet(this, consumers, EMPTY);
}

@Override
public CommandArgs<K, V> getArgs() {
return command.getArgs();
Expand Down
84 changes: 84 additions & 0 deletions src/test/java/io/lettuce/core/SubscriptionCommandUnitTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.lettuce.core;

import static org.mockito.Mockito.*;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscriber;

import io.lettuce.core.api.StatefulConnection;
import io.lettuce.core.codec.RedisCodec;
import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.output.CommandOutput;
import io.lettuce.core.output.StatusOutput;
import io.lettuce.core.protocol.Command;
import io.lettuce.core.protocol.CommandType;
import io.netty.util.concurrent.ImmediateEventExecutor;

/**
* Unit tests for {@link io.lettuce.core.RedisPublisher.SubscriptionCommand}.
*
* @author Mark Paluch
*/
class SubscriptionCommandUnitTests {

private RedisCodec<String, String> codec = StringCodec.UTF8;

private Command<String, String, String> command;

private RedisPublisher.RedisSubscription<String> subscription;

private Subscriber<String> subscriber = mock(Subscriber.class);

@BeforeEach
final void createCommand() {

CommandOutput<String, String, String> output = new StatusOutput<>(codec);
command = new Command<>(CommandType.INFO, output, null);
}

@Test
void shouldCompleteOnlyOnce() {

subscription = new RedisPublisher.RedisSubscription<>(mock(StatefulConnection.class), command, false,
ImmediateEventExecutor.INSTANCE);
subscription.subscribe(subscriber);
subscription.changeState(RedisPublisher.State.NO_DEMAND, RedisPublisher.State.DEMAND);
subscription.request(1);

RedisPublisher.SubscriptionCommand<String, String, String> wrapper = new RedisPublisher.SubscriptionCommand<>(command,
subscription, false);
command.getOutput().setSingle(ByteBuffer.wrap("Hello".getBytes(StandardCharsets.UTF_8)));

wrapper.onComplete((s, throwable) -> {

wrapper.completeExceptionally(new IllegalStateException());
});

wrapper.complete();

verify(subscriber).onSubscribe(any());
verify(subscriber).onNext("Hello");
verify(subscriber).onComplete();
verifyNoMoreInteractions(subscriber);
}

}
Loading

0 comments on commit 04de69d

Please sign in to comment.