Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CognitoIDP: Ensure MFA functions work with non-python SDK's #8241

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions moto/cognitoidp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,20 @@
return backend
return cognitoidp_backends[self.account_id][self.region_name]

def _find_backend_by_access_token_or_session(
self, access_token: str, session: str
) -> CognitoIdpBackend:
for account_specific_backends in cognitoidp_backends.values():
for region, backend in account_specific_backends.items():
if region == "global":
continue

Check warning on line 2307 in moto/cognitoidp/models.py

View check run for this annotation

Codecov / codecov/patch

moto/cognitoidp/models.py#L2307

Added line #L2307 was not covered by tests
if session and session in backend.sessions:
return backend
for p in backend.user_pools.values():
if access_token and access_token in p.access_tokens:
return backend
return cognitoidp_backends[self.account_id][self.region_name]

Check warning on line 2313 in moto/cognitoidp/models.py

View check run for this annotation

Codecov / codecov/patch

moto/cognitoidp/models.py#L2313

Added line #L2313 was not covered by tests

def _find_backend_for_clientid(self, client_id: str) -> CognitoIdpBackend:
for account_specific_backends in cognitoidp_backends.values():
for region, backend in account_specific_backends.items():
Expand Down Expand Up @@ -2356,6 +2370,33 @@
session, client_id, challenge_name, challenge_responses
)

def associate_software_token(
self, access_token: str, session: str
) -> Dict[str, str]:
backend = self._find_backend_by_access_token_or_session(access_token, session)
return backend.associate_software_token(access_token, session)

def verify_software_token(self, access_token: str, session: str) -> Dict[str, str]:
backend = self._find_backend_by_access_token_or_session(access_token, session)
return backend.verify_software_token(access_token, session)

def set_user_mfa_preference(
self,
access_token: str,
software_token_mfa_settings: Dict[str, bool],
sms_mfa_settings: Dict[str, bool],
) -> None:
backend = self._find_backend_by_access_token(access_token)
return backend.set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)

def update_user_attributes(
self, access_token: str, attributes: List[Dict[str, str]]
) -> None:
backend = self._find_backend_by_access_token(access_token)
return backend.update_user_attributes(access_token, attributes)


cognitoidp_backends = BackendDict(CognitoIdpBackend, "cognito-idp")

Expand Down
14 changes: 10 additions & 4 deletions moto/cognitoidp/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,20 +624,24 @@ def initiate_auth(self) -> str:
def associate_software_token(self) -> str:
access_token = self._get_param("AccessToken")
session = self._get_param("Session")
result = self.backend.associate_software_token(access_token, session)
result = self._get_region_agnostic_backend().associate_software_token(
access_token, session
)
return json.dumps(result)

def verify_software_token(self) -> str:
access_token = self._get_param("AccessToken")
session = self._get_param("Session")
result = self.backend.verify_software_token(access_token, session)
result = self._get_region_agnostic_backend().verify_software_token(
access_token, session
)
return json.dumps(result)

def set_user_mfa_preference(self) -> str:
access_token = self._get_param("AccessToken")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings")
self.backend.set_user_mfa_preference(
self._get_region_agnostic_backend().set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)
return ""
Expand Down Expand Up @@ -671,7 +675,9 @@ def add_custom_attributes(self) -> str:
def update_user_attributes(self) -> str:
access_token = self._get_param("AccessToken")
attributes = self._get_param("UserAttributes")
self.backend.update_user_attributes(access_token, attributes)
self._get_region_agnostic_backend().update_user_attributes(
access_token, attributes
)
return json.dumps({})


Expand Down
10 changes: 9 additions & 1 deletion other_langs/tests_java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.28.13</version>
<version>2.28.26</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand All @@ -36,6 +36,14 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cognitoidentity</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cognitoidentityprovider</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cognitoidentityprovider.CognitoIdentityProviderClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.ses.SesClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
Expand All @@ -22,9 +23,14 @@ public class DependencyFactory {

private DependencyFactory() {}

/**
* @return an instance of S3Client
*/
public static CognitoIdentityProviderClient cognitoIdpClient() {
return CognitoIdentityProviderClient.builder()
.region(Region.US_EAST_1)
.httpClientBuilder(ApacheHttpClient.builder())
.endpointOverride(MOTO_URI)
.build();
}

public static S3Client s3Client() {
return S3Client.builder()
.region(Region.US_EAST_1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package moto.tests;

import static moto.tests.DependencyFactory.cognitoIdpClient;
import static org.junit.Assert.*;

import org.junit.Test;
import software.amazon.awssdk.services.cognitoidentityprovider.CognitoIdentityProviderClient;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AssociateSoftwareTokenRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AssociateSoftwareTokenResponse;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AttributeType;
import software.amazon.awssdk.services.cognitoidentityprovider.model.AuthFlowType;
import software.amazon.awssdk.services.cognitoidentityprovider.model.CreateUserPoolRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.CreateUserPoolClientRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.ConfirmSignUpRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.InitiateAuthRequest;
import software.amazon.awssdk.services.cognitoidentityprovider.model.InitiateAuthResponse;
import software.amazon.awssdk.services.cognitoidentityprovider.model.SignUpRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;


public class CognitoIDPTest {

@Test
public void testAssociateSoftwareToken() {
CognitoIdentityProviderClient client = cognitoIdpClient();

CreateUserPoolRequest up_request = CreateUserPoolRequest.builder()
.poolName("myfirstuserpool")
.build();

String userPoolId = client.createUserPool(up_request).userPool().id();
System.out.println(userPoolId);

CreateUserPoolClientRequest upc_request = CreateUserPoolClientRequest.builder()
.clientName("myfirstuserpoolclient")
.userPoolId(userPoolId)
.build();

String clientId = client.createUserPoolClient(upc_request).userPoolClient().clientId();
System.out.println(clientId);

// Create User
AttributeType userAttrs = AttributeType.builder()
.name("email")
.value("[email protected]")
.build();
String password = "P@ssw0rdWithTonsOfCharacters";

List<AttributeType> userAttrsList = new ArrayList<>();
userAttrsList.add(userAttrs);
SignUpRequest signUpRequest = SignUpRequest.builder()
.userAttributes(userAttrsList)
.username("myuser")
.clientId(clientId)
.password(password)
.build();

client.signUp(signUpRequest);

ConfirmSignUpRequest confirmSignUpRequest = ConfirmSignUpRequest.builder()
.clientId(clientId)
.confirmationCode("code")
.username("myuser")
.build();

client.confirmSignUp(confirmSignUpRequest);

InitiateAuthResponse initiateAuthResponse = client.initiateAuth(
InitiateAuthRequest.builder()
.authFlow(AuthFlowType.USER_PASSWORD_AUTH)
.authParameters(Map.of(
"USERNAME", "myuser",
"SECRET_HASH", "n/a",
"PASSWORD", password))
.clientId(clientId)
.build());
String accessToken = initiateAuthResponse.authenticationResult().accessToken();

AssociateSoftwareTokenRequest softwareTokenRequest = AssociateSoftwareTokenRequest.builder()
.accessToken(accessToken)
.build();
AssociateSoftwareTokenResponse tokenResponse = client
.associateSoftwareToken(softwareTokenRequest);
String secretCode = tokenResponse.secretCode();

assertNotNull(secretCode);
}
}
91 changes: 91 additions & 0 deletions tests/test_cognitoidp/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,94 @@ def test_admin_create_user_without_authentication():
assert "AuthenticationResult" in response
assert "IdToken" in response["AuthenticationResult"]
assert "AccessToken" in response["AuthenticationResult"]


def test_associate_software_token():
backend = server.create_backend_app("cognito-idp")
test_client = backend.test_client()

# Create User Pool
res = test_client.post(
"/",
data='{"PoolName": "test-pool"}',
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.CreateUserPool",
"Authorization": "AWS4-HMAC-SHA256 Credential=abcd/20010101/us-east-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=...",
},
)
user_pool_id = json.loads(res.data)["UserPool"]["Id"]

# Create User Pool Client
data = {
"UserPoolId": user_pool_id,
"ClientName": "some-client",
"GenerateSecret": False,
"ExplicitAuthFlows": ["ALLOW_USER_PASSWORD_AUTH"],
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.CreateUserPoolClient",
"Authorization": "AWS4-HMAC-SHA256 Credential=abcd/20010101/us-east-2/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=...",
},
)
client_id = json.loads(res.data)["UserPoolClient"]["ClientId"]

# Sign Up User
data = {
"ClientId": client_id,
"Username": "user_2_mfa",
"Password": "12312sdfasASDFDSF$",
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.SignUp"},
)
assert res.status_code == 200
assert json.loads(res.data)["UserConfirmed"] is False

# Confirm Sign Up User
data = {"ClientId": client_id, "Username": "user_2_mfa", "ConfirmationCode": "sth"}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.ConfirmSignUp"},
)

# Initiate Auth
data = {
"AuthFlow": "USER_PASSWORD_AUTH",
"AuthParameters": {
"USERNAME": "user_2_mfa",
"PASSWORD": "12312sdfasASDFDSF$",
"SECRET_HASH": "kIWuIv6ElVe9ahZHJ+gqvZe6CgEkVE/BjQmJcMSgF3E=",
},
"ClientId": client_id,
}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.InitiateAuth"},
)
auth_data = json.loads(res.data.decode("utf-8"))["AuthenticationResult"]

# Get User
data = {"AccessToken": auth_data["AccessToken"]}
res = test_client.post(
"/",
data=json.dumps(data),
headers={"X-Amz-Target": "AWSCognitoIdentityProviderService.GetUser"},
)

# Associate Software Token
data = {"AccessToken": auth_data["AccessToken"]}
res = test_client.post(
"/",
data=json.dumps(data),
headers={
"X-Amz-Target": "AWSCognitoIdentityProviderService.AssociateSoftwareToken"
},
)
assert json.loads(res.data.decode("utf-8")) == {"SecretCode": "asdfasdfasdf"}
Loading