Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate idempotence token length in subsequent tasks #2604

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,41 @@ def sorted_dict_str(d):
}


def get_nested_value(d: Dict[str, Any], keys: list[str]) -> Any:
"""
Retrieve the nested value from a dictionary based on a list of keys.
"""
for key in keys:
if key not in d:
raise ValueError(f"Could not find the key {key} in {d}.")
d = d[key]
return d


def replace_placeholder(
service: str,
original_dict: str,
placeholder: str,
replacement: str,
) -> str:
"""
Replace a placeholder in the original string and handle the specific logic for the sagemaker service and idempotence token.
"""
temp_dict = original_dict.replace(f"{{{placeholder}}}", replacement)
if service == "sagemaker" and placeholder in [
"inputs.idempotence_token",
"idempotence_token",
]:
if len(temp_dict) > 63:
truncated_token = replacement[: 63 - len(original_dict.replace(f"{{{placeholder}}}", ""))]
return original_dict.replace(f"{{{placeholder}}}", truncated_token)
else:
return temp_dict
return temp_dict


def update_dict_fn(
service: str,
original_dict: Any,
update_dict: Dict[str, Any],
idempotence_token: Optional[str] = None,
Expand All @@ -63,6 +97,7 @@ def update_dict_fn(
and update_dict is {"endpoint_config_name": "my-endpoint-config"},
then the result will be {"EndpointConfigName": "my-endpoint-config"}.

:param service: The AWS service to use
:param original_dict: The dictionary to update (in place)
:param update_dict: The dictionary to use for updating
:param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic
Expand All @@ -71,55 +106,27 @@ def update_dict_fn(
if original_dict is None:
return None

# If the original value is a string and contains placeholder curly braces
if isinstance(original_dict, str):
if "{" in original_dict and "}" in original_dict:
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
# Check if there are nested keys
if "." in match:
# Create a copy of update_dict
update_dict_copy = update_dict.copy()

# Fetch keys from the original_dict
keys = match.split(".")

# Get value from the nested dictionary
for key in keys:
try:
update_dict_copy = update_dict_copy[key]
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

if f"{{{match}}}" == original_dict:
# If there's only one match, it needn't always be a string, so not replacing the original dict.
return update_dict_copy
else:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
elif match == "idempotence_token" and idempotence_token:
temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token)
if len(temp_dict) > 63:
truncated_idempotence_token = idempotence_token[
: (63 - len(original_dict.replace("{idempotence_token}", "")))
]
original_dict = original_dict.replace(f"{{{match}}}", truncated_idempotence_token)
else:
original_dict = temp_dict

# If the string does not contain placeholders or if there are multiple placeholders, return the original dict.
if isinstance(original_dict, str) and "{" in original_dict and "}" in original_dict:
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
if "." in match:
keys = match.split(".")
nested_value = get_nested_value(update_dict, keys)
if f"{{{match}}}" == original_dict:
return nested_value
else:
original_dict = replace_placeholder(service, original_dict, match, nested_value)
elif match == "idempotence_token" and idempotence_token:
original_dict = replace_placeholder(service, original_dict, match, idempotence_token)
return original_dict

# If the original value is a list, recursively update each element in the list
if isinstance(original_dict, list):
return [update_dict_fn(item, update_dict, idempotence_token) for item in original_dict]
return [update_dict_fn(service, item, update_dict, idempotence_token) for item in original_dict]

# If the original value is a dictionary, recursively update each key-value pair
if isinstance(original_dict, dict):
for key, value in original_dict.items():
original_dict[key] = update_dict_fn(value, update_dict, idempotence_token)
original_dict[key] = update_dict_fn(service, value, update_dict, idempotence_token)

# Return the updated original dict
return original_dict


Expand Down Expand Up @@ -192,13 +199,13 @@ async def _call(
}
args["images"] = images

updated_config = update_dict_fn(config, args)
updated_config = update_dict_fn(self._service, config, args)

hash = ""
if "idempotence_token" in str(updated_config):
# compute hash of the config
hash = xxhash.xxh64(sorted_dict_str(updated_config)).hexdigest()
updated_config = update_dict_fn(updated_config, args, idempotence_token=hash)
updated_config = update_dict_fn(self._service, updated_config, args, idempotence_token=hash)

# Asynchronous Boto3 session
session = aioboto3.Session()
Expand Down
131 changes: 129 additions & 2 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_inputs():
)

result = update_dict_fn(
service="s3",
original_dict=original_dict,
update_dict={"inputs": literal_map_string_repr(inputs)},
)
Expand All @@ -74,14 +75,16 @@ def test_container():
original_dict = {"a": "{images.primary_container_image}"}
images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"}

result = update_dict_fn(original_dict=original_dict, update_dict={"images": images})
result = update_dict_fn(
service="sagemaker", original_dict=original_dict, update_dict={"images": images}
)

assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"}


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call(mock_session):
async def test_call_with_no_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
Expand Down Expand Up @@ -118,3 +121,127 @@ async def test_call(mock_session):

assert result == mock_method.return_value
assert idempotence_token == ""


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_model

config = {
"ModelName": "{inputs.model_name}-{idempotence_token}",
"PrimaryContainer": {
"Image": "{images.image}",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{"model_name": "xgboost", "region": "us-west-2"},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_model",
config=config,
inputs=inputs,
images={"image": triton_image_uri(version="21.08")},
)

mock_method.assert_called_with(
ModelName="xgboost-23dba5d7c5aa79a8",
PrimaryContainer={
"Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
)

assert result == mock_method.return_value
assert idempotence_token == "23dba5d7c5aa79a8"


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_truncated_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_model

config = {
"ModelName": "{inputs.model_name}-{idempotence_token}",
"PrimaryContainer": {
"Image": "{images.image}",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{
"model_name": "xgboost-random-string-1234567890123456789012345678", # length=50
"region": "us-west-2",
},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_model",
config=config,
inputs=inputs,
images={"image": triton_image_uri(version="21.08")},
)

mock_method.assert_called_with(
ModelName="xgboost-random-string-1234567890123456789012345678-432aa64034f3",
PrimaryContainer={
"Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
)

assert result == mock_method.return_value
assert idempotence_token == "432aa64034f37edb"


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_truncated_idempotence_token_as_input(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_endpoint

config = {
"EndpointName": "{inputs.endpoint_name}-{idempotence_token}",
"EndpointConfigName": "{inputs.endpoint_config_name}-{inputs.idempotence_token}",
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{
"endpoint_name": "xgboost",
"endpoint_config_name": "xgboost-random-string-1234567890123456789012345678", # length=50
"idempotence_token": "432aa64034f37edb",
"region": "us-west-2",
},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_endpoint",
config=config,
inputs=inputs,
)

mock_method.assert_called_with(
EndpointName="xgboost-ce735d6a183643f1",
EndpointConfigName="xgboost-random-string-1234567890123456789012345678-432aa64034f3",
)

assert result == mock_method.return_value
assert idempotence_token == "ce735d6a183643f1"
Loading