diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 7d5c1e4905..b6602087c1 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -52,7 +52,41 @@ def sorted_dict_str(d): } +def get_nested_value(d: Dict[str, Any], keys: list[str]) -> Any: + """ + Retrieve the nested value from a dictionary based on a list of keys. + """ + for key in keys: + if key not in d: + raise ValueError(f"Could not find the key {key} in {d}.") + d = d[key] + return d + + +def replace_placeholder( + service: str, + original_dict: str, + placeholder: str, + replacement: str, +) -> str: + """ + Replace a placeholder in the original string and handle the specific logic for the sagemaker service and idempotence token. + """ + temp_dict = original_dict.replace(f"{{{placeholder}}}", replacement) + if service == "sagemaker" and placeholder in [ + "inputs.idempotence_token", + "idempotence_token", + ]: + if len(temp_dict) > 63: + truncated_token = replacement[: 63 - len(original_dict.replace(f"{{{placeholder}}}", ""))] + return original_dict.replace(f"{{{placeholder}}}", truncated_token) + else: + return temp_dict + return temp_dict + + def update_dict_fn( + service: str, original_dict: Any, update_dict: Dict[str, Any], idempotence_token: Optional[str] = None, @@ -63,6 +97,7 @@ def update_dict_fn( and update_dict is {"endpoint_config_name": "my-endpoint-config"}, then the result will be {"EndpointConfigName": "my-endpoint-config"}. + :param service: The AWS service to use :param original_dict: The dictionary to update (in place) :param update_dict: The dictionary to use for updating :param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic @@ -71,55 +106,27 @@ def update_dict_fn( if original_dict is None: return None - # If the original value is a string and contains placeholder curly braces - if isinstance(original_dict, str): - if "{" in original_dict and "}" in original_dict: - matches = re.findall(r"\{([^}]+)\}", original_dict) - for match in matches: - # Check if there are nested keys - if "." in match: - # Create a copy of update_dict - update_dict_copy = update_dict.copy() - - # Fetch keys from the original_dict - keys = match.split(".") - - # Get value from the nested dictionary - for key in keys: - try: - update_dict_copy = update_dict_copy[key] - except Exception: - raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") - - if f"{{{match}}}" == original_dict: - # If there's only one match, it needn't always be a string, so not replacing the original dict. - return update_dict_copy - else: - # Replace the placeholder in the original_dict - original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy) - elif match == "idempotence_token" and idempotence_token: - temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token) - if len(temp_dict) > 63: - truncated_idempotence_token = idempotence_token[ - : (63 - len(original_dict.replace("{idempotence_token}", ""))) - ] - original_dict = original_dict.replace(f"{{{match}}}", truncated_idempotence_token) - else: - original_dict = temp_dict - - # If the string does not contain placeholders or if there are multiple placeholders, return the original dict. + if isinstance(original_dict, str) and "{" in original_dict and "}" in original_dict: + matches = re.findall(r"\{([^}]+)\}", original_dict) + for match in matches: + if "." in match: + keys = match.split(".") + nested_value = get_nested_value(update_dict, keys) + if f"{{{match}}}" == original_dict: + return nested_value + else: + original_dict = replace_placeholder(service, original_dict, match, nested_value) + elif match == "idempotence_token" and idempotence_token: + original_dict = replace_placeholder(service, original_dict, match, idempotence_token) return original_dict - # If the original value is a list, recursively update each element in the list if isinstance(original_dict, list): - return [update_dict_fn(item, update_dict, idempotence_token) for item in original_dict] + return [update_dict_fn(service, item, update_dict, idempotence_token) for item in original_dict] - # If the original value is a dictionary, recursively update each key-value pair if isinstance(original_dict, dict): for key, value in original_dict.items(): - original_dict[key] = update_dict_fn(value, update_dict, idempotence_token) + original_dict[key] = update_dict_fn(service, value, update_dict, idempotence_token) - # Return the updated original dict return original_dict @@ -192,13 +199,13 @@ async def _call( } args["images"] = images - updated_config = update_dict_fn(config, args) + updated_config = update_dict_fn(self._service, config, args) hash = "" if "idempotence_token" in str(updated_config): # compute hash of the config hash = xxhash.xxh64(sorted_dict_str(updated_config)).hexdigest() - updated_config = update_dict_fn(updated_config, args, idempotence_token=hash) + updated_config = update_dict_fn(self._service, updated_config, args, idempotence_token=hash) # Asynchronous Boto3 session session = aioboto3.Session() diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 60d0dd45af..304ae49a01 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -51,6 +51,7 @@ def test_inputs(): ) result = update_dict_fn( + service="s3", original_dict=original_dict, update_dict={"inputs": literal_map_string_repr(inputs)}, ) @@ -74,14 +75,16 @@ def test_container(): original_dict = {"a": "{images.primary_container_image}"} images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} - result = update_dict_fn(original_dict=original_dict, update_dict={"images": images}) + result = update_dict_fn( + service="sagemaker", original_dict=original_dict, update_dict={"images": images} + ) assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} @pytest.mark.asyncio @patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") -async def test_call(mock_session): +async def test_call_with_no_idempotence_token(mock_session): mixin = Boto3AgentMixin(service="sagemaker") mock_client = AsyncMock() @@ -118,3 +121,127 @@ async def test_call(mock_session): assert result == mock_method.return_value assert idempotence_token == "" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_idempotence_token(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_model + + config = { + "ModelName": "{inputs.model_name}-{idempotence_token}", + "PrimaryContainer": { + "Image": "{images.image}", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + {"model_name": "xgboost", "region": "us-west-2"}, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_model", + config=config, + inputs=inputs, + images={"image": triton_image_uri(version="21.08")}, + ) + + mock_method.assert_called_with( + ModelName="xgboost-23dba5d7c5aa79a8", + PrimaryContainer={ + "Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + ) + + assert result == mock_method.return_value + assert idempotence_token == "23dba5d7c5aa79a8" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_truncated_idempotence_token(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_model + + config = { + "ModelName": "{inputs.model_name}-{idempotence_token}", + "PrimaryContainer": { + "Image": "{images.image}", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + { + "model_name": "xgboost-random-string-1234567890123456789012345678", # length=50 + "region": "us-west-2", + }, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_model", + config=config, + inputs=inputs, + images={"image": triton_image_uri(version="21.08")}, + ) + + mock_method.assert_called_with( + ModelName="xgboost-random-string-1234567890123456789012345678-432aa64034f3", + PrimaryContainer={ + "Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + ) + + assert result == mock_method.return_value + assert idempotence_token == "432aa64034f37edb" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_truncated_idempotence_token_as_input(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_endpoint + + config = { + "EndpointName": "{inputs.endpoint_name}-{idempotence_token}", + "EndpointConfigName": "{inputs.endpoint_config_name}-{inputs.idempotence_token}", + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + { + "endpoint_name": "xgboost", + "endpoint_config_name": "xgboost-random-string-1234567890123456789012345678", # length=50 + "idempotence_token": "432aa64034f37edb", + "region": "us-west-2", + }, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_endpoint", + config=config, + inputs=inputs, + ) + + mock_method.assert_called_with( + EndpointName="xgboost-ce735d6a183643f1", + EndpointConfigName="xgboost-random-string-1234567890123456789012345678-432aa64034f3", + ) + + assert result == mock_method.return_value + assert idempotence_token == "ce735d6a183643f1"