Skip to content

Commit

Permalink
[aws#36] Adding a proxy to delegate to appropriate IAMSaslClientProvi…
Browse files Browse the repository at this point in the history
…der based on ClassLoader
  • Loading branch information
dannycranmer committed Oct 14, 2022
1 parent 5e2cb95 commit f3069d3
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.gradle
bin
internals
build
lombok.config
out/
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package software.amazon.msk.auth.iam;

import software.amazon.msk.auth.iam.internals.ClassLoaderAwareIAMSaslClientProvider;
import software.amazon.msk.auth.iam.internals.IAMSaslClientProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,6 +36,7 @@ public class IAMLoginModule implements LoginModule {
private static final Logger log = LoggerFactory.getLogger(IAMLoginModule.class);

static {
ClassLoaderAwareIAMSaslClientProvider.initialize();
IAMSaslClientProvider.initialize();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package software.amazon.msk.auth.iam.internals;

import software.amazon.msk.auth.iam.IAMLoginModule;
import software.amazon.msk.auth.iam.internals.IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory;

import java.security.Provider;
import java.security.Security;

public class ClassLoaderAwareIAMSaslClientProvider extends Provider {
/**
* Constructs an IAM Sasl Client provider that installs a {@link ClassLoaderAwareIAMSaslClientFactory}.
*/
protected ClassLoaderAwareIAMSaslClientProvider() {
super("ClassLoader Aware SASL/IAM Client Provider", 1.0, "SASL/IAM Client Provider for Kafka");
put("SaslClientFactory." + IAMLoginModule.MECHANISM, ClassLoaderAwareIAMSaslClientFactory.class.getName());
}

public static void initialize() {
Security.addProvider(new ClassLoaderAwareIAMSaslClientProvider());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
Expand Down Expand Up @@ -203,28 +204,58 @@ private static boolean isChallengeEmpty(byte[] challenge) {
return true;
}

public static class ClassLoaderAwareIAMSaslClientFactory implements SaslClientFactory {

@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {
String mechanismName = getMechanismNameForClassLoader(cbh.getClass().getClassLoader());

// Create a client by delegating to the SaslClientFactory for the classloader of the CallbackHandler
return Sasl.createSaslClient(
new String[] { mechanismName },
authorizationId, protocol, serverName, props, cbh);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
return new String[] { IAMLoginModule.MECHANISM };
}
}

public static class IAMSaslClientFactory implements SaslClientFactory {

@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {
String mechanismName = getMechanismNameForClassLoader(getClass().getClassLoader());

for (String mechanism : mechanisms) {
if (IAMLoginModule.MECHANISM.equals(mechanism)) {
if (mechanismName.equals(mechanism)) {
return new IAMSaslClient(mechanism, cbh, serverName, new AWS4SignedPayloadGenerator());
}
}

throw new SaslException(
"Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. The supported" +
"mechanism is " + IAMLoginModule.MECHANISM);
"Requested mechanisms " + Arrays.asList(mechanisms) + " not supported. " +
"The supported mechanism is " + mechanismName);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
return new String[] { IAMLoginModule.MECHANISM };
return new String[] { getMechanismNameForClassLoader(getClass().getClassLoader()) };
}
}

public static String getMechanismNameForClassLoader(ClassLoader classLoader) {
return IAMLoginModule.MECHANISM + "." + classLoader.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
*/
package software.amazon.msk.auth.iam.internals;

import software.amazon.msk.auth.iam.IAMLoginModule;
import software.amazon.msk.auth.iam.internals.IAMSaslClient.IAMSaslClientFactory;

import java.security.Provider;
import java.security.Security;

import static software.amazon.msk.auth.iam.internals.IAMSaslClient.getMechanismNameForClassLoader;

public class IAMSaslClientProvider extends Provider {
/**
* Constructs a IAM Sasl Client provider with a fixed name, version number,
* and information.
*/
protected IAMSaslClientProvider() {
super("SASL/IAM Client Provider", 1.0, "SASL/IAM Client Provider for Kafka");
put("SaslClientFactory." + IAMLoginModule.MECHANISM, IAMSaslClient.IAMSaslClientFactory.class.getName());
super("SASL/IAM Client Provider (" +
IAMSaslClientProvider.class.getClassLoader().hashCode(), 1.0,
") SASL/IAM Client Provider for Kafka");
put("SaslClientFactory." + getMechanismNameForClassLoader(getClass().getClassLoader()), IAMSaslClientFactory.class.getName());
}

public static void initialize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
package software.amazon.msk.auth.iam.internals;

import com.amazonaws.auth.BasicAWSCredentials;
import org.junit.jupiter.api.BeforeEach;
import software.amazon.msk.auth.iam.IAMClientCallbackHandler;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.junit.jupiter.api.Test;
import software.amazon.msk.auth.iam.internals.IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory;

import static java.util.Collections.emptyMap;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand All @@ -47,14 +50,19 @@ public class IAMSaslClientTest {

private static final BasicAWSCredentials BASIC_AWS_CREDENTIALS = new BasicAWSCredentials(ACCESS_KEY_VALUE, SECRET_KEY_VALUE);

@BeforeEach
public void setUp() {
IAMSaslClientProvider.initialize();
}

@Test
public void testCompleteValidExchange() throws IOException, ParseException {
public void testCompleteValidExchange() throws IOException {
IAMSaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
runValidExchangeForSaslClient(saslClient, ACCESS_KEY_VALUE, SECRET_KEY_VALUE);
}

private void runValidExchangeForSaslClient(IAMSaslClient saslClient, String accessKey, String secretKey) {
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
Expand Down Expand Up @@ -97,7 +105,7 @@ public void testMultipleSaslClients() throws IOException, ParseException {

private IAMClientCallbackHandler getIamClientCallbackHandler() {
IAMClientCallbackHandler cbh = new IAMClientCallbackHandler();
cbh.configure(Collections.EMPTY_MAP, AWS_MSK_IAM, Collections.emptyList());
cbh.configure(emptyMap(), AWS_MSK_IAM, Collections.emptyList());
return cbh;
}

Expand Down Expand Up @@ -127,11 +135,11 @@ public void testThrowingCallback() throws SaslException {
@Test
public void testInvalidServerResponse() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -149,7 +157,7 @@ public void testInvalidResponseVersion() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -174,11 +182,11 @@ private byte[] getResponseWithInvalidVersion() {
@Test
public void testEmptyServerResponse() throws SaslException {
SaslClient saslClient = getSuccessfulIAMClient(getIamClientCallbackHandler());
assertEquals(AWS_MSK_IAM, saslClient.getMechanismName());
assertEquals(getMechanismName(), saslClient.getMechanismName());
assertTrue(saslClient.hasInitialResponse());
SystemPropertyCredentialsUtils.runTestWithSystemPropertyCredentials(() -> {
try {
byte[] response = saslClient.evaluateChallenge(new byte[]{});
saslClient.evaluateChallenge(new byte[]{});
} catch (SaslException e) {
throw new RuntimeException("Test failed", e);
}
Expand All @@ -191,19 +199,25 @@ public void testEmptyServerResponse() throws SaslException {

@Test
public void testFactoryMechanisms() {
assertArrayEquals(new String[]{AWS_MSK_IAM},
new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(Collections.emptyMap()));
assertArrayEquals(new String[] { getMechanismName() },
new IAMSaslClient.IAMSaslClientFactory().getMechanismNames(emptyMap()));
}

@Test
public void testInvalidMechanism() {

assertThrows(SaslException.class, () -> new IAMSaslClient.IAMSaslClientFactory()
.createSaslClient(new String[]{AWS_MSK_IAM + "BAD"}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
Collections.emptyMap(),
emptyMap(),
new SuccessfulIAMCallbackHandler(BASIC_AWS_CREDENTIALS)));
}

@Test
public void testClassLoaderAwareIAMSaslClientFactoryMechanisms() {
assertArrayEquals(new String[] { AWS_MSK_IAM },
new ClassLoaderAwareIAMSaslClientFactory().getMechanismNames(emptyMap()));
}

private static class SuccessfulIAMCallbackHandler extends IAMClientCallbackHandler {
private final BasicAWSCredentials basicAWSCredentials;

Expand Down Expand Up @@ -240,10 +254,14 @@ protected void handleCallback(AWSCredentialsCallback callback) throws IOExceptio
}

private IAMSaslClient getIAMClient(Supplier<IAMClientCallbackHandler> handlerSupplier) throws SaslException {
return (IAMSaslClient )new IAMSaslClient.IAMSaslClientFactory()
.createSaslClient(new String[]{AWS_MSK_IAM}, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
Collections.emptyMap(),
return (IAMSaslClient) new IAMSaslClient.ClassLoaderAwareIAMSaslClientFactory()
.createSaslClient(new String[] { AWS_MSK_IAM }, "AUTH_ID", "PROTOCOL", VALID_HOSTNAME,
emptyMap(),
handlerSupplier.get());
}

private String getMechanismName() {
return AWS_MSK_IAM + "." + getClass().getClassLoader().hashCode();
}

}

0 comments on commit f3069d3

Please sign in to comment.