From 6cd3311427d367cd134df062c7e2a4bc18b62a6c Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Tue, 21 Nov 2023 17:08:20 +0100 Subject: [PATCH] refactor: remove warmup (#6114) Signed-off-by: Joan Fontanals Martinez Co-authored-by: Joan Fontanals Martinez --- jina/serve/networking/__init__.py | 65 ------------ jina/serve/networking/replica_list.py | 16 --- .../runtimes/gateway/request_handling.py | 26 ----- jina/serve/runtimes/gateway/streamer.py | 32 ------ jina/serve/runtimes/head/request_handling.py | 42 -------- jina/serve/runtimes/servers/__init__.py | 1 - .../test_flow_instrumentation.py | 4 +- tests/integration/runtimes/test_warmup.py | 100 ------------------ .../serve/networking/test_replica_list.py | 4 - 9 files changed, 2 insertions(+), 288 deletions(-) delete mode 100644 tests/integration/runtimes/test_warmup.py diff --git a/jina/serve/networking/__init__.py b/jina/serve/networking/__init__.py index f6a089b43d15c..cdfd907359078 100644 --- a/jina/serve/networking/__init__.py +++ b/jina/serve/networking/__init__.py @@ -615,71 +615,6 @@ async def task_coroutine(): return task_coroutine() - async def warmup( - self, - deployment: str, - stop_event: 'threading.Event', - ): - """Executes JinaInfoRPC against the provided deployment. A single task is created for each replica connection. - :param deployment: deployment name and the replicas that needs to be warmed up. - :param stop_event: signal to indicate if an early termination of the task is required for graceful teardown. - """ - self._logger.debug(f'starting warmup task for deployment {deployment}') - - async def task_wrapper(target_warmup_responses, stub): - try: - call_result = stub.send_info_rpc(timeout=0.5) - await call_result - target_warmup_responses[stub.address] = True - except asyncio.CancelledError: - self._logger.debug(f'warmup task got cancelled') - target_warmup_responses[stub.address] = False - raise - except Exception: - target_warmup_responses[stub.address] = False - - - try: - start_time = time.time() - timeout = start_time + 60 * 5 # 5 minutes from now - warmed_up_targets = set() - replicas = self._get_all_replicas(deployment) - - while not stop_event.is_set(): - replica_warmup_responses = {} - tasks = [] - try: - for replica in replicas: - for stub in replica.warmup_stubs: - if stub.address not in warmed_up_targets: - tasks.append( - asyncio.create_task( - task_wrapper(replica_warmup_responses, stub) - ) - ) - - await asyncio.gather(*tasks, return_exceptions=True) - for target, response in replica_warmup_responses.items(): - if response: - warmed_up_targets.add(target) - - now = time.time() - if now > timeout or all(list(replica_warmup_responses.values())): - self._logger.debug( - f'completed warmup task in {now - start_time}s.' - ) - return - await asyncio.sleep(0.2) - except asyncio.CancelledError: - self._logger.debug(f'warmup task got cancelled') - if tasks: - for task in tasks: - task.cancel() - raise - except Exception as ex: - self._logger.error(f'error with warmup up task: {ex}') - return - def _get_all_replicas(self, deployment): replica_set = set() replica_set.update(self._connections.get_replicas_all_shards(deployment)) diff --git a/jina/serve/networking/replica_list.py b/jina/serve/networking/replica_list.py index ece40bc9fff19..e23a07aaaae96 100644 --- a/jina/serve/networking/replica_list.py +++ b/jina/serve/networking/replica_list.py @@ -45,9 +45,6 @@ def __init__( self.tracing_client_interceptors = tracing_client_interceptor self._deployment_name = deployment_name self.channel_options = channel_options - # a set containing all the ConnectionStubs that will be created using add_connection - # this set is not updated in reset_connection and remove_connection - self._warmup_stubs = set() async def reset_connection(self, address: str, deployment_name: str): """ @@ -90,10 +87,7 @@ def add_connection(self, address: str, deployment_name: str): stubs, channel = self._create_connection(address, deployment_name) self._address_to_channel[resolved_address] = channel self._connections.append(stubs) - # create a new set of stubs and channels for warmup to avoid - # loosing channel during remove_connection or reset_connection stubs, _ = self._create_connection(address, deployment_name) - self._warmup_stubs.add(stubs) async def remove_connection(self, address: str): """ @@ -213,13 +207,3 @@ async def close(self): self._address_to_connection_idx.clear() self._connections.clear() self._rr_counter = 0 - for stub in self._warmup_stubs: - await stub.channel.close(0.5) - self._warmup_stubs.clear() - - @property - def warmup_stubs(self): - """Return set of warmup stubs - :returns: Set of stubs. The set doesn't remove any items once added. - """ - return self._warmup_stubs diff --git a/jina/serve/runtimes/gateway/request_handling.py b/jina/serve/runtimes/gateway/request_handling.py index 1e4a47a5f2410..ec3b936c0c23e 100644 --- a/jina/serve/runtimes/gateway/request_handling.py +++ b/jina/serve/runtimes/gateway/request_handling.py @@ -30,7 +30,6 @@ def __init__( meter=None, aio_tracing_client_interceptors=None, tracing_client_interceptor=None, - works_as_load_balancer: bool = False, **kwargs, ): import json @@ -102,37 +101,12 @@ def __init__( if isinstance(addresses, Dict): servers.extend(addresses.get(ProtocolType.HTTP.to_string(), [])) self.load_balancer_servers = itertools.cycle(servers) - self.warmup_stop_event = threading.Event() - self.warmup_task = None - if not works_as_load_balancer: - try: - self.warmup_task = asyncio.create_task( - self.streamer.warmup(self.warmup_stop_event) - ) - except RuntimeError: - # when Gateway is started locally, it may not have loop - pass - - def cancel_warmup_task(self): - """Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the - task is successful or hasn't reached the max timeout. - """ - if self.warmup_task: - try: - if not self.warmup_task.done(): - self.logger.debug(f'Cancelling warmup task.') - self.warmup_stop_event.set() # this event is useless if simply cancel - self.warmup_task.cancel() - except Exception as ex: - self.logger.debug(f'exception during warmup task cancellation: {ex}') - pass async def close(self): """ Gratefully closes the object making sure all the floating requests are taken care and the connections are closed gracefully """ self.logger.debug(f'Closing Request Handler') - self.cancel_warmup_task() await self.streamer.close() self.logger.debug(f'Request Handler closed') diff --git a/jina/serve/runtimes/gateway/streamer.py b/jina/serve/runtimes/gateway/streamer.py index d024f5a07bc06..be31358d266d8 100644 --- a/jina/serve/runtimes/gateway/streamer.py +++ b/jina/serve/runtimes/gateway/streamer.py @@ -427,38 +427,6 @@ def get_streamer(): def _set_env_streamer_args(**kwargs): os.environ['JINA_STREAMER_ARGS'] = json.dumps(kwargs) - async def warmup(self, stop_event: threading.Event): - """Executes warmup task on each deployment. This forces the gateway to establish connection and open a - gRPC channel to each executor so that the first request doesn't need to experience the penalty of - eastablishing a brand new gRPC channel. - :param stop_event: signal to indicate if an early termination of the task is required for graceful teardown. - """ - self.logger.debug(f'Running GatewayRuntime warmup') - deployments = {key for key in self._executor_addresses.keys()} - - try: - deployment_warmup_tasks = [] - try: - for deployment in deployments: - deployment_warmup_tasks.append( - asyncio.create_task( - self._connection_pool.warmup( - deployment=deployment, stop_event=stop_event - ) - ) - ) - - await asyncio.gather(*deployment_warmup_tasks, return_exceptions=True) - except asyncio.CancelledError: - self.logger.debug(f'Warmup task got cancelled') - if deployment_warmup_tasks: - for task in deployment_warmup_tasks: - task.cancel() - raise - except Exception as ex: - self.logger.error(f'error with GatewayRuntime warmup up task: {ex}') - return - class _ExecutorStreamer: def __init__(self, connection_pool: GrpcConnectionPool, executor_name: str) -> None: diff --git a/jina/serve/runtimes/head/request_handling.py b/jina/serve/runtimes/head/request_handling.py index ecdb393f7b5c6..6bab920f03ca6 100644 --- a/jina/serve/runtimes/head/request_handling.py +++ b/jina/serve/runtimes/head/request_handling.py @@ -139,14 +139,6 @@ def __init__( self._executor_endpoint_mapping = None self._gathering_endpoints = False self.runtime_name = runtime_name - self.warmup_stop_event = threading.Event() - self.warmup_task = asyncio.create_task( - self.warmup( - connection_pool=self.connection_pool, - stop_event=self.warmup_stop_event, - deployment=self._deployment_name, - ) - ) self._pydantic_models_by_endpoint = None self.endpoints_discovery_stop_event = threading.Event() self.endpoints_discovery_task = None @@ -383,39 +375,6 @@ async def task(): return task() - async def warmup( - self, - connection_pool: GrpcConnectionPool, - stop_event: 'threading.Event', - deployment: str, - ): - """Executes warmup task against the deployments from the connection pool. - :param connection_pool: GrpcConnectionPool that implements the warmup to the connected deployments. - :param stop_event: signal to indicate if an early termination of the task is required for graceful teardown. - :param deployment: deployment name that need to be warmed up. - """ - self.logger.debug(f'Running HeadRuntime warmup') - - try: - await connection_pool.warmup(deployment=deployment, stop_event=stop_event) - except Exception as ex: - self.logger.error(f'error with HeadRuntime warmup up task: {ex}') - return - - def cancel_warmup_task(self): - """Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the - task is successful or hasn't reached the max timeout. - """ - if self.warmup_task: - try: - if not self.warmup_task.done(): - self.logger.debug(f'Cancelling warmup task.') - self.warmup_stop_event.set() # this event is useless if simply cancel - self.warmup_task.cancel() - except Exception as ex: - self.logger.debug(f'exception during warmup task cancellation: {ex}') - pass - def cancel_endpoint_discovery_from_workers_task(self): """Cancel endpoint_discovery_from_worker task if exists and is not completed. Cancellation is required if the Flow is being terminated before the task is successful or hasn't reached the max timeout. @@ -433,7 +392,6 @@ def cancel_endpoint_discovery_from_workers_task(self): async def close(self): """Close the data request handler, by closing the executor and the batch queues.""" self.logger.debug(f'Closing Request Handler') - self.cancel_warmup_task() self.cancel_endpoint_discovery_from_workers_task() await self.connection_pool.close() self.logger.debug(f'Request Handler closed') diff --git a/jina/serve/runtimes/servers/__init__.py b/jina/serve/runtimes/servers/__init__.py index 4a03ef71651c8..e335739bada59 100644 --- a/jina/serve/runtimes/servers/__init__.py +++ b/jina/serve/runtimes/servers/__init__.py @@ -97,7 +97,6 @@ def _get_request_handler(self): aio_tracing_client_interceptors=self.aio_tracing_client_interceptors(), tracing_client_interceptor=self.tracing_client_interceptor(), deployment_name=self.name.split('/')[0], - works_as_load_balancer=self.works_as_load_balancer, ) def _add_gateway_args(self): diff --git a/tests/integration/instrumentation/test_flow_instrumentation.py b/tests/integration/instrumentation/test_flow_instrumentation.py index f5e221bbfe696..27f27c9389b80 100644 --- a/tests/integration/instrumentation/test_flow_instrumentation.py +++ b/tests/integration/instrumentation/test_flow_instrumentation.py @@ -135,8 +135,8 @@ def test_multiprotocol_gateway_instrumentation( (server_spans, client_spans, executor_spans) = partition_spans_by_kind( gateway_traces ) - assert len(client_spans) == 11 - assert len(server_spans) == 12 + assert len(client_spans) == 9 + assert len(server_spans) == 10 def test_executor_instrumentation(jaeger_port, otlp_collector, otlp_receiver_port): diff --git a/tests/integration/runtimes/test_warmup.py b/tests/integration/runtimes/test_warmup.py deleted file mode 100644 index 33c5e308c9d7e..0000000000000 --- a/tests/integration/runtimes/test_warmup.py +++ /dev/null @@ -1,100 +0,0 @@ -import time - -import pytest - -from jina import Executor, Flow - -SLOW_EXECUTOR_SLEEP_TIME = 3 - - -class SlowExecutor(Executor): - def __init__(self, **kwargs): - super().__init__(**kwargs) - time.sleep(SLOW_EXECUTOR_SLEEP_TIME) - - -@pytest.mark.asyncio -@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) -async def test_gateway_warmup_fast_executor(protocol, capfd): - flow = Flow(protocol=protocol).add() - - with flow: - time.sleep(1) - out, _ = capfd.readouterr() - assert 'recv _status' in out - assert out.count('recv _status') == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) -async def test_gateway_warmup_with_replicas_and_shards(protocol, capfd): - flow = ( - Flow(protocol=protocol) - .add(name='executor0', shards=2) - .add(name='executor1', replicas=2) - ) - - with flow: - time.sleep(1) - out, _ = capfd.readouterr() - assert 'recv _status' in out - # 2 calls from gateway runtime to deployments - # 2 calls from head to shards - # 1 call from the gateway to the head runtime warmup adds an additional call to any shard - assert out.count('recv _status') == 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) -async def test_gateway_warmup_slow_executor(protocol, capfd): - flow = Flow(protocol=protocol).add(name='slowExecutor', uses='SlowExecutor') - - with flow: - # requires high sleep time to account for Flow readiness and properly capture the output logs - time.sleep(SLOW_EXECUTOR_SLEEP_TIME * 3) - out, _ = capfd.readouterr() - assert 'recv _status' in out - assert out.count('recv _status') == 1 - - -@pytest.mark.asyncio -async def test_multi_protocol_gateway_warmup_fast_executor(port_generator, capfd): - http_port = port_generator() - grpc_port = port_generator() - websocket_port = port_generator() - flow = ( - Flow() - .config_gateway( - port=[http_port, grpc_port, websocket_port], - protocol=['http', 'grpc', 'websocket'], - ) - .add() - ) - - with flow: - time.sleep(1) - out, _ = capfd.readouterr() - assert 'recv _status' in out - assert out.count('recv _status') == 1 - - -@pytest.mark.asyncio -async def test_multi_protocol_gateway_warmup_slow_executor(port_generator, capfd): - http_port = port_generator() - grpc_port = port_generator() - websocket_port = port_generator() - flow = ( - Flow() - .config_gateway( - port=[http_port, grpc_port, websocket_port], - protocol=['http', 'grpc', 'websocket'], - ) - .add(name='slowExecutor', uses='SlowExecutor') - ) - - with flow: - # requires high sleep time to account for Flow readiness and properly capture the output logs - time.sleep(SLOW_EXECUTOR_SLEEP_TIME * 3) - out, _ = capfd.readouterr() - assert 'recv _status' in out - assert out.count('recv _status') == 1 diff --git a/tests/unit/serve/networking/test_replica_list.py b/tests/unit/serve/networking/test_replica_list.py index 4d0cc580b7fa3..3852156d42ad1 100644 --- a/tests/unit/serve/networking/test_replica_list.py +++ b/tests/unit/serve/networking/test_replica_list.py @@ -22,7 +22,6 @@ def test_add_connection(replica_list): replica_list.add_connection('executor0', 'executor-0') assert replica_list.has_connections() assert replica_list.has_connection('executor0') - assert len(replica_list.warmup_stubs) assert not replica_list.has_connection('random-address') assert len(replica_list.get_all_connections()) == 1 @@ -34,8 +33,6 @@ async def test_remove_connection(replica_list): await replica_list.remove_connection('executor0') assert not replica_list.has_connections() assert not replica_list.has_connection('executor0') - # warmup stubs are not updated in the remove_connection method - assert len(replica_list.warmup_stubs) # unknown/unmanaged connections removed_connection_invalid = await replica_list.remove_connection('random-address') assert removed_connection_invalid is None @@ -64,7 +61,6 @@ async def test_close(replica_list): assert replica_list.has_connection('executor1') await replica_list.close() assert not replica_list.has_connections() - assert not len(replica_list.warmup_stubs) async def _print_channel_attributes(connection_stub: _ConnectionStubs):