Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for confused deputy parameters #173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import software.amazon.awssdk.services.sts.endpoints.StsEndpointProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;
import software.amazon.awssdk.utils.StringUtils;


/**
Expand Down Expand Up @@ -97,7 +98,9 @@ public class MSKCredentialProvider implements AwsCredentialsProvider, AutoClosea
private static final int DEFAULT_MAX_RETRIES = 3;
private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000;
private static final Duration BASE_DELAY = Duration.ofMillis(500);

// Confused deputy headers
private static final String HEADER_X_AMZ_SOURCE_ACCOUNT = "x-amz-source-account";
private static final String HEADER_X_AMZ_SOURCE_ARN = "x-amz-source-arn";
private final List<AutoCloseable> closeableProviders;
private final AwsCredentialsProvider compositeDelegate;
@Getter(AccessLevel.PACKAGE)
Expand Down Expand Up @@ -324,74 +327,63 @@ private Optional<StsAssumeRoleCredentialsProvider> getStsRoleProvider() {
}
String sessionName = Optional.ofNullable((String) optionsMap.get(AWS_ROLE_SESSION_KEY))
.orElse("aws-msk-iam-auth");
String stsRegion = getStsRegion();
String externalId = (String) optionsMap.getOrDefault(AWS_ROLE_EXTERNAL_ID, null);
String sourceAccount = (String) optionsMap.getOrDefault(HEADER_X_AMZ_SOURCE_ACCOUNT, null);
String sourceArn = (String) optionsMap.getOrDefault(HEADER_X_AMZ_SOURCE_ARN, null);
AssumeRoleRequest assumeRoleRequest = getAssumeRoleRequest((String) p, sessionName, externalId, sourceAccount, sourceArn);

String stsRegion = getStsRegion();
String accessKey = (String) optionsMap.getOrDefault(AWS_ROLE_ACCESS_KEY_ID, null);
String secretKey = (String) optionsMap.getOrDefault(AWS_ROLE_SECRET_ACCESS_KEY, null);
String sessionToken = (String) optionsMap.getOrDefault(AWS_ROLE_SESSION_TOKEN, null);
String externalId = (String) optionsMap.getOrDefault(AWS_ROLE_EXTERNAL_ID, null);
if (accessKey != null && secretKey != null) {
AwsCredentialsProvider credentials = StaticCredentialsProvider.create(
sessionToken != null
? AwsSessionCredentials.create(accessKey, secretKey, sessionToken)
: AwsBasicCredentials.create(accessKey, secretKey));
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials);
}
else if (externalId != null) {
return createSTSRoleCredentialProvider((String) p, externalId, sessionName, stsRegion);
}

return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion);
return createSTSRoleCredentialProvider(assumeRoleRequest, stsRegion, accessKey, secretKey, sessionToken);
});
}

StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
String roleArn,
String sessionName,
String stsRegion) {
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.roleArn(roleArn)
.roleSessionName(sessionName)
.build();
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion))
.build();
return StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(roleRequest)
.build();
}

StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
String roleArn,
String sessionName, String stsRegion,
AwsCredentialsProvider credentials) {
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.roleArn(roleArn)
.roleSessionName(sessionName)
.build();
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion))
.credentialsProvider(credentials)
.build();
return StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(roleRequest)
.build();
AssumeRoleRequest getAssumeRoleRequest(
String roleArn,
String sessionName,
String externalId,
String sourceAccount,
String sourceArn) {
AssumeRoleRequest.Builder roleRequestBuilder = AssumeRoleRequest.builder()
.roleArn(roleArn)
.roleSessionName(sessionName);
if (!StringUtils.isEmpty(externalId)) {
roleRequestBuilder.externalId(externalId);
}
if (!StringUtils.isEmpty(sourceAccount) || !StringUtils.isEmpty(sourceArn)) {
roleRequestBuilder.overrideConfiguration(c -> {
if (!StringUtils.isEmpty(sourceAccount)) {
c.putHeader(HEADER_X_AMZ_SOURCE_ACCOUNT, sourceAccount);
}
if (!StringUtils.isEmpty(sourceArn)) {
c.putHeader(HEADER_X_AMZ_SOURCE_ARN, sourceArn);
}
});
}
return roleRequestBuilder.build();
}

StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
String roleArn,
String externalId,
String sessionName,
String stsRegion) {
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.externalId(externalId)
.roleArn(roleArn)
.roleSessionName(sessionName)
.build();
AssumeRoleRequest assumeRoleRequest,
String stsRegion,
String accessKey,
String secretKey,
String sessionToken) {
StsClientBuilder stsClientBuilder = getStsClientBuilder(Region.of(stsRegion));
if (accessKey != null && secretKey != null) {
AwsCredentialsProvider credentials = StaticCredentialsProvider.create(
sessionToken != null
? AwsSessionCredentials.create(accessKey, secretKey, sessionToken)
: AwsBasicCredentials.create(accessKey, secretKey));
stsClientBuilder.credentialsProvider(credentials);
}
return StsAssumeRoleCredentialsProvider.builder()
.stsClient(getStsClientBuilder(Region.of(stsRegion)).build())
.refreshRequest(roleRequest)
.build();
.stsClient(stsClientBuilder.build())
.refreshRequest(assumeRoleRequest)
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class MSKCredentialProviderTest {
private static final String AWS_ROLE_SECRET_ACCESS_KEY = "awsRoleSecretAccessKey";
private static final String AWS_PROFILE_NAME = "awsProfileName";
private static final String AWS_DEBUG_CREDS_NAME = "awsDebugCreds";
private static final String HEADER_X_AMZ_SOURCE_ACCOUNT = "x-amz-source-account";
private static final String HEADER_X_AMZ_SOURCE_ARN = "x-amz-source-arn";

/**
* If no options are passed in it should use the default credentials provider
Expand Down Expand Up @@ -281,6 +283,30 @@ StsClient getStsClientForDebuggingCreds(AwsCredentials credentials) {
Mockito.verifyNoMoreInteractions(mockEcsCredsProvider);
}

@Test
public void testWithConfusedDeputyParameters() {
StsAssumeRoleCredentialsProvider mockStsRoleProvider = Mockito
.mock(StsAssumeRoleCredentialsProvider.class);
Mockito.when(mockStsRoleProvider.resolveIdentity())
.thenAnswer(i -> CompletableFuture.completedFuture(AwsSessionCredentials.create(ACCESS_KEY_VALUE, SECRET_KEY_VALUE, SESSION_TOKEN)));

Map<String, String> optionsMap = new HashMap<>();
optionsMap.put(AWS_ROLE_ARN, TEST_ROLE_ARN);
optionsMap.put(HEADER_X_AMZ_SOURCE_ACCOUNT, "ACCT");
optionsMap.put(HEADER_X_AMZ_SOURCE_ARN, "ARN");

MSKCredentialProvider.ProviderBuilder providerBuilder = getProviderBuilder(mockStsRoleProvider, optionsMap,
"aws-msk-iam-auth");
MSKCredentialProvider provider = new MSKCredentialProvider(providerBuilder);
assertFalse(provider.getShouldDebugCreds());

AwsCredentials credentials = provider.resolveCredentials();
validateBasicSessionCredentials(credentials);

provider.close();
Mockito.verify(mockStsRoleProvider, times(1)).close();
}

@Test
public void testEc2CredsWithDebugCredsNoAccessToSts_Succeed() {
Map<String, String> optionsMap = new HashMap<>();
Expand Down