From 1c63835887d08a064a9cdf360f68a75345b9f9ca Mon Sep 17 00:00:00 2001 From: Purvin Patel Date: Wed, 29 May 2024 11:30:01 -0700 Subject: [PATCH] update for confused deputy parameters --- .../iam/internals/MSKCredentialProvider.java | 106 ++++++++---------- .../internals/MSKCredentialProviderTest.java | 26 +++++ 2 files changed, 75 insertions(+), 57 deletions(-) diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java index 3945400..a7b5fb1 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java @@ -60,6 +60,7 @@ import software.amazon.awssdk.services.sts.endpoints.StsEndpointProvider; import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse; +import software.amazon.awssdk.utils.StringUtils; /** @@ -97,7 +98,9 @@ public class MSKCredentialProvider implements AwsCredentialsProvider, AutoClosea private static final int DEFAULT_MAX_RETRIES = 3; private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000; private static final Duration BASE_DELAY = Duration.ofMillis(500); - + // Confused deputy headers + private static final String HEADER_X_AMZ_SOURCE_ACCOUNT = "x-amz-source-account"; + private static final String HEADER_X_AMZ_SOURCE_ARN = "x-amz-source-arn"; private final List closeableProviders; private final AwsCredentialsProvider compositeDelegate; @Getter(AccessLevel.PACKAGE) @@ -324,74 +327,63 @@ private Optional getStsRoleProvider() { } String sessionName = Optional.ofNullable((String) optionsMap.get(AWS_ROLE_SESSION_KEY)) .orElse("aws-msk-iam-auth"); - String stsRegion = getStsRegion(); + String externalId = (String) optionsMap.getOrDefault(AWS_ROLE_EXTERNAL_ID, null); + String sourceAccount = (String) optionsMap.getOrDefault(HEADER_X_AMZ_SOURCE_ACCOUNT, null); + String sourceArn = (String) optionsMap.getOrDefault(HEADER_X_AMZ_SOURCE_ARN, null); + AssumeRoleRequest assumeRoleRequest = getAssumeRoleRequest((String) p, sessionName, externalId, sourceAccount, sourceArn); + String stsRegion = getStsRegion(); String accessKey = (String) optionsMap.getOrDefault(AWS_ROLE_ACCESS_KEY_ID, null); String secretKey = (String) optionsMap.getOrDefault(AWS_ROLE_SECRET_ACCESS_KEY, null); String sessionToken = (String) optionsMap.getOrDefault(AWS_ROLE_SESSION_TOKEN, null); - String externalId = (String) optionsMap.getOrDefault(AWS_ROLE_EXTERNAL_ID, null); - if (accessKey != null && secretKey != null) { - AwsCredentialsProvider credentials = StaticCredentialsProvider.create( - sessionToken != null - ? AwsSessionCredentials.create(accessKey, secretKey, sessionToken) - : AwsBasicCredentials.create(accessKey, secretKey)); - return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials); - } - else if (externalId != null) { - return createSTSRoleCredentialProvider((String) p, externalId, sessionName, stsRegion); - } - return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion); + return createSTSRoleCredentialProvider(assumeRoleRequest, stsRegion, accessKey, secretKey, sessionToken); }); } - StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider( - String roleArn, - String sessionName, - String stsRegion) { - AssumeRoleRequest roleRequest = AssumeRoleRequest.builder() - .roleArn(roleArn) - .roleSessionName(sessionName) - .build(); - StsClient stsClient = getStsClientBuilder(Region.of(stsRegion)) - .build(); - return StsAssumeRoleCredentialsProvider.builder() - .stsClient(stsClient) - .refreshRequest(roleRequest) - .build(); - } - - StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider( - String roleArn, - String sessionName, String stsRegion, - AwsCredentialsProvider credentials) { - AssumeRoleRequest roleRequest = AssumeRoleRequest.builder() - .roleArn(roleArn) - .roleSessionName(sessionName) - .build(); - StsClient stsClient = getStsClientBuilder(Region.of(stsRegion)) - .credentialsProvider(credentials) - .build(); - return StsAssumeRoleCredentialsProvider.builder() - .stsClient(stsClient) - .refreshRequest(roleRequest) - .build(); + AssumeRoleRequest getAssumeRoleRequest( + String roleArn, + String sessionName, + String externalId, + String sourceAccount, + String sourceArn) { + AssumeRoleRequest.Builder roleRequestBuilder = AssumeRoleRequest.builder() + .roleArn(roleArn) + .roleSessionName(sessionName); + if (!StringUtils.isEmpty(externalId)) { + roleRequestBuilder.externalId(externalId); + } + if (!StringUtils.isEmpty(sourceAccount) || !StringUtils.isEmpty(sourceArn)) { + roleRequestBuilder.overrideConfiguration(c -> { + if (!StringUtils.isEmpty(sourceAccount)) { + c.putHeader(HEADER_X_AMZ_SOURCE_ACCOUNT, sourceAccount); + } + if (!StringUtils.isEmpty(sourceArn)) { + c.putHeader(HEADER_X_AMZ_SOURCE_ARN, sourceArn); + } + }); + } + return roleRequestBuilder.build(); } StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider( - String roleArn, - String externalId, - String sessionName, - String stsRegion) { - AssumeRoleRequest roleRequest = AssumeRoleRequest.builder() - .externalId(externalId) - .roleArn(roleArn) - .roleSessionName(sessionName) - .build(); + AssumeRoleRequest assumeRoleRequest, + String stsRegion, + String accessKey, + String secretKey, + String sessionToken) { + StsClientBuilder stsClientBuilder = getStsClientBuilder(Region.of(stsRegion)); + if (accessKey != null && secretKey != null) { + AwsCredentialsProvider credentials = StaticCredentialsProvider.create( + sessionToken != null + ? AwsSessionCredentials.create(accessKey, secretKey, sessionToken) + : AwsBasicCredentials.create(accessKey, secretKey)); + stsClientBuilder.credentialsProvider(credentials); + } return StsAssumeRoleCredentialsProvider.builder() - .stsClient(getStsClientBuilder(Region.of(stsRegion)).build()) - .refreshRequest(roleRequest) - .build(); + .stsClient(stsClientBuilder.build()) + .refreshRequest(assumeRoleRequest) + .build(); } } } diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java index 728f6ac..b8b2bc4 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java @@ -74,6 +74,8 @@ public class MSKCredentialProviderTest { private static final String AWS_ROLE_SECRET_ACCESS_KEY = "awsRoleSecretAccessKey"; private static final String AWS_PROFILE_NAME = "awsProfileName"; private static final String AWS_DEBUG_CREDS_NAME = "awsDebugCreds"; + private static final String HEADER_X_AMZ_SOURCE_ACCOUNT = "x-amz-source-account"; + private static final String HEADER_X_AMZ_SOURCE_ARN = "x-amz-source-arn"; /** * If no options are passed in it should use the default credentials provider @@ -281,6 +283,30 @@ StsClient getStsClientForDebuggingCreds(AwsCredentials credentials) { Mockito.verifyNoMoreInteractions(mockEcsCredsProvider); } + @Test + public void testWithConfusedDeputyParameters() { + StsAssumeRoleCredentialsProvider mockStsRoleProvider = Mockito + .mock(StsAssumeRoleCredentialsProvider.class); + Mockito.when(mockStsRoleProvider.resolveIdentity()) + .thenAnswer(i -> CompletableFuture.completedFuture(AwsSessionCredentials.create(ACCESS_KEY_VALUE, SECRET_KEY_VALUE, SESSION_TOKEN))); + + Map optionsMap = new HashMap<>(); + optionsMap.put(AWS_ROLE_ARN, TEST_ROLE_ARN); + optionsMap.put(HEADER_X_AMZ_SOURCE_ACCOUNT, "ACCT"); + optionsMap.put(HEADER_X_AMZ_SOURCE_ARN, "ARN"); + + MSKCredentialProvider.ProviderBuilder providerBuilder = getProviderBuilder(mockStsRoleProvider, optionsMap, + "aws-msk-iam-auth"); + MSKCredentialProvider provider = new MSKCredentialProvider(providerBuilder); + assertFalse(provider.getShouldDebugCreds()); + + AwsCredentials credentials = provider.resolveCredentials(); + validateBasicSessionCredentials(credentials); + + provider.close(); + Mockito.verify(mockStsRoleProvider, times(1)).close(); + } + @Test public void testEc2CredsWithDebugCredsNoAccessToSts_Succeed() { Map optionsMap = new HashMap<>();