Skip to content

Commit

Permalink
Fix races in emission vs. request #1140
Browse files Browse the repository at this point in the history
onNext now uses decrement instead of compareAndSet to avoid races with request(n) calls ensuring that all elements get emitted.
This allows simplification of readAndPublish() and eliminating another race where completion was potentially dropped.

Completion now also considers the guard through READING instead of completing from any state to protect against active drain loops.
  • Loading branch information
mp911de committed Oct 28, 2019
1 parent c05dc0d commit 8734e0c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 81 deletions.
133 changes: 52 additions & 81 deletions src/main/java/io/lettuce/core/RedisPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import java.util.Collection;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
Expand All @@ -32,6 +32,7 @@
import reactor.core.CoreSubscriber;
import reactor.core.Exceptions;
import reactor.util.context.Context;
import sun.rmi.runtime.Log;
import io.lettuce.core.api.StatefulConnection;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.internal.LettuceAssert;
Expand Down Expand Up @@ -141,21 +142,17 @@ static class RedisSubscription<T> extends StreamingOutput.Subscriber<T> implemen
static final int ST_COMPLETED = 1;

@SuppressWarnings({ "rawtypes", "unchecked" })
static final AtomicLongFieldUpdater<RedisSubscription> DEMAND = AtomicLongFieldUpdater.newUpdater(
RedisSubscription.class, "demand");
static final AtomicLongFieldUpdater<RedisSubscription> DEMAND = AtomicLongFieldUpdater
.newUpdater(RedisSubscription.class, "demand");

@SuppressWarnings({ "rawtypes", "unchecked" })
static final AtomicReferenceFieldUpdater<RedisSubscription, State> STATE = AtomicReferenceFieldUpdater.newUpdater(
RedisSubscription.class, State.class, "state");
static final AtomicReferenceFieldUpdater<RedisSubscription, State> STATE = AtomicReferenceFieldUpdater
.newUpdater(RedisSubscription.class, State.class, "state");

@SuppressWarnings({ "rawtypes", "unchecked" })
static final AtomicReferenceFieldUpdater<RedisSubscription, CommandDispatch> COMMAND_DISPATCH = AtomicReferenceFieldUpdater
.newUpdater(RedisSubscription.class, CommandDispatch.class, "commandDispatch");

@SuppressWarnings({ "rawtypes", "unchecked" })
static final AtomicIntegerFieldUpdater<RedisSubscription> COMPLETION = AtomicIntegerFieldUpdater.newUpdater(
RedisSubscription.class, "completion");

private final SubscriptionCommand<?, ?, T> subscriptionCommand;
private final boolean traceEnabled = LOG.isTraceEnabled();

Expand All @@ -164,15 +161,14 @@ static class RedisSubscription<T> extends StreamingOutput.Subscriber<T> implemen
final RedisCommand<?, ?, T> command;
final boolean dissolve;
private final Executor executor;
final ArrayBlockingQueue<Log> logs = new ArrayBlockingQueue<>(1024 * 4);

// accessed via AtomicLongFieldUpdater
@SuppressWarnings("unused")
volatile long demand;
@SuppressWarnings("unused")
volatile State state = State.UNSUBSCRIBED;
@SuppressWarnings("unused")
volatile int completion = ST_PROGRESS;
@SuppressWarnings("unused")
volatile CommandDispatch commandDispatch = CommandDispatch.UNDISPATCHED;

volatile boolean allDataRead = false;
Expand Down Expand Up @@ -274,13 +270,14 @@ public void onNext(T t) {
}

// Fast-path publishing, preserve ordering
if (state == State.DEMAND && data.isEmpty()) {
if (data.isEmpty() && state() == State.DEMAND) {

long initial = getDemand();

if (initial > 0 && DEMAND.compareAndSet(this, initial, initial - 1)) {
if (initial > 0) {

try {
DEMAND.decrementAndGet(this);
this.subscriber.onNext(t);
} catch (Exception e) {
onError(e);
Expand Down Expand Up @@ -330,11 +327,12 @@ final void onAllDataRead() {
LOG.trace("{} onAllDataRead()", state);
}

state.onAllDataRead(this);
allDataRead = true;
onDataAvailable();
}

/**
* Called by a listener interface to indicate that as error has occured.
* Called by a listener interface to indicate that as error has occurred.
*
* @param t the error
*/
Expand Down Expand Up @@ -370,15 +368,12 @@ boolean changeState(State oldState, State newState) {
return STATE.compareAndSet(this, oldState, newState);
}

public boolean complete() {

if (COMPLETION.compareAndSet(this, ST_PROGRESS, ST_COMPLETED)) {

STATE.set(this, State.COMPLETED);
return true;
}
boolean afterRead() {
return changeState(State.READING, getDemand() > 0 ? State.DEMAND : State.NO_DEMAND);
}

return false;
public boolean complete() {
return changeState(State.READING, State.COMPLETED);
}

void checkCommandDispatch() {
Expand Down Expand Up @@ -411,37 +406,20 @@ void potentiallyReadMore() {
/**
* Reads and publishes data from the input. Continues until either there is no more demand, or until there is no more
* data to be read.
*
* @return {@literal true} if there is more demand, {@literal false} otherwise.
*/
boolean readAndPublish() {
void readAndPublish() {

while (hasDemand()) {

long initial = getDemand();

if (!hasDemand(initial)) {
return false;
}

T data = read();

if (data == null) {
return hasDemand(initial);
return;
}

boolean success = DEMAND.compareAndSet(this, initial, initial - 1);

if (success) {
this.subscriber.onNext(data);
}
DEMAND.decrementAndGet(this);
this.subscriber.onNext(data);
}

return false;
}

private static boolean hasDemand(long n) {
return n > 0;
}

RedisPublisher.State state() {
Expand Down Expand Up @@ -548,10 +526,7 @@ void request(RedisSubscription<?> subscription, long n) {
}

subscription.potentiallyReadMore();

if (subscription.allDataRead) {
onAllDataRead(subscription);
}
onDataAvailable(subscription);
} else {
onError(subscription, Exceptions.nullOrNegativeRequestException(n));
}
Expand All @@ -566,15 +541,15 @@ void request(RedisSubscription<?> subscription, long n) {
@Override
void onDataAvailable(RedisSubscription<?> subscription) {

while (subscription.hasDemand()) {

if (subscription.state() == NO_DEMAND && !subscription.changeState(NO_DEMAND, DEMAND)) {
return;
}
try {
do {

if (!read(subscription)) {
return;
}
if (!read(subscription)) {
return;
}
} while (subscription.hasDemand());
} catch (Exception e) {
subscription.onError(e);
}
}

Expand All @@ -583,9 +558,7 @@ void request(RedisSubscription<?> subscription, long n) {

if (Operators.request(RedisSubscription.DEMAND, subscription, n)) {

if (subscription.changeState(NO_DEMAND, DEMAND)) {
read(subscription);
}
onDataAvailable(subscription);

subscription.potentiallyReadMore();
} else {
Expand All @@ -600,34 +573,32 @@ void request(RedisSubscription<?> subscription, long n) {
*/
private boolean read(RedisSubscription<?> subscription) {

State state = subscription.state();

// concurrency/entry guard
if (!subscription.changeState(this, READING)) {
if (state == NO_DEMAND || state == DEMAND) {
if (!subscription.changeState(state, READING)) {
return false;
}
} else {
return false;
}

boolean hasDemand = subscription.readAndPublish();

try {

if (subscription.data.isEmpty()) {
subscription.readAndPublish();

if (subscription.allDataRead) {
subscription.onAllDataRead();
}
if (subscription.allDataRead && subscription.data.isEmpty()) {
state.onAllDataRead(subscription);
return false;
}

return false;
}
// concurrency/leave guard
subscription.afterRead();

if (subscription.allDataRead || !subscription.data.isEmpty()) {
return true;
} finally {

// concurrency/leave guard
if (hasDemand) {
subscription.changeState(READING, DEMAND);
} else {
subscription.changeState(READING, NO_DEMAND);
}
}

return false;
}
},

Expand Down Expand Up @@ -695,8 +666,6 @@ void onDataAvailable(RedisSubscription<?> subscription) {

void onAllDataRead(RedisSubscription<?> subscription) {

subscription.allDataRead = true;

if (subscription.data.isEmpty() && subscription.complete()) {

readData(subscription);
Expand All @@ -711,13 +680,15 @@ void onAllDataRead(RedisSubscription<?> subscription) {

void onError(RedisSubscription<?> subscription, Throwable t) {

if (subscription.changeState(this, COMPLETED)) {
State state;
while ((state = subscription.state()) != COMPLETED && subscription.changeState(state, COMPLETED)) {

readData(subscription);

Subscriber<?> subscriber = subscription.subscriber;
if (subscriber != null) {
subscriber.onError(t);
return;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void keysDoesNotRunIntoRaceConditions() {
for (int i = 0; i < 1000; i++) {
CompletableFuture<Long> future = commands.keys("*").count().toFuture();
Futures.await(future);
assertThat(future).isCompletedWithValue(1000L);
}
}

Expand Down

0 comments on commit 8734e0c

Please sign in to comment.