Skip to content

Commit

Permalink
Synchronization/cross thread visibility of variables #94
Browse files Browse the repository at this point in the history
Unify connection state tracking into one field (lifecycleState) to prevent race conditions when reading the state. Use synchronized blocks to adopt JMM semantics and address thread visibility issues
  • Loading branch information
mp911de committed Jul 3, 2015
1 parent 6c29bb7 commit ad4d2c1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ public void close() {
this.connections.invalidateAll();
resetPartitions();
for (RedisAsyncConnection<K, V> kvRedisAsyncConnection : copy.values()) {
kvRedisAsyncConnection.close();
if (kvRedisAsyncConnection.isOpen()) {
kvRedisAsyncConnection.close();
}
}
}

Expand Down
94 changes: 65 additions & 29 deletions src/main/java/com/lambdaworks/redis/protocol/CommandHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,16 @@
package com.lambdaworks.redis.protocol;

import java.nio.charset.Charset;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.locks.ReentrantLock;

import com.lambdaworks.redis.ClientOptions;
import com.lambdaworks.redis.ConnectionEvents;
import com.lambdaworks.redis.RedisChannelHandler;
import com.lambdaworks.redis.RedisChannelWriter;
import com.lambdaworks.redis.RedisException;
import com.lambdaworks.redis.*;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.*;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand All @@ -38,13 +30,14 @@ public class CommandHandler<K, V> extends ChannelDuplexHandler implements RedisC

protected ClientOptions clientOptions;
protected Queue<RedisCommand<K, V, ?>> queue;
protected Queue<RedisCommand<K, V, ?>> commandBuffer = new LinkedBlockingQueue<RedisCommand<K, V, ?>>();
protected Queue<RedisCommand<K, V, ?>> commandBuffer = new ArrayDeque<RedisCommand<K, V, ?>>();
protected ByteBuf buffer;
protected RedisStateMachine<K, V> rsm;

private LifecycleState lifecycleState = LifecycleState.NOT_CONNECTED;
private Object stateLock = new Object();
private Channel channel;
private boolean closed;
private boolean connected;

private RedisChannelHandler<K, V> redisChannelHandler;
private final ReentrantLock writeLock = new ReentrantLock();
private Throwable connectionError;
Expand Down Expand Up @@ -79,19 +72,24 @@ public CommandHandler(ClientOptions clientOptions, Queue<RedisCommand<K, V, ?>>
*/
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
closed = false;
setState(LifecycleState.REGISTERED);
buffer = ctx.alloc().heapBuffer();
rsm = new RedisStateMachine<K, V>();
channel = ctx.channel();
synchronized (stateLock) {
channel = ctx.channel();
}
}

@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
releaseBuffer();
if (closed) {

if (lifecycleState == LifecycleState.CLOSED) {
cancelCommands("Connection closed");
}
channel = null;
synchronized (stateLock) {
channel = null;
}
}

/**
Expand Down Expand Up @@ -133,17 +131,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Interrup
@Override
public <T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {

if (closed) {
if (lifecycleState == LifecycleState.CLOSED) {
throw new RedisException("Connection is closed");
}
try {
/**
* This lock causes safety for connection activation and somehow netty gets more stable and predictable performance
* than without a lock and all threads are hammering towards writeAndFlush.
*/

writeLock.lock();
Channel channel = this.channel;
if (channel != null && connected && channel.isActive()) {
if (channel != null && isConnected() && channel.isActive()) {
if (debugEnabled) {
logger.debug("{} write() writeAndFlush Command {}", logPrefix(), command);
}
Expand All @@ -159,10 +158,14 @@ public <T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {
*/

if (!channel.isActive()) {
write(command);
return write(command);
}
} else {

if (commandBuffer.contains(command) || queue.contains(command)) {
return command;
}

if (connectionError != null) {
if (debugEnabled) {
logger.debug("{} write() completing Command {} due to connection error", logPrefix(), command);
Expand All @@ -187,6 +190,11 @@ public <T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {
return command;
}

private boolean isConnected() {
return lifecycleState.ordinal() >= LifecycleState.CONNECTED.ordinal()
&& lifecycleState.ordinal() <= LifecycleState.DISCONNECTED.ordinal();
}

/**
*
* @see io.netty.channel.ChannelDuplexHandler#write(io.netty.channel.ChannelHandlerContext, java.lang.Object,
Expand Down Expand Up @@ -214,8 +222,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (debugEnabled) {
logger.debug("{} channelActive()", logPrefix());
}
connected = true;
closed = false;
setStateIfNotClosed(LifecycleState.CONNECTED);

try {
executeQueuedCommands(ctx);
Expand Down Expand Up @@ -251,14 +258,19 @@ protected void executeQueuedCommands(ChannelHandlerContext ctx) {
if (debugEnabled) {
logger.debug("{} executeQueuedCommands {} command(s) queued", logPrefix(), queue.size());
}
channel = ctx.channel();

synchronized (stateLock) {
channel = ctx.channel();
}

if (redisChannelHandler != null) {
if (debugEnabled) {
logger.debug("{} activating channel handler", logPrefix());
}
setStateIfNotClosed(LifecycleState.ACTIVATING);
redisChannelHandler.activated();
}
setStateIfNotClosed(LifecycleState.ACTIVE);

for (RedisCommand<K, V, ?> cmd : tmp) {
if (!cmd.isCancelled()) {
Expand Down Expand Up @@ -289,21 +301,40 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (debugEnabled) {
logger.debug("{} channelInactive()", logPrefix());
}
connected = false;
setStateIfNotClosed(LifecycleState.DISCONNECTED);

if (redisChannelHandler != null) {
if (debugEnabled) {
logger.debug("{} deactivating channel handler", logPrefix());
}
setStateIfNotClosed(LifecycleState.DEACTIVATING);
redisChannelHandler.deactivated();
}
setStateIfNotClosed(LifecycleState.DEACTIVATED);

if (buffer != null) {
rsm.reset();
buffer.clear();
}

if (debugEnabled) {
logger.debug("{} channelInactive() done", logPrefix());
}
super.channelInactive(ctx);
}

protected void setStateIfNotClosed(LifecycleState lifecycleState) {
if (this.lifecycleState != LifecycleState.CLOSED) {
setState(lifecycleState);
}
}

protected void setState(LifecycleState lifecycleState) {
synchronized (stateLock) {
this.lifecycleState = lifecycleState;
}
}

private void cancelCommands(String message) {
int size = 0;
if (queue != null) {
Expand Down Expand Up @@ -342,7 +373,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
command.complete();
}

if (channel == null || !connected) {
if (channel == null || !channel.isActive() || !isConnected()) {
connectionError = cause;
return;
}
Expand All @@ -359,11 +390,11 @@ public void close() {
logger.debug("{} close()", logPrefix());
}

if (closed) {
if (lifecycleState == LifecycleState.CLOSED) {
return;
}

closed = true;
setStateIfNotClosed(LifecycleState.CLOSED);
Channel currentChannel = this.channel;
if (currentChannel != null) {
currentChannel.pipeline().fireUserEventTriggered(new ConnectionEvents.PrepareClose());
Expand All @@ -380,7 +411,7 @@ private void releaseBuffer() {
}

public boolean isClosed() {
return closed;
return lifecycleState == LifecycleState.CLOSED;
}

/**
Expand Down Expand Up @@ -425,4 +456,9 @@ private String logPrefix() {
return logPrefix = buffer.toString();
}

enum LifecycleState {

NOT_CONNECTED, REGISTERED, CONNECTED, ACTIVATING, ACTIVE, DISCONNECTED, DEACTIVATING, DEACTIVATED, CLOSED,
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;

import com.lambdaworks.redis.ClientOptions;
import com.lambdaworks.redis.RedisException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
Expand All @@ -24,10 +25,13 @@
@RunWith(MockitoJUnitRunner.class)
public class CommandHandlerTest {

private BlockingQueue<RedisCommand<String, String, ?>> q = new ArrayBlockingQueue<RedisCommand<String, String, ?>>(10);
private Queue<RedisCommand<String, String, ?>> q = new ArrayDeque<RedisCommand<String, String, ?>>(10);

private CommandHandler<String, String> sut = new CommandHandler<String, String>(new ClientOptions.Builder().build(), q);

private Command<String, String, String> command = new Command<String, String, String>(CommandType.APPEND,
new StatusOutput<String, String>(new Utf8StringCodec()), null);

@Mock
private ChannelHandlerContext context;

Expand All @@ -37,7 +41,10 @@ public class CommandHandlerTest {
@Test
public void testExceptionChannelActive() throws Exception {

sut.setState(CommandHandler.LifecycleState.ACTIVE);

when(context.channel()).thenReturn(channel);
when(channel.isActive()).thenReturn(true);

sut.channelActive(context);
sut.exceptionCaught(context, new Exception());
Expand All @@ -46,19 +53,20 @@ public void testExceptionChannelActive() throws Exception {

@Test
public void testExceptionChannelInactive() throws Exception {
sut.setState(CommandHandler.LifecycleState.DISCONNECTED);
sut.exceptionCaught(context, new Exception());
verify(context, never()).fireExceptionCaught(any(Exception.class));
}

@Test
public void testExceptionWithQueue() throws Exception {
sut.setState(CommandHandler.LifecycleState.ACTIVE);
q.clear();
when(context.channel()).thenReturn(channel);

sut.channelActive(context);
when(channel.isActive()).thenReturn(true);

Command<String, String, String> command = new Command<String, String, String>(CommandType.APPEND,
new StatusOutput<String, String>(new Utf8StringCodec()), null);
q.add(command);
sut.exceptionCaught(context, new Exception());

Expand All @@ -68,4 +76,21 @@ public void testExceptionWithQueue() throws Exception {
verify(context).fireExceptionCaught(any(Exception.class));
}

@Test(expected = RedisException.class)
public void testWriteWhenClosed() throws Exception {

sut.setState(CommandHandler.LifecycleState.CLOSED);

sut.write(command);
}

@Test
public void testExceptionWhenClosed() throws Exception {

sut.setState(CommandHandler.LifecycleState.CLOSED);

sut.exceptionCaught(context, new Exception());
verifyZeroInteractions(context);
}

}

0 comments on commit ad4d2c1

Please sign in to comment.