diff --git a/.gitignore b/.gitignore index 1118716..cb2897b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ .gradle bin -internals build lombok.config +out/ diff --git a/src/main/java/software/amazon/msk/auth/iam/IAMLoginModule.java b/src/main/java/software/amazon/msk/auth/iam/IAMLoginModule.java index f4b1ab3..9e725d1 100644 --- a/src/main/java/software/amazon/msk/auth/iam/IAMLoginModule.java +++ b/src/main/java/software/amazon/msk/auth/iam/IAMLoginModule.java @@ -15,6 +15,7 @@ */ package software.amazon.msk.auth.iam; +import software.amazon.msk.auth.iam.internals.ClassLoaderAwareIAMSaslClientProvider; import software.amazon.msk.auth.iam.internals.IAMSaslClientProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +36,7 @@ public class IAMLoginModule implements LoginModule { private static final Logger log = LoggerFactory.getLogger(IAMLoginModule.class); static { + ClassLoaderAwareIAMSaslClientProvider.initialize(); IAMSaslClientProvider.initialize(); } diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/ClassLoaderAwareIAMSaslClientProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/ClassLoaderAwareIAMSaslClientProvider.java new file mode 100644 index 0000000..712fc70 --- /dev/null +++ b/src/main/java/software/amazon/msk/auth/iam/internals/ClassLoaderAwareIAMSaslClientProvider.java @@ -0,0 +1,36 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +package software.amazon.msk.auth.iam.internals; + +import software.amazon.msk.auth.iam.IAMLoginModule; +import software.amazon.msk.auth.iam.internals.IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory; + +import java.security.Provider; +import java.security.Security; + +public class ClassLoaderAwareIAMSaslClientProvider extends Provider { + /** + * Constructs an IAM Sasl Client provider that installs a {@link ClassLoaderAwareIAMSaslClientFactory}. + */ + protected ClassLoaderAwareIAMSaslClientProvider() { + super("ClassLoader Aware SASL/IAM Client Provider", 1.0, "SASL/IAM Client Provider for Kafka"); + put("SaslClientFactory." + IAMLoginModule.MECHANISM, ClassLoaderAwareIAMSaslClientFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new ClassLoaderAwareIAMSaslClientProvider()); + } +} diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClient.java b/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClient.java index 1c865cd..9b4794a 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClient.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClient.java @@ -26,6 +26,7 @@ import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslClientFactory; import javax.security.sasl.SaslException; @@ -203,7 +204,31 @@ private static boolean isChallengeEmpty(byte[] challenge) { return true; } + public static class ClassLoaderAwareIAMSaslClientFactory implements SaslClientFactory { + + @Override + public SaslClient createSaslClient(String[] mechanisms, + String authorizationId, + String protocol, + String serverName, + Map props, + CallbackHandler cbh) throws SaslException { + String mechanismName = getMechanismNameForClassLoader(cbh.getClass().getClassLoader()); + + // Create a client by delegating to the SaslClientFactory for the classloader of the CallbackHandler + return Sasl.createSaslClient( + new String[] { mechanismName }, + authorizationId, protocol, serverName, props, cbh); + } + + @Override + public String[] getMechanismNames(Map props) { + return new String[] { IAMLoginModule.MECHANISM }; + } + } + public static class IAMSaslClientFactory implements SaslClientFactory { + @Override public SaslClient createSaslClient(String[] mechanisms, String authorizationId, @@ -211,20 +236,26 @@ public SaslClient createSaslClient(String[] mechanisms, String serverName, Map props, CallbackHandler cbh) throws SaslException { + String mechanismName = getMechanismNameForClassLoader(getClass().getClassLoader()); + for (String mechanism : mechanisms) { - if (IAMLoginModule.MECHANISM.equals(mechanism)) { + if (mechanismName.equals(mechanism)) { return new IAMSaslClient(mechanism, cbh, serverName, new AWS4SignedPayloadGenerator()); } } + throw new SaslException( - "Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. The supported" + - "mechanism is " + IAMLoginModule.MECHANISM); + "Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. " + + "The supported mechanism is " + mechanismName); } @Override public String[] getMechanismNames(Map props) { - return new String[] { IAMLoginModule.MECHANISM }; + return new String[] { getMechanismNameForClassLoader(getClass().getClassLoader()) }; } } + public static String getMechanismNameForClassLoader(ClassLoader classLoader) { + return IAMLoginModule.MECHANISM + "." + classLoader.hashCode(); + } } diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClientProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClientProvider.java index df32691..15535fb 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClientProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/IAMSaslClientProvider.java @@ -15,19 +15,23 @@ */ package software.amazon.msk.auth.iam.internals; -import software.amazon.msk.auth.iam.IAMLoginModule; +import software.amazon.msk.auth.iam.internals.IAMSaslClient.IAMSaslClientFactory; import java.security.Provider; import java.security.Security; +import static software.amazon.msk.auth.iam.internals.IAMSaslClient.getMechanismNameForClassLoader; + public class IAMSaslClientProvider extends Provider { /** * Constructs a IAM Sasl Client provider with a fixed name, version number, * and information. */ protected IAMSaslClientProvider() { - super("SASL/IAM Client Provider", 1.0, "SASL/IAM Client Provider for Kafka"); - put("SaslClientFactory." + IAMLoginModule.MECHANISM, IAMSaslClient.IAMSaslClientFactory.class.getName()); + super("SASL/IAM Client Provider (" + + IAMSaslClientProvider.class.getClassLoader().hashCode(), 1.0, + ") SASL/IAM Client Provider for Kafka"); + put("SaslClientFactory." + getMechanismNameForClassLoader(getClass().getClassLoader()), IAMSaslClientFactory.class.getName()); } public static void initialize() { diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/IAMSaslClientTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/IAMSaslClientTest.java index 3dd82b5..578e841 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/IAMSaslClientTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/IAMSaslClientTest.java @@ -16,13 +16,16 @@ package software.amazon.msk.auth.iam.internals; import com.amazonaws.auth.BasicAWSCredentials; +import org.junit.jupiter.api.BeforeEach; import software.amazon.msk.auth.iam.IAMClientCallbackHandler; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.RandomStringUtils; import org.apache.kafka.common.errors.IllegalSaslStateException; import org.junit.jupiter.api.Test; +import software.amazon.msk.auth.iam.internals.IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory; +import static java.util.Collections.emptyMap; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -47,14 +50,19 @@ public class IAMSaslClientTest { private static final BasicAWSCredentials BASIC_AWS_CREDENTIALS = new BasicAWSCredentials(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); + @BeforeEach + public void setUp() { + IAMSaslClientProvider.initialize(); + } + @Test - public void testCompleteValidExchange() throws IOException, ParseException { + public void testCompleteValidExchange() throws IOException { IAMSaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler()); runValidExchangeForSaslClient(saslClient, ACCESS_KEY_VALUE, SECRET_KEY_VALUE); } private void runValidExchangeForSaslClient(IAMSaslClient saslClient, String accessKey, String secretKey) { - assertEquals(AWS_MSK_IAM, saslClient.getMechanismName()); + assertEquals(getMechanismName(), saslClient.getMechanismName()); assertTrue(saslClient.hasInitialResponse()); SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> { try { @@ -97,7 +105,7 @@ public void testMultipleSaslClients() throws IOException, ParseException { private IAMClientCallbackHandler getIamClientCallbackHandler() { IAMClientCallbackHandler cbh = new IAMClientCallbackHandler(); - cbh.configure(Collections.EMPTY_MAP, AWS_MSK_IAM, Collections.emptyList()); + cbh.configure(emptyMap(), AWS_MSK_IAM, Collections.emptyList()); return cbh; } @@ -127,11 +135,11 @@ public void testThrowingCallback() throws SaslException { @Test public void testInvalidServerResponse() throws SaslException { SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler()); - assertEquals(AWS_MSK_IAM, saslClient.getMechanismName()); + assertEquals(getMechanismName(), saslClient.getMechanismName()); assertTrue(saslClient.hasInitialResponse()); SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> { try { - byte[] response = saslClient.evaluateChallenge(new byte[]{}); + saslClient.evaluateChallenge(new byte[]{}); } catch (SaslException e) { throw new RuntimeException("Test failed", e); } @@ -149,7 +157,7 @@ public void testInvalidResponseVersion() throws SaslException { SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler()); SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> { try { - byte[] response = saslClient.evaluateChallenge(new byte[]{}); + saslClient.evaluateChallenge(new byte[]{}); } catch (SaslException e) { throw new RuntimeException("Test failed", e); } @@ -174,11 +182,11 @@ private byte[] getResponseWithInvalidVersion() { @Test public void testEmptyServerResponse() throws SaslException { SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler()); - assertEquals(AWS_MSK_IAM, saslClient.getMechanismName()); + assertEquals(getMechanismName(), saslClient.getMechanismName()); assertTrue(saslClient.hasInitialResponse()); SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> { try { - byte[] response = saslClient.evaluateChallenge(new byte[]{}); + saslClient.evaluateChallenge(new byte[]{}); } catch (SaslException e) { throw new RuntimeException("Test failed", e); } @@ -191,8 +199,8 @@ public void testEmptyServerResponse() throws SaslException { @Test public void testFactoryMechanisms() { - assertArrayEquals(new String[]{AWS_MSK_IAM}, - new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(Collections.emptyMap())); + assertArrayEquals(new String[] { getMechanismName() }, + new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(emptyMap())); } @Test @@ -200,10 +208,16 @@ public void testInvalidMechanism() { assertThrows(SaslException.class, () -> new IAMSaslClient.IAMSaslClientFactory() .createSaslClient(new String[]{AWS_MSK_IAM + "BAD"}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME, - Collections.emptyMap(), + emptyMap(), new SuccessfulIAMCallbackHandler(BASIC_AWS_CREDENTIALS))); } + @Test + public void testClassLoaderAwareIAMSaslClientFactoryMechanisms() { + assertArrayEquals(new String[] { AWS_MSK_IAM }, + new ClassLoaderAwareIAMSaslClientFactory().getMechanismNames(emptyMap())); + } + private static class SuccessfulIAMCallbackHandler extends IAMClientCallbackHandler { private final BasicAWSCredentials basicAWSCredentials; @@ -240,10 +254,14 @@ protected void handleCallback(AWSCredentialsCallback callback) throws IOExceptio } private IAMSaslClient getIAMClient(Supplier handlerSupplier) throws SaslException { - return (IAMSaslClient )new IAMSaslClient.IAMSaslClientFactory() - .createSaslClient(new String[]{AWS_MSK_IAM}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME, - Collections.emptyMap(), + return (IAMSaslClient) new IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory() + .createSaslClient(new String[] { AWS_MSK_IAM }, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME, + emptyMap(), handlerSupplier.get()); } + private String getMechanismName() { + return AWS_MSK_IAM + "." + getClass().getClassLoader().hashCode(); + } + }