From f51950a95277723b776730689e2f05123d0909cc Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 25 Sep 2023 19:31:54 +0530 Subject: [PATCH] feat(sagemaker): support batch-transform --- .../docarray_v2/sagemaker/invalid_input.csv | 10 ++++ .../docarray_v2/sagemaker/test_sagemaker.py | 48 +++++++++++++++++-- .../sagemaker/{input.csv => valid_input.csv} | 0 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 tests/integration/docarray_v2/sagemaker/invalid_input.csv rename tests/integration/docarray_v2/sagemaker/{input.csv => valid_input.csv} (100%) diff --git a/tests/integration/docarray_v2/sagemaker/invalid_input.csv b/tests/integration/docarray_v2/sagemaker/invalid_input.csv new file mode 100644 index 0000000000000..514f99c8a0fc9 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/invalid_input.csv @@ -0,0 +1,10 @@ +abcd +efgh +ijkl +mnop +qrst +uvwx +yzab +cdef +ghij +klmn \ No newline at end of file diff --git a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py index bddd943fa31ef..1ac481a48e600 100644 --- a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py +++ b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py @@ -57,7 +57,7 @@ def test_provider_sagemaker_pod_inference(): assert len(resp_json['data'][0]['embeddings'][0]) == 64 -def test_provider_sagemaker_pod_batch_transform(): +def test_provider_sagemaker_pod_batch_transform_valid(): with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): args, _ = set_pod_parser().parse_known_args( [ @@ -71,8 +71,10 @@ def test_provider_sagemaker_pod_batch_transform(): with Pod(args): # provider=sagemaker would set the port to 8080 port = 8080 - # Test the `POST /invocations` endpoint for batch-transform - with open(os.path.join(os.path.dirname(__file__), 'input.csv'), 'r') as f: + # Test `POST /invocations` endpoint for batch-transform with valid input + with open( + os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' + ) as f: csv_data = f.read() resp = requests.post( @@ -90,6 +92,42 @@ def test_provider_sagemaker_pod_batch_transform(): assert len(d['embeddings'][0]) == 64 +def test_provider_sagemaker_pod_batch_transform_invalid(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + 'config.yml', + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # provider=sagemaker would set the port to 8080 + port = 8080 + # Test `POST /invocations` endpoint for batch-transform with invalid input + with open( + os.path.join(os.path.dirname(__file__), 'invalid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + resp = requests.post( + f'http://localhost:{port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert resp.status_code == 400 + assert ( + resp.json()['detail'] + == "Invalid CSV format. Line ['abcd'] doesn't match the expected field " + "order ['id', 'text']." + ) + + def test_provider_sagemaker_deployment_inference(): with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): dep_port = 12345 @@ -121,7 +159,9 @@ def test_provider_sagemaker_deployment_batch(): dep_port = 12345 with Deployment(uses='config.yml', provider='sagemaker', port=dep_port): # Test the `POST /invocations` endpoint for batch-transform - with open(os.path.join(os.path.dirname(__file__), 'input.csv'), 'r') as f: + with open( + os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' + ) as f: csv_data = f.read() rsp = requests.post( diff --git a/tests/integration/docarray_v2/sagemaker/input.csv b/tests/integration/docarray_v2/sagemaker/valid_input.csv similarity index 100% rename from tests/integration/docarray_v2/sagemaker/input.csv rename to tests/integration/docarray_v2/sagemaker/valid_input.csv