Skip to content

Commit

Permalink
CognitoIDP: Ensure MFA functions work with non-python SDK's (#8241)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Oct 19, 2024
1 parent eea6b16 commit fe37392
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 8 deletions.
41 changes: 41 additions & 0 deletions moto/cognitoidp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,20 @@ def _find_backend_by_access_token(self, access_token: str) -> CognitoIdpBackend:
return backend
return cognitoidp_backends[self.account_id][self.region_name]

def _find_backend_by_access_token_or_session(
self, access_token: str, session: str
) -> CognitoIdpBackend:
for account_specific_backends in cognitoidp_backends.values():
for region, backend in account_specific_backends.items():
if region == "global":
continue
if session and session in backend.sessions:
return backend
for p in backend.user_pools.values():
if access_token and access_token in p.access_tokens:
return backend
return cognitoidp_backends[self.account_id][self.region_name]

def _find_backend_for_clientid(self, client_id: str) -> CognitoIdpBackend:
for account_specific_backends in cognitoidp_backends.values():
for region, backend in account_specific_backends.items():
Expand Down Expand Up @@ -2356,6 +2370,33 @@ def respond_to_auth_challenge(
session, client_id, challenge_name, challenge_responses
)

def associate_software_token(
self, access_token: str, session: str
) -> Dict[str, str]:
backend = self._find_backend_by_access_token_or_session(access_token, session)
return backend.associate_software_token(access_token, session)

def verify_software_token(self, access_token: str, session: str) -> Dict[str, str]:
backend = self._find_backend_by_access_token_or_session(access_token, session)
return backend.verify_software_token(access_token, session)

def set_user_mfa_preference(
self,
access_token: str,
software_token_mfa_settings: Dict[str, bool],
sms_mfa_settings: Dict[str, bool],
) -> None:
backend = self._find_backend_by_access_token(access_token)
return backend.set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)

def update_user_attributes(
self, access_token: str, attributes: List[Dict[str, str]]
) -> None:
backend = self._find_backend_by_access_token(access_token)
return backend.update_user_attributes(access_token, attributes)


cognitoidp_backends = BackendDict(CognitoIdpBackend, "cognito-idp")

Expand Down
14 changes: 10 additions & 4 deletions moto/cognitoidp/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,20 +624,24 @@ def initiate_auth(self) -> str:
def associate_software_token(self) -> str:
access_token = self._get_param("AccessToken")
session = self._get_param("Session")
result = self.backend.associate_software_token(access_token, session)
result = self._get_region_agnostic_backend().associate_software_token(
access_token, session
)
return json.dumps(result)

def verify_software_token(self) -> str:
access_token = self._get_param("AccessToken")
session = self._get_param("Session")
result = self.backend.verify_software_token(access_token, session)
result = self._get_region_agnostic_backend().verify_software_token(
access_token, session
)
return json.dumps(result)

def set_user_mfa_preference(self) -> str:
access_token = self._get_param("AccessToken")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings")
self.backend.set_user_mfa_preference(
self._get_region_agnostic_backend().set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)
return ""
Expand Down Expand Up @@ -671,7 +675,9 @@ def add_custom_attributes(self) -> str:
def update_user_attributes(self) -> str:
access_token = self._get_param("AccessToken")
attributes = self._get_param("UserAttributes")
self.backend.update_user_attributes(access_token, attributes)
self._get_region_agnostic_backend().update_user_attributes(
access_token, attributes
)
return json.dumps({})


Expand Down
10 changes: 9 additions & 1 deletion other_langs/tests_java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.28.13</version>
<version>2.28.26</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand All @@ -36,6 +36,14 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cognitoidentity</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cognitoidentityprovider</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cognitoidentityprovider.CognitoIdentityProviderClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.ses.SesClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
Expand All @@ -22,9 +23,14 @@ public class DependencyFactory {

private DependencyFactory() {}

/**
* @return an instance of S3Client
*/
public static CognitoIdentityProviderClient cognitoIdpClient() {
return CognitoIdentityProviderClient.builder()
.region(Region.US_EAST_1)
.httpClientBuilder(ApacheHttpClient.builder())
.endpointOverride(MOTO_URI)
.build();
}

public static S3Client s3Client() {
return S3Client.builder()
.region(Region.US_EAST_1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package moto.tests;

import static moto.tests.DependencyFactory.cognitoIdpClient;
import static org.junit.Assert.*;

import org.junit.Test;
import software.amazon.awssdk.services.cognitoidentityprovider.CognitoIdentityProviderClient;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AssociateSoftwareTokenRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AssociateSoftwareTokenResponse;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AttributeType;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AuthFlowType;
import software.amazon.awssdk.services.cognitoidentityprovider.model.CreateUserPoolRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.CreateUserPoolClientRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.ConfirmSignUpRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.InitiateAuthRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.InitiateAuthResponse;
import software.amazon.awssdk.services.cognitoidentityprovider.model.SignUpRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;


public class CognitoIDPTest {

@Test
public void testAssociateSoftwareToken() {
CognitoIdentityProviderClient client = cognitoIdpClient();

CreateUserPoolRequest up_request = CreateUserPoolRequest.builder()
.poolName("myfirstuserpool")
.build();

String userPoolId = client.createUserPool(up_request).userPool().id();
System.out.println(userPoolId);

CreateUserPoolClientRequest upc_request = CreateUserPoolClientRequest.builder()
.clientName("myfirstuserpoolclient")
.userPoolId(userPoolId)
.build();

String clientId = client.createUserPoolClient(upc_request).userPoolClient().clientId();
System.out.println(clientId);

// Create User
AttributeType userAttrs = AttributeType.builder()
.name("email")
.value("[email protected]")
.build();
String password = "P@ssw0rdWithTonsOfCharacters";

List<AttributeType> userAttrsList = new ArrayList<>();
userAttrsList.add(userAttrs);
SignUpRequest signUpRequest = SignUpRequest.builder()
.userAttributes(userAttrsList)
.username("myuser")
.clientId(clientId)
.password(password)
.build();

client.signUp(signUpRequest);

ConfirmSignUpRequest confirmSignUpRequest = ConfirmSignUpRequest.builder()
.clientId(clientId)
.confirmationCode("code")
.username("myuser")
.build();

client.confirmSignUp(confirmSignUpRequest);

InitiateAuthResponse initiateAuthResponse = client.initiateAuth(
InitiateAuthRequest.builder()
.authFlow(AuthFlowType.USER_PASSWORD_AUTH)
.authParameters(Map.of(
"USERNAME", "myuser",
"SECRET_HASH", "n/a",
"PASSWORD", password))
.clientId(clientId)
.build());
String accessToken = initiateAuthResponse.authenticationResult().accessToken();

AssociateSoftwareTokenRequest softwareTokenRequest = AssociateSoftwareTokenRequest.builder()
.accessToken(accessToken)
.build();
AssociateSoftwareTokenResponse tokenResponse = client
.associateSoftwareToken(softwareTokenRequest);
String secretCode = tokenResponse.secretCode();

assertNotNull(secretCode);
}
}
91 changes: 91 additions & 0 deletions tests/test_cognitoidp/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,94 @@ def test_admin_create_user_without_authentication():
assert "AuthenticationResult" in response
assert "IdToken" in response["AuthenticationResult"]
assert "AccessToken" in response["AuthenticationResult"]


def test_associate_software_token():
backend = server.create_backend_app("cognito-idp")
test_client = backend.test_client()

# Create User Pool
res = test_client.post(
"/",
data='{"PoolName": "test-pool"}',
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.CreateUserPool",
"Authorization": "AWS4-HMAC-SHA256 Credential=abcd/20010101/us-east-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=...",
},
)
user_pool_id = json.loads(res.data)["UserPool"]["Id"]

# Create User Pool Client
data = {
"UserPoolId": user_pool_id,
"ClientName": "some-client",
"GenerateSecret": False,
"ExplicitAuthFlows": ["ALLOW_USER_PASSWORD_AUTH"],
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.CreateUserPoolClient",
"Authorization": "AWS4-HMAC-SHA256 Credential=abcd/20010101/us-east-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=...",
},
)
client_id = json.loads(res.data)["UserPoolClient"]["ClientId"]

# Sign Up User
data = {
"ClientId": client_id,
"Username": "user_2_mfa",
"Password": "12312sdfasASDFDSF$",
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.SignUp"},
)
assert res.status_code == 200
assert json.loads(res.data)["UserConfirmed"] is False

# Confirm Sign Up User
data = {"ClientId": client_id, "Username": "user_2_mfa", "ConfirmationCode": "sth"}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.ConfirmSignUp"},
)

# Initiate Auth
data = {
"AuthFlow": "USER_PASSWORD_AUTH",
"AuthParameters": {
"USERNAME": "user_2_mfa",
"PASSWORD": "12312sdfasASDFDSF$",
"SECRET_HASH": "kIWuIv6ElVe9ahZHJ+gqvZe6CgEkVE/BjQmJcMSgF3E=",
},
"ClientId": client_id,
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.InitiateAuth"},
)
auth_data = json.loads(res.data.decode("utf-8"))["AuthenticationResult"]

# Get User
data = {"AccessToken": auth_data["AccessToken"]}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.GetUser"},
)

# Associate Software Token
data = {"AccessToken": auth_data["AccessToken"]}
res = test_client.post(
"/",
data=json.dumps(data),
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.AssociateSoftwareToken"
},
)
assert json.loads(res.data.decode("utf-8")) == {"SecretCode": "asdfasdfasdf"}

0 comments on commit fe37392

Please sign in to comment.