From 86751d2f1b854fa2165c8bef78f09afc35c0e2d2 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 24 Jul 2024 16:26:33 +0530 Subject: [PATCH] add tests Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_mixin.py | 1 + .../tests/test_boto3_mixin.py | 131 +++++++++++++++++- 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 0bf4e71dbb..b6602087c1 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -97,6 +97,7 @@ def update_dict_fn( and update_dict is {"endpoint_config_name": "my-endpoint-config"}, then the result will be {"EndpointConfigName": "my-endpoint-config"}. + :param service: The AWS service to use :param original_dict: The dictionary to update (in place) :param update_dict: The dictionary to use for updating :param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 60d0dd45af..304ae49a01 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -51,6 +51,7 @@ def test_inputs(): ) result = update_dict_fn( + service="s3", original_dict=original_dict, update_dict={"inputs": literal_map_string_repr(inputs)}, ) @@ -74,14 +75,16 @@ def test_container(): original_dict = {"a": "{images.primary_container_image}"} images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} - result = update_dict_fn(original_dict=original_dict, update_dict={"images": images}) + result = update_dict_fn( + service="sagemaker", original_dict=original_dict, update_dict={"images": images} + ) assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} @pytest.mark.asyncio @patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") -async def test_call(mock_session): +async def test_call_with_no_idempotence_token(mock_session): mixin = Boto3AgentMixin(service="sagemaker") mock_client = AsyncMock() @@ -118,3 +121,127 @@ async def test_call(mock_session): assert result == mock_method.return_value assert idempotence_token == "" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_idempotence_token(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_model + + config = { + "ModelName": "{inputs.model_name}-{idempotence_token}", + "PrimaryContainer": { + "Image": "{images.image}", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + {"model_name": "xgboost", "region": "us-west-2"}, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_model", + config=config, + inputs=inputs, + images={"image": triton_image_uri(version="21.08")}, + ) + + mock_method.assert_called_with( + ModelName="xgboost-23dba5d7c5aa79a8", + PrimaryContainer={ + "Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + ) + + assert result == mock_method.return_value + assert idempotence_token == "23dba5d7c5aa79a8" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_truncated_idempotence_token(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_model + + config = { + "ModelName": "{inputs.model_name}-{idempotence_token}", + "PrimaryContainer": { + "Image": "{images.image}", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + { + "model_name": "xgboost-random-string-1234567890123456789012345678", # length=50 + "region": "us-west-2", + }, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_model", + config=config, + inputs=inputs, + images={"image": triton_image_uri(version="21.08")}, + ) + + mock_method.assert_called_with( + ModelName="xgboost-random-string-1234567890123456789012345678-432aa64034f3", + PrimaryContainer={ + "Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + ) + + assert result == mock_method.return_value + assert idempotence_token == "432aa64034f37edb" + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call_with_truncated_idempotence_token_as_input(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_endpoint + + config = { + "EndpointName": "{inputs.endpoint_name}-{idempotence_token}", + "EndpointConfigName": "{inputs.endpoint_config_name}-{inputs.idempotence_token}", + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + { + "endpoint_name": "xgboost", + "endpoint_config_name": "xgboost-random-string-1234567890123456789012345678", # length=50 + "idempotence_token": "432aa64034f37edb", + "region": "us-west-2", + }, + {"model_name": str, "region": str}, + ) + + result, idempotence_token = await mixin._call( + method="create_endpoint", + config=config, + inputs=inputs, + ) + + mock_method.assert_called_with( + EndpointName="xgboost-ce735d6a183643f1", + EndpointConfigName="xgboost-random-string-1234567890123456789012345678-432aa64034f3", + ) + + assert result == mock_method.return_value + assert idempotence_token == "ce735d6a183643f1"