Skip to content

Commit

Permalink
validate idempotence token length in subsequent tasks (flyteorg#2604)
Browse files Browse the repository at this point in the history
* validate idempotence token length in subsequent tasks

Signed-off-by: Samhita Alla <[email protected]>

* remove redundant param

Signed-off-by: Samhita Alla <[email protected]>

* add tests

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored and mao3267 committed Aug 1, 2024
1 parent c775fac commit 5259b5c
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,41 @@ def sorted_dict_str(d):
}


def get_nested_value(d: Dict[str, Any], keys: list[str]) -> Any:
"""
Retrieve the nested value from a dictionary based on a list of keys.
"""
for key in keys:
if key not in d:
raise ValueError(f"Could not find the key {key} in {d}.")
d = d[key]
return d


def replace_placeholder(
service: str,
original_dict: str,
placeholder: str,
replacement: str,
) -> str:
"""
Replace a placeholder in the original string and handle the specific logic for the sagemaker service and idempotence token.
"""
temp_dict = original_dict.replace(f"{{{placeholder}}}", replacement)
if service == "sagemaker" and placeholder in [
"inputs.idempotence_token",
"idempotence_token",
]:
if len(temp_dict) > 63:
truncated_token = replacement[: 63 - len(original_dict.replace(f"{{{placeholder}}}", ""))]
return original_dict.replace(f"{{{placeholder}}}", truncated_token)
else:
return temp_dict
return temp_dict


def update_dict_fn(
service: str,
original_dict: Any,
update_dict: Dict[str, Any],
idempotence_token: Optional[str] = None,
Expand All @@ -63,6 +97,7 @@ def update_dict_fn(
and update_dict is {"endpoint_config_name": "my-endpoint-config"},
then the result will be {"EndpointConfigName": "my-endpoint-config"}.
:param service: The AWS service to use
:param original_dict: The dictionary to update (in place)
:param update_dict: The dictionary to use for updating
:param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic
Expand All @@ -71,55 +106,27 @@ def update_dict_fn(
if original_dict is None:
return None

# If the original value is a string and contains placeholder curly braces
if isinstance(original_dict, str):
if "{" in original_dict and "}" in original_dict:
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
# Check if there are nested keys
if "." in match:
# Create a copy of update_dict
update_dict_copy = update_dict.copy()

# Fetch keys from the original_dict
keys = match.split(".")

# Get value from the nested dictionary
for key in keys:
try:
update_dict_copy = update_dict_copy[key]
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

if f"{{{match}}}" == original_dict:
# If there's only one match, it needn't always be a string, so not replacing the original dict.
return update_dict_copy
else:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
elif match == "idempotence_token" and idempotence_token:
temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token)
if len(temp_dict) > 63:
truncated_idempotence_token = idempotence_token[
: (63 - len(original_dict.replace("{idempotence_token}", "")))
]
original_dict = original_dict.replace(f"{{{match}}}", truncated_idempotence_token)
else:
original_dict = temp_dict

# If the string does not contain placeholders or if there are multiple placeholders, return the original dict.
if isinstance(original_dict, str) and "{" in original_dict and "}" in original_dict:
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
if "." in match:
keys = match.split(".")
nested_value = get_nested_value(update_dict, keys)
if f"{{{match}}}" == original_dict:
return nested_value
else:
original_dict = replace_placeholder(service, original_dict, match, nested_value)
elif match == "idempotence_token" and idempotence_token:
original_dict = replace_placeholder(service, original_dict, match, idempotence_token)
return original_dict

# If the original value is a list, recursively update each element in the list
if isinstance(original_dict, list):
return [update_dict_fn(item, update_dict, idempotence_token) for item in original_dict]
return [update_dict_fn(service, item, update_dict, idempotence_token) for item in original_dict]

# If the original value is a dictionary, recursively update each key-value pair
if isinstance(original_dict, dict):
for key, value in original_dict.items():
original_dict[key] = update_dict_fn(value, update_dict, idempotence_token)
original_dict[key] = update_dict_fn(service, value, update_dict, idempotence_token)

# Return the updated original dict
return original_dict


Expand Down Expand Up @@ -192,13 +199,13 @@ async def _call(
}
args["images"] = images

updated_config = update_dict_fn(config, args)
updated_config = update_dict_fn(self._service, config, args)

hash = ""
if "idempotence_token" in str(updated_config):
# compute hash of the config
hash = xxhash.xxh64(sorted_dict_str(updated_config)).hexdigest()
updated_config = update_dict_fn(updated_config, args, idempotence_token=hash)
updated_config = update_dict_fn(self._service, updated_config, args, idempotence_token=hash)

# Asynchronous Boto3 session
session = aioboto3.Session()
Expand Down
131 changes: 129 additions & 2 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_inputs():
)

result = update_dict_fn(
service="s3",
original_dict=original_dict,
update_dict={"inputs": literal_map_string_repr(inputs)},
)
Expand All @@ -74,14 +75,16 @@ def test_container():
original_dict = {"a": "{images.primary_container_image}"}
images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"}

result = update_dict_fn(original_dict=original_dict, update_dict={"images": images})
result = update_dict_fn(
service="sagemaker", original_dict=original_dict, update_dict={"images": images}
)

assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"}


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call(mock_session):
async def test_call_with_no_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
Expand Down Expand Up @@ -118,3 +121,127 @@ async def test_call(mock_session):

assert result == mock_method.return_value
assert idempotence_token == ""


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_model

config = {
"ModelName": "{inputs.model_name}-{idempotence_token}",
"PrimaryContainer": {
"Image": "{images.image}",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{"model_name": "xgboost", "region": "us-west-2"},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_model",
config=config,
inputs=inputs,
images={"image": triton_image_uri(version="21.08")},
)

mock_method.assert_called_with(
ModelName="xgboost-23dba5d7c5aa79a8",
PrimaryContainer={
"Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
)

assert result == mock_method.return_value
assert idempotence_token == "23dba5d7c5aa79a8"


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_truncated_idempotence_token(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_model

config = {
"ModelName": "{inputs.model_name}-{idempotence_token}",
"PrimaryContainer": {
"Image": "{images.image}",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{
"model_name": "xgboost-random-string-1234567890123456789012345678", # length=50
"region": "us-west-2",
},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_model",
config=config,
inputs=inputs,
images={"image": triton_image_uri(version="21.08")},
)

mock_method.assert_called_with(
ModelName="xgboost-random-string-1234567890123456789012345678-432aa64034f3",
PrimaryContainer={
"Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3",
"ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz",
},
)

assert result == mock_method.return_value
assert idempotence_token == "432aa64034f37edb"


@pytest.mark.asyncio
@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session")
async def test_call_with_truncated_idempotence_token_as_input(mock_session):
mixin = Boto3AgentMixin(service="sagemaker")

mock_client = AsyncMock()
mock_session.return_value.client.return_value.__aenter__.return_value = mock_client
mock_method = mock_client.create_endpoint

config = {
"EndpointName": "{inputs.endpoint_name}-{idempotence_token}",
"EndpointConfigName": "{inputs.endpoint_config_name}-{inputs.idempotence_token}",
}
inputs = TypeEngine.dict_to_literal_map(
FlyteContext.current_context(),
{
"endpoint_name": "xgboost",
"endpoint_config_name": "xgboost-random-string-1234567890123456789012345678", # length=50
"idempotence_token": "432aa64034f37edb",
"region": "us-west-2",
},
{"model_name": str, "region": str},
)

result, idempotence_token = await mixin._call(
method="create_endpoint",
config=config,
inputs=inputs,
)

mock_method.assert_called_with(
EndpointName="xgboost-ce735d6a183643f1",
EndpointConfigName="xgboost-random-string-1234567890123456789012345678-432aa64034f3",
)

assert result == mock_method.return_value
assert idempotence_token == "ce735d6a183643f1"

0 comments on commit 5259b5c

Please sign in to comment.