Skip to content

Commit

Permalink
Add batch prediction for pipelines (#3432)
Browse files Browse the repository at this point in the history
* Add batch prediction for pipelines

* Fix some hardcode problem& Update comments
  • Loading branch information
w5688414 authored Oct 12, 2022
1 parent ee696c3 commit ddb59bf
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 37 deletions.
16 changes: 16 additions & 0 deletions pipelines/examples/semantic-search/semantic_search_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ def semantic_search_tutorial():
})

print_documents(prediction)
# Batch prediction
predictions = pipe.run_batch(queries=["亚马逊河流的介绍", '期货交易手续费指的是什么?'],
params={
"Retriever": {
"top_k": 50
},
"Ranker": {
"top_k": 5
}
})
for i in range(len(predictions['queries'])):
result = {
'documents': predictions['documents'][i],
'query': predictions['queries'][i]
}
print_documents(result)


if __name__ == "__main__":
Expand Down
31 changes: 24 additions & 7 deletions pipelines/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,33 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
- collate `_debug` information if present
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
"""
return self._dispatch_run_general(self.run, **kwargs)

def _dispatch_run_batch(self, **kwargs):
"""
The Pipelines call this method when run_batch() is executed. This method in turn executes the
_dispatch_run_general() method with the correct run method.
"""
return self._dispatch_run_general(self.run_batch, **kwargs)

def _dispatch_run_general(self, run_method: Callable, **kwargs):
"""
This method takes care of the following:
- inspect run_method's signature to validate if all necessary arguments are available
- pop `debug` and sets them on the instance to control debug output
- call run_method with the corresponding arguments and gather output
- collate `_debug` information if present
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
"""
arguments = deepcopy(kwargs)
params = arguments.get("params") or {}

run_signature_args = inspect.signature(self.run).parameters.keys()
run_signature_args = inspect.signature(run_method).parameters.keys()

run_params: Dict[str, Any] = {}
for key, value in params.items():
if key == self.name: # targeted params for this node
if isinstance(value, dict):

# Extract debug attributes
if "debug" in value.keys():
self.debug = value.pop("debug")
Expand All @@ -156,19 +173,19 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
if key in run_signature_args:
run_inputs[key] = value

output, stream = self.run(**run_inputs, **run_params)
output, stream = run_method(**run_inputs, **run_params)

# Collect debug information
debug_info = {}
if getattr(self, "debug", None):
# Include input
debug_info["input"] = {**run_inputs, **run_params}
debug_info["input"]["debug"] = self.debug
# Include output
# Include output, exclude _debug to avoid recursion
filtered_output = {
key: value
for key, value in output.items() if key != "_debug"
} # Exclude _debug to avoid recursion
}
debug_info["output"] = filtered_output
# Include custom debug info
custom_debug = output.get("_debug", {})
Expand All @@ -182,9 +199,9 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
if all_debug:
output["_debug"] = all_debug

# add "extra" args that were not used by the node
# add "extra" args that were not used by the node, but not the 'inputs' value
for k, v in arguments.items():
if k not in output.keys():
if k not in output.keys() and k != "inputs":
output[k] = v

output["params"] = params
Expand Down
26 changes: 24 additions & 2 deletions pipelines/pipelines/nodes/ranker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List, Optional, Union

import logging
from abc import abstractmethod
Expand Down Expand Up @@ -48,7 +48,7 @@ def predict_batch(self,
def run(self,
query: str,
documents: List[Document],
top_k: Optional[int] = None): # type: ignore
top_k: Optional[int] = None):
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
Expand All @@ -62,6 +62,28 @@ def run(self,

return output, "output_1"

def run_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
):
self.query_count += len(queries)
predict_batch = self.timing(self.predict_batch, "query_time")
results = predict_batch(queries=queries,
documents=documents,
top_k=top_k,
batch_size=batch_size)

for doc_list in results:
document_ids = [doc.id for doc in doc_list]
logger.debug("Ranked documents with IDs: %s", document_ids)

output = {"documents": results}

return output, "output_1"

def timing(self, fn, attr_name):
"""Wrapper method used to time functions."""

Expand Down
169 changes: 150 additions & 19 deletions pipelines/pipelines/nodes/ranker/ernie_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple, Iterator
import logging
from pathlib import Path
from tqdm import tqdm

import paddle
from paddlenlp.transformers import ErnieCrossEncoder, AutoTokenizer
Expand Down Expand Up @@ -44,6 +45,9 @@ def __init__(
model_name_or_path: Union[str, Path],
top_k: int = 10,
use_gpu: bool = True,
max_seq_len: int = 256,
progress_bar: bool = True,
batch_size: int = 1000,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
Expand All @@ -66,26 +70,13 @@ def __init__(
self.transformer_model = ErnieCrossEncoder(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.transformer_model.eval()
self.progress_bar = progress_bar
self.batch_size = batch_size
self.max_seq_len = max_seq_len

if len(self.devices) > 1:
self.model = paddle.DataParallel(self.transformer_model)

def predict_batch(self,
query_doc_list: List[dict],
top_k: int = None,
batch_size: int = None):
"""
Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document.
Returns list of dictionary of query and list of document sorted by (desc.) similarity with query
:param query_doc_list: List of dictionaries containing queries with their retrieved documents
:param top_k: The maximum number of answers to return for each query
:param batch_size: Number of samples the model receives in one batch for inference
:return: List of dictionaries containing query and ranked list of Document
"""
raise NotImplementedError

def predict(self,
query: str,
documents: List[Document],
Expand All @@ -105,7 +96,7 @@ def predict(self,

features = self.tokenizer([query for doc in documents],
[doc.content for doc in documents],
max_seq_len=256,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=True,
truncation_strategy="longest_first")

Expand All @@ -125,6 +116,146 @@ def predict(self,
reverse=True,
)

# rank documents according to scores
# Rank documents according to scores
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
return sorted_documents[:top_k]

def predict_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Union[List[Document], List[List[Document]]]:
"""
Use loaded ranker model to re-rank the supplied lists of Documents
Returns lists of Documents sorted by (desc.) similarity with the corresponding queries.
:param queries: Single query string or list of queries
:param documents: Single list of Documents or list of lists of Documents to be reranked.
:param top_k: The maximum number of documents to return per Document list.
:param batch_size: Number of Documents to process at a time.
"""
if top_k is None:
top_k = self.top_k

if batch_size is None:
batch_size = self.batch_size

number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
queries=queries, documents=documents)
batches = self._get_batches(all_queries=all_queries,
all_docs=all_docs,
batch_size=batch_size)
pb = tqdm(total=len(all_docs),
disable=not self.progress_bar,
desc="Ranking")

preds = []
for cur_queries, cur_docs in batches:
features = self.tokenizer(cur_queries,
[doc.content for doc in cur_docs],
max_seq_len=256,
pad_to_max_seq_len=True,
truncation_strategy="longest_first")

tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()}

with paddle.no_grad():
similarity_scores = self.transformer_model.matching(
**tensors).numpy()
preds.extend(similarity_scores)

for doc, rank_score in zip(cur_docs, similarity_scores):
doc.rank_score = rank_score
doc.score = rank_score
pb.update(len(cur_docs))
pb.close()
if single_list_of_docs:
sorted_scores_and_documents = sorted(
zip(preds, documents),
key=lambda similarity_document_tuple: similarity_document_tuple[
0],
reverse=True,
)
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
return sorted_documents[:top_k]
else:
grouped_predictions = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_predictions.append(
similarity_scores[left_idx:right_idx])
left_idx = right_idx
result = []
for pred_group, doc_group in zip(grouped_predictions, documents):
sorted_scores_and_documents = sorted(
zip(pred_group, doc_group),
key=lambda similarity_document_tuple:
similarity_document_tuple[0],
reverse=True,
)
sorted_documents = [
doc for _, doc in sorted_scores_and_documents
]
result.append(sorted_documents[:top_k])

return result

def _preprocess_batch_queries_and_docs(
self, queries: List[str], documents: Union[List[Document],
List[List[Document]]]
) -> Tuple[List[int], List[str], List[Document], bool]:
number_of_docs = []
all_queries = []
all_docs: List[Document] = []
single_list_of_docs = False

# Docs case 1: single list of Documents -> rerank single list of Documents based on single query
if len(documents) > 0 and isinstance(documents[0], Document):
if len(queries) != 1:
raise Exception(
"Number of queries must be 1 if a single list of Documents is provided."
)
query = queries[0]
number_of_docs = [len(documents)]
all_queries = [query] * len(documents)
all_docs = documents # type: ignore
single_list_of_docs = True

# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
# If queries contains a single query, apply it to each list of Documents
if len(documents) > 0 and isinstance(documents[0], list):
if len(queries) == 1:
queries = queries * len(documents)
if len(queries) != len(documents):
raise Exception(
"Number of queries must be equal to number of provided Document lists."
)
for query, cur_docs in zip(queries, documents):
if not isinstance(cur_docs, list):
raise Exception(
f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents."
)
number_of_docs.append(len(cur_docs))
all_queries.extend([query] * len(cur_docs))
all_docs.extend(cur_docs)

return number_of_docs, all_queries, all_docs, single_list_of_docs

@staticmethod
def _get_batches(
all_queries: List[str], all_docs: List[Document],
batch_size: Optional[int]
) -> Iterator[Tuple[List[str], List[Document]]]:
if batch_size is None:
yield all_queries, all_docs
return
else:
for index in range(0, len(all_queries), batch_size):
yield all_queries[index:index +
batch_size], all_docs[index:index +
batch_size]
Loading

0 comments on commit ddb59bf

Please sign in to comment.