diff --git a/scripts/format-all.sh b/scripts/format-all.sh index 94c9ed2a..9209c354 100755 --- a/scripts/format-all.sh +++ b/scripts/format-all.sh @@ -39,6 +39,7 @@ SRC_FILES=(src/main/java/com/xiaomi/infra/pegasus/client/*.java src/test/java/com/xiaomi/infra/pegasus/rpc/async/*.java src/test/java/com/xiaomi/infra/pegasus/tools/*.java src/test/java/com/xiaomi/infra/pegasus/base/*.java + src/test/java/com/xiaomi/infra/pegasus/security/*.java ) if [ ! -f "${PROJECT_DIR}"/google-java-format-1.7-all-deps.jar ]; then diff --git a/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java b/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java index 95de53ca..a3ed99a3 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java +++ b/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java @@ -341,6 +341,10 @@ void tryNotifyFailureWithSeqID(int seqID, error_types errno, boolean isTimeoutTa } private void write(final RequestEntry entry, VolatileFields cache) { + if (!interceptorManager.onSendMessage(this, entry)) { + return; + } + cache .nettyChannel .writeAndFlush(entry) @@ -381,6 +385,26 @@ public void run() { TimeUnit.MILLISECONDS); } + // return value: + // true - pend succeed + // false - pend failed + public boolean tryPendRequest(RequestEntry entry) { + // double check. the first one doesn't lock the lock. + // Because authSucceed only transfered from false to true. + // So if it is true now, it will not change in the later. + // But if it is false now, maybe it will change soon. So we should use lock to protect it. + if (!this.authSucceed) { + synchronized (authPendingSend) { + if (!this.authSucceed) { + authPendingSend.offer(entry); + return true; + } + } + } + + return false; + } + final class DefaultHandler extends SimpleChannelInboundHandler { @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { @@ -445,6 +469,8 @@ static final class VolatileFields { private Bootstrap boot; private EventLoopGroup rpcGroup; private ReplicaSessionInterceptorManager interceptorManager; + private boolean authSucceed; + final Queue authPendingSend = new LinkedList<>(); // Session will be actively closed if all the rpcs across `sessionResetTimeWindowMs` // are timed out, in that case we suspect that the server is unavailable. diff --git a/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptor.java b/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptor.java index 99c5e66c..e1bd0f28 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptor.java +++ b/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptor.java @@ -23,4 +23,8 @@ public interface ReplicaSessionInterceptor { // The behavior when a rpc session is connected. void onConnected(ReplicaSession session); + + // The behavior when rpc session is sending a message. + // @returns false if this message shouldn't be sent. + boolean onSendMessage(ReplicaSession session, final ReplicaSession.RequestEntry entry); } diff --git a/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptorManager.java b/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptorManager.java index 86bb0d6a..817bb021 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptorManager.java +++ b/src/main/java/com/xiaomi/infra/pegasus/rpc/interceptor/ReplicaSessionInterceptorManager.java @@ -39,4 +39,13 @@ public void onConnected(ReplicaSession session) { interceptor.onConnected(session); } } + + public boolean onSendMessage(ReplicaSession session, final ReplicaSession.RequestEntry entry) { + for (ReplicaSessionInterceptor interceptor : interceptors) { + if (!interceptor.onSendMessage(session, entry)) { + return false; + } + } + return true; + } } diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/AuthProtocol.java b/src/main/java/com/xiaomi/infra/pegasus/security/AuthProtocol.java index b309df29..ed390455 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/AuthProtocol.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/AuthProtocol.java @@ -24,4 +24,6 @@ public interface AuthProtocol { /** start the authentiate process */ void authenticate(ReplicaSession session); + + boolean isAuthRequest(final ReplicaSession.RequestEntry entry); } diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/AuthReplicaSessionInterceptor.java b/src/main/java/com/xiaomi/infra/pegasus/security/AuthReplicaSessionInterceptor.java index 877c7c9d..f3c68e9c 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/AuthReplicaSessionInterceptor.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/AuthReplicaSessionInterceptor.java @@ -29,7 +29,14 @@ public AuthReplicaSessionInterceptor(ClientOptions options) throws IllegalArgume this.protocol = options.getCredential().getProtocol(); } + @Override public void onConnected(ReplicaSession session) { protocol.authenticate(session); } + + @Override + public boolean onSendMessage(ReplicaSession session, final ReplicaSession.RequestEntry entry) { + // tryPendRequest returns false means that the negotiation is succeed now + return protocol.isAuthRequest(entry) || !session.tryPendRequest(entry); + } } diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/KerberosProtocol.java b/src/main/java/com/xiaomi/infra/pegasus/security/KerberosProtocol.java index 6ee8779d..15cbb5e5 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/KerberosProtocol.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/KerberosProtocol.java @@ -19,6 +19,7 @@ package com.xiaomi.infra.pegasus.security; import com.sun.security.auth.callback.TextCallbackHandler; +import com.xiaomi.infra.pegasus.operator.negotiation_operator; import com.xiaomi.infra.pegasus.rpc.async.ReplicaSession; import java.util.HashMap; import java.util.Map; @@ -75,6 +76,11 @@ public void authenticate(ReplicaSession session) { negotiation.start(); } + @Override + public boolean isAuthRequest(final ReplicaSession.RequestEntry entry) { + return entry.op instanceof negotiation_operator; + } + private static Configuration getLoginContextConfiguration(String keyTab, String principal) { return new Configuration() { @Override diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java b/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java index 400ad57e..8ae185ea 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java @@ -26,30 +26,26 @@ import com.xiaomi.infra.pegasus.operator.negotiation_operator; import com.xiaomi.infra.pegasus.rpc.ReplicationException; import com.xiaomi.infra.pegasus.rpc.async.ReplicaSession; -import java.util.HashMap; +import java.nio.charset.Charset; +import java.util.Collections; +import java.util.List; import javax.security.auth.Subject; -import javax.security.sasl.Sasl; import org.slf4j.Logger; class Negotiation { private static final Logger logger = org.slf4j.LoggerFactory.getLogger(Negotiation.class); - private negotiation_status status; - private ReplicaSession session; - private String serviceName; // used for SASL authentication - private String serviceFqdn; // name used for SASL authentication - private final HashMap props = new HashMap(); - private final Subject subject; - // Because negotiation message is always the first rpc sent to pegasus server, // which will cost much more time. so we set negotiation timeout to 10s here private static final int negotiationTimeoutMS = 10000; + private static final List expectedMechanisms = Collections.singletonList("GSSAPI"); + + private negotiation_status status; + private ReplicaSession session; + SaslWrapper saslWrapper; - Negotiation(ReplicaSession session, Subject subject, String serviceName, String serviceFqdn) { + Negotiation(ReplicaSession session, Subject subject, String serviceName, String serviceFQDN) { + this.saslWrapper = new SaslWrapper(subject, serviceName, serviceFQDN); this.session = session; - this.subject = subject; - this.serviceName = serviceName; - this.serviceFqdn = serviceFqdn; - this.props.put(Sasl.QOP, "auth"); } void start() { @@ -60,10 +56,11 @@ void start() { void send(negotiation_status status, blob msg) { negotiation_request request = new negotiation_request(status, msg); negotiation_operator operator = new negotiation_operator(request); - session.asyncSend(operator, new RecvHandler(operator), negotiationTimeoutMS, false); + session.asyncSend( + operator, new RecvHandler(operator), negotiationTimeoutMS, /* isBackupRequest */ false); } - private static class RecvHandler implements Runnable { + private class RecvHandler implements Runnable { negotiation_operator op; RecvHandler(negotiation_operator op) { @@ -79,6 +76,7 @@ public void run() { handleResponse(); } catch (Exception e) { logger.error("Negotiation failed", e); + negotiationFailed(); } } @@ -88,19 +86,61 @@ private void handleResponse() throws Exception { throw new Exception("RecvHandler received a null response, abandon it"); } - switch (resp.status) { - case SASL_LIST_MECHANISMS_RESP: - case SASL_SELECT_MECHANISMS_RESP: - case SASL_CHALLENGE: - case SASL_SUCC: + switch (status) { + case SASL_LIST_MECHANISMS: + onRecvMechanisms(resp); + break; + case SASL_SELECT_MECHANISMS: + case SASL_INITIATE: + case SASL_CHALLENGE_RESP: + // TBD(zlw): break; default: - throw new Exception("Received an unexpected response, status " + resp.status); + throw new Exception("unexpected negotiation status: " + resp.status); + } + } + } + + public void onRecvMechanisms(negotiation_response response) throws Exception { + checkStatus(response.status, negotiation_status.SASL_LIST_MECHANISMS_RESP); + + String[] matchMechanisms = new String[1]; + matchMechanisms[0] = getMatchMechanism(new String(response.msg.data, Charset.defaultCharset())); + if (matchMechanisms[0].equals("")) { + throw new Exception("No matching mechanism was found"); + } + + status = negotiation_status.SASL_SELECT_MECHANISMS; + blob msg = new blob(saslWrapper.init(matchMechanisms)); + send(status, msg); + } + + public String getMatchMechanism(String respString) { + String matchMechanism = ""; + String[] serverSupportMechanisms = respString.split(","); + for (String serverSupportMechanism : serverSupportMechanisms) { + if (expectedMechanisms.contains(serverSupportMechanism)) { + matchMechanism = serverSupportMechanism; + break; } } + + return matchMechanism; + } + + public void checkStatus(negotiation_status status, negotiation_status expected_status) + throws Exception { + if (status != expected_status) { + throw new Exception("status is " + status + " while expect " + expected_status); + } + } + + private void negotiationFailed() { + status = negotiation_status.SASL_AUTH_FAIL; + session.closeSession(); } - negotiation_status get_status() { + negotiation_status getStatus() { return status; } } diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java b/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java new file mode 100644 index 00000000..81ca4e58 --- /dev/null +++ b/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package com.xiaomi.infra.pegasus.security; + +import java.nio.charset.Charset; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import javax.security.auth.Subject; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; + +class SaslWrapper { + private SaslClient saslClient; + private Subject subject; + private String serviceName; + private String serviceFQDN; + private HashMap properties = new HashMap<>(); + + SaslWrapper(Subject subject, String serviceName, String serviceFQDN) { + this.subject = subject; + this.serviceName = serviceName; + this.serviceFQDN = serviceFQDN; + this.properties.put(Sasl.QOP, "auth"); + } + + byte[] init(String[] mechanims) throws PrivilegedActionException { + return Subject.doAs( + subject, + (PrivilegedExceptionAction) + () -> { + saslClient = + Sasl.createSaslClient( + mechanims, null, serviceName, serviceFQDN, properties, null); + return saslClient.getMechanismName().getBytes(Charset.defaultCharset()); + }); + } +} diff --git a/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java b/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java index d47755d9..9bde4210 100644 --- a/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java +++ b/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java @@ -18,22 +18,95 @@ */ package com.xiaomi.infra.pegasus.security; -import static com.xiaomi.infra.pegasus.apps.negotiation_status.SASL_LIST_MECHANISMS; import static org.mockito.ArgumentMatchers.any; -import com.xiaomi.infra.pegasus.security.Negotiation; +import com.xiaomi.infra.pegasus.apps.negotiation_response; +import com.xiaomi.infra.pegasus.apps.negotiation_status; +import com.xiaomi.infra.pegasus.base.blob; +import java.nio.charset.Charset; +import javax.security.auth.Subject; import org.junit.Assert; import org.junit.Test; +import org.junit.jupiter.api.Assertions; import org.mockito.Mockito; public class NegotiationTest { + private Negotiation negotiation = new Negotiation(null, new Subject(), "", ""); + @Test public void testStart() { - Negotiation negotiation = new Negotiation(null, null, "", ""); Negotiation mockNegotiation = Mockito.spy(negotiation); Mockito.doNothing().when(mockNegotiation).send(any(), any()); mockNegotiation.start(); - Assert.assertEquals(mockNegotiation.get_status(), SASL_LIST_MECHANISMS); + Assert.assertEquals(mockNegotiation.getStatus(), negotiation_status.SASL_LIST_MECHANISMS); + } + + @Test + public void tetGetMatchMechanism() { + String matchMechanism = negotiation.getMatchMechanism("GSSAPI,ABC"); + Assert.assertEquals(matchMechanism, "GSSAPI"); + + matchMechanism = negotiation.getMatchMechanism("TEST,ABC"); + Assert.assertEquals(matchMechanism, ""); + } + + @Test + public void testCheckStatus() { + negotiation_status expectedStatus = negotiation_status.SASL_LIST_MECHANISMS; + + Assertions.assertDoesNotThrow( + () -> negotiation.checkStatus(negotiation_status.SASL_LIST_MECHANISMS, expectedStatus)); + + Assertions.assertThrows( + Exception.class, + () -> + negotiation.checkStatus(negotiation_status.SASL_LIST_MECHANISMS_RESP, expectedStatus)); + } + + @Test + public void testRecvMechanisms() { + Negotiation mockNegotiation = Mockito.spy(negotiation); + SaslWrapper mockSaslWrapper = Mockito.mock(SaslWrapper.class); + mockNegotiation.saslWrapper = mockSaslWrapper; + + Mockito.doNothing().when(mockNegotiation).send(any(), any()); + Assertions.assertDoesNotThrow( + () -> { + Mockito.when(mockNegotiation.saslWrapper.init(any())).thenReturn(new byte[0]); + }); + + // normal case + Assertions.assertDoesNotThrow( + () -> { + negotiation_response response = + new negotiation_response( + negotiation_status.SASL_LIST_MECHANISMS_RESP, + new blob("GSSAPI".getBytes(Charset.defaultCharset()))); + mockNegotiation.onRecvMechanisms(response); + Assert.assertEquals( + mockNegotiation.getStatus(), negotiation_status.SASL_SELECT_MECHANISMS); + }); + + // deal with wrong response.msg + Assertions.assertThrows( + Exception.class, + () -> { + negotiation_response response = + new negotiation_response( + negotiation_status.SASL_LIST_MECHANISMS, + new blob("NOTSUPPORTED".getBytes(Charset.defaultCharset()))); + mockNegotiation.onRecvMechanisms(response); + }); + + // deal with wrong response.status + Assertions.assertThrows( + Exception.class, + () -> { + negotiation_response response = + new negotiation_response( + negotiation_status.SASL_LIST_MECHANISMS, new blob(new byte[0])); + mockNegotiation.onRecvMechanisms(response); + }); } }