Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sagemaker): support batch-transform #6055

Merged
merged 9 commits into from
Sep 28, 2023
53 changes: 27 additions & 26 deletions jina/orchestrate/deployments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type, Union, overload

from hubble.executor.helper import replace_secret_of_hub_uri
from hubble.executor.hubio import HubIO
from rich import print
from rich.panel import Panel

Expand Down Expand Up @@ -54,7 +53,6 @@
from jina.orchestrate.pods.factory import PodFactory
from jina.parsers import set_deployment_parser, set_gateway_parser
from jina.parsers.helper import _update_gateway_args
from jina.serve.networking import GrpcConnectionPool
from jina.serve.networking.utils import host_is_local, in_docker

WRAPPED_SLICE_BASE = r'\[[-\d:]+\]'
Expand Down Expand Up @@ -103,7 +101,7 @@ def _call_add_voters(leader, voters, replica_ids, name, event_signal=None):
logger.success(
f'Replica-{str(replica_id)} successfully added as voter with address {voter_address} to leader at {leader}'
)
logger.debug(f'Adding voters to leader finished')
logger.debug('Adding voters to leader finished')
if event_signal:
event_signal.set()

Expand Down Expand Up @@ -138,7 +136,7 @@ def _add_voter_to_leader(self):
voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]]
replica_ids = [pod.args.replica_id for pod in self._pods[1:]]
event_signal = multiprocessing.Event()
self.logger.debug(f'Starting process to call Add Voters')
self.logger.debug('Starting process to call Add Voters')
process = multiprocessing.Process(
target=_call_add_voters,
kwargs={
Expand All @@ -159,19 +157,19 @@ def _add_voter_to_leader(self):
else:
time.sleep(1.0)
if properly_closed:
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')
process.terminate()
else:
self.logger.error(f'Add Voters process did not finish successfully')
self.logger.error('Add Voters process did not finish successfully')
process.kill()
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')

async def _async_add_voter_to_leader(self):
leader_address = f'{self._pods[0].runtime_ctrl_address}'
voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]]
replica_ids = [pod.args.replica_id for pod in self._pods[1:]]
event_signal = multiprocessing.Event()
self.logger.debug(f'Starting process to call Add Voters')
self.logger.debug('Starting process to call Add Voters')
process = multiprocessing.Process(
target=_call_add_voters,
kwargs={
Expand All @@ -192,10 +190,10 @@ async def _async_add_voter_to_leader(self):
else:
await asyncio.sleep(1.0)
if properly_closed:
self.logger.debug(f'Add Voters process finished')
self.logger.debug('Add Voters process finished')
process.terminate()
else:
self.logger.error(f'Add Voters process did not finish successfully')
self.logger.error('Add Voters process did not finish successfully')
process.kill()

@property
Expand All @@ -214,23 +212,23 @@ def join(self):
pod.join()

def wait_start_success(self):
self.logger.debug(f'Waiting for ReplicaSet to start successfully')
self.logger.debug('Waiting for ReplicaSet to start successfully')
for pod in self._pods:
pod.wait_start_success()
# should this be done only when the cluster is started ?
if self._pods[0].args.stateful:
self._add_voter_to_leader()
self.logger.debug(f'ReplicaSet started successfully')
self.logger.debug('ReplicaSet started successfully')

async def async_wait_start_success(self):
self.logger.debug(f'Waiting for ReplicaSet to start successfully')
self.logger.debug('Waiting for ReplicaSet to start successfully')
await asyncio.gather(
*[pod.async_wait_start_success() for pod in self._pods]
)
# should this be done only when the cluster is started ?
if self._pods[0].args.stateful:
await self._async_add_voter_to_leader()
self.logger.debug(f'ReplicaSet started successfully')
self.logger.debug('ReplicaSet started successfully')

def __enter__(self):
for _args in self.args:
Expand Down Expand Up @@ -481,16 +479,19 @@ def __init__(
if self.args.provider == ProviderType.SAGEMAKER:
if self._gateway_kwargs.get('port', 0) == 8080:
raise ValueError(
f'Port 8080 is reserved for Sagemaker deployment. Please use another port'
'Port 8080 is reserved for Sagemaker deployment. '
'Please use another port'
)
if self.args.port != [8080]:
warnings.warn(
f'Port is changed to 8080 for Sagemaker deployment. Port {self.args.port} is ignored'
'Port is changed to 8080 for Sagemaker deployment. '
f'Port {self.args.port} is ignored'
)
self.args.port = [8080]
if self.args.protocol != [ProtocolType.HTTP]:
warnings.warn(
f'Protocol is changed to HTTP for Sagemaker deployment. Protocol {self.args.protocol} is ignored'
'Protocol is changed to HTTP for Sagemaker deployment. '
f'Protocol {self.args.protocol} is ignored'
)
self.args.protocol = [ProtocolType.HTTP]
if self._include_gateway and ProtocolType.HTTP in self.args.protocol:
Expand Down Expand Up @@ -529,10 +530,10 @@ def __init__(

if self.args.stateful and (is_windows_os or (is_mac_os and is_37)):
if is_windows_os:
raise RuntimeError(f'Stateful feature is not available on Windows')
raise RuntimeError('Stateful feature is not available on Windows')
if is_mac_os:
raise RuntimeError(
f'Stateful feature when running on MacOS requires Python3.8 or newer version'
'Stateful feature when running on MacOS requires Python3.8 or newer version'
)
if self.args.stateful and (
ProtocolType.WEBSOCKET in self.args.protocol
Expand Down Expand Up @@ -805,7 +806,7 @@ def _copy_to_head_args(args: Namespace) -> Namespace:
if args.name:
_head_args.name = f'{args.name}/head'
else:
_head_args.name = f'head'
_head_args.name = 'head'

return _head_args

Expand Down Expand Up @@ -1209,7 +1210,7 @@ async def async_wait_start_success(self) -> None:
coros.append(self.shards[shard_id].async_wait_start_success())

await asyncio.gather(*coros)
self.logger.debug(f'Deployment started successfully')
self.logger.debug('Deployment started successfully')
except:
self.close()
raise
Expand Down Expand Up @@ -1374,7 +1375,7 @@ def _set_pod_args(self) -> Dict[int, List[Namespace]]:
peer_ports = peer_ports_all_shards.get(str(shard_id), [])
if len(peer_ports) > 0 and len(peer_ports) != replicas:
raise ValueError(
f'peer-ports argument does not match number of replicas, it will be ignored'
'peer-ports argument does not match number of replicas, it will be ignored'
)
elif len(peer_ports) == 0:
peer_ports = [random_port() for _ in range(replicas)]
Expand Down Expand Up @@ -1506,12 +1507,12 @@ def _parse_base_deployment_args(self, args):

if self.args.stateful and self.args.replicas in [1, 2]:
self.logger.debug(
f'Stateful Executor is not recommended to be used less than 3 replicas'
'Stateful Executor is not recommended to be used less than 3 replicas'
)

if self.args.stateful and self.args.workspace is None:
raise ValueError(
f'Stateful Executors need to be provided `workspace` when used in a Deployment'
'Stateful Executors need to be provided `workspace` when used in a Deployment'
)

# a gateway has no heads and uses
Expand Down Expand Up @@ -1568,7 +1569,7 @@ def _mermaid_str(self) -> List[str]:
mermaid_graph = []
secret = '&ltsecret&gt'
if self.role != DeploymentRoleType.GATEWAY and not self.external:
mermaid_graph = [f'subgraph {self.name};', f'\ndirection LR;\n']
mermaid_graph = [f'subgraph {self.name};', '\ndirection LR;\n']

uses_before_name = (
self.uses_before_args.name
Expand Down Expand Up @@ -1596,7 +1597,7 @@ def _mermaid_str(self) -> List[str]:
shard_names.append(shard_name)
shard_mermaid_graph = [
f'subgraph {shard_name};',
f'\ndirection TB;\n',
'\ndirection TB;\n',
]
names = [
args.name for args in pod_args
Expand Down
71 changes: 62 additions & 9 deletions jina/serve/runtimes/worker/http_sagemaker_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ def get_fastapi_app(
:return: fastapi app
"""
with ImportExtensions(required=True):
from fastapi import FastAPI, Response, HTTPException
import pydantic
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

import os

from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

from jina.proto import jina_pb2
from jina.serve.runtimes.gateway.models import _to_camel_case

Expand Down Expand Up @@ -83,26 +82,38 @@ def add_post_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Endpoint {endpoint_path}',
response_model=output_model,
response_model=Union[output_model, List[output_model]],
response_class=DocArrayResponse,
)

@app.api_route(**app_kwargs)
async def post(body: input_model, response: Response):
def is_valid_csv(content: str) -> bool:
import csv
from io import StringIO

try:
f = StringIO(content)
reader = csv.DictReader(f)
for _ in reader:
pass

return True
except Exception:
return False

async def process(body) -> output_model:
req = DataRequest()
if body.header is not None:
req.header.request_id = body.header.request_id

if body.parameters is not None:
req.parameters = body.parameters
req.header.exec_endpoint = endpoint_path
req.document_array_cls = DocList[input_doc_model]

data = body.data
if isinstance(data, list):
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model](data)
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model]([data])
if body.header is None:
req.header.request_id = req.docs[0].id
Expand All @@ -115,6 +126,48 @@ async def post(body: input_model, response: Response):
else:
return output_model(data=resp.docs, parameters=resp.parameters)

@app.api_route(**app_kwargs)
async def post(request: Request):
content_type = request.headers.get('content-type')
if content_type == 'application/json':
json_body = await request.json()
return await process(input_model(**json_body))

elif content_type in ('text/csv', 'application/csv'):
bytes_body = await request.body()
csv_body = bytes_body.decode('utf-8')
if not is_valid_csv(csv_body):
raise HTTPException(
status_code=400,
detail='Invalid CSV input. Please check your input.',
)

# NOTE: Sagemaker only supports csv files without header, so we enforce
# the header by getting the field names from the input model.
# This will also enforce the order of the fields in the csv file.
# This also means, all fields in the input model must be present in the
# csv file including the optional ones.
field_names = [f for f in input_doc_list_model.__fields__]
deepankarm marked this conversation as resolved.
Show resolved Hide resolved
data = []
for line in csv_body.splitlines():
fields = line.split(',')
if len(fields) != len(field_names):
raise HTTPException(
status_code=400,
detail=f'Invalid CSV format. Line {fields} doesn\'t match '
f'the expected field order {field_names}.',
)
data.append(input_doc_list_model(**dict(zip(field_names, fields))))

return await process(input_model(data=data))

else:
raise HTTPException(
status_code=400,
detail=f'Invalid content-type: {content_type}. '
f'Please use either application/json or text/csv.',
)

for endpoint, input_output_map in request_models_map.items():
if endpoint != '_jina_dry_run_':
input_doc_model = input_output_map['input']['model']
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/docarray_v2/sagemaker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM jinaai/jina:test-pip

COPY . /executor_root/

WORKDIR /executor_root/SampleExecutor

ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ class SampleExecutor(Executor):
def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]:
ret = []
for doc in docs:
ret.append(EmbeddingResponseModel(embeddings=np.random.random((1, 64))))
ret.append(
EmbeddingResponseModel(id=doc.id, embeddings=np.random.random((1, 64)))
)
return DocList[EmbeddingResponseModel](ret)
10 changes: 10 additions & 0 deletions tests/integration/docarray_v2/sagemaker/invalid_input.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
abcd
efgh
ijkl
mnop
qrst
uvwx
yzab
cdef
ghij
klmn
Loading