Skip to content

Commit

Permalink
CognitoIDP: Fix Limit and Filter params for list_users() (#8392)
Browse files Browse the repository at this point in the history
  • Loading branch information
changchaishi authored Dec 12, 2024
1 parent ad7de35 commit 8896093
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 49 deletions.
48 changes: 46 additions & 2 deletions moto/cognitoidp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,10 +1354,54 @@ def get_user(self, access_token: str) -> CognitoIdpUser:
raise NotAuthorizedError("Invalid token")

@paginate(pagination_model=PAGINATION_MODEL)
def list_users(self, user_pool_id: str) -> List[CognitoIdpUser]:
def list_users(self, user_pool_id: str, filt: str) -> List[CognitoIdpUser]:
user_pool = self.describe_user_pool(user_pool_id)
users = list(user_pool.users.values())
if filt:
inherent_attributes: Dict[str, Any] = {
"cognito:user_status": lambda u: u.status,
"status": lambda u: "Enabled" if u.enabled else "Disabled",
"username": lambda u: u.username,
}
comparisons: Dict[str, Any] = {
"=": lambda x, y: x == y,
"^=": lambda x, y: x.startswith(y),
}
allowed_attributes = [
"username",
"email",
"phone_number",
"name",
"given_name",
"family_name",
"preferred_username",
"cognito:user_status",
"status",
"sub",
]

return list(user_pool.users.values())
match = re.match(r"([\w:]+)\s*(=|\^=)\s*\"(.*)\"", filt)
if match:
name, op, value = match.groups()
else:
raise InvalidParameterException("Error while parsing filter")
if name not in allowed_attributes:
raise InvalidParameterException(f"Invalid search attribute: {name}")
compare = comparisons[op]
users = [
user
for user in users
if [
attr
for attr in user.attributes
if attr["Name"] == name and compare(attr["Value"], value)
]
or (
name in inherent_attributes
and compare(inherent_attributes[name](user), value)
)
]
return users

def admin_disable_user(self, user_pool_id: str, username: str) -> None:
user = self.admin_get_user(user_pool_id, username)
Expand Down
49 changes: 2 additions & 47 deletions moto/cognitoidp/responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import re
from typing import Any, Dict

from moto.core.responses import TYPE_RESPONSE, BaseResponse
Expand Down Expand Up @@ -371,52 +370,8 @@ def list_users(self) -> str:
filt = self._get_param("Filter")
attributes_to_get = self._get_param("AttributesToGet")
users, token = self.backend.list_users(
user_pool_id, limit=limit, pagination_token=token
)
if filt:
inherent_attributes: Dict[str, Any] = {
"cognito:user_status": lambda u: u.status,
"status": lambda u: "Enabled" if u.enabled else "Disabled",
"username": lambda u: u.username,
}
comparisons: Dict[str, Any] = {
"=": lambda x, y: x == y,
"^=": lambda x, y: x.startswith(y),
}
allowed_attributes = [
"username",
"email",
"phone_number",
"name",
"given_name",
"family_name",
"preferred_username",
"cognito:user_status",
"status",
"sub",
]

match = re.match(r"([\w:]+)\s*(=|\^=)\s*\"(.*)\"", filt)
if match:
name, op, value = match.groups()
else:
raise InvalidParameterException("Error while parsing filter")
if name not in allowed_attributes:
raise InvalidParameterException(f"Invalid search attribute: {name}")
compare = comparisons[op]
users = [
user
for user in users
if [
attr
for attr in user.attributes
if attr["Name"] == name and compare(attr["Value"], value)
]
or (
name in inherent_attributes
and compare(inherent_attributes[name](user), value)
)
]
user_pool_id, filt, limit=limit, pagination_token=token
)
response: Dict[str, Any] = {
"Users": [
user.to_json(extended=True, attributes_to_get=attributes_to_get)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_cognitoidp/test_cognitoidp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,28 @@ def test_list_users():
result = conn.list_users(UserPoolId=user_pool_id, Filter='family_name=""')
assert len(result["Users"]) == 0

# checking Limit and Filter work correctly together
user1_username = "[email protected]"
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=user1_username,
UserAttributes=[{"Name": "phone_number", "Value": "+48555555555"}],
)

result = conn.list_users(
UserPoolId=user_pool_id, Filter='phone_number ^= "+48"', Limit=1
)
assert len(result["Users"]) == 1
assert result["PaginationToken"] is not None

result = conn.list_users(
UserPoolId=user_pool_id,
Filter='phone_number ^= "+48"',
Limit=1,
PaginationToken=result["PaginationToken"],
)
assert len(result["Users"]) == 1


@mock_aws
def test_list_users_incorrect_filter():
Expand Down

0 comments on commit 8896093

Please sign in to comment.