diff --git a/pipelines/examples/semantic-search/semantic_search_example.py b/pipelines/examples/semantic-search/semantic_search_example.py index 1c01de93879f..3f2df2ab10a6 100644 --- a/pipelines/examples/semantic-search/semantic_search_example.py +++ b/pipelines/examples/semantic-search/semantic_search_example.py @@ -209,6 +209,22 @@ def semantic_search_tutorial(): }) print_documents(prediction) + # Batch prediction + predictions = pipe.run_batch(queries=["亚马逊河流的介绍", '期货交易手续费指的是什么?'], + params={ + "Retriever": { + "top_k": 50 + }, + "Ranker": { + "top_k": 5 + } + }) + for i in range(len(predictions['queries'])): + result = { + 'documents': predictions['documents'][i], + 'query': predictions['queries'][i] + } + print_documents(result) if __name__ == "__main__": diff --git a/pipelines/pipelines/nodes/base.py b/pipelines/pipelines/nodes/base.py index 797568daf627..3e5d22456ba1 100644 --- a/pipelines/pipelines/nodes/base.py +++ b/pipelines/pipelines/nodes/base.py @@ -127,16 +127,33 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]: - collate `_debug` information if present - merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline """ + return self._dispatch_run_general(self.run, **kwargs) + + def _dispatch_run_batch(self, **kwargs): + """ + The Pipelines call this method when run_batch() is executed. This method in turn executes the + _dispatch_run_general() method with the correct run method. + """ + return self._dispatch_run_general(self.run_batch, **kwargs) + + def _dispatch_run_general(self, run_method: Callable, **kwargs): + """ + This method takes care of the following: + - inspect run_method's signature to validate if all necessary arguments are available + - pop `debug` and sets them on the instance to control debug output + - call run_method with the corresponding arguments and gather output + - collate `_debug` information if present + - merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline + """ arguments = deepcopy(kwargs) params = arguments.get("params") or {} - run_signature_args = inspect.signature(self.run).parameters.keys() + run_signature_args = inspect.signature(run_method).parameters.keys() run_params: Dict[str, Any] = {} for key, value in params.items(): if key == self.name: # targeted params for this node if isinstance(value, dict): - # Extract debug attributes if "debug" in value.keys(): self.debug = value.pop("debug") @@ -156,7 +173,7 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]: if key in run_signature_args: run_inputs[key] = value - output, stream = self.run(**run_inputs, **run_params) + output, stream = run_method(**run_inputs, **run_params) # Collect debug information debug_info = {} @@ -164,11 +181,11 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]: # Include input debug_info["input"] = {**run_inputs, **run_params} debug_info["input"]["debug"] = self.debug - # Include output + # Include output, exclude _debug to avoid recursion filtered_output = { key: value for key, value in output.items() if key != "_debug" - } # Exclude _debug to avoid recursion + } debug_info["output"] = filtered_output # Include custom debug info custom_debug = output.get("_debug", {}) @@ -182,9 +199,9 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]: if all_debug: output["_debug"] = all_debug - # add "extra" args that were not used by the node + # add "extra" args that were not used by the node, but not the 'inputs' value for k, v in arguments.items(): - if k not in output.keys(): + if k not in output.keys() and k != "inputs": output[k] = v output["params"] = params diff --git a/pipelines/pipelines/nodes/ranker/base.py b/pipelines/pipelines/nodes/ranker/base.py index 216c917bb6e3..555b3fa46f4b 100644 --- a/pipelines/pipelines/nodes/ranker/base.py +++ b/pipelines/pipelines/nodes/ranker/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Union import logging from abc import abstractmethod @@ -48,7 +48,7 @@ def predict_batch(self, def run(self, query: str, documents: List[Document], - top_k: Optional[int] = None): # type: ignore + top_k: Optional[int] = None): self.query_count += 1 if documents: predict = self.timing(self.predict, "query_time") @@ -62,6 +62,28 @@ def run(self, return output, "output_1" + def run_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ): + self.query_count += len(queries) + predict_batch = self.timing(self.predict_batch, "query_time") + results = predict_batch(queries=queries, + documents=documents, + top_k=top_k, + batch_size=batch_size) + + for doc_list in results: + document_ids = [doc.id for doc in doc_list] + logger.debug("Ranked documents with IDs: %s", document_ids) + + output = {"documents": results} + + return output, "output_1" + def timing(self, fn, attr_name): """Wrapper method used to time functions.""" diff --git a/pipelines/pipelines/nodes/ranker/ernie_ranker.py b/pipelines/pipelines/nodes/ranker/ernie_ranker.py index 0d9f825c852a..8146e246bf06 100644 --- a/pipelines/pipelines/nodes/ranker/ernie_ranker.py +++ b/pipelines/pipelines/nodes/ranker/ernie_ranker.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple, Iterator import logging from pathlib import Path +from tqdm import tqdm import paddle from paddlenlp.transformers import ErnieCrossEncoder, AutoTokenizer @@ -44,6 +45,9 @@ def __init__( model_name_or_path: Union[str, Path], top_k: int = 10, use_gpu: bool = True, + max_seq_len: int = 256, + progress_bar: bool = True, + batch_size: int = 1000, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. @@ -66,26 +70,13 @@ def __init__( self.transformer_model = ErnieCrossEncoder(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.transformer_model.eval() + self.progress_bar = progress_bar + self.batch_size = batch_size + self.max_seq_len = max_seq_len if len(self.devices) > 1: self.model = paddle.DataParallel(self.transformer_model) - def predict_batch(self, - query_doc_list: List[dict], - top_k: int = None, - batch_size: int = None): - """ - Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document. - - Returns list of dictionary of query and list of document sorted by (desc.) similarity with query - - :param query_doc_list: List of dictionaries containing queries with their retrieved documents - :param top_k: The maximum number of answers to return for each query - :param batch_size: Number of samples the model receives in one batch for inference - :return: List of dictionaries containing query and ranked list of Document - """ - raise NotImplementedError - def predict(self, query: str, documents: List[Document], @@ -105,7 +96,7 @@ def predict(self, features = self.tokenizer([query for doc in documents], [doc.content for doc in documents], - max_seq_len=256, + max_seq_len=self.max_seq_len, pad_to_max_seq_len=True, truncation_strategy="longest_first") @@ -125,6 +116,146 @@ def predict(self, reverse=True, ) - # rank documents according to scores + # Rank documents according to scores sorted_documents = [doc for _, doc in sorted_scores_and_documents] return sorted_documents[:top_k] + + def predict_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + Use loaded ranker model to re-rank the supplied lists of Documents + + Returns lists of Documents sorted by (desc.) similarity with the corresponding queries. + + :param queries: Single query string or list of queries + :param documents: Single list of Documents or list of lists of Documents to be reranked. + :param top_k: The maximum number of documents to return per Document list. + :param batch_size: Number of Documents to process at a time. + """ + if top_k is None: + top_k = self.top_k + + if batch_size is None: + batch_size = self.batch_size + + number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs( + queries=queries, documents=documents) + batches = self._get_batches(all_queries=all_queries, + all_docs=all_docs, + batch_size=batch_size) + pb = tqdm(total=len(all_docs), + disable=not self.progress_bar, + desc="Ranking") + + preds = [] + for cur_queries, cur_docs in batches: + features = self.tokenizer(cur_queries, + [doc.content for doc in cur_docs], + max_seq_len=256, + pad_to_max_seq_len=True, + truncation_strategy="longest_first") + + tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()} + + with paddle.no_grad(): + similarity_scores = self.transformer_model.matching( + **tensors).numpy() + preds.extend(similarity_scores) + + for doc, rank_score in zip(cur_docs, similarity_scores): + doc.rank_score = rank_score + doc.score = rank_score + pb.update(len(cur_docs)) + pb.close() + if single_list_of_docs: + sorted_scores_and_documents = sorted( + zip(preds, documents), + key=lambda similarity_document_tuple: similarity_document_tuple[ + 0], + reverse=True, + ) + sorted_documents = [doc for _, doc in sorted_scores_and_documents] + return sorted_documents[:top_k] + else: + grouped_predictions = [] + left_idx = 0 + right_idx = 0 + for number in number_of_docs: + right_idx = left_idx + number + grouped_predictions.append( + similarity_scores[left_idx:right_idx]) + left_idx = right_idx + result = [] + for pred_group, doc_group in zip(grouped_predictions, documents): + sorted_scores_and_documents = sorted( + zip(pred_group, doc_group), + key=lambda similarity_document_tuple: + similarity_document_tuple[0], + reverse=True, + ) + sorted_documents = [ + doc for _, doc in sorted_scores_and_documents + ] + result.append(sorted_documents[:top_k]) + + return result + + def _preprocess_batch_queries_and_docs( + self, queries: List[str], documents: Union[List[Document], + List[List[Document]]] + ) -> Tuple[List[int], List[str], List[Document], bool]: + number_of_docs = [] + all_queries = [] + all_docs: List[Document] = [] + single_list_of_docs = False + + # Docs case 1: single list of Documents -> rerank single list of Documents based on single query + if len(documents) > 0 and isinstance(documents[0], Document): + if len(queries) != 1: + raise Exception( + "Number of queries must be 1 if a single list of Documents is provided." + ) + query = queries[0] + number_of_docs = [len(documents)] + all_queries = [query] * len(documents) + all_docs = documents # type: ignore + single_list_of_docs = True + + # Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query + # If queries contains a single query, apply it to each list of Documents + if len(documents) > 0 and isinstance(documents[0], list): + if len(queries) == 1: + queries = queries * len(documents) + if len(queries) != len(documents): + raise Exception( + "Number of queries must be equal to number of provided Document lists." + ) + for query, cur_docs in zip(queries, documents): + if not isinstance(cur_docs, list): + raise Exception( + f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents." + ) + number_of_docs.append(len(cur_docs)) + all_queries.extend([query] * len(cur_docs)) + all_docs.extend(cur_docs) + + return number_of_docs, all_queries, all_docs, single_list_of_docs + + @staticmethod + def _get_batches( + all_queries: List[str], all_docs: List[Document], + batch_size: Optional[int] + ) -> Iterator[Tuple[List[str], List[Document]]]: + if batch_size is None: + yield all_queries, all_docs + return + else: + for index in range(0, len(all_queries), batch_size): + yield all_queries[index:index + + batch_size], all_docs[index:index + + batch_size] diff --git a/pipelines/pipelines/nodes/retriever/base.py b/pipelines/pipelines/nodes/retriever/base.py index 723175dcdfe5..41e3d490b94e 100644 --- a/pipelines/pipelines/nodes/retriever/base.py +++ b/pipelines/pipelines/nodes/retriever/base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union, Iterator import logging from abc import abstractmethod @@ -84,6 +84,20 @@ def retrieve( """ pass + @abstractmethod + def retrieve_batch( + self, + queries: List[str], + filters: Optional[Dict[str, Union[Dict, List, str, int, float, + bool]]] = None, + top_k: Optional[int] = None, + index: str = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: bool = None, + ) -> List[List[Document]]: + pass + def timing(self, fn, attr_name): """Wrapper method used to time functions.""" @@ -125,6 +139,33 @@ def run( # type: ignore raise Exception(f"Invalid root_node '{root_node}'.") return output, stream + def run_batch( # type: ignore + self, + root_node: str, + queries: Optional[List[str]] = None, + filters: Optional[Union[dict, List[dict]]] = None, + top_k: Optional[int] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + if root_node == "Query": + self.query_count += len(queries) if isinstance(queries, list) else 1 + run_query_batch_timed = self.timing(self.run_query_batch, + "query_time") + output, stream = run_query_batch_timed(queries=queries, + filters=filters, + top_k=top_k, + index=index, + headers=headers) + elif root_node == "File": + self.index_count += len(documents) # type: ignore + run_indexing = self.timing(self.run_indexing, "index_time") + output, stream = run_indexing(documents=documents) + else: + raise Exception(f"Invalid root_node '{root_node}'.") + return output, stream + def run_query( self, query: str, @@ -144,6 +185,33 @@ def run_query( return output, "output_1" + def run_query_batch( + self, + queries: List[str], + filters: Optional[dict] = None, + top_k: Optional[int] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + ): + documents = self.retrieve_batch(queries=queries, + filters=filters, + top_k=top_k, + index=index, + headers=headers, + batch_size=batch_size) + if isinstance(queries, str): + document_ids = [] + for doc in documents: + document_ids.append(doc.id) + logger.debug("Retrieved documents with IDs: %s", document_ids) + else: + for doc_list in documents: + document_ids = [doc.id for doc in doc_list] + logger.debug("Retrieved documents with IDs: %s", document_ids) + output = {"documents": documents} + return output, "output_1" + def run_indexing(self, documents: List[dict]): if self.__class__.__name__ in [ "DensePassageRetriever", "EmbeddingRetriever" @@ -171,3 +239,13 @@ def print_time(self): print(f"Queries Performed: {self.query_count}") print(f"Query time: {self.query_time}s") print(f"{self.query_time / self.query_count} seconds per query") + + @staticmethod + def _get_batches(queries: List[str], + batch_size: Optional[int]) -> Iterator[List[str]]: + if batch_size is None: + yield queries + return + else: + for index in range(0, len(queries), batch_size): + yield queries[index:index + batch_size] diff --git a/pipelines/pipelines/nodes/retriever/dense.py b/pipelines/pipelines/nodes/retriever/dense.py index 6040938faf29..3f7bfadae8c3 100644 --- a/pipelines/pipelines/nodes/retriever/dense.py +++ b/pipelines/pipelines/nodes/retriever/dense.py @@ -206,6 +206,60 @@ def retrieve( return_embedding=False) return documents + def retrieve_batch( + self, + queries: List[str], + filters: Optional[Union[Dict[str, Union[Dict, List, str, int, float, + bool]], + List[Dict[str, Union[Dict, List, str, int, + float, bool]]], ]] = None, + top_k: Optional[int] = None, + index: str = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: bool = None, + ) -> List[List[Document]]: + if top_k is None: + top_k = self.top_k + if batch_size is None: + batch_size = self.batch_size + + if isinstance(filters, list): + if len(filters) != len(queries): + raise Exception( + "Number of filters does not match number of queries. Please provide as many filters" + " as queries or a single filter that will be applied to each query." + ) + else: + filters = [filters] * len( + queries) if filters is not None else [{}] * len(queries) + if index is None: + index = self.document_store.index + if not self.document_store: + logger.error( + "Cannot perform retrieve_batch() since DensePassageRetriever initialized with document_store=None" + ) + return [[] * len(queries)] # type: ignore + documents = [] + query_embs: List[np.ndarray] = [] + for batch in self._get_batches(queries=queries, batch_size=batch_size): + query_embs.extend(self.embed_queries(texts=batch)) + for query_emb, cur_filters in tqdm(zip(query_embs, filters), + total=len(query_embs), + disable=not self.progress_bar, + desc="Querying"): + cur_docs = self.document_store.query_by_embedding( + query_emb=query_emb, + top_k=top_k, + filters=cur_filters, + index=index, + headers=headers, + return_embedding=False, + ) + documents.append(cur_docs) + + return documents + def _get_predictions(self, dicts): """ Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). diff --git a/pipelines/pipelines/pipelines/base.py b/pipelines/pipelines/pipelines/base.py index 37c8f08da5d6..2ef81c4d5f49 100644 --- a/pipelines/pipelines/pipelines/base.py +++ b/pipelines/pipelines/pipelines/base.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, Union import copy import json @@ -72,6 +72,9 @@ class RootNode(BaseComponent): def run(self, root_node: str): # type: ignore return {}, "output_1" + def run_batch(self): # type: ignore + return {}, "output_1" + class BasePipeline: """ @@ -513,6 +516,179 @@ def run( # type: ignore i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors return node_output + def run_batch( # type: ignore + self, + queries: List[str] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + if file_paths is not None or meta is not None: + logger.info( + "It seems that an indexing Pipeline is run, so using the nodes' run method instead of run_batch." + ) + if isinstance(queries, list): + raise Exception( + "For indexing, only a single query can be provided.") + if isinstance(labels, list): + raise Exception( + "For indexing, only one MultiLabel object can be provided as labels." + ) + flattened_documents: List[Document] = [] + if documents and isinstance(documents[0], list): + for doc_list in documents: + assert isinstance(doc_list, list) + flattened_documents.extend(doc_list) + return self.run( + query=queries, + file_paths=file_paths, + labels=labels, + documents=flattened_documents, + meta=meta, + params=params, + debug=debug, + ) + # Validate node names + self._validate_node_names_in_params(params=params) + + root_node = self.root_node + if not root_node: + raise Exception("Cannot run a pipeline with no nodes.") + + node_output = None + queue: Dict[str, Any] = { + root_node: { + "root_node": root_node, + "params": params + } + } # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue + if queries: + queue[root_node]["queries"] = queries + if file_paths: + queue[root_node]["file_paths"] = file_paths + if labels: + queue[root_node]["labels"] = labels + if documents: + queue[root_node]["documents"] = documents + if meta: + queue[root_node]["meta"] = meta + + i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors + while queue: + node_id = list(queue.keys())[i] + node_input = queue[node_id] + node_input["node_id"] = node_id + + # Apply debug attributes to the node input params + # NOTE: global debug attributes will override the value specified in each node's params dictionary. + if debug is None and node_input: + if node_input.get("params", {}): + debug = params.get("debug", None) # type: ignore + if debug is not None: + if not node_input.get("params", None): + node_input["params"] = {} + if node_id not in node_input["params"].keys(): + node_input["params"][node_id] = {} + node_input["params"][node_id]["debug"] = debug + + predecessors = set(nx.ancestors(self.graph, node_id)) + if predecessors.isdisjoint(set(queue.keys( + ))): # only execute if predecessor nodes are executed + try: + logger.debug("Running node '%s` with input: %s", node_id, + node_input) + node_output, stream_id = self.graph.nodes[node_id][ + "component"]._dispatch_run_batch(**node_input) + except Exception as e: + # The input might be a really large object with thousands of embeddings. + # If you really want to see it, raise the log level. + logger.debug( + "Exception while running node '%s' with input %s", + node_id, node_input) + raise Exception( + f"Exception while running node '{node_id}': {e}\nEnable debug logging to see the data that was passed when the pipeline failed." + ) from e + queue.pop(node_id) + + if stream_id == "split": + for stream_id in [ + key for key in node_output.keys() + if key.startswith("output_") + ]: + current_node_output = { + k: v + for k, v in node_output.items() + if not k.startswith("output_") + } + current_docs = node_output.pop(stream_id) + current_node_output["documents"] = current_docs + next_nodes = self.get_next_nodes(node_id, stream_id) + for n in next_nodes: + queue[n] = current_node_output + else: + next_nodes = self.get_next_nodes(node_id, stream_id) + for n in next_nodes: + if queue.get( + n): # concatenate inputs if it's a join node + existing_input = queue[n] + if "inputs" not in existing_input.keys(): + updated_input: Dict = { + "inputs": [existing_input, node_output], + "params": params + } + if queries: + updated_input["queries"] = queries + if file_paths: + updated_input["file_paths"] = file_paths + if labels: + updated_input["labels"] = labels + if documents: + updated_input["documents"] = documents + if meta: + updated_input["meta"] = meta + else: + existing_input["inputs"].append(node_output) + updated_input = existing_input + queue[n] = updated_input + else: + queue[n] = node_output + i = 0 + else: + i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors + return node_output + + def _validate_node_names_in_params(self, params: Optional[Dict]): + """ + Validates the node names provided in the 'params' arg of run/run_batch method. + """ + if params: + if not all(node_id in self.graph.nodes + for node_id in params.keys()): + + # Might be a non-targeted param. Verify that too + not_a_node = set(params.keys()) - set(self.graph.nodes) + valid_global_params = set([ + "debug" + ]) # Debug will be picked up by _dispatch_run, see its code + for node_id in self.graph.nodes: + run_signature_args = self._get_run_node_signature(node_id) + valid_global_params |= set(run_signature_args) + invalid_keys = [ + key for key in not_a_node if key not in valid_global_params + ] + + if invalid_keys: + raise ValueError( + f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline." + ) + + def _get_run_node_signature(self, node_id: str): + return inspect.signature( + self.graph.nodes[node_id]["component"].run).parameters.keys() + def _reorder_columns(self, df: DataFrame, desired_order: List[str]) -> DataFrame: filtered_order = [col for col in desired_order if col in df.columns] diff --git a/pipelines/pipelines/pipelines/standard_pipelines.py b/pipelines/pipelines/pipelines/standard_pipelines.py index 597bda4c8a5d..d459c33db7c2 100644 --- a/pipelines/pipelines/pipelines/standard_pipelines.py +++ b/pipelines/pipelines/pipelines/standard_pipelines.py @@ -166,6 +166,26 @@ def get_document_store(self) -> Optional[BaseDocumentStore]: """ return self.pipeline.get_document_store() + def run_batch(self, + queries: List[str], + params: Optional[dict] = None, + debug: Optional[bool] = None): + """ + Run a batch of queries through the pipeline. + :param queries: List of query strings. + :param params: Parameters for the individual nodes of the pipeline. For instance, + `params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}` + :param debug: Whether the pipeline should instruct nodes to collect debug information + about their execution. By default these include the input parameters + they received and the output they generated. + All debug information can then be found in the dict returned + by this method under the key "_debug" + """ + output = self.pipeline.run_batch(queries=queries, + params=params, + debug=debug) + return output + class ExtractiveQAPipeline(BaseStandardPipeline): """ diff --git a/pipelines/pipelines/utils/__init__.py b/pipelines/pipelines/utils/__init__.py index 32ddc1f50f41..9502492b1032 100644 --- a/pipelines/pipelines/utils/__init__.py +++ b/pipelines/pipelines/utils/__init__.py @@ -23,10 +23,7 @@ stop_opensearch, stop_service, ) -from pipelines.utils.export_utils import ( - print_answers, - print_documents, - print_questions, - export_answers_to_csv, - convert_labels_to_squad, -) +from pipelines.utils.export_utils import (print_answers, print_documents, + print_questions, + export_answers_to_csv, + convert_labels_to_squad)