From 3e92091618f51888c80c343f9edae7633531e488 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 25 Jul 2023 11:54:36 +0200 Subject: [PATCH] Add fallback to RESP2 upon `NOPROTO` response #2455 --- .../java/io/lettuce/core/RedisHandshake.java | 27 +++-- .../lettuce/core/RedisHandshakeUnitTests.java | 104 ++++++++++++++++++ 2 files changed, 121 insertions(+), 10 deletions(-) create mode 100644 src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java diff --git a/src/main/java/io/lettuce/core/RedisHandshake.java b/src/main/java/io/lettuce/core/RedisHandshake.java index 5de946e93c..979774ad85 100644 --- a/src/main/java/io/lettuce/core/RedisHandshake.java +++ b/src/main/java/io/lettuce/core/RedisHandshake.java @@ -104,7 +104,7 @@ private CompletionStage tryHandshakeResp3(Channel channel) { } if (throwable != null) { - if (isUnknownCommand(throwable)) { + if (isUnknownCommand(throwable) || isNoProto(throwable)) { try { fallbackToResp2(channel, handshake); } catch (Exception e) { @@ -115,6 +115,7 @@ private CompletionStage tryHandshakeResp3(Channel channel) { handshake.completeExceptionally(throwable); } } else { + onHelloResponse(settings); handshake.complete(null); } }); @@ -145,19 +146,20 @@ private CompletableFuture initializeResp2(Channel channel) { } private CompletionStage initializeResp3(Channel channel) { + return initiateHandshakeResp3(channel, connectionState.getCredentialsProvider()).thenAccept(this::onHelloResponse); + } - return initiateHandshakeResp3(channel, connectionState.getCredentialsProvider()).thenAccept(response -> { + private void onHelloResponse(Map response) { - Long id = (Long) response.get("id"); - String mode = (String) response.get("mode"); - String version = (String) response.get("version"); - String role = (String) response.get("role"); + Long id = (Long) response.get("id"); + String mode = (String) response.get("mode"); + String version = (String) response.get("version"); + String role = (String) response.get("role"); - negotiatedProtocolVersion = ProtocolVersion.RESP3; + negotiatedProtocolVersion = ProtocolVersion.RESP3; - connectionState.setHandshakeResponse( - new ConnectionState.HandshakeResponse(negotiatedProtocolVersion, id, version, mode, role)); - }); + connectionState.setHandshakeResponse( + new ConnectionState.HandshakeResponse(negotiatedProtocolVersion, id, version, mode, role)); } /** @@ -272,4 +274,9 @@ private static boolean isUnknownCommand(Throwable error) { && ((error.getMessage().startsWith("ERR") && error.getMessage().contains("unknown"))); } + private static boolean isNoProto(Throwable error) { + return error instanceof RedisException && LettuceStrings.isNotEmpty(error.getMessage()) + && error.getMessage().startsWith("NOPROTO"); + } + } diff --git a/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java b/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java new file mode 100644 index 0000000000..37ad837042 --- /dev/null +++ b/src/test/java/io/lettuce/core/RedisHandshakeUnitTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core; + +import java.nio.ByteBuffer; +import java.util.Map; + +import io.lettuce.core.output.CommandOutput; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.ProtocolVersion; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.*; + +/** + * Unit tests for {@link RedisHandshake}. + * + * @author Mark Paluch + */ +class RedisHandshakeUnitTests { + + @Test + void handshakeWithResp3ShouldPass() { + + EmbeddedChannel channel = new EmbeddedChannel(true, false); + + ConnectionState state = new ConnectionState(); + state.setCredentialsProvider(new StaticCredentialsProvider("foo", "bar".toCharArray())); + RedisHandshake handshake = new RedisHandshake(ProtocolVersion.RESP3, false, state); + handshake.initialize(channel); + + AsyncCommand> hello = channel.readOutbound(); + helloResponse(hello.getOutput()); + hello.complete(); + + assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP3); + } + + @Test + void handshakeWithDiscoveryShouldPass() { + + EmbeddedChannel channel = new EmbeddedChannel(true, false); + + ConnectionState state = new ConnectionState(); + state.setCredentialsProvider(new StaticCredentialsProvider("foo", "bar".toCharArray())); + RedisHandshake handshake = new RedisHandshake(null, false, state); + handshake.initialize(channel); + + AsyncCommand> hello = channel.readOutbound(); + helloResponse(hello.getOutput()); + hello.complete(); + + assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP3); + } + + @Test + void handshakeWithDiscoveryShouldDowngrade() { + + EmbeddedChannel channel = new EmbeddedChannel(true, false); + + ConnectionState state = new ConnectionState(); + state.setCredentialsProvider(new StaticCredentialsProvider(null, null)); + RedisHandshake handshake = new RedisHandshake(null, false, state); + handshake.initialize(channel); + + AsyncCommand> hello = channel.readOutbound(); + hello.getOutput().setError("NOPROTO"); + hello.completeExceptionally(new RedisException("NOPROTO")); + hello.complete(); + + assertThat(state.getNegotiatedProtocolVersion()).isEqualTo(ProtocolVersion.RESP2); + } + + private static void helloResponse(CommandOutput> output) { + + output.multi(8); + output.set(ByteBuffer.wrap("id".getBytes())); + output.set(1); + + output.set(ByteBuffer.wrap("mode".getBytes())); + output.set(ByteBuffer.wrap("master".getBytes())); + + output.set(ByteBuffer.wrap("role".getBytes())); + output.set(ByteBuffer.wrap("master".getBytes())); + + output.set(ByteBuffer.wrap("version".getBytes())); + output.set(ByteBuffer.wrap("1.2.3".getBytes())); + } + +}