diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 572fcf6fe2743..9368e81b7d8fd 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -1,4 +1,5 @@ import asyncio +import copy from asyncio import Event, Task from typing import Callable, Dict, List, Optional, TYPE_CHECKING from jina._docarray import docarray_v2 @@ -63,6 +64,7 @@ def _reset(self) -> None: self._big_doc = self._request_docarray_cls() self._flush_task: Optional[Task] = None + self._flush_trigger: Event = Event() def _cancel_timer_if_pending(self): if ( @@ -102,20 +104,19 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue: # this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc` # before the `flush` task processes it. self._start_timer() - async with self._data_lock: - if not self._flush_task: - self._flush_task = asyncio.create_task(self._await_then_flush(http)) - - self._big_doc.extend(docs) - next_req_idx = len(self._requests) - num_docs = len(docs) - self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(len(docs)) - self._requests.append(request) - queue = asyncio.Queue() - self._requests_completed.append(queue) - if len(self._big_doc) >= self._preferred_batch_size: - self._flush_trigger.set() + if not self._flush_task: + self._flush_task = asyncio.create_task(self._await_then_flush(http)) + + self._big_doc.extend(docs) + next_req_idx = len(self._requests) + num_docs = len(docs) + self._request_idxs.extend([next_req_idx] * num_docs) + self._request_lens.append(len(docs)) + self._requests.append(request) + queue = asyncio.Queue() + self._requests_completed.append(queue) + if len(self._big_doc) >= self._preferred_batch_size: + self._flush_trigger.set() return queue @@ -128,6 +129,7 @@ def _get_docs_groups_completed_request_indexes( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ): """ This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to. @@ -136,6 +138,7 @@ def _get_docs_groups_completed_request_indexes( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. :return: list of document groups and a list of request Idx to which each of these groups belong """ @@ -164,7 +167,7 @@ def _get_docs_groups_completed_request_indexes( if ( req_idx not in completed_req_idx and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx - == self._request_lens[req_idx] + == requests_lens_in_batch[req_idx] ): completed_req_idx.append(req_idx) request_bucket = non_assigned_docs[ @@ -178,6 +181,9 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, ): """ This method aims to assign to the corresponding request objects the resulting documents from the mini batches. @@ -187,6 +193,9 @@ async def _assign_results( :param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result :param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed) :param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed. + :param requests_lens_in_batch: List of lens of documents for each request in the batch. + :param requests_in_batch: List requests in batch + :param requests_completed_in_batch: List of queues for requests to be completed :return: amount of assigned documents so that some documents can come back in the next iteration """ @@ -197,12 +206,13 @@ async def _assign_results( non_assigned_docs, non_assigned_docs_reqs_idx, sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch ) num_assigned_docs = sum(len(group) for group in docs_grouped) for docs_group, request_idx in zip(docs_grouped, completed_req_idxs): - request = self._requests[request_idx] - request_completed = self._requests_completed[request_idx] + request = requests_in_batch[request_idx] + request_completed = requests_completed_in_batch[request_idx] if http is False or self._output_array_type is not None: request.direct_docs = None # batch queue will work in place, therefore result will need to read from data. request.data.set_docs_convert_arrays( @@ -226,91 +236,100 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive - async with self._data_lock: - # At this moment, we have documents concatenated in self._big_doc corresponding to requests in - # self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to - # communicate that the request has been processed properly. At this stage the data_lock is ours and - # therefore no-one can add requests to this list. - self._flush_trigger: Event = Event() - try: - if not docarray_v2: - non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() - else: - non_assigned_to_response_docs = self._response_docarray_cls() + big_doc_in_batch = copy.copy(self._big_doc) + requests_idxs_in_batch = copy.copy(self._request_idxs) + requests_lens_in_batch = copy.copy(self._request_lens) + requests_in_batch = copy.copy(self._requests) + requests_completed_in_batch = copy.copy(self._requests_completed) - non_assigned_to_response_request_idxs = [] - sum_from_previous_first_req_idx = 0 - for docs_inner_batch, req_idxs in batch( - self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None - ): - involved_requests_min_indx = req_idxs[0] - involved_requests_max_indx = req_idxs[-1] - input_len_before_call: int = len(docs_inner_batch) - batch_res_docs = None - try: - batch_res_docs = await self.func( - docs=docs_inner_batch, - parameters=self.params, - docs_matrix=None, # joining manually with batch queue is not supported right now - tracing_context=None, - ) - # Output validation - if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) - ): - if not len(batch_res_docs) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' - ) - elif batch_res_docs is None: - if not len(docs_inner_batch) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' - ) - else: - array_name = ( - 'DocumentArray' if not docarray_v2 else 'DocList' - ) - raise TypeError( - f'The return type must be {array_name} / `None` when using dynamic batching, ' - f'but getting {batch_res_docs!r}' - ) - except Exception as exc: - # All the requests containing docs in this Exception should be raising it - for request_full in self._requests_completed[ - involved_requests_min_indx : involved_requests_max_indx + 1 - ]: - await request_full.put(exc) - else: - # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend( - batch_res_docs or docs_inner_batch - ) - non_assigned_to_response_request_idxs.extend(req_idxs) - num_assigned_docs = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, - ) + self._reset() + + # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in + # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to + # communicate that the request has been processed properly. + + if not docarray_v2: + non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() + else: + non_assigned_to_response_docs = self._response_docarray_cls() - sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs + non_assigned_to_response_request_idxs = [] + sum_from_previous_first_req_idx = 0 + for docs_inner_batch, req_idxs in batch( + big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None + ): + involved_requests_min_indx = req_idxs[0] + involved_requests_max_indx = req_idxs[-1] + input_len_before_call: int = len(docs_inner_batch) + batch_res_docs = None + try: + batch_res_docs = await self.func( + docs=docs_inner_batch, + parameters=self.params, + docs_matrix=None, # joining manually with batch queue is not supported right now + tracing_context=None, + ) + # Output validation + if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( + not docarray_v2 + and isinstance(batch_res_docs, DocumentArray) + ): + if not len(batch_res_docs) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' ) - non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] - non_assigned_to_response_request_idxs = ( - non_assigned_to_response_request_idxs[num_assigned_docs:] + elif batch_res_docs is None: + if not len(docs_inner_batch) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' ) - if len(non_assigned_to_response_request_idxs) > 0: - _ = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, + else: + array_name = ( + 'DocumentArray' if not docarray_v2 else 'DocList' + ) + raise TypeError( + f'The return type must be {array_name} / `None` when using dynamic batching, ' + f'but getting {batch_res_docs!r}' ) - finally: - self._reset() + except Exception as exc: + # All the requests containing docs in this Exception should be raising it + for request_full in requests_completed_in_batch[ + involved_requests_min_indx : involved_requests_max_indx + 1 + ]: + await request_full.put(exc) + else: + # We need to attribute the docs to their requests + non_assigned_to_response_docs.extend( + batch_res_docs or docs_inner_batch + ) + non_assigned_to_response_request_idxs.extend(req_idxs) + num_assigned_docs = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) + + sum_from_previous_first_req_idx = ( + len(non_assigned_to_response_docs) - num_assigned_docs + ) + non_assigned_to_response_docs = non_assigned_to_response_docs[ + num_assigned_docs: + ] + non_assigned_to_response_request_idxs = ( + non_assigned_to_response_request_idxs[num_assigned_docs:] + ) + if len(non_assigned_to_response_request_idxs) > 0: + _ = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) async def close(self): """Closes the batch queue by flushing pending requests.""" diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 0a9bf57847e8c..483f247db7892 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -636,7 +636,14 @@ def test_failure_propagation(): True ], ) -def test_exception_handling_in_dynamic_batch(flush_all): +@pytest.mark.parametrize( + 'allow_concurrent', + [ + False, + True + ], +) +def test_exception_handling_in_dynamic_batch(flush_all, allow_concurrent): class SlowExecutorWithException(Executor): @dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all) @@ -646,7 +653,7 @@ def foo(self, docs, **kwargs): if doc.text == 'fail': raise Exception('Fail is in the Batch') - depl = Deployment(uses=SlowExecutorWithException) + depl = Deployment(uses=SlowExecutorWithException, allow_concurrent=allow_concurrent) with depl: da = DocumentArray([Document(text='good') for _ in range(50)]) @@ -670,6 +677,7 @@ def foo(self, docs, **kwargs): else: assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing + @pytest.mark.asyncio @pytest.mark.parametrize( 'flush_all', @@ -694,11 +702,11 @@ def foo(self, docs, **kwargs): cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) res = [] async for r in cl.post( - on='/foo', - inputs=da, - request_size=7, - continue_on_error=True, - results_in_order=True, + on='/foo', + inputs=da, + request_size=7, + continue_on_error=True, + results_in_order=True, ): res.extend(r) assert len(res) == 50 # 1 request per input @@ -707,8 +715,11 @@ def foo(self, docs, **kwargs): assert int(d.text) <= 5 else: larger_than_5 = 0 + smaller_than_5 = 0 for d in res: if int(d.text) > 5: larger_than_5 += 1 - assert int(d.text) >= 5 + if int(d.text) < 5: + smaller_than_5 += 1 + assert smaller_than_5 == 1 assert larger_than_5 > 0 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index bb922ed60d970..9db1958b86e05 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -64,11 +64,13 @@ async def process_request(req): @pytest.mark.parametrize('flush_all', [False, True]) async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all): batches_lengths_computed = [] + lock = asyncio.Lock() async def foo(docs, **kwargs): - await asyncio.sleep(4) - batches_lengths_computed.append(len(docs)) - return DocumentArray([Document(text='Done') for _ in docs]) + async with lock: + await asyncio.sleep(4) + batches_lengths_computed.append(len(docs)) + return DocumentArray([Document(text='Done') for _ in docs]) bq: BatchQueue = BatchQueue( foo, @@ -109,7 +111,7 @@ async def process_request(req, sleep=0): assert time_spent >= 8000 assert time_spent <= 8500 if flush_all is False: - assert batches_lengths_computed == [5, 1, 2] + assert batches_lengths_computed == [5, 2, 1] else: assert batches_lengths_computed == [6, 2]