diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index 2498a8fd8..1df55b9c4 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -20,6 +20,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.mqtt.*; import io.netty.handler.timeout.IdleStateHandler; @@ -27,7 +28,6 @@ import org.slf4j.LoggerFactory; import java.net.InetSocketAddress; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -47,7 +47,7 @@ final class MQTTConnection { private IAuthenticator authenticator; private SessionRegistry sessionRegistry; private final PostOffice postOffice; - private boolean connected; + private volatile boolean connected; private final AtomicInteger lastPacketId = new AtomicInteger(0); MQTTConnection(Channel channel, BrokerConfiguration brokerConfig, IAuthenticator authenticator, @@ -164,21 +164,54 @@ void processConnect(MqttConnectMessage msg) { return; } + final SessionRegistry.SessionCreationResult result; try { LOG.trace("Binding MQTTConnection (channel: {}) to session", channel); - sessionRegistry.bindToSession(this, msg, clientId); - - initializeKeepAliveTimeout(channel, msg, clientId); - setupInflightResender(channel); - - NettyUtils.clientID(channel, clientId); - LOG.trace("CONNACK sent, channel: {}", channel); - postOffice.dispatchConnection(msg); - LOG.trace("dispatch connection: {}", msg.toString()); + result = sessionRegistry.createOrReopenSession(msg, clientId, this.getUsername()); + result.session.bind(this); } catch (SessionCorruptedException scex) { LOG.warn("MQTT session for client ID {} cannot be created, channel: {}", clientId, channel); abortConnection(CONNECTION_REFUSED_SERVER_UNAVAILABLE); + return; } + + final boolean msgCleanSessionFlag = msg.variableHeader().isCleanSession(); + boolean isSessionAlreadyPresent = !msgCleanSessionFlag && result.alreadyStored; + final String clientIdUsed = clientId; + sendConnAck(isSessionAlreadyPresent).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + LOG.trace("CONNACK sent, channel: {}", channel); + if (!result.session.completeConnection()) { + // send DISCONNECT and close the channel + final MqttMessage disconnectMsg = disconnect(); + channel.writeAndFlush(disconnectMsg).addListener(CLOSE); + LOG.warn("CONNACK is sent but the session created can't transition in CONNECTED state"); + } else { + NettyUtils.clientID(channel, clientIdUsed); + connected = true; + // OK continue with sending queued messages and normal flow + + if (result.mode == SessionRegistry.CreationModeEnum.REOPEN_EXISTING) { + result.session.sendQueuedMessagesWhileOffline(); + } + + initializeKeepAliveTimeout(channel, msg, clientIdUsed); + setupInflightResender(channel); + + postOffice.dispatchConnection(msg); + LOG.trace("dispatch connection: {}", msg.toString()); + } + } else { + sessionRegistry.disconnect(clientIdUsed); + sessionRegistry.remove(clientIdUsed); + LOG.error("CONNACK send failed, cleanup session and close the connection", future.cause()); + channel.close(); + } + + } + }); } private void setupInflightResender(Channel channel) { @@ -222,12 +255,18 @@ private MqttConnAckMessage connAck(MqttConnectReturnCode returnCode, boolean ses return new MqttConnAckMessage(mqttFixedHeader, mqttConnAckVariableHeader); } + private MqttMessage disconnect() { + MqttFixedHeader mqttFixedHeader = new MqttFixedHeader(MqttMessageType.DISCONNECT, false, MqttQoS.AT_MOST_ONCE, + false, 0); + return new MqttMessage(mqttFixedHeader); + } + private boolean login(MqttConnectMessage msg, final String clientId) { // handle user authentication if (msg.variableHeader().hasUserName()) { byte[] pwd = null; if (msg.variableHeader().hasPassword()) { - pwd = msg.payload().password().getBytes(StandardCharsets.UTF_8); + pwd = msg.payload().passwordInBytes(); } else if (!brokerConfig.isAllowAnonymous()) { LOG.info("Client didn't supply any password and MQTT anonymous mode is disabled CId={}", clientId); return false; @@ -267,10 +306,9 @@ void handleConnectionLost() { LOG.trace("dispatch disconnection: clientId={}, userName={}", clientID, userName); } - void sendConnAck(boolean isSessionAlreadyPresent) { - connected = true; + private ChannelFuture sendConnAck(boolean isSessionAlreadyPresent) { final MqttConnAckMessage ackMessage = connAck(CONNECTION_ACCEPTED, isSessionAlreadyPresent); - channel.writeAndFlush(ackMessage).addListener(FIRE_EXCEPTION_ON_FAILURE); + return channel.writeAndFlush(ackMessage); } boolean isConnected() { @@ -293,7 +331,7 @@ void processDisconnect(MqttMessage msg) { channel.close().addListener(FIRE_EXCEPTION_ON_FAILURE); LOG.trace("Processed DISCONNECT CId={}, channel: {}", clientID, channel); String userName = NettyUtils.userName(channel); - postOffice.dispatchDisconnection(clientID,userName); + postOffice.dispatchDisconnection(clientID, userName); LOG.trace("dispatch disconnection: clientId={}, userName={}", clientID, userName); } diff --git a/broker/src/main/java/io/moquette/broker/PostOffice.java b/broker/src/main/java/io/moquette/broker/PostOffice.java index f416666bf..e647741f6 100644 --- a/broker/src/main/java/io/moquette/broker/PostOffice.java +++ b/broker/src/main/java/io/moquette/broker/PostOffice.java @@ -297,16 +297,16 @@ public void internalPublish(MqttPublishMessage msg) { * notify MqttConnectMessage after connection established (already pass login). * @param msg */ - void dispatchConnection(MqttConnectMessage msg){ + void dispatchConnection(MqttConnectMessage msg) { interceptor.notifyClientConnected(msg); } - void dispatchDisconnection(String clientId,String userName){ - interceptor.notifyClientDisconnected(clientId,userName); + void dispatchDisconnection(String clientId,String userName) { + interceptor.notifyClientDisconnected(clientId, userName); } - void dispatchConnectionLost(String clientId,String userName){ - interceptor.notifyClientConnectionLost(clientId,userName); + void dispatchConnectionLost(String clientId,String userName) { + interceptor.notifyClientConnectionLost(clientId, userName); } void flushInFlight(MQTTConnection mqttConnection) { diff --git a/broker/src/main/java/io/moquette/broker/Session.java b/broker/src/main/java/io/moquette/broker/Session.java index 8d2580f01..55c1c903d 100644 --- a/broker/src/main/java/io/moquette/broker/Session.java +++ b/broker/src/main/java/io/moquette/broker/Session.java @@ -113,8 +113,12 @@ void update(boolean clean, Will will) { this.will = will; } - void markConnected() { - assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTED); + void markConnecting() { + assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING); + } + + boolean completeConnection() { + return assignState(Session.SessionStatus.CONNECTING, Session.SessionStatus.CONNECTED); } void bind(MQTTConnection mqttConnection) { diff --git a/broker/src/main/java/io/moquette/broker/SessionRegistry.java b/broker/src/main/java/io/moquette/broker/SessionRegistry.java index d6d26d68a..b3f05a8b7 100644 --- a/broker/src/main/java/io/moquette/broker/SessionRegistry.java +++ b/broker/src/main/java/io/moquette/broker/SessionRegistry.java @@ -55,8 +55,21 @@ static class PublishedMessage extends EnqueuedMessage { static final class PubRelMarker extends EnqueuedMessage { } - private enum PostConnectAction { - NONE, SEND_STORED_MESSAGES + public enum CreationModeEnum { + CREATED_CLEAN_NEW, REOPEN_EXISTING, DROP_EXISTING; + } + + public static class SessionCreationResult { + + final Session session; + final CreationModeEnum mode; + final boolean alreadyStored; + + public SessionCreationResult(Session session, CreationModeEnum mode, boolean alreadyStored) { + this.session = session; + this.mode = mode; + this.alreadyStored = alreadyStored; + } } private static final Logger LOG = LoggerFactory.getLogger(SessionRegistry.class); @@ -75,12 +88,12 @@ private enum PostConnectAction { this.authorizator = authorizator; } - void bindToSession(MQTTConnection mqttConnection, MqttConnectMessage msg, String clientId) { - boolean isSessionAlreadyStored = false; - PostConnectAction postConnectAction = PostConnectAction.NONE; + SessionCreationResult createOrReopenSession(MqttConnectMessage msg, String clientId, String username) { + SessionCreationResult postConnectAction; if (!pool.containsKey(clientId)) { // case 1 - final Session newSession = createNewSession(mqttConnection, msg, clientId); + final Session newSession = createNewSession(msg, clientId); + postConnectAction = new SessionCreationResult(newSession, CreationModeEnum.CREATED_CLEAN_NEW, false); // publish the session final Session previous = pool.putIfAbsent(clientId, newSession); @@ -89,88 +102,66 @@ void bindToSession(MQTTConnection mqttConnection, MqttConnectMessage msg, String if (success) { LOG.trace("case 1, not existing session with CId {}", clientId); } else { - postConnectAction = bindToExistingSession(mqttConnection, msg, clientId, newSession); - isSessionAlreadyStored = true; + postConnectAction = reopenExistingSession(msg, clientId, newSession, username); } } else { - final Session newSession = createNewSession(mqttConnection, msg, clientId); - postConnectAction = bindToExistingSession(mqttConnection, msg, clientId, newSession); - isSessionAlreadyStored = true; - } - final boolean msgCleanSessionFlag = msg.variableHeader().isCleanSession(); - boolean isSessionAlreadyPresent = !msgCleanSessionFlag && isSessionAlreadyStored; - mqttConnection.sendConnAck(isSessionAlreadyPresent); - - if (postConnectAction == PostConnectAction.SEND_STORED_MESSAGES) { - final Session session = pool.get(clientId); - session.sendQueuedMessagesWhileOffline(); + final Session newSession = createNewSession(msg, clientId); + postConnectAction = reopenExistingSession(msg, clientId, newSession, username); } + return postConnectAction; } - private PostConnectAction bindToExistingSession(MQTTConnection mqttConnection, MqttConnectMessage msg, - String clientId, Session newSession) { - PostConnectAction postConnectAction = PostConnectAction.NONE; + private SessionCreationResult reopenExistingSession(MqttConnectMessage msg, String clientId, + Session newSession, String username) { final boolean newIsClean = msg.variableHeader().isCleanSession(); final Session oldSession = pool.get(clientId); - if (newIsClean && oldSession.disconnected()) { - // case 2 - dropQueuesForClient(clientId); - unsubscribe(oldSession); - - // publish new session - boolean result = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING); - if (!result) { - throw new SessionCorruptedException("old session was already changed state"); - } - copySessionConfig(msg, oldSession); - oldSession.bind(mqttConnection); - - result = oldSession.assignState(SessionStatus.CONNECTING, SessionStatus.CONNECTED); - if (!result) { - throw new SessionCorruptedException("old session moved in connected state by other thread"); - } - final boolean published = pool.replace(clientId, oldSession, oldSession); - if (!published) { - throw new SessionCorruptedException("old session was already removed"); - } - LOG.trace("case 2, oldSession with same CId {} disconnected", clientId); - } else if (!newIsClean && oldSession.disconnected()) { - // case 3 - final String username = mqttConnection.getUsername(); - reactivateSubscriptions(oldSession, username); - - // mark as connected - final boolean connecting = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING); - if (!connecting) { - throw new SessionCorruptedException("old session moved in connected state by other thread"); - } - oldSession.bind(mqttConnection); - - final boolean connected = oldSession.assignState(SessionStatus.CONNECTING, SessionStatus.CONNECTED); - if (!connected) { - throw new SessionCorruptedException("old session moved in other state state by other thread"); - } - - // publish new session - final boolean published = pool.replace(clientId, oldSession, oldSession); - if (!published) { - throw new SessionCorruptedException("old session was already removed"); + final SessionCreationResult creationResult; + if (oldSession.disconnected()) { + if (newIsClean) { + boolean result = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING); + if (!result) { + throw new SessionCorruptedException("old session was already changed state"); + } + + // case 2 + // publish new session + dropQueuesForClient(clientId); + unsubscribe(oldSession); + copySessionConfig(msg, oldSession); + + LOG.trace("case 2, oldSession with same CId {} disconnected", clientId); + creationResult = new SessionCreationResult(oldSession, CreationModeEnum.CREATED_CLEAN_NEW, true); + } else { + final boolean connecting = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING); + if (!connecting) { + throw new SessionCorruptedException("old session moved in connected state by other thread"); + } + // case 3 + reactivateSubscriptions(oldSession, username); + + LOG.trace("case 3, oldSession with same CId {} disconnected", clientId); + creationResult = new SessionCreationResult(oldSession, CreationModeEnum.REOPEN_EXISTING, true); } - postConnectAction = PostConnectAction.SEND_STORED_MESSAGES; - LOG.trace("case 3, oldSession with same CId {} disconnected", clientId); - } else if (oldSession.connected()) { + } else { // case 4 LOG.trace("case 4, oldSession with same CId {} still connected, force to close", clientId); oldSession.closeImmediately(); //remove(clientId); - // publish new session - final boolean published = pool.replace(clientId, oldSession, newSession); - if (!published) { - throw new SessionCorruptedException("old session was already removed"); - } + creationResult = new SessionCreationResult(newSession, CreationModeEnum.DROP_EXISTING, true); } + + final boolean published; + if (creationResult.mode == CreationModeEnum.DROP_EXISTING) { + published = pool.replace(clientId, oldSession, newSession); + } else { + published = pool.replace(clientId, oldSession, oldSession); + } + if (!published) { + throw new SessionCorruptedException("old session was already removed"); + } + // case not covered new session is clean true/false and old session not in CONNECTED/DISCONNECTED - return postConnectAction; + return creationResult; } private void reactivateSubscriptions(Session session, String username) { @@ -192,7 +183,7 @@ private void unsubscribe(Session session) { } } - private Session createNewSession(MQTTConnection mqttConnection, MqttConnectMessage msg, String clientId) { + private Session createNewSession(MqttConnectMessage msg, String clientId) { final boolean clean = msg.variableHeader().isCleanSession(); final Queue sessionQueue = queues.computeIfAbsent(clientId, (String cli) -> queueRepository.createQueue(cli, clean)); @@ -204,9 +195,7 @@ private Session createNewSession(MQTTConnection mqttConnection, MqttConnectMessa newSession = new Session(clientId, clean, sessionQueue); } - newSession.markConnected(); - newSession.bind(mqttConnection); - + newSession.markConnecting(); return newSession; } diff --git a/broker/src/test/java/io/moquette/broker/SessionRegistryTest.java b/broker/src/test/java/io/moquette/broker/SessionRegistryTest.java index 8fa105bcc..b970ed6b7 100644 --- a/broker/src/test/java/io/moquette/broker/SessionRegistryTest.java +++ b/broker/src/test/java/io/moquette/broker/SessionRegistryTest.java @@ -32,6 +32,7 @@ import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_ACCEPTED; import static java.util.Collections.singleton; import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -87,17 +88,17 @@ public void testConnAckContainsSessionPresentFlag() { NettyUtils.cleanSession(channel, false); // Connect a first time - sut.bindToSession(connection, msg, FAKE_CLIENT_ID); + sut.createOrReopenSession(msg, FAKE_CLIENT_ID, connection.getUsername()); // disconnect sut.disconnect(FAKE_CLIENT_ID); // Exercise, reconnect EmbeddedChannel anotherChannel = new EmbeddedChannel(); MQTTConnection anotherConnection = createMQTTConnection(ALLOW_ANONYMOUS_AND_ZEROBYTE_CLIENT_ID, anotherChannel); - sut.bindToSession(anotherConnection, msg, FAKE_CLIENT_ID); + final SessionRegistry.SessionCreationResult result = sut.createOrReopenSession(msg, FAKE_CLIENT_ID, anotherConnection.getUsername()); // Verify - assertEqualsConnAck(CONNECTION_ACCEPTED, anotherChannel.readOutbound()); + assertEquals(SessionRegistry.CreationModeEnum.CREATED_CLEAN_NEW, result.mode); assertTrue("Connection is accepted and therefore should remain open", anotherChannel.isOpen()); } diff --git a/broker/src/test/java/io/moquette/broker/SessionTest.java b/broker/src/test/java/io/moquette/broker/SessionTest.java index 4a0e7694f..b3cd600f7 100644 --- a/broker/src/test/java/io/moquette/broker/SessionTest.java +++ b/broker/src/test/java/io/moquette/broker/SessionTest.java @@ -22,8 +22,9 @@ public void testPubAckDrainMessagesRemainingInQueue() { final EmbeddedChannel testChannel = new EmbeddedChannel(); BrokerConfiguration brokerConfiguration = new BrokerConfiguration(true, false, false, false); MQTTConnection mqttConnection = new MQTTConnection(testChannel, brokerConfiguration, null, null, null); - client.markConnected(); + client.markConnecting(); client.bind(mqttConnection); + client.completeConnection(); final Topic destinationTopic = new Topic("/a/b"); sendQoS1To(client, destinationTopic, "Hello World!");