Skip to content

Commit

Permalink
feat(sagemaker): support batch-transform
Browse files Browse the repository at this point in the history
  • Loading branch information
deepankarm committed Sep 25, 2023
1 parent 054785f commit f51950a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
10 changes: 10 additions & 0 deletions tests/integration/docarray_v2/sagemaker/invalid_input.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
abcd
efgh
ijkl
mnop
qrst
uvwx
yzab
cdef
ghij
klmn
48 changes: 44 additions & 4 deletions tests/integration/docarray_v2/sagemaker/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_provider_sagemaker_pod_inference():
assert len(resp_json['data'][0]['embeddings'][0]) == 64


def test_provider_sagemaker_pod_batch_transform():
def test_provider_sagemaker_pod_batch_transform_valid():
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
args, _ = set_pod_parser().parse_known_args(
[
Expand All @@ -71,8 +71,10 @@ def test_provider_sagemaker_pod_batch_transform():
with Pod(args):
# provider=sagemaker would set the port to 8080
port = 8080
# Test the `POST /invocations` endpoint for batch-transform
with open(os.path.join(os.path.dirname(__file__), 'input.csv'), 'r') as f:
# Test `POST /invocations` endpoint for batch-transform with valid input
with open(
os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r'
) as f:
csv_data = f.read()

resp = requests.post(
Expand All @@ -90,6 +92,42 @@ def test_provider_sagemaker_pod_batch_transform():
assert len(d['embeddings'][0]) == 64


def test_provider_sagemaker_pod_batch_transform_invalid():
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
args, _ = set_pod_parser().parse_known_args(
[
'--uses',
'config.yml',
'--provider',
'sagemaker',
'serve', # This is added by sagemaker
]
)
with Pod(args):
# provider=sagemaker would set the port to 8080
port = 8080
# Test `POST /invocations` endpoint for batch-transform with invalid input
with open(
os.path.join(os.path.dirname(__file__), 'invalid_input.csv'), 'r'
) as f:
csv_data = f.read()

resp = requests.post(
f'http://localhost:{port}/invocations',
headers={
'accept': 'application/json',
'content-type': 'text/csv',
},
data=csv_data,
)
assert resp.status_code == 400
assert (
resp.json()['detail']
== "Invalid CSV format. Line ['abcd'] doesn't match the expected field "
"order ['id', 'text']."
)


def test_provider_sagemaker_deployment_inference():
with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
dep_port = 12345
Expand Down Expand Up @@ -121,7 +159,9 @@ def test_provider_sagemaker_deployment_batch():
dep_port = 12345
with Deployment(uses='config.yml', provider='sagemaker', port=dep_port):
# Test the `POST /invocations` endpoint for batch-transform
with open(os.path.join(os.path.dirname(__file__), 'input.csv'), 'r') as f:
with open(
os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r'
) as f:
csv_data = f.read()

rsp = requests.post(
Expand Down

0 comments on commit f51950a

Please sign in to comment.