diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java index 94e7b7edb508..03af94ddad96 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java @@ -148,6 +148,16 @@ public boolean saslConnect(InputStream inS, OutputStream outS) throws IOExceptio inStream.readFully(saslToken); } } + + try { + readStatus(inStream); + } + catch (IOException e){ + if(e instanceof RemoteException){ + LOG.debug("Sasl connection failed: ", e); + throw e; + } + } if (LOG.isDebugEnabled()) { LOG.debug("SASL client context established. Negotiated QoP: " + saslClient.getNegotiatedProperty(Sasl.QOP)); diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java index 9fc510c365a1..2252c215fa68 100644 --- a/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java +++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java @@ -52,10 +52,12 @@ import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Rule; @@ -318,4 +320,33 @@ private HBaseSaslRpcClient createSaslRpcClientSimple(String principal, String pa private Token<? extends TokenIdentifier> createTokenMock() { return mock(Token.class); } + + @Test(expected = IOException.class) + public void testFailedEvaluateResponse() throws IOException { + //prep mockin the SaslClient + SimpleSaslClientAuthenticationProvider mockProvider = + Mockito.mock(SimpleSaslClientAuthenticationProvider.class); + SaslClient mockClient = Mockito.mock(SaslClient.class); + Assert.assertNotNull(mockProvider); + Assert.assertNotNull(mockClient); + Mockito.when(mockProvider.createClient(Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.anyBoolean(), Mockito.any())).thenReturn(mockClient); + HBaseSaslRpcClient rpcClient = new HBaseSaslRpcClient(HBaseConfiguration.create(), + mockProvider, createTokenMock(), + Mockito.mock(InetAddress.class), Mockito.mock(SecurityInfo.class), false); + + //simulate getting an error from a failed saslServer.evaluateResponse + DataOutputBuffer errorBuffer = new DataOutputBuffer(); + errorBuffer.writeInt(SaslStatus.ERROR.state); + WritableUtils.writeString(errorBuffer, IOException.class.getName()); + WritableUtils.writeString(errorBuffer, "Invalid Token"); + + DataInputBuffer in = new DataInputBuffer(); + in.reset(errorBuffer.getData(), 0, errorBuffer.getLength()); + DataOutputBuffer out = new DataOutputBuffer(); + + //simulate that authentication exchange has completed quickly after sending the token + Mockito.when(mockClient.isComplete()).thenReturn(true); + rpcClient.saslConnect(in, out); + } }