diff --git a/src/java/main/org/apache/zookeeper/LoginFactory.java b/src/java/main/org/apache/zookeeper/LoginFactory.java new file mode 100644 index 00000000000..dac622324d6 --- /dev/null +++ b/src/java/main/org/apache/zookeeper/LoginFactory.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper; + +import org.apache.zookeeper.common.ZKConfig; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.login.LoginException; + +public interface LoginFactory { + Login createLogin(final String loginContextName, CallbackHandler callbackHandler, final ZKConfig zkConfig) throws LoginException; +} diff --git a/src/java/main/org/apache/zookeeper/LoginFactoryImpl.java b/src/java/main/org/apache/zookeeper/LoginFactoryImpl.java new file mode 100644 index 00000000000..12e78a96706 --- /dev/null +++ b/src/java/main/org/apache/zookeeper/LoginFactoryImpl.java @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper; + +import org.apache.zookeeper.common.ZKConfig; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.login.LoginException; + +public class LoginFactoryImpl implements LoginFactory { + @Override + public Login createLogin(String loginContextName, CallbackHandler callbackHandler, ZKConfig zkConfig) throws LoginException { + return new Login(loginContextName, callbackHandler, zkConfig); + } +} diff --git a/src/java/main/org/apache/zookeeper/server/quorum/QuorumPeer.java b/src/java/main/org/apache/zookeeper/server/quorum/QuorumPeer.java index cec1f948901..ae0d3166967 100644 --- a/src/java/main/org/apache/zookeeper/server/quorum/QuorumPeer.java +++ b/src/java/main/org/apache/zookeeper/server/quorum/QuorumPeer.java @@ -45,6 +45,7 @@ import javax.security.sasl.SaslException; import org.apache.zookeeper.KeeperException.BadArgumentsException; +import org.apache.zookeeper.LoginFactoryImpl; import org.apache.zookeeper.common.AtomicFileWritingIdiom; import org.apache.zookeeper.common.AtomicFileWritingIdiom.WriterStatement; import org.apache.zookeeper.common.Time; @@ -833,7 +834,7 @@ public void initialize() throws SaslException { authServer = new SaslQuorumAuthServer(isQuorumServerSaslAuthRequired(), quorumServerLoginContext, authzHosts); authLearner = new SaslQuorumAuthLearner(isQuorumLearnerSaslAuthRequired(), - quorumServicePrincipal, quorumLearnerLoginContext); + quorumServicePrincipal, quorumLearnerLoginContext, new LoginFactoryImpl()); } else { authServer = new NullQuorumAuthServer(); authLearner = new NullQuorumAuthLearner(); diff --git a/src/java/main/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearner.java b/src/java/main/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearner.java index 31f4f55c81f..f9828a134ab 100644 --- a/src/java/main/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearner.java +++ b/src/java/main/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearner.java @@ -36,6 +36,7 @@ import org.apache.jute.BinaryInputArchive; import org.apache.jute.BinaryOutputArchive; import org.apache.zookeeper.Login; +import org.apache.zookeeper.LoginFactory; import org.apache.zookeeper.SaslClientCallbackHandler; import org.apache.zookeeper.common.ZKConfig; import org.apache.zookeeper.server.quorum.QuorumAuthPacket; @@ -52,7 +53,7 @@ public class SaslQuorumAuthLearner implements QuorumAuthLearner { private final String quorumServicePrincipal; public SaslQuorumAuthLearner(boolean quorumRequireSasl, - String quorumServicePrincipal, String loginContext) + String quorumServicePrincipal, String loginContext, LoginFactory loginFactory) throws SaslException { this.quorumRequireSasl = quorumRequireSasl; this.quorumServicePrincipal = quorumServicePrincipal; @@ -66,8 +67,8 @@ public SaslQuorumAuthLearner(boolean quorumRequireSasl, + "section '" + loginContext + "' could not be found."); } - this.learnerLogin = new Login(loginContext, - new SaslClientCallbackHandler(null, "QuorumLearner"), new ZKConfig()); + this.learnerLogin = loginFactory.createLogin(loginContext, + new SaslClientCallbackHandler(null, "QuorumLearner"), new ZKConfig()); this.learnerLogin.startThreadIfNeeded(); } catch (LoginException e) { throw new SaslException("Failed to initialize authentication mechanism using SASL", e); @@ -94,7 +95,10 @@ public void authenticate(Socket sock, String hostName) throws IOException { principalConfig, QuorumAuth.QUORUM_SERVER_PROTOCOL_NAME, QuorumAuth.QUORUM_SERVER_SASL_DIGEST, LOG, "QuorumLearner"); - + if (sc == null) { + LOG.error("SaslClient object is null while trying to create SASL client"); + throw new SaslException("Exception while trying to create SASL client"); + } if (sc.hasInitialResponse()) { responseToken = createSaslToken(new byte[0], sc, learnerLogin); } diff --git a/src/java/test/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearnerTest.java b/src/java/test/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearnerTest.java new file mode 100644 index 00000000000..31ffe952120 --- /dev/null +++ b/src/java/test/org/apache/zookeeper/server/quorum/auth/SaslQuorumAuthLearnerTest.java @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper.server.quorum.auth; + +import org.apache.zookeeper.Login; +import org.apache.zookeeper.LoginFactory; +import org.apache.zookeeper.common.ZKConfig; +import org.junit.Before; +import org.junit.Test; + +import javax.security.auth.Subject; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginException; +import javax.security.sasl.SaslException; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.Socket; +import java.security.Principal; + +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SaslQuorumAuthLearnerTest { + + private SaslQuorumAuthLearner learner; + + @Before + public void setUp() throws SaslException, LoginException { + Configuration configMock = mock(Configuration.class); + when(configMock.getAppConfigurationEntry(any(String.class))).thenReturn(new AppConfigurationEntry[1]); + Configuration.setConfiguration(configMock); + //mock object + Login loginMock = mock(Login.class); + Subject subjectMock = new Subject(); + Principal principalMock = mock(Principal.class); + when(principalMock.getName()).thenReturn("hello"); + subjectMock.getPrincipals().add(principalMock); + when(loginMock.getSubject()).thenReturn(subjectMock); + + LoginFactory loginFactoryMock = mock(LoginFactory.class); + when(loginFactoryMock.createLogin(any(String.class), any(CallbackHandler.class), any(ZKConfig.class))).thenReturn(loginMock); + + learner = new SaslQuorumAuthLearner(true, null, "andorContext", loginFactoryMock); + } + + @Test(expected = SaslException.class) + public void testNullCheckSc() throws IOException { + assertThat(learner, is(notNullValue())); + + Socket socketMock = mock(Socket.class); + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]); + when(socketMock.getOutputStream()).thenReturn(byteArrayOutputStream); + when(socketMock.getInputStream()).thenReturn(byteArrayInputStream); + + learner.authenticate(socketMock, null); + } +}