diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 27327402ebdcf..ae19d52d88ff1 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type, Union, overload from hubble.executor.helper import replace_secret_of_hub_uri -from hubble.executor.hubio import HubIO from rich import print from rich.panel import Panel @@ -54,7 +53,6 @@ from jina.orchestrate.pods.factory import PodFactory from jina.parsers import set_deployment_parser, set_gateway_parser from jina.parsers.helper import _update_gateway_args -from jina.serve.networking import GrpcConnectionPool from jina.serve.networking.utils import host_is_local, in_docker WRAPPED_SLICE_BASE = r'\[[-\d:]+\]' @@ -103,7 +101,7 @@ def _call_add_voters(leader, voters, replica_ids, name, event_signal=None): logger.success( f'Replica-{str(replica_id)} successfully added as voter with address {voter_address} to leader at {leader}' ) - logger.debug(f'Adding voters to leader finished') + logger.debug('Adding voters to leader finished') if event_signal: event_signal.set() @@ -138,7 +136,7 @@ def _add_voter_to_leader(self): voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]] replica_ids = [pod.args.replica_id for pod in self._pods[1:]] event_signal = multiprocessing.Event() - self.logger.debug(f'Starting process to call Add Voters') + self.logger.debug('Starting process to call Add Voters') process = multiprocessing.Process( target=_call_add_voters, kwargs={ @@ -159,19 +157,19 @@ def _add_voter_to_leader(self): else: time.sleep(1.0) if properly_closed: - self.logger.debug(f'Add Voters process finished') + self.logger.debug('Add Voters process finished') process.terminate() else: - self.logger.error(f'Add Voters process did not finish successfully') + self.logger.error('Add Voters process did not finish successfully') process.kill() - self.logger.debug(f'Add Voters process finished') + self.logger.debug('Add Voters process finished') async def _async_add_voter_to_leader(self): leader_address = f'{self._pods[0].runtime_ctrl_address}' voter_addresses = [pod.runtime_ctrl_address for pod in self._pods[1:]] replica_ids = [pod.args.replica_id for pod in self._pods[1:]] event_signal = multiprocessing.Event() - self.logger.debug(f'Starting process to call Add Voters') + self.logger.debug('Starting process to call Add Voters') process = multiprocessing.Process( target=_call_add_voters, kwargs={ @@ -192,10 +190,10 @@ async def _async_add_voter_to_leader(self): else: await asyncio.sleep(1.0) if properly_closed: - self.logger.debug(f'Add Voters process finished') + self.logger.debug('Add Voters process finished') process.terminate() else: - self.logger.error(f'Add Voters process did not finish successfully') + self.logger.error('Add Voters process did not finish successfully') process.kill() @property @@ -214,23 +212,23 @@ def join(self): pod.join() def wait_start_success(self): - self.logger.debug(f'Waiting for ReplicaSet to start successfully') + self.logger.debug('Waiting for ReplicaSet to start successfully') for pod in self._pods: pod.wait_start_success() # should this be done only when the cluster is started ? if self._pods[0].args.stateful: self._add_voter_to_leader() - self.logger.debug(f'ReplicaSet started successfully') + self.logger.debug('ReplicaSet started successfully') async def async_wait_start_success(self): - self.logger.debug(f'Waiting for ReplicaSet to start successfully') + self.logger.debug('Waiting for ReplicaSet to start successfully') await asyncio.gather( *[pod.async_wait_start_success() for pod in self._pods] ) # should this be done only when the cluster is started ? if self._pods[0].args.stateful: await self._async_add_voter_to_leader() - self.logger.debug(f'ReplicaSet started successfully') + self.logger.debug('ReplicaSet started successfully') def __enter__(self): for _args in self.args: @@ -481,16 +479,19 @@ def __init__( if self.args.provider == ProviderType.SAGEMAKER: if self._gateway_kwargs.get('port', 0) == 8080: raise ValueError( - f'Port 8080 is reserved for Sagemaker deployment. Please use another port' + 'Port 8080 is reserved for Sagemaker deployment. ' + 'Please use another port' ) if self.args.port != [8080]: warnings.warn( - f'Port is changed to 8080 for Sagemaker deployment. Port {self.args.port} is ignored' + 'Port is changed to 8080 for Sagemaker deployment. ' + f'Port {self.args.port} is ignored' ) self.args.port = [8080] if self.args.protocol != [ProtocolType.HTTP]: warnings.warn( - f'Protocol is changed to HTTP for Sagemaker deployment. Protocol {self.args.protocol} is ignored' + 'Protocol is changed to HTTP for Sagemaker deployment. ' + f'Protocol {self.args.protocol} is ignored' ) self.args.protocol = [ProtocolType.HTTP] if self._include_gateway and ProtocolType.HTTP in self.args.protocol: @@ -529,10 +530,10 @@ def __init__( if self.args.stateful and (is_windows_os or (is_mac_os and is_37)): if is_windows_os: - raise RuntimeError(f'Stateful feature is not available on Windows') + raise RuntimeError('Stateful feature is not available on Windows') if is_mac_os: raise RuntimeError( - f'Stateful feature when running on MacOS requires Python3.8 or newer version' + 'Stateful feature when running on MacOS requires Python3.8 or newer version' ) if self.args.stateful and ( ProtocolType.WEBSOCKET in self.args.protocol @@ -805,7 +806,7 @@ def _copy_to_head_args(args: Namespace) -> Namespace: if args.name: _head_args.name = f'{args.name}/head' else: - _head_args.name = f'head' + _head_args.name = 'head' return _head_args @@ -1209,7 +1210,7 @@ async def async_wait_start_success(self) -> None: coros.append(self.shards[shard_id].async_wait_start_success()) await asyncio.gather(*coros) - self.logger.debug(f'Deployment started successfully') + self.logger.debug('Deployment started successfully') except: self.close() raise @@ -1374,7 +1375,7 @@ def _set_pod_args(self) -> Dict[int, List[Namespace]]: peer_ports = peer_ports_all_shards.get(str(shard_id), []) if len(peer_ports) > 0 and len(peer_ports) != replicas: raise ValueError( - f'peer-ports argument does not match number of replicas, it will be ignored' + 'peer-ports argument does not match number of replicas, it will be ignored' ) elif len(peer_ports) == 0: peer_ports = [random_port() for _ in range(replicas)] @@ -1506,12 +1507,12 @@ def _parse_base_deployment_args(self, args): if self.args.stateful and self.args.replicas in [1, 2]: self.logger.debug( - f'Stateful Executor is not recommended to be used less than 3 replicas' + 'Stateful Executor is not recommended to be used less than 3 replicas' ) if self.args.stateful and self.args.workspace is None: raise ValueError( - f'Stateful Executors need to be provided `workspace` when used in a Deployment' + 'Stateful Executors need to be provided `workspace` when used in a Deployment' ) # a gateway has no heads and uses @@ -1568,7 +1569,7 @@ def _mermaid_str(self) -> List[str]: mermaid_graph = [] secret = '<secret>' if self.role != DeploymentRoleType.GATEWAY and not self.external: - mermaid_graph = [f'subgraph {self.name};', f'\ndirection LR;\n'] + mermaid_graph = [f'subgraph {self.name};', '\ndirection LR;\n'] uses_before_name = ( self.uses_before_args.name @@ -1596,7 +1597,7 @@ def _mermaid_str(self) -> List[str]: shard_names.append(shard_name) shard_mermaid_graph = [ f'subgraph {shard_name};', - f'\ndirection TB;\n', + '\ndirection TB;\n', ] names = [ args.name for args in pod_args diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py index f095b8cfc9000..c2fc0df8de0d2 100644 --- a/jina/serve/runtimes/worker/http_sagemaker_app.py +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -29,15 +29,14 @@ def get_fastapi_app( :return: fastapi app """ with ImportExtensions(required=True): - from fastapi import FastAPI, Response, HTTPException import pydantic + from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware + from pydantic import BaseModel, Field + from pydantic.config import BaseConfig, inherit_config import os - from pydantic import BaseModel, Field - from pydantic.config import BaseConfig, inherit_config - from jina.proto import jina_pb2 from jina.serve.runtimes.gateway.models import _to_camel_case @@ -83,12 +82,25 @@ def add_post_route( path=f'/{endpoint_path.strip("/")}', methods=['POST'], summary=f'Endpoint {endpoint_path}', - response_model=output_model, + response_model=Union[output_model, List[output_model]], response_class=DocArrayResponse, ) - @app.api_route(**app_kwargs) - async def post(body: input_model, response: Response): + def is_valid_csv(content: str) -> bool: + import csv + from io import StringIO + + try: + f = StringIO(content) + reader = csv.DictReader(f) + for _ in reader: + pass + + return True + except Exception: + return False + + async def process(body) -> output_model: req = DataRequest() if body.header is not None: req.header.request_id = body.header.request_id @@ -96,13 +108,12 @@ async def post(body: input_model, response: Response): if body.parameters is not None: req.parameters = body.parameters req.header.exec_endpoint = endpoint_path + req.document_array_cls = DocList[input_doc_model] data = body.data if isinstance(data, list): - req.document_array_cls = DocList[input_doc_model] req.data.docs = DocList[input_doc_list_model](data) else: - req.document_array_cls = DocList[input_doc_model] req.data.docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -115,6 +126,48 @@ async def post(body: input_model, response: Response): else: return output_model(data=resp.docs, parameters=resp.parameters) + @app.api_route(**app_kwargs) + async def post(request: Request): + content_type = request.headers.get('content-type') + if content_type == 'application/json': + json_body = await request.json() + return await process(input_model(**json_body)) + + elif content_type in ('text/csv', 'application/csv'): + bytes_body = await request.body() + csv_body = bytes_body.decode('utf-8') + if not is_valid_csv(csv_body): + raise HTTPException( + status_code=400, + detail='Invalid CSV input. Please check your input.', + ) + + # NOTE: Sagemaker only supports csv files without header, so we enforce + # the header by getting the field names from the input model. + # This will also enforce the order of the fields in the csv file. + # This also means, all fields in the input model must be present in the + # csv file including the optional ones. + field_names = [f for f in input_doc_list_model.__fields__] + data = [] + for line in csv_body.splitlines(): + fields = line.split(',') + if len(fields) != len(field_names): + raise HTTPException( + status_code=400, + detail=f'Invalid CSV format. Line {fields} doesn\'t match ' + f'the expected field order {field_names}.', + ) + data.append(input_doc_list_model(**dict(zip(field_names, fields)))) + + return await process(input_model(data=data)) + + else: + raise HTTPException( + status_code=400, + detail=f'Invalid content-type: {content_type}. ' + f'Please use either application/json or text/csv.', + ) + for endpoint, input_output_map in request_models_map.items(): if endpoint != '_jina_dry_run_': input_doc_model = input_output_map['input']['model'] diff --git a/tests/integration/docarray_v2/sagemaker/Dockerfile b/tests/integration/docarray_v2/sagemaker/Dockerfile new file mode 100644 index 0000000000000..b04f20b347f45 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/Dockerfile @@ -0,0 +1,7 @@ +FROM jinaai/jina:test-pip + +COPY . /executor_root/ + +WORKDIR /executor_root/SampleExecutor + +ENTRYPOINT ["jina", "executor", "--uses", "config.yml"] diff --git a/tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py b/tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py index 4bccb70388561..7e86906f9873e 100644 --- a/tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py +++ b/tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py @@ -24,5 +24,7 @@ class SampleExecutor(Executor): def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: ret = [] for doc in docs: - ret.append(EmbeddingResponseModel(embeddings=np.random.random((1, 64)))) + ret.append( + EmbeddingResponseModel(id=doc.id, embeddings=np.random.random((1, 64))) + ) return DocList[EmbeddingResponseModel](ret) diff --git a/tests/integration/docarray_v2/sagemaker/invalid_input.csv b/tests/integration/docarray_v2/sagemaker/invalid_input.csv new file mode 100644 index 0000000000000..514f99c8a0fc9 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/invalid_input.csv @@ -0,0 +1,10 @@ +abcd +efgh +ijkl +mnop +qrst +uvwx +yzab +cdef +ghij +klmn \ No newline at end of file diff --git a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py index 2d1d0c6d88cb3..af15d241a27d8 100644 --- a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py +++ b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py @@ -1,13 +1,31 @@ import os +import time from contextlib import AbstractContextManager import pytest import requests from jina import Deployment +from jina.helper import random_port from jina.orchestrate.pods import Pod from jina.parsers import set_pod_parser +cur_dir = os.path.dirname(os.path.abspath(__file__)) +sagemaker_port = 8080 + + +@pytest.fixture +def replica_docker_image_built(): + import docker + + client = docker.from_env() + client.images.build(path=cur_dir, tag='sampler-executor') + client.close() + yield + time.sleep(2) + client = docker.from_env() + client.containers.prune() + class chdir(AbstractContextManager): def __init__(self, path): @@ -22,7 +40,7 @@ def __exit__(self, *excinfo): os.chdir(self._old_cwd.pop()) -def test_provider_sagemaker_pod(): +def test_provider_sagemaker_pod_inference(): with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): args, _ = set_pod_parser().parse_known_args( [ @@ -34,17 +52,15 @@ def test_provider_sagemaker_pod(): ] ) with Pod(args): - # provider=sagemaker would set the port to 8080 - port = 8080 # Test the `GET /ping` endpoint (added by jina for sagemaker) - resp = requests.get(f'http://localhost:{port}/ping') + resp = requests.get(f'http://localhost:{sagemaker_port}/ping') assert resp.status_code == 200 assert resp.json() == {} - # Test the `POST /invocations` endpoint + # Test the `POST /invocations` endpoint for inference # Note: this endpoint is not implemented in the sample executor resp = requests.post( - f'http://localhost:{port}/invocations', + f'http://localhost:{sagemaker_port}/invocations', json={ 'data': [ {'text': 'hello world'}, @@ -57,10 +73,77 @@ def test_provider_sagemaker_pod(): assert len(resp_json['data'][0]['embeddings'][0]) == 64 -def test_provider_sagemaker_deployment(): +def test_provider_sagemaker_pod_batch_transform_valid(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + 'config.yml', + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test `POST /invocations` endpoint for batch-transform with valid input + with open( + os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + resp = requests.post( + f'http://localhost:{sagemaker_port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['data']) == 10 + for d in resp_json['data']: + assert len(d['embeddings'][0]) == 64 + + +def test_provider_sagemaker_pod_batch_transform_invalid(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + 'config.yml', + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test `POST /invocations` endpoint for batch-transform with invalid input + with open( + os.path.join(os.path.dirname(__file__), 'invalid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + resp = requests.post( + f'http://localhost:{sagemaker_port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert resp.status_code == 400 + assert ( + resp.json()['detail'] + == "Invalid CSV format. Line ['abcd'] doesn't match the expected field " + "order ['id', 'text']." + ) + + +def test_provider_sagemaker_deployment_inference(): with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): - dep_port = 12345 - with Deployment(uses='config.yml', provider='sagemaker', port=dep_port) as dep: + dep_port = random_port() + with Deployment(uses='config.yml', provider='sagemaker', port=dep_port): # Test the `GET /ping` endpoint (added by jina for sagemaker) rsp = requests.get(f'http://localhost:{dep_port}/ping') assert rsp.status_code == 200 @@ -82,8 +165,62 @@ def test_provider_sagemaker_deployment(): assert len(resp_json['data'][0]['embeddings'][0]) == 64 +def test_provider_sagemaker_deployment_inference_docker(replica_docker_image_built): + dep_port = random_port() + with Deployment( + uses='docker://sampler-executor', provider='sagemaker', port=dep_port + ): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + rsp = requests.get(f'http://localhost:{dep_port}/ping') + assert rsp.status_code == 200 + assert rsp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + rsp = requests.post( + f'http://localhost:{dep_port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +@pytest.mark.skip('Sagemaker with Deployment for batch-transform is not supported yet') +def test_provider_sagemaker_deployment_batch(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + dep_port = random_port() + with Deployment(uses='config.yml', provider='sagemaker', port=dep_port): + # Test the `POST /invocations` endpoint for batch-transform + with open( + os.path.join(os.path.dirname(__file__), 'valid_input.csv'), 'r' + ) as f: + csv_data = f.read() + + rsp = requests.post( + f'http://localhost:{dep_port}/invocations', + headers={ + 'accept': 'application/json', + 'content-type': 'text/csv', + }, + data=csv_data, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 10 + for d in resp_json['data']: + assert len(d['embeddings'][0]) == 64 + + def test_provider_sagemaker_deployment_wrong_port(): + # Sagemaker executor would start on 8080. + # If we use the same port for deployment, it should raise an error. with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): with pytest.raises(ValueError): - with Deployment(uses='config.yml', provider='sagemaker', port=8080) as dep: + with Deployment(uses='config.yml', provider='sagemaker', port=8080): pass diff --git a/tests/integration/docarray_v2/sagemaker/valid_input.csv b/tests/integration/docarray_v2/sagemaker/valid_input.csv new file mode 100644 index 0000000000000..03bdaf5ff576e --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/valid_input.csv @@ -0,0 +1,10 @@ +1,abcd +2,efgh +3,ijkl +4,mnop +5,qrst +6,uvwx +7,yzab +8,cdef +9,ghij +10,klmn \ No newline at end of file diff --git a/tests/integration/docarray_v2/test_parameters_as_pydantic.py b/tests/integration/docarray_v2/test_parameters_as_pydantic.py index 60ea26eabc96c..94045b5a64e83 100644 --- a/tests/integration/docarray_v2/test_parameters_as_pydantic.py +++ b/tests/integration/docarray_v2/test_parameters_as_pydantic.py @@ -1,11 +1,13 @@ -import pytest from typing import Dict -from jina import Flow, Deployment, Executor, requests -from docarray import DocList, BaseDoc + +import pytest +from docarray import BaseDoc, DocList from docarray.documents import TextDoc -from jina.helper import random_port from pydantic import BaseModel +from jina import Deployment, Executor, Flow, requests +from jina.helper import random_port + @pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) @pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) @@ -20,7 +22,9 @@ class Parameters(BaseModel): class FooParameterExecutor(Executor): @requests(on='/hello') - def foo(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[TextDoc]: + def foo( + self, docs: DocList[TextDoc], parameters: Parameters, **kwargs + ) -> DocList[TextDoc]: for doc in docs: doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' @@ -33,7 +37,11 @@ def bar(self, doc: TextDoc, parameters: Parameters, **kwargs) -> TextDoc: else: ctxt_mgr = Deployment(protocol=protocol, uses=FooParameterExecutor) - params_to_send = {'param': 'value'} if parameters_in_client == 'dict' else Parameters(param='value') + params_to_send = ( + {'param': 'value'} + if parameters_in_client == 'dict' + else Parameters(param='value') + ) with ctxt_mgr: ret = ctxt_mgr.post( on='/hello', @@ -52,23 +60,31 @@ def bar(self, doc: TextDoc, parameters: Parameters, **kwargs) -> TextDoc: assert ret[0].text == 'Processed by bar with param: value and num: 5' if protocol == 'http': import requests as global_requests + for endpoint in {'hello', 'hello_single'}: processed_by = 'foo' if endpoint == 'hello' else 'bar' url = f'http://localhost:{ctxt_mgr.port}/{endpoint}' myobj = {'data': {'text': ''}, 'parameters': {'param': 'value'}} resp = global_requests.post(url, json=myobj) resp_json = resp.json() - assert resp_json['data'][0]['text'] == f'Processed by {processed_by} with param: value and num: 5' + assert ( + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: value and num: 5' + ) myobj = {'data': [{'text': ''}], 'parameters': {'param': 'value'}} resp = global_requests.post(url, json=myobj) resp_json = resp.json() - assert resp_json['data'][0]['text'] == f'Processed by {processed_by} with param: value and num: 5' + assert ( + resp_json['data'][0]['text'] + == f'Processed by {processed_by} with param: value and num: 5' + ) @pytest.mark.parametrize('protocol', ['http', 'websocket', 'grpc']) @pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) def test_parameters_invalid(protocol, ctxt_manager): - if ctxt_manager == 'deployment' and protocol == 'websocket': + # TODO: websocket test is failing with flow as well + if protocol == 'websocket': return class Parameters(BaseModel): @@ -77,7 +93,9 @@ class Parameters(BaseModel): class FooInvalidParameterExecutor(Executor): @requests(on='/hello') - def foo(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[TextDoc]: + def foo( + self, docs: DocList[TextDoc], parameters: Parameters, **kwargs + ) -> DocList[TextDoc]: for doc in docs: doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' @@ -112,7 +130,9 @@ class ParametersFirst(BaseModel): class Exec1Chain(Executor): @requests(on='/bar') - def bar(self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs) -> DocList[Output1]: + def bar( + self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs + ) -> DocList[Output1]: docs_return = DocList[Output1]( [Output1(price=5 * parameters.mult) for _ in range(len(docs))] ) @@ -122,16 +142,18 @@ class Exec2Chain(Executor): @requests(on='/bar') def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]: docs_return = DocList[Output2]( - [ - Output2(a=f'final price {docs[0].price}') - for _ in range(len(docs)) - ] + [Output2(a=f'final price {docs[0].price}') for _ in range(len(docs))] ) return docs_return f = Flow(protocol=protocol).add(uses=Exec1Chain).add(uses=Exec2Chain) with f: - docs = f.post(on='/bar', inputs=Input1(text='ignored'), parameters={'mult': 10}, return_type=DocList[Output2]) + docs = f.post( + on='/bar', + inputs=Input1(text='ignored'), + parameters={'mult': 10}, + return_type=DocList[Output2], + ) assert docs[0].a == 'final price 50' @@ -152,14 +174,14 @@ class ParametersSecond(BaseModel): class Exec1Chain(Executor): @requests(on='/bar') def bar(self, docs: DocList[Input1], **kwargs) -> DocList[Output1]: - docs_return = DocList[Output1]( - [Output1(price=5) for _ in range(len(docs))] - ) + docs_return = DocList[Output1]([Output1(price=5) for _ in range(len(docs))]) return docs_return class Exec2Chain(Executor): @requests(on='/bar') - def bar(self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs) -> DocList[Output2]: + def bar( + self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs + ) -> DocList[Output2]: docs_return = DocList[Output2]( [ Output2(a=f'final price {docs[0].price * parameters.mult}') @@ -170,7 +192,12 @@ def bar(self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs) -> f = Flow(protocol=protocol).add(uses=Exec1Chain).add(uses=Exec2Chain) with f: - docs = f.post(on='/bar', inputs=Input1(text='ignored'), parameters={'mult': 10}, return_type=DocList[Output2]) + docs = f.post( + on='/bar', + inputs=Input1(text='ignored'), + parameters={'mult': 10}, + return_type=DocList[Output2], + ) assert docs[0].a == 'final price 50' @@ -179,27 +206,33 @@ def bar(self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs) -> def test_openai(ctxt_manager, include_gateway): if ctxt_manager == 'flow' and include_gateway: return - import string import random + import string random_example = ''.join(random.choices(string.ascii_letters, k=10)) random_description = ''.join(random.choices(string.ascii_letters, k=10)) - from pydantic.fields import Field from pydantic import BaseModel + from pydantic.fields import Field + class MyDocWithExample(BaseDoc): """This test should be in description""" + t: str = Field(examples=[random_example], description=random_description) + class Config: title: str = 'MyDocWithExampleTitle' schema_extra: Dict = {'extra_key': 'extra_value'} class MyConfigParam(BaseModel): """Configuration for Executor endpoint""" + param1: int = Field(description='batch size', example=256) class MyExecDocWithExample(Executor): @requests - def foo(self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs) -> DocList[MyDocWithExample]: + def foo( + self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs + ) -> DocList[MyDocWithExample]: pass port = random_port() @@ -207,10 +240,16 @@ def foo(self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwar if ctxt_manager == 'flow': ctxt = Flow(protocol='http', port=port).add(uses=MyExecDocWithExample) else: - ctxt = Deployment(uses=MyExecDocWithExample, protocol='http', port=port, include_gateway=include_gateway) + ctxt = Deployment( + uses=MyExecDocWithExample, + protocol='http', + port=port, + include_gateway=include_gateway, + ) with ctxt: import requests as general_requests + resp = general_requests.get(f'http://localhost:{port}/openapi.json') resp_str = str(resp.json()) assert random_example in resp_str @@ -222,4 +261,3 @@ def foo(self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwar assert 'Configuration for Executor endpoint' in resp_str assert 'batch size' in resp_str assert '256' in resp_str -