From 7ee48612860740f3d25fbc7d02f0b93528591468 Mon Sep 17 00:00:00 2001 From: Sara Zan Date: Fri, 22 Jul 2022 16:29:30 +0200 Subject: [PATCH] Simplify `language_modeling.py` and `tokenization.py` (#2703) * Simplification of language_model.py and tokenization.py to remove code duplication Co-authored-by: vblagoje --- docs/_src/api/api/retriever.md | 13 +- haystack/document_stores/memory.py | 9 +- haystack/errors.py | 7 + .../haystack-pipeline-master.schema.json | 15 +- haystack/modeling/data_handler/data_silo.py | 11 +- haystack/modeling/data_handler/processor.py | 64 +- haystack/modeling/data_handler/samples.py | 10 +- haystack/modeling/evaluation/eval.py | 8 +- haystack/modeling/infer.py | 14 +- haystack/modeling/model/adaptive_model.py | 66 +- haystack/modeling/model/biadaptive_model.py | 85 +- haystack/modeling/model/language_model.py | 1925 ++++++----------- haystack/modeling/model/tokenization.py | 542 ++--- haystack/modeling/model/triadaptive_model.py | 76 +- haystack/modeling/training/base.py | 45 +- haystack/modeling/visual.py | 2 +- haystack/nodes/retriever/dense.py | 112 +- .../{test_modeling_dpr.py => test_dpr.py} | 180 +- ...odeling_inference.py => test_inference.py} | 0 test/modeling/test_language.py | 34 + ...iction_head.py => test_prediction_head.py} | 4 +- ...odeling_processor.py => test_processor.py} | 6 +- ...loading.py => test_processor_save_load.py} | 4 +- ...nswering.py => test_question_answering.py} | 0 test/modeling/test_tokenization.py | 747 +++---- test/nodes/test_question_generator.py | 4 +- test/nodes/test_retriever.py | 9 +- test/samples/squad/tiny_augmented.json | 2 +- 28 files changed, 1533 insertions(+), 2461 deletions(-) rename test/modeling/{test_modeling_dpr.py => test_dpr.py} (86%) rename test/modeling/{test_modeling_inference.py => test_inference.py} (100%) create mode 100644 test/modeling/test_language.py rename test/modeling/{test_modeling_prediction_head.py => test_prediction_head.py} (87%) rename test/modeling/{test_modeling_processor.py => test_processor.py} (98%) rename test/modeling/{test_modeling_processor_saving_loading.py => test_processor_save_load.py} (89%) rename test/modeling/{test_modeling_question_answering.py => test_question_answering.py} (100%) diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index a600983127..715c2e2156 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -519,7 +519,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que #### DensePassageRetriever.\_\_init\_\_ ```python -def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True) +def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True) ``` Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -561,8 +561,6 @@ The title is expected to be present in doc.meta["name"] and can be supplied in t before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. - `use_fast_tokenizers`: Whether to use fast Rust tokenizers -- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name. -If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. - `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training. Options: `dot_product` (Default) or `cosine` - `global_loss_buffer_size`: Buffer size for all_gather() in DDP. @@ -871,7 +869,7 @@ None ```python @classmethod -def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", infer_tokenizer_classes: bool = False) +def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder") ``` Load DensePassageRetriever from the specified directory. @@ -895,7 +893,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using #### TableTextRetriever.\_\_init\_\_ ```python -def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True) +def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, use_fast: bool = True) ``` Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -923,8 +921,6 @@ This is the approach used in the original paper and is likely to improve performance if your titles contain meaningful information for retrieval (topic, entities etc.). - `use_fast_tokenizers`: Whether to use fast Rust tokenizers -- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name. -If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. - `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training. Options: `dot_product` (Default) or `cosine` - `global_loss_buffer_size`: Buffer size for all_gather() in DDP. @@ -942,6 +938,7 @@ Additional information can be found here https://huggingface.co/transformers/mai - `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. +- `use_fast`: Whether to use the fast version of DPR tokenizers or fallback to the standard version. Defaults to True. @@ -1153,7 +1150,7 @@ None ```python @classmethod -def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", table_encoder_dir: str = "table_encoder", infer_tokenizer_classes: bool = False) +def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", table_encoder_dir: str = "table_encoder") ``` Load TableTextRetriever from the specified directory. diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index c86144c771..760df00ccc 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -10,7 +10,7 @@ from tqdm import tqdm from haystack.schema import Document, Label -from haystack.errors import DuplicateDocumentError +from haystack.errors import DuplicateDocumentError, DocumentStoreError from haystack.document_stores import BaseDocumentStore from haystack.document_stores.base import get_batches_from_generator from haystack.modeling.utils import initialize_device_settings @@ -448,8 +448,11 @@ def update_embeddings( ) as progress_bar: for document_batch in batched_documents: embeddings = retriever.embed_documents(document_batch) # type: ignore - assert len(document_batch) == len(embeddings) - + if not len(document_batch) == len(embeddings): + raise DocumentStoreError( + "The number of embeddings does not match the number of documents in the batch " + f"({len(embeddings)} != {len(document_batch)})" + ) if embeddings[0].shape[0] != self.embedding_dim: raise RuntimeError( f"Embedding dim. of model ({embeddings[0].shape[0]})" diff --git a/haystack/errors.py b/haystack/errors.py index 88d6de4222..bc81faf0f8 100644 --- a/haystack/errors.py +++ b/haystack/errors.py @@ -35,6 +35,13 @@ def __repr__(self): return str(self) +class ModelingError(HaystackError): + """Exception for issues raised by the modeling module""" + + def __init__(self, message: Optional[str] = None, docs_link: Optional[str] = "https://haystack.deepset.ai/"): + super().__init__(message=message, docs_link=docs_link) + + class PipelineError(HaystackError): """Exception for issues raised within a pipeline""" diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index 1625f3069d..9776d870de 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -2116,11 +2116,6 @@ "default": true, "type": "boolean" }, - "infer_tokenizer_classes": { - "title": "Infer Tokenizer Classes", - "default": false, - "type": "boolean" - }, "similarity_function": { "title": "Similarity Function", "default": "dot_product", @@ -4338,11 +4333,6 @@ "default": true, "type": "boolean" }, - "infer_tokenizer_classes": { - "title": "Infer Tokenizer Classes", - "default": false, - "type": "boolean" - }, "similarity_function": { "title": "Similarity Function", "default": "dot_product", @@ -4387,6 +4377,11 @@ "title": "Scale Score", "default": true, "type": "boolean" + }, + "use_fast": { + "title": "Use Fast", + "default": true, + "type": "boolean" } }, "required": [ diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index 435e1ef686..f7237b8d28 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -812,7 +812,16 @@ def _run_teacher(self, batch: dict) -> List[torch.Tensor]: """ Run the teacher model on the given batch. """ - return self.teacher.inferencer.model(**batch) + params = { + "input_ids": batch["input_ids"], + "segment_ids": batch["segment_ids"], + "padding_mask": batch["padding_mask"], + } + if "output_hidden_states" in batch.keys(): + params["output_hidden_states"] = batch["output_hidden_states"] + if "output_attentions" in batch.keys(): + params["output_attentions"] = batch["output_attentions"] + return self.teacher.inferencer.model(**params) def _pass_batches( self, diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index dd90a00a46..e9584bddc3 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, List, Union, Any, Iterable +from typing import Optional, Dict, List, Union, Any, Iterable, Type import os import json @@ -16,9 +16,11 @@ import requests from tqdm import tqdm from torch.utils.data import TensorDataset +import transformers +from transformers import PreTrainedTokenizer from haystack.modeling.model.tokenization import ( - Tokenizer, + get_tokenizer, tokenize_batch_question_answering, tokenize_with_metadata, truncate_sequences, @@ -176,11 +178,9 @@ def load_from_dir(cls, load_dir: str): "Loading tokenizer from deprecated config. " "If you used `custom_vocab` or `never_split_chars`, this won't work anymore." ) - tokenizer = Tokenizer.load( - load_dir, tokenizer_class=config["tokenizer"], do_lower_case=config["lower_case"] - ) + tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"], do_lower_case=config["lower_case"]) else: - tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["tokenizer"]) + tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"]) # we have to delete the tokenizer string from config, because we pass it as Object del config["tokenizer"] @@ -216,7 +216,7 @@ def convert_from_transformers( **kwargs, ): tokenizer_args = tokenizer_args or {} - tokenizer = Tokenizer.load( + tokenizer = get_tokenizer( tokenizer_name_or_path, tokenizer_class=tokenizer_class, use_fast=use_fast, @@ -308,7 +308,9 @@ def file_to_dicts(self, file: str) -> List[dict]: raise NotImplementedError() @abstractmethod - def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): raise NotImplementedError() @abstractmethod @@ -445,7 +447,9 @@ def __init__( "using the default task or add a custom task later via processor.add_task()" ) - def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): """ Convert input dictionaries into a pytorch dataset for Question Answering. For this we have an internal representation called "baskets". @@ -492,7 +496,7 @@ def file_to_dicts(self, file: str) -> List[dict]: return dicts # TODO use Input Objects instead of this function, remove Natural Questions (NQ) related code - def convert_qa_input_dict(self, infer_dict: dict): + def convert_qa_input_dict(self, infer_dict: dict) -> Dict[str, Any]: """Input dictionaries in QA can either have ["context", "qas"] (internal format) as keys or ["text", "questions"] (api format). This function converts the latter into the former. It also converts the is_impossible field to answer_type so that NQ and SQuAD dicts have the same format. @@ -929,9 +933,15 @@ def load_from_dir(cls, load_dir: str): # read config processor_config_file = Path(load_dir) / "processor_config.json" config = json.load(open(processor_config_file)) - # init tokenizer - query_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query") - passage_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage") + # init tokenizers + query_tokenizer_class: Type[PreTrainedTokenizer] = getattr(transformers, config["query_tokenizer"]) + query_tokenizer = query_tokenizer_class.from_pretrained( + pretrained_model_name_or_path=load_dir, subfolder="query" + ) + passage_tokenizer_class: Type[PreTrainedTokenizer] = getattr(transformers, config["passage_tokenizer"]) + passage_tokenizer = passage_tokenizer_class.from_pretrained( + pretrained_model_name_or_path=load_dir, subfolder="passage" + ) # we have to delete the tokenizer string from config, because we pass it as Object del config["query_tokenizer"] @@ -978,7 +988,9 @@ def save(self, save_dir: Union[str, Path]): with open(output_config_file, "w") as file: json.dump(config, file) - def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): """ Convert input dictionaries into a pytorch dataset for TextSimilarity (e.g. DPR). For conversion we have an internal representation called "baskets". @@ -1334,9 +1346,9 @@ def load_from_dir(cls, load_dir: str): processor_config_file = Path(load_dir) / "processor_config.json" config = json.load(open(processor_config_file)) # init tokenizer - query_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query") - passage_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage") - table_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["table_tokenizer"], subfolder="table") + query_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query") + passage_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage") + table_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["table_tokenizer"], subfolder="table") # we have to delete the tokenizer string from config, because we pass it as Object del config["query_tokenizer"] @@ -1488,7 +1500,9 @@ def _read_multimodal_dpr_json(self, file: str, max_samples: Optional[int] = None standard_dicts.append(sample) return standard_dicts - def dataset_from_dicts(self, dicts: List[Dict], indices: Optional[List[int]] = None, return_baskets: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): """ Convert input dictionaries into a pytorch dataset for TextSimilarity. For conversion we have an internal representation called "baskets". @@ -1836,7 +1850,9 @@ def __init__( def file_to_dicts(self, file: str) -> List[Dict]: raise NotImplementedError - def dataset_from_dicts(self, dicts, indices=None, return_baskets=False, debug=False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): self.baskets = [] # Tokenize in batches texts = [x["text"] for x in dicts] @@ -1958,7 +1974,7 @@ def load_from_dir(cls, load_dir: str): processor_config_file = Path(load_dir) / "processor_config.json" config = json.load(open(processor_config_file)) # init tokenizer - tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["tokenizer"]) + tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"]) # we have to delete the tokenizer string from config, because we pass it as Object del config["tokenizer"] @@ -1979,7 +1995,9 @@ def convert_labels(self, dictionary: Dict): ret: Dict = {} return ret - def dataset_from_dicts(self, dicts: List[Dict], indices=None, return_baskets: bool = False, debug: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): """ Function to convert input dictionaries containing text into a torch dataset. For normal operation with Language Models it calls the superclass' TextClassification.dataset_from_dicts method. @@ -2067,7 +2085,9 @@ def file_to_dicts(self, file: str) -> List[dict]: dicts.append({"text": line}) return dicts - def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False): + def dataset_from_dicts( + self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False + ): if return_baskets: raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor") texts = [dict_["text"] for dict_ in dicts] diff --git a/haystack/modeling/data_handler/samples.py b/haystack/modeling/data_handler/samples.py index 443295ea64..6335490ec7 100644 --- a/haystack/modeling/data_handler/samples.py +++ b/haystack/modeling/data_handler/samples.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List +from typing import Any, Union, Optional, List, Dict import logging import numpy as np @@ -13,7 +13,13 @@ class Sample: the human readable clear_text. Over the course of data preprocessing, this object is populated with tokenized and featurized versions of the data.""" - def __init__(self, id: str, clear_text: dict, tokenized: Optional[dict] = None, features: Optional[dict] = None): + def __init__( + self, + id: str, + clear_text: dict, + tokenized: Optional[dict] = None, + features: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + ): """ :param id: The unique id of the sample :param clear_text: A dictionary containing various human readable fields (e.g. text, label). diff --git a/haystack/modeling/evaluation/eval.py b/haystack/modeling/evaluation/eval.py index 4cdba7409f..d73d77213c 100644 --- a/haystack/modeling/evaluation/eval.py +++ b/haystack/modeling/evaluation/eval.py @@ -69,7 +69,13 @@ def eval( with torch.no_grad(): - logits = model.forward(**batch) + logits = model.forward( + input_ids=batch.get("input_ids", None), + segment_ids=batch.get("segment_ids", None), + padding_mask=batch.get("padding_mask", None), + output_hidden_states=batch.get("output_hidden_states", False), + output_attentions=batch.get("output_attentions", False), + ) losses_per_head = model.logits_to_loss_per_head(logits=logits, **batch) preds = model.logits_to_preds(logits=logits, **batch) labels = model.prepare_labels(**batch) diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index 85b828b22b..adfddf1d50 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -470,11 +470,7 @@ def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets): with torch.no_grad(): logits = self.model.forward(**batch) preds = self.model.formatted_preds( - logits=logits, - samples=batch_samples, - tokenizer=self.processor.tokenizer, - return_class_probs=self.return_class_probs, - **batch, + logits=logits, samples=batch_samples, padding_mask=batch.get("padding_mask", None) ) preds_all += preds return preds_all @@ -511,7 +507,13 @@ def _get_predictions_and_aggregate(self, dataset: Dataset, tensor_names: List, b with torch.no_grad(): # Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU # So we transform logits to preds here as well - logits = self.model.forward(**batch) + logits = self.model.forward( + input_ids=batch["input_ids"], + segment_ids=batch["segment_ids"], + padding_mask=batch["padding_mask"], + output_hidden_states=batch.get("output_hidden_states", False), + output_attentions=batch.get("output_attentions", False), + ) # preds = self.model.logits_to_preds(logits, **batch)[0] (This must somehow be useful for SQuAD) preds = self.model.logits_to_preds(logits, **batch) unaggregated_preds_all.append(preds) diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index ac126e485b..1d01dc4671 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -13,7 +13,7 @@ from transformers.convert_graph_to_onnx import convert, quantize as quantize_model from haystack.modeling.data_handler.processor import Processor -from haystack.modeling.model.language_model import LanguageModel +from haystack.modeling.model.language_model import get_language_model, LanguageModel from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead from haystack.utils.experiment_tracking import Tracker as tracker @@ -196,7 +196,7 @@ def __init__( super(AdaptiveModel, self).__init__() # type: ignore self.device = device self.language_model = language_model.to(device) - self.lm_output_dims = language_model.get_output_dims() + self.lm_output_dims = language_model.output_dims self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads]) self.fit_heads_to_lm() self.dropout = nn.Dropout(embeds_dropout_prob) @@ -262,7 +262,6 @@ def load( # type: ignore load_dir: Union[str, Path], device: Union[str, torch.device], strict: bool = True, - lm_name: Optional[str] = None, processor: Optional[Processor] = None, ): """ @@ -277,17 +276,12 @@ def load( # type: ignore :param load_dir: Location where the AdaptiveModel is stored. :param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda"). - :param lm_name: The name to assign to the loaded language model. :param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in the PredictionHead (see torch.nn.module.load_state_dict()). :param processor: Processor to populate prediction head with information coming from tasks. """ device = torch.device(device) - # Language Model - if lm_name: - language_model = LanguageModel.load(load_dir, haystack_lm_name=lm_name) - else: - language_model = LanguageModel.load(load_dir) + language_model = get_language_model(load_dir) # Prediction heads _, ph_config_files = cls._get_prediction_head_files(load_dir) @@ -334,7 +328,9 @@ def convert_from_transformers( :return: AdaptiveModel """ - lm = LanguageModel.load(model_name_or_path, revision=revision, use_auth_token=use_auth_token, **kwargs) + lm = get_language_model( + model_name_or_path, revision=revision, use_auth_token=use_auth_token, model_kwargs=kwargs + ) if task_type is None: # Infer task type from config architecture = lm.model.config.architectures[0] @@ -462,31 +458,44 @@ def prepare_labels(self, **kwargs): all_labels.append(labels) return all_labels - def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs): + def forward( + self, + input_ids: torch.Tensor, + segment_ids: torch.Tensor, + padding_mask: torch.Tensor, + output_hidden_states: bool = False, + output_attentions: bool = False, + ): """ Push data through the whole model and returns logits. The data will propagate through the language model and each of the attached prediction heads. - :param kwargs: Holds all arguments that need to be passed to the language model - and prediction head(s). + :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. + :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the + first sentence are marked with 0 and the tokens in the second sentence are marked with 1. + It is a tensor of shape [batch_size, max_seq_len]. + :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens + of shape [batch_size, max_seq_len]. :param output_hidden_states: Whether to output hidden states :param output_attentions: Whether to output attentions :return: All logits as torch.tensor or multiple tensors. """ # Run forward pass of language model output_tuple = self.language_model.forward( - **kwargs, output_hidden_states=output_hidden_states, output_attentions=output_attentions + input_ids=input_ids, + segment_ids=segment_ids, + attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, ) - if output_hidden_states: - if output_attentions: - sequence_output, pooled_output, hidden_states, attentions = output_tuple - else: - sequence_output, pooled_output, hidden_states = output_tuple + if output_hidden_states and output_attentions: + sequence_output, pooled_output, hidden_states, attentions = output_tuple + elif output_hidden_states: + sequence_output, pooled_output, hidden_states = output_tuple + elif output_attentions: + sequence_output, pooled_output, attentions = output_tuple else: - if output_attentions: - sequence_output, pooled_output, attentions = output_tuple - else: - sequence_output, pooled_output = output_tuple + sequence_output, pooled_output = output_tuple # Run forward pass of (multiple) prediction heads using the output from above all_logits = [] if len(self.prediction_heads) > 0: @@ -509,12 +518,11 @@ def forward(self, output_hidden_states: bool = False, output_attentions: bool = # just return LM output (e.g. useful for extracting embeddings at inference time) all_logits.append((sequence_output, pooled_output)) + if output_hidden_states and output_attentions: + return all_logits, hidden_states, attentions if output_hidden_states: - if output_attentions: - return all_logits, hidden_states, attentions - else: - return all_logits, hidden_states - elif output_attentions: + return all_logits, hidden_states + if output_attentions: return all_logits, attentions return all_logits @@ -570,7 +578,7 @@ def verify_vocab_size(self, vocab_size: int): msg = ( f"Vocab size of tokenizer {vocab_size} doesn't match with model {model_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size == model_vocab_len, msg diff --git a/haystack/modeling/model/biadaptive_model.py b/haystack/modeling/model/biadaptive_model.py index e960fb01dd..d80f009578 100644 --- a/haystack/modeling/model/biadaptive_model.py +++ b/haystack/modeling/model/biadaptive_model.py @@ -6,9 +6,10 @@ import torch from torch import nn +from transformers import DPRContextEncoder, DPRQuestionEncoder, AutoModel from haystack.modeling.data_handler.processor import Processor -from haystack.modeling.model.language_model import LanguageModel +from haystack.modeling.model.language_model import get_language_model, LanguageModel from haystack.modeling.model.prediction_head import PredictionHead, TextSimilarityHead from haystack.utils.experiment_tracking import Tracker as tracker @@ -28,8 +29,11 @@ def loss_per_head_sum( class BiAdaptiveModel(nn.Module): - """PyTorch implementation containing all the modelling needed for your NLP task. Combines 2 language - models for representation of 2 sequences and a prediction head. Allows for gradient flow back to the 2 language model components.""" + """ + PyTorch implementation containing all the modelling needed for your NLP task. + Combines 2 language models for representation of 2 sequences and a prediction head. + Allows for gradient flow back to the 2 language model components. + """ def __init__( self, @@ -74,9 +78,9 @@ def __init__( self.device = device self.language_model1 = language_model1.to(device) - self.lm1_output_dims = language_model1.get_output_dims() + self.lm1_output_dims = language_model1.output_dims self.language_model2 = language_model2.to(device) - self.lm2_output_dims = language_model2.get_output_dims() + self.lm2_output_dims = language_model2.output_dims self.dropout1 = nn.Dropout(embeds_dropout_prob) self.dropout2 = nn.Dropout(embeds_dropout_prob) self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads]) @@ -140,13 +144,13 @@ def load( """ # Language Model if lm1_name: - language_model1 = LanguageModel.load(os.path.join(load_dir, lm1_name)) + language_model1 = get_language_model(os.path.join(load_dir, lm1_name)) else: - language_model1 = LanguageModel.load(load_dir) + language_model1 = get_language_model(load_dir) if lm2_name: - language_model2 = LanguageModel.load(os.path.join(load_dir, lm2_name)) + language_model2 = get_language_model(os.path.join(load_dir, lm2_name)) else: - language_model2 = LanguageModel.load(load_dir) + language_model2 = get_language_model(load_dir) # Prediction heads ph_config_files = cls._get_prediction_head_files(load_dir) @@ -258,7 +262,15 @@ def prepare_labels(self, **kwargs): all_labels.append(labels) return all_labels - def forward(self, **kwargs): + def forward( + self, + query_input_ids: Optional[torch.Tensor] = None, + query_segment_ids: Optional[torch.Tensor] = None, + query_attention_mask: Optional[torch.Tensor] = None, + passage_input_ids: Optional[torch.Tensor] = None, + passage_segment_ids: Optional[torch.Tensor] = None, + passage_attention_mask: Optional[torch.Tensor] = None, + ): """ Push data through the whole model and returns logits. The data will propagate through the first language model and second language model based on the tensor names and both the @@ -269,7 +281,14 @@ def forward(self, **kwargs): """ # Run forward pass of both language models - pooled_output = self.forward_lm(**kwargs) + pooled_output = self.forward_lm( + query_input_ids=query_input_ids, + query_segment_ids=query_segment_ids, + query_attention_mask=query_attention_mask, + passage_input_ids=passage_input_ids, + passage_segment_ids=passage_segment_ids, + passage_attention_mask=passage_attention_mask, + ) # Run forward pass of (multiple) prediction heads using the output from above all_logits = [] @@ -304,7 +323,15 @@ def forward(self, **kwargs): return all_logits - def forward_lm(self, **kwargs): + def forward_lm( + self, + query_input_ids: Optional[torch.Tensor] = None, + query_segment_ids: Optional[torch.Tensor] = None, + query_attention_mask: Optional[torch.Tensor] = None, + passage_input_ids: Optional[torch.Tensor] = None, + passage_segment_ids: Optional[torch.Tensor] = None, + passage_attention_mask: Optional[torch.Tensor] = None, + ): """ Forward pass for the BiAdaptive model. @@ -312,11 +339,23 @@ def forward_lm(self, **kwargs): :return: 2 tensors of pooled_output from the 2 language models. """ pooled_output = [None, None] - if "query_input_ids" in kwargs.keys(): - pooled_output1, hidden_states1 = self.language_model1(**kwargs) + + if query_input_ids is not None and query_segment_ids is not None and query_attention_mask is not None: + pooled_output1, _ = self.language_model1( + input_ids=query_input_ids, segment_ids=query_segment_ids, attention_mask=query_attention_mask + ) pooled_output[0] = pooled_output1 - if "passage_input_ids" in kwargs.keys(): - pooled_output2, hidden_states2 = self.language_model2(**kwargs) + + if passage_input_ids is not None and passage_segment_ids is not None and passage_attention_mask is not None: + + max_seq_len = passage_input_ids.shape[-1] + passage_input_ids = passage_input_ids.view(-1, max_seq_len) + passage_attention_mask = passage_attention_mask.view(-1, max_seq_len) + passage_segment_ids = passage_segment_ids.view(-1, max_seq_len) + + pooled_output2, _ = self.language_model2( + input_ids=passage_input_ids, segment_ids=passage_segment_ids, attention_mask=passage_attention_mask + ) pooled_output[1] = pooled_output2 return tuple(pooled_output) @@ -350,7 +389,7 @@ def verify_vocab_size(self, vocab_size1: int, vocab_size2: int): msg = ( f"Vocab size of tokenizer {vocab_size1} doesn't match with model {model1_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size1 == model1_vocab_len, msg @@ -359,7 +398,7 @@ def verify_vocab_size(self, vocab_size1: int, vocab_size2: int): msg = ( f"Vocab size of tokenizer {vocab_size1} doesn't match with model {model2_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size2 == model2_vocab_len, msg @@ -395,8 +434,6 @@ def _get_prediction_head_files(cls, load_dir: Union[str, Path]): return config_files def convert_to_transformers(self): - from transformers import DPRContextEncoder, DPRQuestionEncoder, AutoModel - if len(self.prediction_heads) != 1: raise ValueError( f"Currently conversion only works for models with a SINGLE prediction head. " @@ -458,12 +495,8 @@ def convert_from_transformers( :type processor: Processor :return: AdaptiveModel """ - lm1 = LanguageModel.load( - pretrained_model_name_or_path=model_name_or_path1, language_model_class="DPRQuestionEncoder" - ) - lm2 = LanguageModel.load( - pretrained_model_name_or_path=model_name_or_path2, language_model_class="DPRContextEncoder" - ) + lm1 = get_language_model(pretrained_model_name_or_path=model_name_or_path1) + lm2 = get_language_model(pretrained_model_name_or_path=model_name_or_path2) prediction_head = TextSimilarityHead(similarity_function=similarity_function) # TODO Infer type of head automatically from config if task_type == "text_similarity": diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index 1247a5dcf6..34a4565768 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -17,47 +17,47 @@ Acknowledgements: Many of the modeling parts here come from the great transformers repository: https://github.com/huggingface/transformers. Thanks for the great work! """ -from __future__ import absolute_import, division, print_function, unicode_literals -from typing import Optional, Dict, Any, Union +from typing import Type, Optional, Dict, Any, Union, List + +import re import json import logging import os +from abc import ABC, abstractmethod from pathlib import Path from functools import wraps import numpy as np import torch from torch import nn import transformers -from transformers import ( - BertModel, - BertConfig, - RobertaModel, - RobertaConfig, - XLNetModel, - XLNetConfig, - AlbertModel, - AlbertConfig, - XLMRobertaModel, - XLMRobertaConfig, - DistilBertModel, - DistilBertConfig, - ElectraModel, - ElectraConfig, - CamembertModel, - CamembertConfig, - BigBirdModel, - BigBirdConfig, - DebertaV2Model, - DebertaV2Config, -) +from transformers import PretrainedConfig, PreTrainedModel from transformers import AutoModel, AutoConfig from transformers.modeling_utils import SequenceSummary +from haystack.errors import ModelingError + logger = logging.getLogger(__name__) +LANGUAGE_HINTS = ( + ("german", "german"), + ("english", "english"), + ("chinese", "chinese"), + ("indian", "indian"), + ("french", "french"), + ("camembert", "french"), + ("polish", "polish"), + ("spanish", "spanish"), + ("umberto", "italian"), + ("multilingual", "multilingual"), +) + +#: Names of the attributes in various model configs which refer to the number of dimensions in the output vectors +OUTPUT_DIM_NAMES = ["dim", "hidden_size", "d_model"] + + def silence_transformers_logs(from_pretrained_func): """ A wrapper that raises the log level of Transformers to @@ -82,240 +82,77 @@ def quiet_from_pretrained_func(cls, *args, **kwargs): return quiet_from_pretrained_func -# These are the names of the attributes in various model configs which refer to the number of dimensions -# in the output vectors -OUTPUT_DIM_NAMES = ["dim", "hidden_size", "d_model"] - # TODO analyse if LMs can be completely used through HF transformers -class LanguageModel(nn.Module): +class LanguageModel(nn.Module, ABC): """ - The parent class for any kind of model that can embed language into a semantic vector space. Practically - speaking, these models read in tokenized sentences and return vectors that capture the meaning of sentences - or of tokens. + The parent class for any kind of model that can embed language into a semantic vector space. + These models read in tokenized sentences and return vectors that capture the meaning of sentences or of tokens. """ - subclasses: dict = {} + def __init__(self, model_type: str): + super().__init__() + self._output_dims = None + self.name = model_type - def __init_subclass__(cls, **kwargs): - """ - This automatically keeps track of all available subclasses. - Enables generic load() or all specific LanguageModel implementation. - """ - super().__init_subclass__(**kwargs) - cls.subclasses[cls.__name__] = cls + @property + def encoder(self): + return self.model.encoder - def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, **kwargs): + @abstractmethod + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + segment_ids: Optional[torch.Tensor], # DistilBERT does not use them, see DistilBERTLanguageModel + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: bool = False, + ): raise NotImplementedError - @classmethod - def load( - cls, - pretrained_model_name_or_path: Union[Path, str], - language: str = None, - use_auth_token: Union[bool, str] = None, - **kwargs, - ): + @property + def output_hidden_states(self): """ - Load a pretrained language model by doing one of the following: - - 1. Specifying its name and downloading the model. - 2. Pointing to the directory the model is saved in. - - Available remote models: - - * bert-base-uncased - * bert-large-uncased - * bert-base-cased - * bert-large-cased - * bert-base-multilingual-uncased - * bert-base-multilingual-cased - * bert-base-chinese - * bert-base-german-cased - * roberta-base - * roberta-large - * xlnet-base-cased - * xlnet-large-cased - * xlm-roberta-base - * xlm-roberta-large - * albert-base-v2 - * albert-large-v2 - * distilbert-base-german-cased - * distilbert-base-multilingual-cased - * google/electra-small-discriminator - * google/electra-base-discriminator - * google/electra-large-discriminator - * facebook/dpr-question_encoder-single-nq-base - * facebook/dpr-ctx_encoder-single-nq-base - - See all supported model variations at: https://huggingface.co/models. - - The appropriate language model class is inferred automatically from model configuration - or can be manually supplied using `language_model_class`. - - :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. - :param revision: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash. - :param language_model_class: (Optional) Name of the language model class to load (for example `Bert`). + Controls whether the model outputs the hidden states or not """ - n_added_tokens = kwargs.pop("n_added_tokens", 0) - language_model_class = kwargs.pop("language_model_class", None) - kwargs["revision"] = kwargs.get("revision", None) - logger.info("LOADING MODEL") - logger.info("=============") - config_file = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(config_file): - logger.info(f"Model found locally at {pretrained_model_name_or_path}") - # it's a local directory in Haystack format - config = json.load(open(config_file)) - language_model = cls.subclasses[config["name"]].load(pretrained_model_name_or_path) - else: - logger.info(f"Could not find {pretrained_model_name_or_path} locally.") - logger.info(f"Looking on Transformers Model Hub (in local cache and online)...") - if language_model_class is None: - language_model_class = cls.get_language_model_class( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - - if language_model_class: - language_model = cls.subclasses[language_model_class].load( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - language_model = None - - if not language_model: - raise Exception( - f"Model not found for {pretrained_model_name_or_path}. Either supply the local path for a saved " - f"model or one of bert/roberta/xlnet/albert/distilbert models that can be downloaded from remote. " - f"Ensure that the model class name can be inferred from the directory name when loading a " - f"Transformers' model." - ) - logger.info(f"Loaded {pretrained_model_name_or_path}") + self.encoder.config.output_hidden_states = True - # resize embeddings in case of custom vocab - if n_added_tokens != 0: - # TODO verify for other models than BERT - model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings - vocab_size = model_emb_size + n_added_tokens - logger.info( - f"Resizing embedding layer of LM from {model_emb_size} to {vocab_size} to cope with custom vocab." - ) - language_model.model.resize_token_embeddings(vocab_size) - # verify - model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings - assert vocab_size == model_emb_size - - return language_model - - @staticmethod - def get_language_model_class(model_name_or_path, use_auth_token: Union[str, bool] = None, **kwargs): - # it's transformers format (either from model hub or local) - model_name_or_path = str(model_name_or_path) - - config = AutoConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token, **kwargs) - model_type = config.model_type - if model_type == "xlm-roberta": - language_model_class = "XLMRoberta" - elif model_type == "roberta": - if "mlm" in model_name_or_path.lower(): - raise NotImplementedError("MLM part of codebert is currently not supported in Haystack") - language_model_class = "Roberta" - elif model_type == "camembert": - language_model_class = "Camembert" - elif model_type == "albert": - language_model_class = "Albert" - elif model_type == "distilbert": - language_model_class = "DistilBert" - elif model_type == "bert": - language_model_class = "Bert" - elif model_type == "xlnet": - language_model_class = "XLNet" - elif model_type == "electra": - language_model_class = "Electra" - elif model_type == "dpr": - if config.architectures[0] == "DPRQuestionEncoder": - language_model_class = "DPRQuestionEncoder" - elif config.architectures[0] == "DPRContextEncoder": - language_model_class = "DPRContextEncoder" - elif config.archictectures[0] == "DPRReader": - raise NotImplementedError("DPRReader models are currently not supported.") - elif model_type == "big_bird": - language_model_class = "BigBird" - elif model_type == "deberta-v2": - language_model_class = "DebertaV2" - else: - # Fall back to inferring type from model name - logger.warning( - "Could not infer LanguageModel class from config. Trying to infer " - "LanguageModel class from model name." - ) - language_model_class = LanguageModel._infer_language_model_class_from_string(model_name_or_path) - - return language_model_class - - @staticmethod - def _infer_language_model_class_from_string(model_name_or_path): - # If inferring Language model class from config doesn't succeed, - # fall back to inferring Language model class from model name. - if "xlm" in model_name_or_path.lower() and "roberta" in model_name_or_path.lower(): - language_model_class = "XLMRoberta" - elif "bigbird" in model_name_or_path.lower(): - language_model_class = "BigBird" - elif "roberta" in model_name_or_path.lower(): - language_model_class = "Roberta" - elif "codebert" in model_name_or_path.lower(): - if "mlm" in model_name_or_path.lower(): - raise NotImplementedError("MLM part of codebert is currently not supported in Haystack") - language_model_class = "Roberta" - elif "camembert" in model_name_or_path.lower() or "umberto" in model_name_or_path.lower(): - language_model_class = "Camembert" - elif "albert" in model_name_or_path.lower(): - language_model_class = "Albert" - elif "distilbert" in model_name_or_path.lower(): - language_model_class = "DistilBert" - elif "bert" in model_name_or_path.lower(): - language_model_class = "Bert" - elif "xlnet" in model_name_or_path.lower(): - language_model_class = "XLNet" - elif "electra" in model_name_or_path.lower(): - language_model_class = "Electra" - elif "word2vec" in model_name_or_path.lower() or "glove" in model_name_or_path.lower(): - language_model_class = "WordEmbedding_LM" - elif "minilm" in model_name_or_path.lower(): - language_model_class = "Bert" - elif "dpr-question_encoder" in model_name_or_path.lower(): - language_model_class = "DPRQuestionEncoder" - elif "dpr-ctx_encoder" in model_name_or_path.lower(): - language_model_class = "DPRContextEncoder" - else: - language_model_class = None + @output_hidden_states.setter + def output_hidden_states(self, value: bool): + """ + Sets the model to output the hidden states or not + """ + self.encoder.config.output_hidden_states = value - return language_model_class + @property + def output_dims(self): + """ + The output dimension of this language model + """ + if self._output_dims: + return self._output_dims - def get_output_dims(self): - config = self.model.config for odn in OUTPUT_DIM_NAMES: - if odn in dir(config): - return getattr(config, odn) - raise Exception("Could not infer the output dimensions of the language model") - - def freeze(self, layers): - """To be implemented""" - raise NotImplementedError() + try: + value = getattr(self.model.config, odn, None) + if value: + self._output_dims = value + return value + except AttributeError as e: + raise ModelingError("Can't get the output dimension before loading the model.") - def unfreeze(self): - """To be implemented""" - raise NotImplementedError() + raise ModelingError("Could not infer the output dimensions of the language model.") - def save_config(self, save_dir): + def save_config(self, save_dir: Union[Path, str]): + """ + Save the configuration of the language model in Haystack format. + """ save_filename = Path(save_dir) / "language_model_config.json" + setattr(self.model.config, "name", self.name) + setattr(self.model.config, "language", self.language) + + string = self.model.config.to_json_string() with open(save_filename, "w") as file: - setattr(self.model.config, "name", self.__class__.__name__) - setattr(self.model.config, "language", self.language) - # For DPR models, transformers overwrites the model_type with the one set in DPRConfig - # Therefore, we copy the model_type from the model config to DPRConfig - if self.__class__.__name__ == "DPRQuestionEncoder" or self.__class__.__name__ == "DPRContextEncoder": - setattr(transformers.DPRConfig, "model_type", self.model.config.model_type) - string = self.model.config.to_json_string() file.write(string) def save(self, save_dir: Union[str, Path], state_dict: Dict[Any, Any] = None): @@ -327,43 +164,16 @@ def save(self, save_dir: Union[str, Path], state_dict: Dict[Any, Any] = None): """ # Save Weights save_name = Path(save_dir) / "language_model.bin" - model_to_save = ( - self.model.module if hasattr(self.model, "module") else self.model - ) # Only save the model it-self + model_to_save = self.model.module if hasattr(self.model, "module") else self.model # Only save the model itself if not state_dict: state_dict = model_to_save.state_dict() torch.save(state_dict, save_name) self.save_config(save_dir) - @classmethod - def _get_or_infer_language_from_name(cls, language, name): - if language is not None: - return language - else: - return cls._infer_language_from_name(name) - - @classmethod - def _infer_language_from_name(cls, name): - known_languages = ("german", "english", "chinese", "indian", "french", "polish", "spanish", "multilingual") - matches = [lang for lang in known_languages if lang in name] - if "camembert" in name: - language = "french" - logger.info(f"Automatically detected language from language model name: {language}") - elif "umberto" in name: - language = "italian" - logger.info(f"Automatically detected language from language model name: {language}") - elif len(matches) == 0: - language = "english" - elif len(matches) > 1: - language = matches[0] - else: - language = matches[0] - logger.info(f"Automatically detected language from language model name: {language}") - - return language - - def formatted_preds(self, logits, samples, ignore_first_token=True, padding_mask=None, input_ids=None, **kwargs): + def formatted_preds( + self, logits, samples, ignore_first_token: bool = True, padding_mask: torch.Tensor = None + ) -> List[Dict[str, Any]]: """ Extracting vectors from a language model (for example, for extracting sentence embeddings). You can use different pooling strategies and layers by specifying them in the object attributes @@ -382,7 +192,7 @@ def formatted_preds(self, logits, samples, ignore_first_token=True, padding_mask :return: A list of dictionaries containing predictions, for example: [{"context": "some text", "vec": [-0.01, 0.5 ...]}]. """ if not hasattr(self, "extraction_layer") or not hasattr(self, "extraction_strategy"): - raise ValueError( + raise ModelingError( "`extraction_layer` or `extraction_strategy` not specified for LM. " "Make sure to set both, e.g. via Inferencer(extraction_strategy='cls_token', extraction_layer=-1)`" ) @@ -394,12 +204,15 @@ def formatted_preds(self, logits, samples, ignore_first_token=True, padding_mask # aggregate vectors if self.extraction_strategy == "pooled": if self.extraction_layer != -1: - raise ValueError( - f"Pooled output only works for the last layer, but got extraction_layer = {self.extraction_layer}. Please set `extraction_layer=-1`.)" + raise ModelingError( + f"Pooled output only works for the last layer, but got extraction_layer={self.extraction_layer}. " + "Please set `extraction_layer=-1`" ) vecs = pooled_output.cpu().numpy() + elif self.extraction_strategy == "per_token": vecs = sequence_output.cpu().numpy() + elif self.extraction_strategy == "reduce_mean": vecs = self._pool_tokens( sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token @@ -411,7 +224,9 @@ def formatted_preds(self, logits, samples, ignore_first_token=True, padding_mask elif self.extraction_strategy == "cls_token": vecs = sequence_output[:, 0, :].cpu().numpy() else: - raise NotImplementedError + raise NotImplementedError( + f"This extraction strategy ({self.extraction_strategy}) is not supported by Haystack." + ) preds = [] for vec, sample in zip(vecs, samples): @@ -421,7 +236,9 @@ def formatted_preds(self, logits, samples, ignore_first_token=True, padding_mask preds.append(pred) return preds - def _pool_tokens(self, sequence_output, padding_mask, strategy, ignore_first_token): + def _pool_tokens( + self, sequence_output: torch.Tensor, padding_mask: torch.Tensor, strategy: str, ignore_first_token: bool + ): token_vecs = sequence_output.cpu().numpy() # we only take the aggregated value of non-padding tokens padding_mask = padding_mask.cpu().numpy() @@ -439,30 +256,22 @@ def _pool_tokens(self, sequence_output, padding_mask, strategy, ignore_first_tok return pooled_vecs -class Bert(LanguageModel): +class HFLanguageModel(LanguageModel): """ - A BERT model that wraps Hugging Face's implementation + A model that wraps Hugging Face's implementation (https://github.com/huggingface/transformers) to fit the LanguageModel class. - Paper: https://arxiv.org/abs/1810.04805. """ - def __init__(self): - super(Bert, self).__init__() - self.model = None - self.name = "bert" - - @classmethod - def from_scratch(cls, vocab_size, name="bert", language="en"): - bert = cls() - bert.name = name - bert.language = language - config = BertConfig(vocab_size=vocab_size) - bert.model = BertModel(config) - return bert - - @classmethod @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): + def __init__( + self, + pretrained_model_name_or_path: Union[Path, str], + model_type: str, + language: str = None, + n_added_tokens: int = 0, + use_auth_token: Optional[Union[str, bool]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ): """ Load a pretrained model by supplying one of the following: @@ -470,362 +279,110 @@ def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = N * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). * A local path of a model trained using Haystack (for example, "some_dir/haystack_model"). - :param pretrained_model_name_or_path: The path of the saved pretrained model or the name of the model. - """ - bert = cls() - if "haystack_lm_name" in kwargs: - bert.name = kwargs["haystack_lm_name"] - else: - bert.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - bert_config = BertConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - bert.model = BertModel.from_pretrained(haystack_lm_model, config=bert_config, **kwargs) - bert.language = bert.model.config.language - else: - # Pytorch-transformer Style - bert.model = BertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - bert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return bert - - def forward( - self, - input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the BERT model. - - :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the - first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. Can also return hidden states and attentions if specified using the arguments `output_hidden_states` and `output_attentions`. - """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - output_tuple = self.model( - input_ids, - token_type_ids=segment_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - return output_tuple - - def enable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = False - - -class Albert(LanguageModel): - """ - An ALBERT model that wraps the Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - """ - - def __init__(self): - super(Albert, self).__init__() - self.model = None - self.name = "albert" - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a language model by supplying one of the following: - - * The name of a remote model on s3 (for example: "albert-base"). - * A local path of a model trained using transformers (for example: "some_dir/huggingface_model") - * A local path of a model trained using Haystack (for example: "some_dir/Haystack_model") - - :param pretrained_model_name_or_path: Name or path of a model. - :param language: (Optional) The language the model was trained for (for example "german"). - If not supplied, Haystack tries to infer it from the model name. - :return: Language Model - """ - albert = cls() - if "haystack_lm_name" in kwargs: - albert.name = kwargs["haystack_lm_name"] - else: - albert.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = AlbertConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - albert.model = AlbertModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - albert.language = albert.model.config.language - else: - # Huggingface transformer Style - albert.model = AlbertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - albert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return albert - - def forward( - self, - input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the Albert model. - - :param input_ids: The IDs of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]. - :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the - first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. - """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - output_tuple = self.model( - input_ids, - token_type_ids=segment_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - return output_tuple - - def enable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = False - - -class Roberta(LanguageModel): - """ - A roberta model that wraps the Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - Paper: https://arxiv.org/abs/1907.11692 - """ - - def __init__(self): - super(Roberta, self).__init__() - self.model = None - self.name = "roberta" + You can also use `get_language_model()` for a uniform interface across different model types. - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): + :param pretrained_model_name_or_path: The path of the saved pretrained model or the name of the model. + :param model_type: the HuggingFace class name prefix (for example 'Bert', 'Roberta', etc...) + :param language: the model's language ('multilingual' is also accepted) + :param use_auth_token: the HF token or False """ - Load a language model by supplying one of the following: + super().__init__(model_type=model_type) - * The name of a remote model on s3 (for example: "roberta-base"). - * A local path of a model trained using transformers (for example: "some_dir/huggingface_model"). - * A local path of a model trained using Haystack (for example: "some_dir/haystack_model"). + config_class: PretrainedConfig = getattr(transformers, model_type + "Config", None) + model_class: PreTrainedModel = getattr(transformers, model_type + "Model", None) - :param pretrained_model_name_or_path: Name or path of a model. - :param language: (Optional) The language the model was trained for (for example: "german"). - If not supplied, Haystack tries to infer it from the model name. - :return: Language Model - """ - roberta = cls() - if "haystack_lm_name" in kwargs: - roberta.name = kwargs["haystack_lm_name"] - else: - roberta.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" if os.path.exists(haystack_lm_config): # Haystack style - config = RobertaConfig.from_pretrained(haystack_lm_config) haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - roberta.model = RobertaModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - roberta.language = roberta.model.config.language + model_config = config_class.from_pretrained(haystack_lm_config) + self.model = model_class.from_pretrained( + haystack_lm_model, config=model_config, use_auth_token=use_auth_token, **(model_kwargs or {}) + ) + self.language = self.model.config.language else: - # Huggingface transformer Style - roberta.model = RobertaModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - roberta.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return roberta - - def forward( - self, - input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the Roberta model. - - :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the - first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. - """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - output_tuple = self.model( - input_ids, - token_type_ids=segment_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - return output_tuple - - def enable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = False - - -class XLMRoberta(LanguageModel): - """ - A roberta model that wraps the Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - Paper: https://arxiv.org/abs/1907.11692 - """ - - def __init__(self): - super(XLMRoberta, self).__init__() - self.model = None - self.name = "xlm_roberta" - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a language model by supplying one fo the following: - - * The name of a remote model on s3 (for example: "xlm-roberta-base") - * A local path of a model trained using transformers (for example: "some_dir/huggingface_model"). - * A local path of a model trained using Haystack (for example: "some_dir/haystack_model"). + # Pytorch-transformer Style + self.model = model_class.from_pretrained( + str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **(model_kwargs or {}) + ) + self.language = language or _guess_language(str(pretrained_model_name_or_path)) - :param pretrained_model_name_or_path: Name or path of a model. - :param language: (Optional) The language the model was trained for (for example, "german"). - If not supplied, Haystack tries to infer it from the model name. - :return: Language Model - """ - xlm_roberta = cls() - if "haystack_lm_name" in kwargs: - xlm_roberta.name = kwargs["haystack_lm_name"] - else: - xlm_roberta.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = XLMRobertaConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - xlm_roberta.model = XLMRobertaModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - xlm_roberta.language = xlm_roberta.model.config.language - else: - # Huggingface transformer Style - xlm_roberta.model = XLMRobertaModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - xlm_roberta.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return xlm_roberta + # resize embeddings in case of custom vocab + if n_added_tokens != 0: + # TODO verify for other models than BERT + model_emb_size = self.model.resize_token_embeddings(new_num_tokens=None).num_embeddings + vocab_size = model_emb_size + n_added_tokens + logger.info( + f"Resizing embedding layer of LM from {model_emb_size} to {vocab_size} to cope with custom vocab." + ) + self.model.resize_token_embeddings(vocab_size) + # verify + model_emb_size = self.model.resize_token_embeddings(new_num_tokens=None).num_embeddings + assert vocab_size == model_emb_size def forward( self, input_ids: torch.Tensor, + attention_mask: torch.Tensor, segment_ids: torch.Tensor, - padding_mask: torch.Tensor, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, - **kwargs, + return_dict: bool = False, ): """ - Perform the forward pass of the XLMRoberta model. + Perform the forward pass of the model. :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the first sentence are marked with 0 and the tokens in the second sentence are marked with 1. It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. + :param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens + of shape [batch_size, max_seq_len]. Different models call this parameter differently (padding/attention mask). :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. + :return: Embeddings for each token in the input sequence. Can also return hidden states and attentions if specified using the arguments `output_hidden_states` and `output_attentions`. """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - output_tuple = self.model( - input_ids, - token_type_ids=segment_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - return output_tuple - - def enable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = False - - -class DistilBert(LanguageModel): + if hasattr(self, "encoder"): # Not all models have an encoder + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + + params = {} + if input_ids is not None: + params["input_ids"] = input_ids + if segment_ids is not None: + # Some models don't take this (see DistilBERT) + params["token_type_ids"] = segment_ids + if attention_mask is not None: + params["attention_mask"] = attention_mask + if output_hidden_states: + params["output_hidden_states"] = output_hidden_states + if output_attentions: + params["output_attentions"] = output_attentions + + return self.model(**params, return_dict=return_dict) + + +class HFLanguageModelWithPooler(HFLanguageModel): """ - A DistilBERT model that wraps Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. + A model that wraps Hugging Face's implementation + (https://github.com/huggingface/transformers) to fit the LanguageModel class, + with an extra pooler. NOTE: - - DistilBert doesn’t have `token_type_ids`, you don’t need to indicate which - token belongs to which segment. Just separate your segments with the separation - token `tokenizer.sep_token` (or [SEP]). - - Unlike the other BERT variants, DistilBert does not output the - `pooled_output`. An additional pooler is initialized. + - Unlike the other BERT variants, these don't output the `pooled_output`. An additional pooler is initialized. """ - def __init__(self): - super(DistilBert, self).__init__() - self.model = None - self.name = "distilbert" - self.pooler = None - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): + def __init__( + self, + pretrained_model_name_or_path: Union[Path, str], + model_type: str, + language: str = None, + n_added_tokens: int = 0, + use_auth_token: Optional[Union[str, bool]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ): """ Load a pretrained model by supplying one of the following: @@ -835,840 +392,576 @@ def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = N :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. """ - distilbert = cls() - if "haystack_lm_name" in kwargs: - distilbert.name = kwargs["haystack_lm_name"] - else: - distilbert.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = DistilBertConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - distilbert.model = DistilBertModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - distilbert.language = distilbert.model.config.language - else: - # Pytorch-transformer Style - distilbert.model = DistilBertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - distilbert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - config = distilbert.model.config + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + model_type=model_type, + language=language, + n_added_tokens=n_added_tokens, + use_auth_token=use_auth_token, + model_kwargs=model_kwargs, + ) + config = self.model.config - # DistilBERT does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. + # These models do not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim). # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we # feed everything to the prediction head - config.summary_last_dropout = 0 - config.summary_type = "first" - config.summary_activation = "tanh" - distilbert.pooler = SequenceSummary(config) - distilbert.pooler.apply(distilbert.model._init_weights) - return distilbert - - def forward( # type: ignore - self, - input_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the DistilBERT model. - - :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. - """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - output_tuple = self.model( - input_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - # We need to manually aggregate that to get a pooled output (one vec per seq) - pooled_output = self.pooler(output_tuple[0]) - return (output_tuple[0], pooled_output) + output_tuple[1:] - - def enable_hidden_states_output(self): - self.model.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.config.output_hidden_states = False - - -class XLNet(LanguageModel): - """ - A XLNet model that wraps the Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - Paper: https://arxiv.org/abs/1906.08237 - """ + sequence_summary_config = POOLER_PARAMETERS.get(self.name.lower(), {}) + for key, value in sequence_summary_config.items(): + setattr(config, key, value) - def __init__(self): - super(XLNet, self).__init__() - self.model = None - self.name = "xlnet" - self.pooler = None - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a language model by supplying one of the following: - - * The name of a remote model on s3 (for example, "xlnet-base-cased"). - * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). - * Alocal path of a model trained using Haystack (for example, "some_dir/haystack_model"). - - :param pretrained_model_name_or_path: Name or path of a model. - :param language: (Optional) The language the model was trained for (for example, "german"). - If not supplied, Haystack tries to infer it from the model name. - :return: Language Model - """ - xlnet = cls() - if "haystack_lm_name" in kwargs: - xlnet.name = kwargs["haystack_lm_name"] - else: - xlnet.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = XLNetConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - xlnet.model = XLNetModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - xlnet.language = xlnet.model.config.language - else: - # Pytorch-transformer Style - xlnet.model = XLNetModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - xlnet.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - config = xlnet.model.config - # XLNet does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. - # The pooler takes the last hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim). - # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we - # feed everything to the prediction head - config.summary_last_dropout = 0 - xlnet.pooler = SequenceSummary(config) - xlnet.pooler.apply(xlnet.model._init_weights) - return xlnet + self.pooler = SequenceSummary(config) + self.pooler.apply(self.model._init_weights) def forward( self, input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, + attention_mask: torch.Tensor, + segment_ids: Optional[torch.Tensor], output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, - **kwargs, + return_dict: bool = False, ): """ - Perform the forward pass of the XLNet model. + Perform the forward pass of the model. :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. + It is a tensor of shape [batch_size, max_seq_len]. Optional, some models don't need it (DistilBERT for example) + :param padding_mask/attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens + of shape [batch_size, max_seq_len]. Different models call this parameter differently (padding/attention mask). :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. :return: Embeddings for each token in the input sequence. """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - - # Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...) - # We will need to implement them, if we wanna support LM adaptation - output_tuple = self.model( - input_ids, - attention_mask=padding_mask, + output_tuple = super().forward( + input_ids=input_ids, + segment_ids=segment_ids, + attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, - return_dict=False, + return_dict=return_dict, ) - # XLNet also only returns the sequence_output (one vec per token) - # We need to manually aggregate that to get a pooled output (one vec per seq) - # TODO verify that this is really doing correct pooling pooled_output = self.pooler(output_tuple[0]) return (output_tuple[0], pooled_output) + output_tuple[1:] - def enable_hidden_states_output(self): - self.model.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.output_hidden_states = False - -class Electra(LanguageModel): +class HFLanguageModelNoSegmentIds(HFLanguageModelWithPooler): """ - ELECTRA is a new pre-training approach which trains two transformer models: - the generator and the discriminator. The generator replaces tokens in a sequence, - and is therefore trained as a masked language model. The discriminator, which is - the model we're interested in, tries to identify which tokens were replaced by - the generator in the sequence. - - The ELECTRA model here wraps Hugging Face's implementation + A model that wraps Hugging Face's implementation of a model that does not need segment ids. (https://github.com/huggingface/transformers) to fit the LanguageModel class. - NOTE: - - Electra does not output the `pooled_output`. An additional pooler is initialized. + These are for now kept in a separate subclass to show a proper warning. """ - def __init__(self): - super(Electra, self).__init__() - self.model = None - self.name = "electra" - self.pooler = None - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a pretrained model by supplying one of the following - - * The name of a remote model on s3 (for example, "google/electra-base-discriminator"). - * A local path of a model trained using transformers ("some_dir/huggingface_model"). - * A local path of a model trained using Haystack ("some_dir/haystack_model"). - - :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. - """ - electra = cls() - if "haystack_lm_name" in kwargs: - electra.name = kwargs["haystack_lm_name"] - else: - electra.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = ElectraConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - electra.model = ElectraModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - electra.language = electra.model.config.language - else: - # Transformers Style - electra.model = ElectraModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - electra.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - config = electra.model.config - - # ELECTRA does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. - # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim). - # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we - # feed everything to the prediction head. - # Note: ELECTRA uses gelu as activation (BERT uses tanh instead) - config.summary_last_dropout = 0 - config.summary_type = "first" - config.summary_activation = "gelu" - config.summary_use_proj = False - electra.pooler = SequenceSummary(config) - electra.pooler.apply(electra.model._init_weights) - return electra - def forward( self, input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, + attention_mask: torch.Tensor, + segment_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, - **kwargs, + return_dict: bool = False, ): """ - Perform the forward pass of the ELECTRA model. + Perform the forward pass of the model. :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. + :param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens + of shape [batch_size, max_seq_len]. Different models call this parameter differently (padding/attention mask). + :param segment_ids: Unused. See DistilBERT documentation. :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. + :return: Embeddings for each token in the input sequence. Can also return hidden states and attentions if + specified using the arguments `output_hidden_states` and `output_attentions`. """ - output_tuple = self.model(input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, return_dict=False) - - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions + if segment_ids is not None: + logging.warning(f"`segment_ids` is not None, but {self.name} does not use them. They will be ignored.") - output_tuple = self.model( - input_ids, - attention_mask=padding_mask, + return super().forward( + input_ids=input_ids, + segment_ids=None, + attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, + return_dict=return_dict, ) - # We need to manually aggregate that to get a pooled output (one vec per seq) - pooled_output = self.pooler(output_tuple[0]) - return (output_tuple[0], pooled_output) + output_tuple[1:] - - def disable_hidden_states_output(self): - self.model.config.output_hidden_states = False - - -class Camembert(Roberta): - """ - A Camembert model that wraps the Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - """ - - def __init__(self): - super(Camembert, self).__init__() - self.model = None - self.name = "camembert" - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a language model by supplying one of the following: - - * The name of a remote model on s3 (for example, "camembert-base"). - * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). - * A local path of a model trained using Haystack (for example, "some_dir/haystack_model"). - - :param pretrained_model_name_or_path: Name or path of a model. - :param language: (Optional) The language the model was trained for (for example, "german"). - If not supplied, Haystack tries to infer it from the model name. - :return: Language Model - """ - camembert = cls() - if "haystack_lm_name" in kwargs: - camembert.name = kwargs["haystack_lm_name"] - else: - camembert.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = CamembertConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - camembert.model = CamembertModel.from_pretrained(haystack_lm_model, config=config, **kwargs) - camembert.language = camembert.model.config.language - else: - # Huggingface transformer Style - camembert.model = CamembertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - camembert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return camembert -class DPRQuestionEncoder(LanguageModel): +class DPREncoder(LanguageModel): """ - A DPRQuestionEncoder model that wraps Hugging Face's implementation. + A DPREncoder model that wraps Hugging Face's implementation. """ - def __init__(self): - super(DPRQuestionEncoder, self).__init__() - self.model = None - self.name = "dpr_question_encoder" - - @classmethod @silence_transformers_logs - def load( - cls, + def __init__( + self, pretrained_model_name_or_path: Union[Path, str], + model_type: str, language: str = None, - use_auth_token: Union[str, bool] = None, - **kwargs, + n_added_tokens: int = 0, + use_auth_token: Optional[Union[str, bool]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Load a pretrained model by supplying one of the following: - * The name of a remote model on s3 (for example, "facebook/dpr-question_encoder-single-nq-base"). * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). * A local path of a model trained using Haystack (for example, "some_dir/haystack_model"). :param pretrained_model_name_or_path: The path of the base pretrained language model whose weights are used to initialize DPRQuestionEncoder. - """ - dpr_question_encoder = cls() - if "haystack_lm_name" in kwargs: - dpr_question_encoder.name = kwargs["haystack_lm_name"] - else: - dpr_question_encoder.name = pretrained_model_name_or_path + :param model_type: the type of model (see `HUGGINGFACE_TO_HAYSTACK`) + :param model_kwargs: any kwarg to pass to the model at init + :param language: the model's language. If not given, it will be inferred. Defaults to english. + :param n_added_tokens: unused for `DPREncoder` + :param use_auth_token: useful if the model is from the HF Hub and private + :param model_kwargs: any kwarg to pass to the model at init + """ + super().__init__(model_type=model_type) + self.role = "question" if "question" in model_type.lower() else "context" + self._encoder = None + + model_classname = f"DPR{self.role.capitalize()}Encoder" + try: + model_class: Type[PreTrainedModel] = getattr(transformers, model_classname) + except AttributeError as e: + raise ModelingError(f"Model class of type '{model_classname}' not found.") - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" if os.path.exists(haystack_lm_config): - # Haystack style - original_model_config = AutoConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - - if original_model_config.model_type == "dpr": - dpr_config = transformers.DPRConfig.from_pretrained(haystack_lm_config) - dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained( - haystack_lm_model, config=dpr_config, **kwargs - ) - else: - if original_model_config.model_type != "bert": - logger.warning( - f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders." - f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors." - ) - original_config_dict = vars(original_model_config) - original_config_dict.update(kwargs) - dpr_question_encoder.model = transformers.DPRQuestionEncoder( - config=transformers.DPRConfig(**original_config_dict) - ) - language_model_class = cls.get_language_model_class(haystack_lm_config, use_auth_token, **kwargs) - dpr_question_encoder.model.base_model.bert_model = ( - cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model - ) - dpr_question_encoder.language = dpr_question_encoder.model.config.language - else: - original_model_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token + self._init_model_haystack_style( + haystack_lm_config=haystack_lm_config, + model_name_or_path=pretrained_model_name_or_path, + model_class=model_class, + model_kwargs=model_kwargs or {}, + use_auth_token=use_auth_token, ) - if original_model_config.model_type == "dpr": - # "pretrained dpr model": load existing pretrained DPRQuestionEncoder model - dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained( - str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs - ) - else: - # "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder - # but keep config values from original architecture - # TODO test for architectures other than BERT, e.g. Electra - if original_model_config.model_type != "bert": - logger.warning( - f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders." - f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors." - ) - original_config_dict = vars(original_model_config) - original_config_dict.update(kwargs) - dpr_question_encoder.model = transformers.DPRQuestionEncoder( - config=transformers.DPRConfig(**original_config_dict) - ) - dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained( - str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict - ) - dpr_question_encoder.language = cls._get_or_infer_language_from_name( - language, pretrained_model_name_or_path + else: + self._init_model_transformers_style( + model_name_or_path=pretrained_model_name_or_path, + model_class=model_class, + model_kwargs=model_kwargs or {}, + use_auth_token=use_auth_token, + language=language, ) - return dpr_question_encoder - - def save(self, save_dir: Union[str, Path], state_dict: Optional[Dict[Any, Any]] = None): + def _init_model_haystack_style( + self, + haystack_lm_config: Path, + model_name_or_path: Union[str, Path], + model_class: Type[PreTrainedModel], + model_kwargs: Dict[str, Any], + use_auth_token: Optional[Union[str, bool]] = None, + ): """ - Save the model `state_dict` and its configuration file so that it can be loaded again. + Init a Haystack-style DPR model. - :param save_dir: The directory in which the model should be saved. - :param state_dict: A dictionary containing the whole state of the module including names of layers. - By default, the unchanged state dictionary of the module is used. + :param haystack_lm_config: path to the language model config file + :param model_name_or_path: name or path of the model to load + :param model_class: The HuggingFace model class name + :param model_kwargs: any kwarg to pass to the model at init + :param use_auth_token: useful if the model is from the HF Hub and private """ - model_to_save = self.model.module if hasattr(self.model, "module") else self.model # Only save the model itself + original_model_config = AutoConfig.from_pretrained(haystack_lm_config) + haystack_lm_model = Path(model_name_or_path) / "language_model.bin" - if self.model.config.model_type != "dpr" and model_to_save.base_model_prefix.startswith("question_"): - state_dict = model_to_save.state_dict() - if state_dict: - keys = state_dict.keys() - for key in list(keys): - new_key = key - if key.startswith("question_encoder.bert_model.model."): - new_key = key.split("_encoder.bert_model.model.", 1)[1] - elif key.startswith("question_encoder.bert_model."): - new_key = key.split("_encoder.bert_model.", 1)[1] - state_dict[new_key] = state_dict.pop(key) + original_model_type = original_model_config.model_type + if original_model_type and "dpr" in original_model_type.lower(): + dpr_config = transformers.DPRConfig.from_pretrained(haystack_lm_config) + self.model = model_class.from_pretrained(haystack_lm_model, config=dpr_config, **model_kwargs) - super(DPRQuestionEncoder, self).save(save_dir=save_dir, state_dict=state_dict) + else: + self.model = self._init_model_through_config( + model_config=original_model_config, model_class=model_class, model_kwargs=model_kwargs + ) + original_model_type = capitalize_model_type(original_model_type) + language_model_class = get_language_model_class(original_model_type) + if not language_model_class: + raise ValueError( + f"The type of model supplied ({model_name_or_path} , " + f"({original_model_type}) is not supported by Haystack. " + f"Supported model categories are: {', '.join(HUGGINGFACE_TO_HAYSTACK.keys())}" + ) + # Instantiate the class for this model + self.model.base_model.bert_model = language_model_class( + pretrained_model_name_or_path=model_name_or_path, + model_type=original_model_type, + use_auth_token=use_auth_token, + **model_kwargs, + ).model - def forward( # type: ignore + self.language = self.model.config.language + + def _init_model_transformers_style( self, - query_input_ids: torch.Tensor, - query_segment_ids: torch.Tensor, - query_attention_mask: torch.Tensor, - **kwargs, + model_name_or_path: Union[str, Path], + model_class: Type[PreTrainedModel], + model_kwargs: Dict[str, Any], + use_auth_token: Optional[Union[str, bool]] = None, + language: Optional[str] = None, ): """ - Perform the forward pass of the DPRQuestionEncoder model. + Init a Transformers-style DPR model. - :param query_input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param query_segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the - first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param query_attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :return: Embeddings for each token in the input sequence. + :param model_name_or_path: name or path of the model to load + :param model_class: The HuggingFace model class name + :param model_kwargs: any kwarg to pass to the model at init + :param use_auth_token: useful if the model is from the HF Hub and private + :param language: the model's language. If not given, it will be inferred. Defaults to english. """ - output_tuple = self.model( - input_ids=query_input_ids, - token_type_ids=query_segment_ids, - attention_mask=query_attention_mask, - return_dict=True, - ) - if self.model.question_encoder.config.output_hidden_states == True: - pooled_output, all_hidden_states = output_tuple.pooler_output, output_tuple.hidden_states - return pooled_output, all_hidden_states + original_model_config = AutoConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) + if "dpr" in original_model_config.model_type.lower(): + # "pretrained dpr model": load existing pretrained DPRQuestionEncoder model + self.model = model_class.from_pretrained( + str(model_name_or_path), use_auth_token=use_auth_token, **model_kwargs + ) else: - pooled_output = output_tuple.pooler_output - return pooled_output, None - - def enable_hidden_states_output(self): - self.model.question_encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.question_encoder.config.output_hidden_states = False - - -class DPRContextEncoder(LanguageModel): - """ - A DPRContextEncoder model that wraps Hugging Face's implementation. - """ - - def __init__(self): - super(DPRContextEncoder, self).__init__() - self.model = None - self.name = "dpr_context_encoder" + # "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder + # but keep config values from original architecture + # TODO test for architectures other than BERT, e.g. Electra + self.model = self._init_model_through_config( + model_config=original_model_config, model_class=model_class, model_kwargs=model_kwargs + ) + self.model.base_model.bert_model = AutoModel.from_pretrained( + str(model_name_or_path), use_auth_token=use_auth_token, **vars(original_model_config) + ) + self.language = language or _guess_language(str(model_name_or_path)) - @classmethod - @silence_transformers_logs - def load( - cls, - pretrained_model_name_or_path: Union[Path, str], - language: str = None, - use_auth_token: Union[str, bool] = None, - **kwargs, + def _init_model_through_config( + self, model_config: AutoConfig, model_class: Type[PreTrainedModel], model_kwargs: Optional[Dict[str, Any]] ): """ - Load a pretrained model by supplying one of the following: - - * The name of a remote model on s3 (for example, "facebook/dpr-ctx_encoder-single-nq-base"). - * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). - * A local path of a model trained using Haystack (for example, "some_dir/haystack_model"). - - :param pretrained_model_name_or_path: The path of the base pretrained language model whose weights are used to initialize DPRContextEncoder. + Init a DPR model using a config object. """ - dpr_context_encoder = cls() - if "haystack_lm_name" in kwargs: - dpr_context_encoder.name = kwargs["haystack_lm_name"] - else: - dpr_context_encoder.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - - if os.path.exists(haystack_lm_config): - # Haystack style - original_model_config = AutoConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" + if model_config.model_type.lower() != "bert": + logger.warning( + f"Using a model of type '{model_config.model_type}' which might be incompatible with DPR encoders. " + f"Only Bert-based encoders are supported. They need input_ids, token_type_ids, attention_mask as input tensors." + ) + config_dict = vars(model_config) + if model_kwargs: + config_dict.update(model_kwargs) + return model_class(config=transformers.DPRConfig(**config_dict)) - if original_model_config.model_type == "dpr": - dpr_config = transformers.DPRConfig.from_pretrained(haystack_lm_config) - dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained( - haystack_lm_model, config=dpr_config, use_auth_token=use_auth_token, **kwargs - ) - else: - if original_model_config.model_type != "bert": - logger.warning( - f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders." - f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors." - ) - original_config_dict = vars(original_model_config) - original_config_dict.update(kwargs) - dpr_context_encoder.model = transformers.DPRContextEncoder( - config=transformers.DPRConfig(**original_config_dict) - ) - language_model_class = cls.get_language_model_class(haystack_lm_config, **kwargs) - dpr_context_encoder.model.base_model.bert_model = ( - cls.subclasses[language_model_class] - .load(str(pretrained_model_name_or_path), use_auth_token=use_auth_token) - .model - ) - dpr_context_encoder.language = dpr_context_encoder.model.config.language + @property + def encoder(self): + if not self._encoder: + self._encoder = self.model.question_encoder if self.role == "question" else self.model.ctx_encoder + return self._encoder - else: - # Pytorch-transformer Style - original_model_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token - ) - if original_model_config.model_type == "dpr": - # "pretrained dpr model": load existing pretrained DPRContextEncoder model - dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained( - str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs - ) - else: - # "from scratch": load weights from different architecture (e.g. bert) into DPRContextEncoder - # but keep config values from original architecture - # TODO test for architectures other than BERT, e.g. Electra - if original_model_config.model_type != "bert": - logger.warning( - f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders." - f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors." - ) - original_config_dict = vars(original_model_config) - original_config_dict.update(kwargs) - dpr_context_encoder.model = transformers.DPRContextEncoder( - config=transformers.DPRConfig(**original_config_dict) - ) - dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained( - str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict - ) - dpr_context_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) + def save_config(self, save_dir: Union[Path, str]) -> None: + """ + Save the configuration of the language model in Haystack format. - return dpr_context_encoder + :param save_dir: the path to save the model at + """ + # For DPR models, transformers overwrites the model_type with the one set in DPRConfig + # Therefore, we copy the model_type from the model config to DPRConfig + setattr(transformers.DPRConfig, "model_type", self.model.config.model_type) + super().save_config(save_dir=save_dir) - def save(self, save_dir: Union[str, Path], state_dict: Optional[Dict[Any, Any]] = None): + def save(self, save_dir: Union[str, Path], state_dict: Optional[Dict[Any, Any]] = None) -> None: """ Save the model `state_dict` and its configuration file so that it can be loaded again. :param save_dir: The directory in which the model should be saved. - :param state_dict: A dictionary containing the whole state of the module including names of layers. By default, the unchanged state dictionary of the module is used. + :param state_dict: A dictionary containing the whole state of the module including names of layers. + By default, the unchanged state dictionary of the module is used. """ - model_to_save = ( - self.model.module if hasattr(self.model, "module") else self.model - ) # Only save the model it-self + model_to_save = self.model.module if hasattr(self.model, "module") else self.model # Only save the model itself + + if "dpr" not in self.model.config.model_type.lower(): + prefix = "question" if self.role == "question" else "ctx" - if self.model.config.model_type != "dpr" and model_to_save.base_model_prefix.startswith("ctx_"): state_dict = model_to_save.state_dict() if state_dict: - keys = state_dict.keys() - for key in list(keys): + for key in list(state_dict.keys()): # list() here performs a copy and allows editing the dict new_key = key - if key.startswith("ctx_encoder.bert_model.model."): + + if key.startswith(f"{prefix}_encoder.bert_model.model."): new_key = key.split("_encoder.bert_model.model.", 1)[1] - elif key.startswith("ctx_encoder.bert_model."): + + elif key.startswith(f"{prefix}_encoder.bert_model."): new_key = key.split("_encoder.bert_model.", 1)[1] + state_dict[new_key] = state_dict.pop(key) - super(DPRContextEncoder, self).save(save_dir=save_dir, state_dict=state_dict) + super().save(save_dir=save_dir, state_dict=state_dict) - def forward( # type: ignore + def forward( self, - passage_input_ids: torch.Tensor, - passage_segment_ids: torch.Tensor, - passage_attention_mask: torch.Tensor, - **kwargs, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + segment_ids: Optional[torch.Tensor], + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: bool = True, ): """ - Perform the forward pass of the DPRContextEncoder model. + Perform the forward pass of the DPR encoder model. - :param passage_input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len]. - :param passage_segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the + :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, number_of_hard_negative, max_seq_len]. + :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the first sentence are marked with 0 and the tokens in the second sentence are marked with 1. It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len]. - :param passage_attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens + :param attention_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens of shape [batch_size, number_of_hard_negative_passages, max_seq_len]. + :param output_hidden_states: whether to add the hidden states along with the pooled output + :param output_attentions: unused :return: Embeddings for each token in the input sequence. """ - max_seq_len = passage_input_ids.shape[-1] - passage_input_ids = passage_input_ids.view(-1, max_seq_len) - passage_segment_ids = passage_segment_ids.view(-1, max_seq_len) - passage_attention_mask = passage_attention_mask.view(-1, max_seq_len) - output_tuple = self.model( - input_ids=passage_input_ids, - token_type_ids=passage_segment_ids, - attention_mask=passage_attention_mask, - return_dict=True, + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.encoder.config.output_hidden_states ) - if self.model.ctx_encoder.config.output_hidden_states == True: - pooled_output, all_hidden_states = output_tuple.pooler_output, output_tuple.hidden_states - return pooled_output, all_hidden_states - else: - pooled_output = output_tuple.pooler_output - return pooled_output, None - - def enable_hidden_states_output(self): - self.model.ctx_encoder.config.output_hidden_states = True - - def disable_hidden_states_output(self): - self.model.ctx_encoder.config.output_hidden_states = False + model_output = self.model( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=False, + return_dict=return_dict, + ) -class BigBird(LanguageModel): + if output_hidden_states: + return model_output.pooler_output, model_output.hidden_states + return model_output.pooler_output, None + + +#: Match the name of the HuggingFace Model class to the corresponding Haystack wrapper +HUGGINGFACE_TO_HAYSTACK: Dict[str, Union[Type[HFLanguageModel], Type[DPREncoder]]] = { + "Auto": HFLanguageModel, + "Albert": HFLanguageModel, + "Bert": HFLanguageModel, + "BigBird": HFLanguageModel, + "Camembert": HFLanguageModel, + "Codebert": HFLanguageModel, + "DebertaV2": HFLanguageModelWithPooler, + "DistilBert": HFLanguageModelNoSegmentIds, + "DPRContextEncoder": DPREncoder, + "DPRQuestionEncoder": DPREncoder, + "Electra": HFLanguageModelWithPooler, + "GloVe": HFLanguageModel, + "MiniLM": HFLanguageModel, + "Roberta": HFLanguageModel, + "Umberto": HFLanguageModel, + "Word2Vec": HFLanguageModel, + "WordEmbedding_LM": HFLanguageModel, + "XLMRoberta": HFLanguageModel, + "XLNet": HFLanguageModelWithPooler, +} +#: HF Capitalization pairs +HUGGINGFACE_CAPITALIZE = { + "xlm-roberta": "XLMRoberta", + "deberta-v2": "DebertaV2", + **{k.lower(): k for k in HUGGINGFACE_TO_HAYSTACK.keys()}, +} + +#: Regex to match variants of the HF class name, to enhance our mode type guessing abilities. +NAME_HINTS: Dict[str, str] = { + "xlm.*roberta": "XLMRoberta", + "roberta.*xml": "XLMRoberta", + "codebert.*mlm": "Roberta", + "mlm.*codebert": "Roberta", + "[dpr]?.*question.*encoder": "DPRQuestionEncoder", + "[dpr]?.*query.*encoder": "DPRQuestionEncoder", + "[dpr]?.*passage.*encoder": "DPRContextEncoder", + "[dpr]?.*context.*encoder": "DPRContextEncoder", + "[dpr]?.*ctx.*encoder": "DPRContextEncoder", + "deberta-v2": "DebertaV2", +} + +#: Parameters or the pooler of models that don't have their own pooler +POOLER_PARAMETERS: Dict[str, Dict[str, Any]] = { + "DistilBert": {"summary_last_dropout": 0, "summary_type": "first", "summary_activation": "tanh"}, + "XLNet": {"summary_last_dropout": 0}, + "Electra": { + "summary_last_dropout": 0, + "summary_type": "first", + "summary_activation": "gelu", + "summary_use_proj": False, + }, + "DebertaV2": { + "summary_last_dropout": 0, + "summary_type": "first", + "summary_activation": "tanh", + "summary_use_proj": False, + }, +} + + +def capitalize_model_type(model_type: str) -> str: """ - A BERT model that wraps Hugging Face's implementation - (https://github.com/huggingface/transformers) to fit the LanguageModel class. - Paper: https://arxiv.org/abs/1810.04805 + Returns the proper capitalized version of the model type, that can be used to + retrieve the model class from transformers. + :param model_type: the model_type as found in the config file + :return: the capitalized version of the model type, or the original name of not found. """ + return HUGGINGFACE_CAPITALIZE.get(model_type.lower(), model_type) - def __init__(self): - super(BigBird, self).__init__() - self.model = None - self.name = "big_bird" - - @classmethod - def from_scratch(cls, vocab_size, name="big_bird", language="en"): - big_bird = cls() - big_bird.name = name - big_bird.language = language - config = BigBirdConfig(vocab_size=vocab_size) - big_bird.model = BigBirdModel(config) - return big_bird - - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a pretrained model by supplying one of the following: - * The name of a remote model on s3 (for example, "bert-base-cased"). - * A local path of a model trained using transformers (for example, "some_dir/huggingface_model"). - * A local path of a model trained using Haystack (for example, "some_dir/haystack_model"). - - :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. - """ - big_bird = cls() - if "haystack_lm_name" in kwargs: - big_bird.name = kwargs["haystack_lm_name"] - else: - big_bird.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Pytorch-Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - big_bird_config = BigBirdConfig.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - big_bird.model = BigBirdModel.from_pretrained(haystack_lm_model, config=big_bird_config, **kwargs) - big_bird.language = big_bird.model.config.language - else: - # Pytorch-transformer Style - big_bird.model = BigBirdModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - big_bird.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - return big_bird - - def forward( - self, - input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the BigBird model. +def is_supported_model(model_type: Optional[str]): + """ + Returns whether the model type is supported by Haystack + :param model_type: the model_type as found in the config file + :return: whether the model type is supported by the Haystack + """ + return model_type and model_type.lower() in HUGGINGFACE_CAPITALIZE - :param input_ids: The IDs of each token in the input sequence. It's a tensor of shape [batch_size, max_seq_len]. - :param segment_ids: The ID of the segment. For example, in next sentence prediction, the tokens in the - first sentence are marked with 0 and the tokens in the second sentence are marked with 1. - It is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. - """ - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions - output_tuple = self.model( - input_ids, - token_type_ids=segment_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=False, - ) - return output_tuple +def get_language_model_class(model_type: str) -> Optional[Type[Union[HFLanguageModel, DPREncoder]]]: + """ + Returns the corresponding Haystack LanguageModel subclass. + :param model_type: the model_type , properly capitalized (see `capitalize_model_type()`) + :return: the wrapper class, or `None` if `model_type` was `None` or was not recognized. + Lower case model_type values will return `None` as well + """ + return HUGGINGFACE_TO_HAYSTACK.get(model_type) + + +def get_language_model( + pretrained_model_name_or_path: Union[Path, str], + language: str = None, + n_added_tokens: int = 0, + use_auth_token: Optional[Union[str, bool]] = None, + revision: Optional[str] = None, + autoconfig_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +) -> LanguageModel: + """ + Load a pretrained language model by doing one of the following: - def enable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = True + 1. Specifying its name and downloading the model. + 2. Pointing to the directory the model is saved in. - def disable_hidden_states_output(self): - self.model.encoder.config.output_hidden_states = False + See all supported model variations at: https://huggingface.co/models. + The appropriate language model class is inferred automatically from model configuration. -class DebertaV2(LanguageModel): + :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. + :param language: The language of the model (i.e english etc). + :param n_added_tokens: The number of added tokens to the model. + :param use_auth_token: Whether to use the huggingface auth token for private repos or not. + :param revision: The version of the model to use from the Hugging Face model hub. This can be a tag name, + a branch name, or a commit hash. + :param autoconfig_kwargs: Additional keyword arguments to pass to the autoconfig function. + :param model_kwargs: Additional keyword arguments to pass to the lamguage model constructor. """ - This is a wrapper around the DebertaV2 model from Hugging Face's transformers library. - It is also compatible with DebertaV3 as DebertaV3 only changes the pretraining procedure. - NOTE: - - DebertaV2 does not output the `pooled_output`. An additional pooler is initialized. - """ + if not pretrained_model_name_or_path or not isinstance(pretrained_model_name_or_path, (str, Path)): + raise ValueError(f"{pretrained_model_name_or_path} is not a valid pretrained_model_name_or_path parameter") - def __init__(self): - super().__init__() - self.model = None - self.name = "deberta-v2" - self.pooler = None + config_file = Path(pretrained_model_name_or_path) / "language_model_config.json" - @classmethod - @silence_transformers_logs - def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): - """ - Load a pretrained model by supplying one of the following: + model_type = None + config_file_exists = os.path.exists(config_file) + if config_file_exists: + # it's a local directory in Haystack format + config = json.load(open(config_file)) + model_type = config["name"] - * A remote name from the Hugging Face's model hub (for example: microsoft/deberta-v3-base). - * A local path of a model trained using transformers (for example: some_dir/huggingface_model). - * A local path of a model trained using Haystack (for example: some_dir/haystack_model). + if not model_type: + model_type = _get_model_type( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + autoconfig_kwargs=autoconfig_kwargs, + ) - :param pretrained_model_name_or_path: The path to the saved pretrained model or the name of the model. - """ - debertav2 = cls() - if "haystack_lm_name" in kwargs: - debertav2.name = kwargs["haystack_lm_name"] - else: - debertav2.name = pretrained_model_name_or_path - # We need to differentiate between loading model using Haystack format and Transformers format - haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" - if os.path.exists(haystack_lm_config): - # Haystack style - config = DebertaV2Config.from_pretrained(haystack_lm_config) - haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" - debertav2.model = DebertaV2Model.from_pretrained(haystack_lm_model, config=config, **kwargs) - debertav2.language = debertav2.model.config.language - else: - # Transformers Style - debertav2.model = DebertaV2Model.from_pretrained(str(pretrained_model_name_or_path), **kwargs) - debertav2.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) - config = debertav2.model.config + if not model_type: + logger.error( + f"Model type not understood for '{pretrained_model_name_or_path}' " + f"({model_type if model_type else 'model_type not set'}). " + "Either supply the local path for a saved model, " + "or the name of a model that can be downloaded from the Model Hub. " + "Ensure that the model class name can be inferred from the directory name " + "when loading a Transformers model." + ) + logger.error(f"Using the AutoModel class for '{pretrained_model_name_or_path}'. This can cause crashes!") + model_type = "Auto" + + # Find the class corresponding to this model type + model_type = capitalize_model_type(model_type) + language_model_class = get_language_model_class(model_type) + if not language_model_class: + raise ValueError( + f"The type of model supplied ({model_type}) is not supported by Haystack or was not correctly identified. " + f"Supported model types are: {', '.join(HUGGINGFACE_TO_HAYSTACK.keys())}" + ) - # DebertaV2 does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. - # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim). - # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we - # feed everything to the prediction head. - config.summary_last_dropout = 0 - config.summary_type = "first" - config.summary_activation = "tanh" - config.summary_use_proj = False - debertav2.pooler = SequenceSummary(config) - debertav2.pooler.apply(debertav2.model._init_weights) - return debertav2 + logger.info(f" * LOADING MODEL: '{pretrained_model_name_or_path}' {'(' + model_type + ')' if model_type else ''}") + + # Instantiate the class for this model + language_model = language_model_class( + pretrained_model_name_or_path=pretrained_model_name_or_path, + model_type=model_type, + language=language, + n_added_tokens=n_added_tokens, + use_auth_token=use_auth_token, + model_kwargs=model_kwargs, + ) + logger.info( + f"Loaded '{pretrained_model_name_or_path}' ({model_type} model) " + f"from {'local file system' if config_file_exists else 'model hub'}." + ) + return language_model + + +def _get_model_type( + model_name_or_path: Union[str, Path], + use_auth_token: Optional[Union[str, bool]] = None, + revision: Optional[str] = None, + autoconfig_kwargs: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + """ + Given a model name, try to use AutoConfig to understand which model type it is. + In case it's not successful, tries to infer the type from the name of the model. + """ + model_name_or_path = str(model_name_or_path) + + model_type: Optional[str] = None + # Use AutoConfig to understand the model class + try: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + **(autoconfig_kwargs or {}), + ) + model_type = config.model_type + # if unsupported model, try to infer from config.architectures + if not is_supported_model(model_type) and config.architectures: + model_type = config.architectures[0] if is_supported_model(config.architectures[0]) else None - def forward( - self, - input_ids: torch.Tensor, - segment_ids: torch.Tensor, - padding_mask: torch.Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ): - """ - Perform the forward pass of the DebertaV2 model. + except Exception as e: + logger.error(f"AutoConfig failed to load on '{model_name_or_path}': {str(e)}") - :param input_ids: The IDs of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]. - :param padding_mask: A mask that assigns 1 to valid input tokens and 0 to padding tokens - of shape [batch_size, max_seq_len]. - :param output_hidden_states: When set to `True`, outputs hidden states in addition to the embeddings. - :param output_attentions: When set to `True`, outputs attentions in addition to the embeddings. - :return: Embeddings for each token in the input sequence. - """ - output_tuple = self.model(input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, return_dict=False) + if not model_type: + logger.warning("Could not infer the model type from its config. Looking for clues in the model name.") - if output_hidden_states is None: - output_hidden_states = self.model.encoder.config.output_hidden_states - if output_attentions is None: - output_attentions = self.model.encoder.config.output_attentions + # Look for other patterns and variation that hints at the model type + for regex, model_name in NAME_HINTS.items(): + if re.match(f".*{regex}.*", model_name_or_path): + model_type = model_name + break - output_tuple = self.model( - input_ids, - attention_mask=padding_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, + if model_type and model_type.lower() == "roberta" and "mlm" in model_name_or_path.lower(): + logger.error( + f"MLM part of codebert is currently not supported in Haystack: '{model_name_or_path}' may crash later." ) - # We need to manually aggregate that to get a pooled output (one vec per seq) - pooled_output = self.pooler(output_tuple[0]) - return (output_tuple[0], pooled_output) + output_tuple[1:] - def disable_hidden_states_output(self): - self.model.config.output_hidden_states = False + return model_type + + +def _guess_language(name: str) -> str: + """ + Looks for clues about the model language in the model name. + """ + languages = [lang for hint, lang in LANGUAGE_HINTS if hint.lower() in name.lower()] + if len(languages) > 0: + language = languages[0] + else: + language = "english" + logger.info(f"Auto-detected model language: {language}") + return language diff --git a/haystack/modeling/model/tokenization.py b/haystack/modeling/model/tokenization.py index 3c0ed9a961..9467d38132 100644 --- a/haystack/modeling/model/tokenization.py +++ b/haystack/modeling/model/tokenization.py @@ -12,308 +12,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tokenization classes. -""" -from __future__ import absolute_import, division, print_function, unicode_literals -from typing import Dict, Any, Tuple, Optional, List, Union + +from typing import Dict, Any, Union, Tuple, Optional, List import re import logging import numpy as np -from transformers import ( - AutoTokenizer, - AlbertTokenizer, - AlbertTokenizerFast, - BertTokenizer, - BertTokenizerFast, - DistilBertTokenizer, - DistilBertTokenizerFast, - ElectraTokenizer, - ElectraTokenizerFast, - RobertaTokenizer, - RobertaTokenizerFast, - XLMRobertaTokenizer, - XLMRobertaTokenizerFast, - XLNetTokenizer, - XLNetTokenizerFast, - CamembertTokenizer, - CamembertTokenizerFast, - DPRContextEncoderTokenizer, - DPRContextEncoderTokenizerFast, - DPRQuestionEncoderTokenizer, - DPRQuestionEncoderTokenizerFast, - BigBirdTokenizer, - BigBirdTokenizerFast, - DebertaV2Tokenizer, - DebertaV2TokenizerFast, -) -from transformers import AutoConfig +from transformers import AutoTokenizer, PreTrainedTokenizer, RobertaTokenizer +from haystack.errors import ModelingError from haystack.modeling.data_handler.samples import SampleBasket logger = logging.getLogger(__name__) -# Special characters used by the different tokenizers to indicate start of word / whitespace +#: Special characters used by the different tokenizers to indicate start of word / whitespace SPECIAL_TOKENIZER_CHARS = r"^(##|Ġ|▁)" -# TODO analyse if tokenizers can be completely used through HF transformers -class Tokenizer: + +def get_tokenizer( + pretrained_model_name_or_path: str, + revision: str = None, + use_fast: bool = True, + use_auth_token: Optional[Union[str, bool]] = None, + **kwargs, +) -> PreTrainedTokenizer: """ - Simple Wrapper for Tokenizers from the transformers package. Enables loading of different Tokenizer classes with a uniform interface. + Enables loading of different Tokenizer classes with a uniform interface. + Right now it always returns an instance of `AutoTokenizer`. + + :param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`) + :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. + :param use_fast: Indicate if Haystack should try to load the fast version of the tokenizer (True) or use the Python one (False). Defaults to True. + :param use_auth_token: The auth_token to use in `PretrainedTokenizer.from_pretrained()`, or False + :param kwargs: other kwargs to pass on to `PretrainedTokenizer.from_pretrained()` + :return: AutoTokenizer instance """ + model_name_or_path = str(pretrained_model_name_or_path) - @classmethod - def load( - cls, - pretrained_model_name_or_path, - revision=None, - tokenizer_class=None, - use_fast=True, - use_auth_token: Union[bool, str] = None, - **kwargs, - ): - """ - Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from - model config or define it manually via `tokenizer_class`. - - :param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`) - :type pretrained_model_name_or_path: str - :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. - :type revision: str - :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`) - :type tokenizer_class: str - :param use_fast: (Optional, False by default) Indicate if Haystack should try to load the fast version of the tokenizer (True) or - use the Python one (False). - Only DistilBERT, BERT and Electra fast tokenizers are supported. - :type use_fast: bool - :param kwargs: - :return: Tokenizer - """ - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - kwargs["revision"] = revision - - if tokenizer_class is None: - tokenizer_class = cls._infer_tokenizer_class(pretrained_model_name_or_path, use_auth_token=use_auth_token) - - logger.debug(f"Loading tokenizer of type '{tokenizer_class}'") - # return appropriate tokenizer object - ret = None - if "AutoTokenizer" in tokenizer_class: - ret = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=use_fast, **kwargs) - elif "AlbertTokenizer" in tokenizer_class: - if use_fast: - ret = AlbertTokenizerFast.from_pretrained( - pretrained_model_name_or_path, keep_accents=True, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = AlbertTokenizer.from_pretrained( - pretrained_model_name_or_path, keep_accents=True, use_auth_token=use_auth_token, **kwargs - ) - elif "XLMRobertaTokenizer" in tokenizer_class: - if use_fast: - ret = XLMRobertaTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = XLMRobertaTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "RobertaTokenizer" in tokenizer_class: - if use_fast: - ret = RobertaTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = RobertaTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "DistilBertTokenizer" in tokenizer_class: - if use_fast: - ret = DistilBertTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = DistilBertTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "BertTokenizer" in tokenizer_class: - if use_fast: - ret = BertTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = BertTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "XLNetTokenizer" in tokenizer_class: - if use_fast: - ret = XLNetTokenizerFast.from_pretrained( - pretrained_model_name_or_path, keep_accents=True, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = XLNetTokenizer.from_pretrained( - pretrained_model_name_or_path, keep_accents=True, use_auth_token=use_auth_token, **kwargs - ) - elif "ElectraTokenizer" in tokenizer_class: - if use_fast: - ret = ElectraTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = ElectraTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "CamembertTokenizer" in tokenizer_class: - if use_fast: - ret = CamembertTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = CamembertTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "DPRQuestionEncoderTokenizer" in tokenizer_class: - if use_fast: - ret = DPRQuestionEncoderTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = DPRQuestionEncoderTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "DPRContextEncoderTokenizer" in tokenizer_class: - if use_fast: - ret = DPRContextEncoderTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = DPRContextEncoderTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "BigBirdTokenizer" in tokenizer_class: - if use_fast: - ret = BigBirdTokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = BigBirdTokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - elif "DebertaV2Tokenizer" in tokenizer_class: - if use_fast: - ret = DebertaV2TokenizerFast.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - else: - ret = DebertaV2Tokenizer.from_pretrained( - pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs - ) - if ret is None: - raise Exception("Unable to load tokenizer") - return ret - - @staticmethod - def _infer_tokenizer_class(pretrained_model_name_or_path, use_auth_token: Union[bool, str] = None): - # Infer Tokenizer from model type in config - try: - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token) - except OSError: - # Haystack model (no 'config.json' file) - try: - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path + "/language_model_config.json", use_auth_token=use_auth_token - ) - except Exception as e: - logger.warning("No config file found. Trying to infer Tokenizer type from model name") - tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path) - return tokenizer_class - - model_type = config.model_type - - if model_type == "xlm-roberta": - tokenizer_class = "XLMRobertaTokenizer" - elif model_type == "roberta": - if "mlm" in pretrained_model_name_or_path.lower(): - raise NotImplementedError("MLM part of codebert is currently not supported in Haystack") - tokenizer_class = "RobertaTokenizer" - elif model_type == "camembert": - tokenizer_class = "CamembertTokenizer" - elif model_type == "albert": - tokenizer_class = "AlbertTokenizer" - elif model_type == "distilbert": - tokenizer_class = "DistilBertTokenizer" - elif model_type == "bert": - tokenizer_class = "BertTokenizer" - elif model_type == "xlnet": - tokenizer_class = "XLNetTokenizer" - elif model_type == "electra": - tokenizer_class = "ElectraTokenizer" - elif model_type == "dpr": - if config.architectures[0] == "DPRQuestionEncoder": - tokenizer_class = "DPRQuestionEncoderTokenizer" - elif config.architectures[0] == "DPRContextEncoder": - tokenizer_class = "DPRContextEncoderTokenizer" - elif config.architectures[0] == "DPRReader": - raise NotImplementedError("DPRReader models are currently not supported.") - elif model_type == "big_bird": - tokenizer_class = "BigBirdTokenizer" - elif model_type == "deberta-v2": - tokenizer_class = "DebertaV2Tokenizer" - else: - # Fall back to inferring type from model name - logger.warning( - "Could not infer Tokenizer type from config. Trying to infer Tokenizer type from model name." - ) - tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path) - - return tokenizer_class - - @staticmethod - def _infer_tokenizer_class_from_string(pretrained_model_name_or_path): - # If inferring tokenizer class from config doesn't succeed, - # fall back to inferring tokenizer class from model name. - if "albert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "AlbertTokenizer" - elif "bigbird" in pretrained_model_name_or_path.lower(): - tokenizer_class = "BigBirdTokenizer" - elif "xlm-roberta" in pretrained_model_name_or_path.lower(): - tokenizer_class = "XLMRobertaTokenizer" - elif "roberta" in pretrained_model_name_or_path.lower(): - tokenizer_class = "RobertaTokenizer" - elif "codebert" in pretrained_model_name_or_path.lower(): - if "mlm" in pretrained_model_name_or_path.lower(): - raise NotImplementedError("MLM part of codebert is currently not supported in Haystack") - tokenizer_class = "RobertaTokenizer" - elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path.lower(): - tokenizer_class = "CamembertTokenizer" - elif "distilbert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DistilBertTokenizer" - elif ( - "debertav2" in pretrained_model_name_or_path.lower() or "debertav3" in pretrained_model_name_or_path.lower() - ): - tokenizer_class = "DebertaV2Tokenizer" - elif "bert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "BertTokenizer" - elif "xlnet" in pretrained_model_name_or_path.lower(): - tokenizer_class = "XLNetTokenizer" - elif "electra" in pretrained_model_name_or_path.lower(): - tokenizer_class = "ElectraTokenizer" - elif "minilm" in pretrained_model_name_or_path.lower(): - tokenizer_class = "BertTokenizer" - elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DPRQuestionEncoderTokenizer" - elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DPRContextEncoderTokenizer" - else: - tokenizer_class = "AutoTokenizer" + if "mlm" in model_name_or_path.lower(): + logging.error("MLM part of codebert is currently not supported in Haystack. Proceed at your own risk.") + + params = {} + if any(tokenizer_type in model_name_or_path for tokenizer_type in ["albert", "xlnet"]): + params["keep_accents"] = True - return tokenizer_class + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=model_name_or_path, + revision=revision, + use_fast=use_fast, + use_auth_token=use_auth_token, + **params, + **kwargs, + ) -def tokenize_batch_question_answering(pre_baskets, tokenizer, indices): +def tokenize_batch_question_answering( + pre_baskets: List[Dict[str, Any]], tokenizer: PreTrainedTokenizer, indices: List[Any] +) -> List[SampleBasket]: """ Tokenizes text data for question answering tasks. Tokenization means splitting words into subwords, depending on the tokenizer's vocabulary. @@ -322,16 +79,20 @@ def tokenize_batch_question_answering(pre_baskets, tokenizer, indices): - Then we tokenize each question individually - We construct dicts with question and corresponding document text + tokens + offsets + ids - :param pre_baskets: input dicts with QA info #todo change to input objects + :param pre_baskets: input dicts with QA info #TODO change to input objects :param tokenizer: tokenizer to be used - :param indices: list, indices used during multiprocessing so that IDs assigned to our baskets are unique + :param indices: indices used during multiprocessing so that IDs assigned to our baskets are unique :return: baskets, list containing question and corresponding document information """ - assert len(indices) == len(pre_baskets) - assert tokenizer.is_fast, ( - "Processing QA data is only supported with fast tokenizers for now.\n" - "Please load Tokenizers with 'use_fast=True' option." - ) + if not len(indices) == len(pre_baskets): + raise ValueError("indices and pre_baskets must have the same length") + + if not tokenizer.is_fast: + raise ModelingError( + "Processing QA data is only supported with fast tokenizers for now." + "Please load Tokenizers with 'use_fast=True' option." + ) + baskets = [] # # Tokenize texts in batch mode texts = [d["context"] for d in pre_baskets] @@ -385,80 +146,13 @@ def tokenize_batch_question_answering(pre_baskets, tokenizer, indices): def _get_start_of_word_QA(word_ids): - words = np.array(word_ids) - start_of_word_single = [1] + list(np.ediff1d(words)) - return start_of_word_single - - -def tokenize_with_metadata(text: str, tokenizer) -> Dict[str, Any]: - """ - Performing tokenization while storing some important metadata for each token: - - * offsets: (int) Character index where the token begins in the original text - * start_of_word: (bool) If the token is the start of a word. Particularly helpful for NER and QA tasks. - - We do this by first doing whitespace tokenization and then applying the model specific tokenizer to each "word". - - .. note:: We don't assume to preserve exact whitespaces in the tokens! - This means: tabs, new lines, multiple whitespace etc will all resolve to a single " ". - This doesn't make a difference for BERT + XLNet but it does for RoBERTa. - For RoBERTa it has the positive effect of a shorter sequence length, but some information about whitespace - type is lost which might be helpful for certain NLP tasks ( e.g tab for tables). - - :param text: Text to tokenize - :param tokenizer: Tokenizer (e.g. from Tokenizer.load()) - :return: Dictionary with "tokens", "offsets" and "start_of_word" - """ - # normalize all other whitespace characters to " " - # Note: using text.split() directly would destroy the offset, - # since \n\n\n would be treated similarly as a single \n - text = re.sub(r"\s", " ", text) - # Fast Tokenizers return offsets, so we don't need to calculate them ourselves - if tokenizer.is_fast: - # tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True) - tokenized2 = tokenizer.encode_plus(text, return_offsets_mapping=True, return_special_tokens_mask=True) - - tokens2 = tokenized2["input_ids"] - offsets2 = np.array([x[0] for x in tokenized2["offset_mapping"]]) - # offsets2 = [x[0] for x in tokenized2["offset_mapping"]] - words = np.array(tokenized2.encodings[0].words) - - # TODO check for validity for all tokenizer and special token types - words[0] = -1 - words[-1] = words[-2] - words += 1 - start_of_word2 = [0] + list(np.ediff1d(words)) - ####### - - # start_of_word3 = [] - # last_word = -1 - # for word_id in tokenized2.encodings[0].words: - # if word_id is None or word_id == last_word: - # start_of_word3.append(0) - # else: - # start_of_word3.append(1) - # last_word = word_id - - tokenized_dict = {"tokens": tokens2, "offsets": offsets2, "start_of_word": start_of_word2} - else: - # split text into "words" (here: simple whitespace tokenizer). - words = text.split(" ") - word_offsets = [] - cumulated = 0 - for idx, word in enumerate(words): - word_offsets.append(cumulated) - cumulated += len(word) + 1 # 1 because we so far have whitespace tokenizer - - # split "words" into "subword tokens" - tokens, offsets, start_of_word = _words_to_tokens(words, word_offsets, tokenizer) - tokenized_dict = {"tokens": tokens, "offsets": offsets, "start_of_word": start_of_word} - return tokenized_dict + return [1] + list(np.ediff1d(np.array(word_ids))) def truncate_sequences( seq_a: list, seq_b: Optional[list], - tokenizer, + tokenizer: AutoTokenizer, max_seq_len: int, truncation_strategy: str = "longest_first", with_special_tokens: bool = True, @@ -467,21 +161,27 @@ def truncate_sequences( """ Reduces a single sequence or a pair of sequences to a maximum sequence length. The sequences can contain tokens or any other elements (offsets, masks ...). - If `with_special_tokens` is enabled, it'll remove some additional tokens to have exactly enough space for later adding special tokens (CLS, SEP etc.) + If `with_special_tokens` is enabled, it'll remove some additional tokens to have exactly + enough space for later adding special tokens (CLS, SEP etc.) Supported truncation strategies: - - longest_first: (default) Iteratively reduce the inputs sequence until the input is under max_length starting from the longest one at each token (when there is a pair of input sequences). Overflowing tokens only contains overflow from the first sequence. - - only_first: Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. + - longest_first: (default) Iteratively reduce the inputs sequence until the input is under + max_length starting from the longest one at each token (when there is a pair of input sequences). + Overflowing tokens only contains overflow from the first sequence. + - only_first: Only truncate the first sequence. raise an error if the first sequence is + shorter or equal to than num_tokens_to_remove. - only_second: Only truncate the second sequence - do_not_truncate: Does not truncate (raise an error if the input sequence is longer than max_length) :param seq_a: First sequence of tokens/offsets/... :param seq_b: Optional second sequence of tokens/offsets/... - :param tokenizer: Tokenizer (e.g. from Tokenizer.load()) + :param tokenizer: Tokenizer (e.g. from get_tokenizer)) :param max_seq_len: - :param truncation_strategy: how the sequence(s) should be truncated down. Default: "longest_first" (see above for other options). - :param with_special_tokens: If true, it'll remove some additional tokens to have exactly enough space for later adding special tokens (CLS, SEP etc.) + :param truncation_strategy: how the sequence(s) should be truncated down. + Default: "longest_first" (see above for other options). + :param with_special_tokens: If true, it'll remove some additional tokens to have exactly enough space + for later adding special tokens (CLS, SEP etc.) :param stride: optional stride of the window during truncation :return: truncated seq_a, truncated seq_b, overflowing tokens """ @@ -503,59 +203,119 @@ def truncate_sequences( return (seq_a, seq_b, overflowing_tokens) -def _words_to_tokens(words, word_offsets, tokenizer): +# +# FIXME this is a relic from FARM. If there's the occasion, remove it! +# +def tokenize_with_metadata(text: str, tokenizer: PreTrainedTokenizer) -> Dict[str, Any]: + """ + Performing tokenization while storing some important metadata for each token: + + * offsets: (int) Character index where the token begins in the original text + * start_of_word: (bool) If the token is the start of a word. Particularly helpful for NER and QA tasks. + + We do this by first doing whitespace tokenization and then applying the model specific tokenizer to each "word". + + .. note:: We don't assume to preserve exact whitespaces in the tokens! + This means: tabs, new lines, multiple whitespace etc will all resolve to a single " ". + This doesn't make a difference for BERT + XLNet but it does for RoBERTa. + For RoBERTa it has the positive effect of a shorter sequence length, but some information about whitespace + type is lost which might be helpful for certain NLP tasks ( e.g tab for tables). + + :param text: Text to tokenize + :param tokenizer: Tokenizer (e.g. from get_tokenizer)) + :return: Dictionary with "tokens", "offsets" and "start_of_word" + """ + # normalize all other whitespace characters to " " + # Note: using text.split() directly would destroy the offset, + # since \n\n\n would be treated similarly as a single \n + text = re.sub(r"\s", " ", text) + + words: Union[List[str], np.ndarray] = [] + word_offsets: Union[List[int], np.ndarray] = [] + start_of_word: List[Union[int, bool]] = [] + + # Fast Tokenizers return offsets, so we don't need to calculate them ourselves + if tokenizer.is_fast: + # tokenized = tokenizer(text, return_offsets_mapping=True, return_special_tokens_mask=True) + tokenized = tokenizer.encode_plus(text, return_offsets_mapping=True, return_special_tokens_mask=True) + + tokens = tokenized["input_ids"] + offsets = np.array([x[0] for x in tokenized["offset_mapping"]]) + # offsets2 = [x[0] for x in tokenized2["offset_mapping"]] + words = np.array(tokenized.encodings[0].words) + + # TODO check for validity for all tokenizer and special token types + words[0] = -1 + words[-1] = words[-2] + words += 1 + start_of_word = [0] + list(np.ediff1d(words)) + return {"tokens": tokens, "offsets": offsets, "start_of_word": start_of_word} + + # split text into "words" (here: simple whitespace tokenizer). + words = text.split(" ") + cumulated = 0 + for word in words: + word_offsets.append(cumulated) + cumulated += len(word) + 1 # 1 because we so far have whitespace tokenizer + + # split "words" into "subword tokens" + tokens, offsets, start_of_word = _words_to_tokens(words, word_offsets, tokenizer) # type: ignore + return {"tokens": tokens, "offsets": offsets, "start_of_word": start_of_word} + + +# Note: only used by tokenize_with_metadata() +def _words_to_tokens( + words: List[str], word_offsets: List[int], tokenizer: PreTrainedTokenizer +) -> Tuple[List[str], List[int], List[bool]]: """ Tokenize "words" into subword tokens while keeping track of offsets and if a token is the start of a word. :param words: list of words. - :type words: list :param word_offsets: Character indices where each word begins in the original text - :type word_offsets: list - :param tokenizer: Tokenizer (e.g. from Tokenizer.load()) - :return: tokens, offsets, start_of_word + :param tokenizer: Tokenizer (e.g. from get_tokenizer)) + :return: Tuple of (tokens, offsets, start_of_word) """ - tokens = [] - token_offsets = [] - start_of_word = [] - idx = 0 - for w, w_off in zip(words, word_offsets): - idx += 1 - if idx % 500000 == 0: - logger.info(idx) + tokens: List[str] = [] + token_offsets: List[int] = [] + start_of_word: List[bool] = [] + index = 0 + for index, (word, word_offset) in enumerate(zip(words, word_offsets)): + if index % 500000 == 0: + logger.info(index) # Get (subword) tokens of single word. # empty / pure whitespace - if len(w) == 0: + if len(word) == 0: continue # For the first word of a text: we just call the regular tokenize function. # For later words: we need to call it with add_prefix_space=True to get the same results with roberta / gpt2 tokenizer # see discussion here. https://github.com/huggingface/transformers/issues/1196 if len(tokens) == 0: - tokens_word = tokenizer.tokenize(w) + tokens_word = tokenizer.tokenize(word) else: if type(tokenizer) == RobertaTokenizer: - tokens_word = tokenizer.tokenize(w, add_prefix_space=True) + tokens_word = tokenizer.tokenize(word, add_prefix_space=True) else: - tokens_word = tokenizer.tokenize(w) + tokens_word = tokenizer.tokenize(word) # Sometimes the tokenizer returns no tokens if len(tokens_word) == 0: continue tokens += tokens_word # get global offset for each token in word + save marker for first tokens of a word - first_tok = True - for tok in tokens_word: - token_offsets.append(w_off) + first_token = True + for token in tokens_word: + token_offsets.append(word_offset) # Depending on the tokenizer type special chars are added to distinguish tokens with preceeding # whitespace (=> "start of a word"). We need to get rid of these to calculate the original length of the token - orig_tok = re.sub(SPECIAL_TOKENIZER_CHARS, "", tok) + original_token = re.sub(SPECIAL_TOKENIZER_CHARS, "", token) # Don't use length of unk token for offset calculation - if orig_tok == tokenizer.special_tokens_map["unk_token"]: - w_off += 1 + if original_token == tokenizer.special_tokens_map["unk_token"]: + word_offset += 1 else: - w_off += len(orig_tok) - if first_tok: + word_offset += len(original_token) + if first_token: start_of_word.append(True) - first_tok = False + first_token = False else: start_of_word.append(False) diff --git a/haystack/modeling/model/triadaptive_model.py b/haystack/modeling/model/triadaptive_model.py index 9d3e8cfe63..9a76dab0d3 100644 --- a/haystack/modeling/model/triadaptive_model.py +++ b/haystack/modeling/model/triadaptive_model.py @@ -7,7 +7,7 @@ from torch import nn from haystack.modeling.data_handler.processor import Processor -from haystack.modeling.model.language_model import LanguageModel +from haystack.modeling.model.language_model import get_language_model, LanguageModel from haystack.modeling.model.prediction_head import PredictionHead from haystack.utils.experiment_tracking import Tracker as tracker @@ -87,11 +87,11 @@ def __init__( super(TriAdaptiveModel, self).__init__() self.device = device self.language_model1 = language_model1.to(device) - self.lm1_output_dims = language_model1.get_output_dims() + self.lm1_output_dims = language_model1.output_dims self.language_model2 = language_model2.to(device) - self.lm2_output_dims = language_model2.get_output_dims() + self.lm2_output_dims = language_model2.output_dims self.language_model3 = language_model3.to(device) - self.lm3_output_dims = language_model3.get_output_dims() + self.lm3_output_dims = language_model3.output_dims self.dropout1 = nn.Dropout(embeds_dropout_prob) self.dropout2 = nn.Dropout(embeds_dropout_prob) self.dropout3 = nn.Dropout(embeds_dropout_prob) @@ -165,17 +165,17 @@ def load( """ # Language Model if lm1_name: - language_model1 = LanguageModel.load(os.path.join(load_dir, lm1_name)) + language_model1 = get_language_model(os.path.join(load_dir, lm1_name)) else: - language_model1 = LanguageModel.load(load_dir) + language_model1 = get_language_model(load_dir) if lm2_name: - language_model2 = LanguageModel.load(os.path.join(load_dir, lm2_name)) + language_model2 = get_language_model(os.path.join(load_dir, lm2_name)) else: - language_model2 = LanguageModel.load(load_dir) + language_model2 = get_language_model(load_dir) if lm3_name: - language_model3 = LanguageModel.load(os.path.join(load_dir, lm3_name)) + language_model3 = get_language_model(os.path.join(load_dir, lm3_name)) else: - language_model3 = LanguageModel.load(load_dir) + language_model3 = get_language_model(load_dir) # Prediction heads ph_config_files = cls._get_prediction_head_files(load_dir) @@ -294,19 +294,30 @@ def forward_lm(self, **kwargs): pooled_output = [None, None] # Forward pass for the queries if "query_input_ids" in kwargs.keys(): - pooled_output1, hidden_states1 = self.language_model1(**kwargs) + pooled_output1, _ = self.language_model1( + input_ids=kwargs.get("query_input_ids"), + segment_ids=kwargs.get("query_segment_ids"), + attention_mask=kwargs.get("query_attention_mask"), + output_hidden_states=False, + output_attentions=False, + ) pooled_output[0] = pooled_output1 + # Forward pass for text passages and tables if "passage_input_ids" in kwargs.keys(): table_mask = torch.flatten(kwargs["is_table"]) == True + # Current batch consists of only tables if all(table_mask): - pooled_output2, hidden_states2 = self.language_model3( + pooled_output2, _ = self.language_model3( passage_input_ids=kwargs["passage_input_ids"], passage_segment_ids=kwargs["table_segment_ids"], passage_attention_mask=kwargs["passage_attention_mask"], + output_hidden_states=False, + output_attentions=False, ) pooled_output[1] = pooled_output2 + # Current batch consists of tables and texts elif any(table_mask): @@ -320,17 +331,31 @@ def forward_lm(self, **kwargs): table_input_ids = passage_input_ids[table_mask] table_segment_ids = table_segment_ids[table_mask] table_attention_mask = passage_attention_mask[table_mask] - pooled_output_tables, _ = self.language_model3(table_input_ids, table_segment_ids, table_attention_mask) + + pooled_output_tables, _ = self.language_model3( + input_ids=table_input_ids, + segment_ids=table_segment_ids, + attention_mask=table_attention_mask, + output_hidden_states=False, + output_attentions=False, + ) text_input_ids = passage_input_ids[~table_mask] text_segment_ids = passage_segment_ids[~table_mask] text_attention_mask = passage_attention_mask[~table_mask] - pooled_output_text, _ = self.language_model2(text_input_ids, text_segment_ids, text_attention_mask) + + pooled_output_text, _ = self.language_model2( + input_ids=text_input_ids, + segment_ids=text_segment_ids, + attention_mask=text_attention_mask, + output_hidden_states=False, + output_attentions=False, + ) last_table_idx = 0 last_text_idx = 0 combined_outputs = [] - for idx, mask in enumerate(table_mask): + for mask in table_mask: if mask: combined_outputs.append(pooled_output_tables[last_table_idx]) last_table_idx += 1 @@ -345,9 +370,22 @@ def forward_lm(self, **kwargs): ), "Passage embedding model and table embedding model use different embedding sizes" pooled_output_combined = combined_outputs.view(-1, embedding_size) pooled_output[1] = pooled_output_combined + # Current batch consists of only texts else: - pooled_output2, hidden_states2 = self.language_model2(**kwargs) + # Make input two-dimensional + max_seq_len = kwargs["passage_input_ids"].shape[-1] + input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len) + attention_mask = kwargs["passage_attention_mask"].view(-1, max_seq_len) + segment_ids = kwargs["passage_segment_ids"].view(-1, max_seq_len) + + pooled_output2, _ = self.language_model2( + input_ids=input_ids, + attention_mask=attention_mask, + segment_ids=segment_ids, + output_hidden_states=False, + output_attentions=False, + ) pooled_output[1] = pooled_output2 return tuple(pooled_output) @@ -382,7 +420,7 @@ def verify_vocab_size(self, vocab_size1: int, vocab_size2: int, vocab_size3: int msg = ( f"Vocab size of tokenizer {vocab_size1} doesn't match with model {model1_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size1 == model1_vocab_len, msg @@ -391,7 +429,7 @@ def verify_vocab_size(self, vocab_size1: int, vocab_size2: int, vocab_size3: int msg = ( f"Vocab size of tokenizer {vocab_size1} doesn't match with model {model2_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size2 == model2_vocab_len, msg @@ -400,7 +438,7 @@ def verify_vocab_size(self, vocab_size1: int, vocab_size2: int, vocab_size3: int msg = ( f"Vocab size of tokenizer {vocab_size3} doesn't match with model {model3_vocab_len}. " "If you added a custom vocabulary to the tokenizer, " - "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" + "make sure to supply 'n_added_tokens' to get_language_model() and BertStyleLM.load()" ) assert vocab_size3 == model1_vocab_len, msg diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index 67a126c2fd..c8ab06ce79 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -18,7 +18,6 @@ from haystack.modeling.evaluation.eval import Evaluator from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.optimization import get_scheduler -from haystack.modeling.model.language_model import DebertaV2 from haystack.modeling.utils import GracefulKiller from haystack.utils.experiment_tracking import Tracker as tracker @@ -251,8 +250,8 @@ def train(self): vocab_size1=len(self.data_silo.processor.query_tokenizer), vocab_size2=len(self.data_silo.processor.passage_tokenizer), ) - elif not isinstance( - self.model.language_model, DebertaV2 + elif ( + self.model.language_model.name != "debertav2" ): # DebertaV2 has mismatched vocab size on purpose (see https://github.com/huggingface/transformers/issues/12428) self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer)) self.model.train() @@ -767,7 +766,15 @@ def compute_loss(self, batch: dict, step: int) -> torch.Tensor: keys = list(batch.keys()) keys = [key for key in keys if key.startswith("teacher_output")] teacher_logits = [batch.pop(key) for key in keys] - logits = self.model.forward(**batch) + + logits = self.model.forward( + input_ids=batch.get("input_ids"), + segment_ids=batch.get("segment_ids"), + padding_mask=batch.get("padding_mask"), + output_hidden_states=batch.get("output_hidden_states"), + output_attentions=batch.get("output_attentions"), + ) + student_loss = self.model.logits_to_loss(logits=logits, global_step=self.global_step, **batch) distillation_loss = self.distillation_loss_fn( student_logits=logits[0] / self.temperature, teacher_logits=teacher_logits[0] / self.temperature @@ -899,7 +906,16 @@ def __init__( self.loss = DataParallel(self.loss).to(device) def compute_loss(self, batch: dict, step: int) -> torch.Tensor: - return self.backward_propagate(torch.sum(self.loss(batch)), step) + return self.backward_propagate( + torch.sum( + self.loss( + input_ids=batch.get("input_ids"), + segment_ids=batch.get("segment_ids"), + padding_mask=batch.get("padding_mask"), + ) + ), + step, + ) class DistillationLoss(Module): @@ -945,14 +961,23 @@ def __init__(self, model: Union[DataParallel, AdaptiveModel], teacher_model: Mod else: self.dim_mappings.append(None) - def forward(self, batch): + def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor): with torch.no_grad(): _, teacher_hidden_states, teacher_attentions = self.teacher_model.forward( - **batch, output_attentions=True, output_hidden_states=True + input_ids=input_ids, + segment_ids=segment_ids, + padding_mask=padding_mask, + output_attentions=True, + output_hidden_states=True, ) - - _, hidden_states, attentions = self.model.forward(**batch, output_attentions=True, output_hidden_states=True) - loss = torch.tensor(0.0, device=batch["input_ids"].device) + _, hidden_states, attentions = self.model.forward( + input_ids=input_ids, + segment_ids=segment_ids, + padding_mask=padding_mask, + output_attentions=True, + output_hidden_states=True, + ) + loss = torch.tensor(0.0, device=input_ids.device) # calculating attention loss for student_attention, teacher_attention, dim_mapping in zip( diff --git a/haystack/modeling/visual.py b/haystack/modeling/visual.py index e45f4d7786..d2084bdc5e 100644 --- a/haystack/modeling/visual.py +++ b/haystack/modeling/visual.py @@ -91,7 +91,7 @@ """ WORKER_M = r""" 0 -/|\ +/|\ /'\ """ WORKER_F = r""" 0 diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index b608a39301..f7c7312e60 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -13,15 +13,20 @@ from torch.utils.data.sampler import SequentialSampler import pandas as pd from huggingface_hub import hf_hub_download -from transformers import AutoConfig +from transformers import ( + AutoConfig, + DPRContextEncoderTokenizerFast, + DPRQuestionEncoderTokenizerFast, + DPRContextEncoderTokenizer, + DPRQuestionEncoderTokenizer, +) from haystack.errors import HaystackError from haystack.schema import Document from haystack.document_stores import BaseDocumentStore from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever._embedding_encoder import _EMBEDDING_ENCODERS -from haystack.modeling.model.tokenization import Tokenizer -from haystack.modeling.model.language_model import LanguageModel +from haystack.modeling.model.language_model import get_language_model from haystack.modeling.model.biadaptive_model import BiAdaptiveModel from haystack.modeling.model.triadaptive_model import TriAdaptiveModel from haystack.modeling.model.prediction_head import TextSimilarityHead @@ -57,7 +62,6 @@ def __init__( batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, - infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, @@ -102,8 +106,6 @@ def __init__( before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. :param use_fast_tokenizers: Whether to use fast Rust tokenizers - :param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name. - If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. :param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training. Options: `dot_product` (Default) or `cosine` :param global_loss_buffer_size: Buffer size for all_gather() in DDP. @@ -151,40 +153,26 @@ def __init__( "This can be set when initializing the DocumentStore" ) - self.infer_tokenizer_classes = infer_tokenizer_classes - tokenizers_default_classes = {"query": "DPRQuestionEncoderTokenizer", "passage": "DPRContextEncoderTokenizer"} - if self.infer_tokenizer_classes: - tokenizers_default_classes["query"] = None # type: ignore - tokenizers_default_classes["passage"] = None # type: ignore - # Init & Load Encoders - self.query_tokenizer = Tokenizer.load( + self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path=query_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["query"], use_auth_token=use_auth_token, ) - self.query_encoder = LanguageModel.load( - pretrained_model_name_or_path=query_embedding_model, - revision=model_version, - language_model_class="DPRQuestionEncoder", - use_auth_token=use_auth_token, + self.query_encoder = get_language_model( + pretrained_model_name_or_path=query_embedding_model, revision=model_version, use_auth_token=use_auth_token ) - self.passage_tokenizer = Tokenizer.load( + self.passage_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path=passage_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["passage"], use_auth_token=use_auth_token, ) - self.passage_encoder = LanguageModel.load( - pretrained_model_name_or_path=passage_embedding_model, - revision=model_version, - language_model_class="DPRContextEncoder", - use_auth_token=use_auth_token, + self.passage_encoder = get_language_model( + pretrained_model_name_or_path=passage_embedding_model, revision=model_version, use_auth_token=use_auth_token ) self.processor = TextSimilarityProcessor( @@ -493,12 +481,19 @@ def _get_predictions(self, dicts): leave=False, disable=disable_tqdm, ) as progress_bar: - for batch in data_loader: - batch = {key: batch[key].to(self.devices[0]) for key in batch} + for raw_batch in data_loader: + batch = {key: raw_batch[key].to(self.devices[0]) for key in raw_batch} # get logits with torch.no_grad(): - query_embeddings, passage_embeddings = self.model.forward(**batch)[0] + query_embeddings, passage_embeddings = self.model.forward( + query_input_ids=batch.get("query_input_ids", None), + query_segment_ids=batch.get("query_segment_ids", None), + query_attention_mask=batch.get("query_attention_mask", None), + passage_input_ids=batch.get("passage_input_ids", None), + passage_segment_ids=batch.get("passage_segment_ids", None), + passage_attention_mask=batch.get("passage_attention_mask", None), + )[0] if query_embeddings is not None: all_embeddings["query"].append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: @@ -550,7 +545,6 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: for d in docs ] embeddings = self._get_predictions(passages)["passages"] - return embeddings def train( @@ -726,7 +720,6 @@ def load( similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", - infer_tokenizer_classes: bool = False, ): """ Load DensePassageRetriever from the specified directory. @@ -743,7 +736,6 @@ def load( embed_title=embed_title, use_fast_tokenizers=use_fast_tokenizers, similarity_function=similarity_function, - infer_tokenizer_classes=infer_tokenizer_classes, ) logger.info(f"DPR model loaded from {load_dir}") @@ -774,13 +766,13 @@ def __init__( batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, - infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, + use_fast: bool = True, ): """ Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -805,8 +797,6 @@ def __init__( performance if your titles contain meaningful information for retrieval (topic, entities etc.). :param use_fast_tokenizers: Whether to use fast Rust tokenizers - :param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name. - If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. :param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training. Options: `dot_product` (Default) or `cosine` :param global_loss_buffer_size: Buffer size for all_gather() in DDP. @@ -824,6 +814,7 @@ def __init__( :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + :param use_fast: Whether to use the fast version of DPR tokenizers or fallback to the standard version. Defaults to True. """ super().__init__() @@ -855,59 +846,40 @@ def __init__( "This can be set when initializing the DocumentStore" ) - self.infer_tokenizer_classes = infer_tokenizer_classes - tokenizers_default_classes = { - "query": "DPRQuestionEncoderTokenizer", - "passage": "DPRContextEncoderTokenizer", - "table": "DPRContextEncoderTokenizer", - } - if self.infer_tokenizer_classes: - tokenizers_default_classes["query"] = None # type: ignore - tokenizers_default_classes["passage"] = None # type: ignore - tokenizers_default_classes["table"] = None # type: ignore + query_tokenizer_class = DPRQuestionEncoderTokenizerFast if use_fast else DPRQuestionEncoderTokenizer + passage_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer + table_tokenizer_class = DPRContextEncoderTokenizerFast if use_fast else DPRContextEncoderTokenizer # Init & Load Encoders - self.query_tokenizer = Tokenizer.load( - pretrained_model_name_or_path=query_embedding_model, + self.query_tokenizer = query_tokenizer_class.from_pretrained( + query_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["query"], use_auth_token=use_auth_token, ) - self.query_encoder = LanguageModel.load( - pretrained_model_name_or_path=query_embedding_model, - revision=model_version, - language_model_class="DPRQuestionEncoder", - use_auth_token=use_auth_token, + self.query_encoder = get_language_model( + pretrained_model_name_or_path=query_embedding_model, revision=model_version, use_auth_token=use_auth_token ) - self.passage_tokenizer = Tokenizer.load( - pretrained_model_name_or_path=passage_embedding_model, + self.passage_tokenizer = passage_tokenizer_class.from_pretrained( + passage_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["passage"], use_auth_token=use_auth_token, ) - self.passage_encoder = LanguageModel.load( - pretrained_model_name_or_path=passage_embedding_model, - revision=model_version, - language_model_class="DPRContextEncoder", - use_auth_token=use_auth_token, + self.passage_encoder = get_language_model( + pretrained_model_name_or_path=passage_embedding_model, revision=model_version, use_auth_token=use_auth_token ) - self.table_tokenizer = Tokenizer.load( - pretrained_model_name_or_path=table_embedding_model, + self.table_tokenizer = table_tokenizer_class.from_pretrained( + table_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["table"], use_auth_token=use_auth_token, ) - self.table_encoder = LanguageModel.load( - pretrained_model_name_or_path=table_embedding_model, - revision=model_version, - language_model_class="DPRContextEncoder", - use_auth_token=use_auth_token, + self.table_encoder = get_language_model( + pretrained_model_name_or_path=table_embedding_model, revision=model_version, use_auth_token=use_auth_token ) self.processor = TableTextSimilarityProcessor( @@ -1419,7 +1391,6 @@ def load( query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", table_encoder_dir: str = "table_encoder", - infer_tokenizer_classes: bool = False, ): """ Load TableTextRetriever from the specified directory. @@ -1439,7 +1410,6 @@ def load( embed_meta_fields=embed_meta_fields, use_fast_tokenizers=use_fast_tokenizers, similarity_function=similarity_function, - infer_tokenizer_classes=infer_tokenizer_classes, ) logger.info(f"TableTextRetriever model loaded from {load_dir}") diff --git a/test/modeling/test_modeling_dpr.py b/test/modeling/test_dpr.py similarity index 86% rename from test/modeling/test_modeling_dpr.py rename to test/modeling/test_dpr.py index c6a30c0212..af1cf0e91a 100644 --- a/test/modeling/test_modeling_dpr.py +++ b/test/modeling/test_dpr.py @@ -1,3 +1,6 @@ +from typing import Tuple + +import os import logging from pathlib import Path @@ -6,13 +9,14 @@ import torch from torch.utils.data import SequentialSampler from tqdm import tqdm +from transformers import DPRQuestionEncoder from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.processor import TextSimilarityProcessor from haystack.modeling.model.biadaptive_model import BiAdaptiveModel -from haystack.modeling.model.language_model import LanguageModel, DPRContextEncoder, DPRQuestionEncoder +from haystack.modeling.model.language_model import get_language_model, DPREncoder from haystack.modeling.model.prediction_head import TextSimilarityHead -from haystack.modeling.model.tokenization import Tokenizer +from haystack.modeling.model.tokenization import get_tokenizer from haystack.modeling.utils import set_all_seeds, initialize_device_settings @@ -24,10 +28,10 @@ def test_dpr_modules(caplog=None): devices, n_gpu = initialize_device_settings(use_cuda=True) # 1.Create question and passage tokenizers - query_tokenizer = Tokenizer.load( + query_tokenizer = get_tokenizer( pretrained_model_name_or_path="facebook/dpr-question_encoder-single-nq-base", do_lower_case=True, use_fast=True ) - passage_tokenizer = Tokenizer.load( + passage_tokenizer = get_tokenizer( pretrained_model_name_or_path="facebook/dpr-ctx_encoder-single-nq-base", do_lower_case=True, use_fast=True ) @@ -46,17 +50,15 @@ def test_dpr_modules(caplog=None): num_hard_negatives=1, ) - question_language_model = LanguageModel.load( + question_language_model = DPREncoder( pretrained_model_name_or_path="bert-base-uncased", - language_model_class="DPRQuestionEncoder", - hidden_dropout_prob=0, - attention_probs_dropout_prob=0, + model_type="DPRQuestionEncoder", + model_kwargs={"hidden_dropout_prob": 0, "attention_probs_dropout_prob": 0}, ) - passage_language_model = LanguageModel.load( + passage_language_model = DPREncoder( pretrained_model_name_or_path="bert-base-uncased", - language_model_class="DPRContextEncoder", - hidden_dropout_prob=0, - attention_probs_dropout_prob=0, + model_type="DPRContextEncoder", + model_kwargs={"hidden_dropout_prob": 0, "attention_probs_dropout_prob": 0}, ) prediction_head = TextSimilarityHead(similarity_function="dot_product") @@ -75,8 +77,8 @@ def test_dpr_modules(caplog=None): assert type(model) == BiAdaptiveModel assert type(processor) == TextSimilarityProcessor - assert type(question_language_model) == DPRQuestionEncoder - assert type(passage_language_model) == DPRContextEncoder + assert type(question_language_model) == DPREncoder + assert type(passage_language_model) == DPREncoder # check embedding layer weights assert list(model.named_parameters())[0][1][0, 0].item() - -0.010200000368058681 < 0.0001 @@ -131,9 +133,17 @@ def test_dpr_modules(caplog=None): torch.eq(features["passage_attention_mask"][0][1].nonzero().cpu().squeeze(), torch.tensor(list(range(143)))) ) + features_query = {key.replace("query_", ""): value for key, value in features.items() if key.startswith("query_")} + features_passage = { + key.replace("passage_", ""): value for key, value in features.items() if key.startswith("passage_") + } + max_seq_len = features_passage.get("input_ids").shape[-1] + features_passage = {key: value.view(-1, max_seq_len) for key, value in features_passage.items()} + # test model encodings - query_vector = model.language_model1(**features)[0] - passage_vector = model.language_model2(**features)[0] + query_vector = model.language_model1(**features_query)[0] + passage_vector = model.language_model2(**features_passage)[0] + assert torch.all( torch.le( query_vector[0, :10].cpu() @@ -157,7 +167,14 @@ def test_dpr_modules(caplog=None): ) # test logits and loss - embeddings = model(**features) + embeddings = model( + query_input_ids=features.get("query_input_ids", None), + query_segment_ids=features.get("query_segment_ids", None), + query_attention_mask=features.get("query_attention_mask", None), + passage_input_ids=features.get("passage_input_ids", None), + passage_segment_ids=features.get("passage_segment_ids", None), + passage_attention_mask=features.get("passage_attention_mask", None), + ) query_emb, passage_emb = embeddings[0] assert torch.all(torch.eq(query_emb.cpu(), query_vector.cpu())) assert torch.all(torch.eq(passage_emb.cpu(), passage_vector.cpu())) @@ -343,9 +360,9 @@ def test_dpr_processor(embed_title, passage_ids, passage_attns, use_fast, num_ha ] query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=use_fast) + query_tokenizer = get_tokenizer(query_tok, use_fast=use_fast) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=use_fast) + passage_tokenizer = get_tokenizer(passage_tok, use_fast=use_fast) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -400,9 +417,9 @@ def test_dpr_processor_empty_title(use_fast, embed_title): } query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=use_fast) + query_tokenizer = get_tokenizer(query_tok, use_fast=use_fast) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=use_fast) + passage_tokenizer = get_tokenizer(passage_tok, use_fast=use_fast) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -485,9 +502,9 @@ def test_dpr_problematic(): ] query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=True) + query_tokenizer = get_tokenizer(query_tok) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=True) + passage_tokenizer = get_tokenizer(passage_tok) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -516,9 +533,9 @@ def test_dpr_query_only(): ] query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=True) + query_tokenizer = get_tokenizer(query_tok) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=True) + passage_tokenizer = get_tokenizer(passage_tok) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -578,9 +595,9 @@ def test_dpr_context_only(): ] query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=True) + query_tokenizer = get_tokenizer(query_tok) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=True) + passage_tokenizer = get_tokenizer(passage_tok) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -629,9 +646,9 @@ def test_dpr_processor_save_load(tmp_path): } query_tok = "facebook/dpr-question_encoder-single-nq-base" - query_tokenizer = Tokenizer.load(query_tok, use_fast=True) + query_tokenizer = get_tokenizer(query_tok) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - passage_tokenizer = Tokenizer.load(passage_tok, use_fast=True) + passage_tokenizer = get_tokenizer(passage_tok) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer, @@ -646,9 +663,10 @@ def test_dpr_processor_save_load(tmp_path): metric="text_similarity_metric", shuffle_negatives=False, ) - processor.save(save_dir=f"{tmp_path}/testsave/dpr_processor") + save_dir = f"{tmp_path}/testsave/dpr_processor" + processor.save(save_dir=save_dir) dataset, tensor_names, _ = processor.dataset_from_dicts(dicts=[d], return_baskets=False) - loadedprocessor = TextSimilarityProcessor.load_from_dir(load_dir=f"{tmp_path}/testsave/dpr_processor") + loadedprocessor = TextSimilarityProcessor.load_from_dir(load_dir=save_dir) dataset2, tensor_names, _ = loadedprocessor.dataset_from_dicts(dicts=[d], return_baskets=False) assert np.array_equal(dataset.tensors[0], dataset2.tensors[0]) @@ -667,7 +685,7 @@ def test_dpr_processor_save_load(tmp_path): {"query": "facebook/dpr-question_encoder-single-nq-base", "passage": "facebook/dpr-ctx_encoder-single-nq-base"}, ], ) -def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_model): +def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_passage_model: Tuple[str, str]): """ This test compares 1) a model that was loaded from model hub with 2) a model from model hub that was saved to disk and then loaded from disk and @@ -679,7 +697,24 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ "passages": [ { "title": "Etalab", - "text": "Etalab est une administration publique française qui fait notamment office de Chief Data Officer de l'État et coordonne la conception et la mise en œuvre de sa stratégie dans le domaine de la donnée (ouverture et partage des données publiques ou open data, exploitation des données et intelligence artificielle...). Ainsi, Etalab développe et maintient le portail des données ouvertes du gouvernement français data.gouv.fr. Etalab promeut également une plus grande ouverture l'administration sur la société (gouvernement ouvert) : transparence de l'action publique, innovation ouverte, participation citoyenne... elle promeut l’innovation, l’expérimentation, les méthodes de travail ouvertes, agiles et itératives, ainsi que les synergies avec la société civile pour décloisonner l’administration et favoriser l’adoption des meilleures pratiques professionnelles dans le domaine du numérique. À ce titre elle étudie notamment l’opportunité de recourir à des technologies en voie de maturation issues du monde de la recherche. Cette entité chargée de l'innovation au sein de l'administration doit contribuer à l'amélioration du service public grâce au numérique. Elle est rattachée à la Direction interministérielle du numérique, dont les missions et l’organisation ont été fixées par le décret du 30 octobre 2019.  Dirigé par Laure Lucchesi depuis 2016, elle rassemble une équipe pluridisciplinaire d'une trentaine de personnes.", + "text": "Etalab est une administration publique française qui fait notamment office " + "de Chief Data Officer de l'État et coordonne la conception et la mise en œuvre " + "de sa stratégie dans le domaine de la donnée (ouverture et partage des données " + "publiques ou open data, exploitation des données et intelligence artificielle...). " + "Ainsi, Etalab développe et maintient le portail des données ouvertes du gouvernement " + "français data.gouv.fr. Etalab promeut également une plus grande ouverture " + "l'administration sur la société (gouvernement ouvert) : transparence de l'action " + "publique, innovation ouverte, participation citoyenne... elle promeut l’innovation, " + "l’expérimentation, les méthodes de travail ouvertes, agiles et itératives, ainsi que " + "les synergies avec la société civile pour décloisonner l’administration et favoriser " + "l’adoption des meilleures pratiques professionnelles dans le domaine du numérique. " + "À ce titre elle étudie notamment l’opportunité de recourir à des technologies en voie " + "de maturation issues du monde de la recherche. Cette entité chargée de l'innovation " + "au sein de l'administration doit contribuer à l'amélioration du service public grâce " + "au numérique. Elle est rattachée à la Direction interministérielle du numérique, dont " + "les missions et l’organisation ont été fixées par le décret du 30 octobre 2019.  Dirigé " + "par Laure Lucchesi depuis 2016, elle rassemble une équipe pluridisciplinaire d'une " + "trentaine de personnes.", "label": "positive", "external_id": "1", } @@ -689,16 +724,12 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # load model from model hub query_embedding_model = query_and_passage_model["query"] passage_embedding_model = query_and_passage_model["passage"] - query_tokenizer = Tokenizer.load( + query_tokenizer = get_tokenizer( pretrained_model_name_or_path=query_embedding_model ) # tokenizer class is inferred automatically - query_encoder = LanguageModel.load( - pretrained_model_name_or_path=query_embedding_model, language_model_class="DPRQuestionEncoder" - ) - passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model) - passage_encoder = LanguageModel.load( - pretrained_model_name_or_path=passage_embedding_model, language_model_class="DPRContextEncoder" - ) + query_encoder = get_language_model(pretrained_model_name_or_path=query_embedding_model) + passage_tokenizer = get_tokenizer(pretrained_model_name_or_path=passage_embedding_model) + passage_encoder = get_language_model(pretrained_model_name_or_path=passage_embedding_model) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, @@ -737,18 +768,14 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}") # load model from disk - loaded_query_tokenizer = Tokenizer.load( + loaded_query_tokenizer = get_tokenizer( pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir, use_fast=True ) # tokenizer class is inferred automatically - loaded_query_encoder = LanguageModel.load( - pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir, language_model_class="DPRQuestionEncoder" - ) - loaded_passage_tokenizer = Tokenizer.load( + loaded_query_encoder = get_language_model(pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir) + loaded_passage_tokenizer = get_tokenizer( pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir, use_fast=True ) - loaded_passage_encoder = LanguageModel.load( - pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir, language_model_class="DPRContextEncoder" - ) + loaded_passage_encoder = get_language_model(pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir) loaded_processor = TextSimilarityProcessor( query_tokenizer=loaded_query_tokenizer, @@ -794,12 +821,19 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ all_embeddings = {"query": [], "passages": []} model.eval() - for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=True)): + for batch in tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=True): batch = {key: batch[key].to(device) for key in batch} # get logits with torch.no_grad(): - query_embeddings, passage_embeddings = model.forward(**batch)[0] + query_embeddings, passage_embeddings = model.forward( + query_input_ids=batch.get("query_input_ids", None), + query_segment_ids=batch.get("query_segment_ids", None), + query_attention_mask=batch.get("query_attention_mask", None), + passage_input_ids=batch.get("passage_input_ids", None), + passage_segment_ids=batch.get("passage_segment_ids", None), + passage_attention_mask=batch.get("passage_attention_mask", None), + )[0] if query_embeddings is not None: all_embeddings["query"].append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: @@ -826,7 +860,14 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # get logits with torch.no_grad(): - query_embeddings, passage_embeddings = loaded_model.forward(**batch)[0] + query_embeddings, passage_embeddings = loaded_model.forward( + query_input_ids=batch.get("query_input_ids", None), + query_segment_ids=batch.get("query_segment_ids", None), + query_attention_mask=batch.get("query_attention_mask", None), + passage_input_ids=batch.get("passage_input_ids", None), + passage_segment_ids=batch.get("passage_segment_ids", None), + passage_attention_mask=batch.get("passage_attention_mask", None), + )[0] if query_embeddings is not None: all_embeddings2["query"].append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: @@ -849,16 +890,12 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ loaded_passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}") # load model from disk - query_tokenizer = Tokenizer.load( + query_tokenizer = get_tokenizer( pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir ) # tokenizer class is inferred automatically - query_encoder = LanguageModel.load( - pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir, language_model_class="DPRQuestionEncoder" - ) - passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir) - passage_encoder = LanguageModel.load( - pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir, language_model_class="DPRContextEncoder" - ) + query_encoder = get_language_model(pretrained_model_name_or_path=Path(save_dir) / query_encoder_dir) + passage_tokenizer = get_tokenizer(pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir) + passage_encoder = get_language_model(pretrained_model_name_or_path=Path(save_dir) / passage_encoder_dir) processor = TextSimilarityProcessor( query_tokenizer=query_tokenizer, @@ -910,7 +947,14 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # get logits with torch.no_grad(): - query_embeddings, passage_embeddings = loaded_model.forward(**batch)[0] + query_embeddings, passage_embeddings = loaded_model.forward( + query_input_ids=batch.get("query_input_ids", None), + query_segment_ids=batch.get("query_segment_ids", None), + query_attention_mask=batch.get("query_attention_mask", None), + passage_input_ids=batch.get("passage_input_ids", None), + passage_segment_ids=batch.get("passage_segment_ids", None), + passage_attention_mask=batch.get("passage_attention_mask", None), + )[0] if query_embeddings is not None: all_embeddings3["query"].append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: @@ -942,9 +986,9 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # # device, n_gpu = initialize_device_settings(use_cuda=False) # -# query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model, +# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model, # do_lower_case=do_lower_case, use_fast=use_fast) -# passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model, +# passage_tokenizer = get_tokenizer(pretrained_model_name_or_path=passage_lang_model, # do_lower_case=do_lower_case, use_fast=use_fast) # label_list = ["hard_negative", "positive"] # @@ -965,9 +1009,9 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # # data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False) # -# question_language_model = LanguageModel.load(pretrained_model_name_or_path=question_lang_model, +# question_language_model = get_language_model(pretrained_model_name_or_path=question_lang_model, # language_model_class="DPRQuestionEncoder") -# passage_language_model = LanguageModel.load(pretrained_model_name_or_path=passage_lang_model, +# passage_language_model = get_language_model(pretrained_model_name_or_path=passage_lang_model, # language_model_class="DPRContextEncoder") # # prediction_head = TextSimilarityHead(similarity_function=similarity_function) @@ -1038,9 +1082,3 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path, query_and_passage_ # ) # # trainer2.train() - - -if __name__ == "__main__": - # test_dpr_training() - test_dpr_context_only() - # test_dpr_modules() diff --git a/test/modeling/test_modeling_inference.py b/test/modeling/test_inference.py similarity index 100% rename from test/modeling/test_modeling_inference.py rename to test/modeling/test_inference.py diff --git a/test/modeling/test_language.py b/test/modeling/test_language.py new file mode 100644 index 0000000000..844f2302b7 --- /dev/null +++ b/test/modeling/test_language.py @@ -0,0 +1,34 @@ +import pytest + +from haystack.modeling.model.language_model import get_language_model + + +@pytest.mark.parametrize( + "pretrained_model_name_or_path, lm_class", + [ + ("google/bert_uncased_L-2_H-128_A-2", "HFLanguageModel"), + ("google/electra-small-generator", "HFLanguageModelWithPooler"), + ("distilbert-base-uncased", "HFLanguageModelNoSegmentIds"), + ("deepset/bert-small-mm_retrieval-passage_encoder", "DPREncoder"), + ], +) +def test_basic_loading(pretrained_model_name_or_path, lm_class): + lm = get_language_model(pretrained_model_name_or_path) + mod = __import__("haystack.modeling.model.language_model", fromlist=[lm_class]) + klass = getattr(mod, lm_class) + assert isinstance(lm, klass) + + +def test_basic_loading_unknown_model(): + with pytest.raises(OSError): + get_language_model("model_that_doesnt_exist") + + +def test_basic_loading_with_empty_string(): + with pytest.raises(ValueError): + get_language_model("") + + +def test_basic_loading_invalid_params(): + with pytest.raises(ValueError): + get_language_model(None) diff --git a/test/modeling/test_modeling_prediction_head.py b/test/modeling/test_prediction_head.py similarity index 87% rename from test/modeling/test_modeling_prediction_head.py rename to test/modeling/test_prediction_head.py index e607bce7cc..368afc5022 100644 --- a/test/modeling/test_modeling_prediction_head.py +++ b/test/modeling/test_prediction_head.py @@ -1,7 +1,7 @@ import logging from haystack.modeling.model.adaptive_model import AdaptiveModel -from haystack.modeling.model.language_model import LanguageModel +from haystack.modeling.model.language_model import get_language_model from haystack.modeling.model.prediction_head import QuestionAnsweringHead from haystack.modeling.utils import set_all_seeds, initialize_device_settings @@ -14,7 +14,7 @@ def test_prediction_head_load_save(tmp_path, caplog=None): devices, n_gpu = initialize_device_settings(use_cuda=False) lang_model = "bert-base-german-cased" - language_model = LanguageModel.load(lang_model) + language_model = get_language_model(lang_model) prediction_head = QuestionAnsweringHead() model = AdaptiveModel( diff --git a/test/modeling/test_modeling_processor.py b/test/modeling/test_processor.py similarity index 98% rename from test/modeling/test_modeling_processor.py rename to test/modeling/test_processor.py index 8744aeb6cb..79308d80f8 100644 --- a/test/modeling/test_modeling_processor.py +++ b/test/modeling/test_processor.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer from haystack.modeling.data_handler.processor import SquadProcessor -from haystack.modeling.model.tokenization import Tokenizer +from haystack.modeling.model.tokenization import get_tokenizer from ..conftest import SAMPLES_PATH @@ -24,7 +24,7 @@ def test_dataset_from_dicts_qa_inference(caplog=None): sample_types = ["answer-wrong", "answer-offset-wrong", "noanswer", "vanilla"] for model in models: - tokenizer = Tokenizer.load(pretrained_model_name_or_path=model, use_fast=True) + tokenizer = get_tokenizer(pretrained_model_name_or_path=model) processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None) for sample_type in sample_types: @@ -251,7 +251,7 @@ def test_dataset_from_dicts_qa_labelconversion(caplog=None): sample_types = ["answer-wrong", "answer-offset-wrong", "noanswer", "vanilla"] for model in models: - tokenizer = Tokenizer.load(pretrained_model_name_or_path=model, use_fast=True) + tokenizer = get_tokenizer(pretrained_model_name_or_path=model) processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None) for sample_type in sample_types: diff --git a/test/modeling/test_modeling_processor_saving_loading.py b/test/modeling/test_processor_save_load.py similarity index 89% rename from test/modeling/test_modeling_processor_saving_loading.py rename to test/modeling/test_processor_save_load.py index 8972422364..154b303f70 100644 --- a/test/modeling/test_modeling_processor_saving_loading.py +++ b/test/modeling/test_processor_save_load.py @@ -2,7 +2,7 @@ from pathlib import Path from haystack.modeling.data_handler.processor import SquadProcessor -from haystack.modeling.model.tokenization import Tokenizer +from haystack.modeling.model.tokenization import get_tokenizer from haystack.modeling.utils import set_all_seeds import torch @@ -16,7 +16,7 @@ def test_processor_saving_loading(tmp_path, caplog): set_all_seeds(seed=42) lang_model = "roberta-base" - tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model, do_lower_case=False) + tokenizer = get_tokenizer(pretrained_model_name_or_path=lang_model, do_lower_case=False) processor = SquadProcessor( tokenizer=tokenizer, diff --git a/test/modeling/test_modeling_question_answering.py b/test/modeling/test_question_answering.py similarity index 100% rename from test/modeling/test_modeling_question_answering.py rename to test/modeling/test_question_answering.py diff --git a/test/modeling/test_tokenization.py b/test/modeling/test_tokenization.py index 486b338f77..5758eeedec 100644 --- a/test/modeling/test_tokenization.py +++ b/test/modeling/test_tokenization.py @@ -1,500 +1,325 @@ -import logging -import pytest +from typing import Tuple + import re -from transformers import ( - BertTokenizer, - BertTokenizerFast, - RobertaTokenizer, - RobertaTokenizerFast, - XLNetTokenizer, - XLNetTokenizerFast, - ElectraTokenizerFast, -) + +import pytest +import numpy as np +from unittest.mock import MagicMock from tokenizers.pre_tokenizers import WhitespaceSplit -from haystack.modeling.model.tokenization import Tokenizer +import haystack +from haystack.modeling.model.tokenization import get_tokenizer -import numpy as np + +BERT = "bert-base-cased" +ROBERTA = "roberta-base" +XLNET = "xlnet-base-cased" + +TOKENIZERS_TO_TEST = [BERT, ROBERTA, XLNET] +TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER = [(BERT, "##"), (ROBERTA, "Ġ"), (XLNET, "▁")] -TEXTS = [ - "This is a sentence", - "Der entscheidende Pass", - "This is a sentence with multiple spaces", - "力加勝北区ᴵᴺᵀᵃছজটডণত", - "Thiso text is included tolod makelio sure Unicodeel is handled properly:", - "This is a sentence...", - "Let's see all on this text and. !23# neverseenwordspossible", - """This is a sentence. - With linebreak""", - """Sentence with multiple +REGULAR_SENTENCE = "This is a sentence" +GERMAN_SENTENCE = "Der entscheidende Pass" +OTHER_ALPHABETS = "力加勝北区ᴵᴺᵀᵃছজটডণত" +GIBBERISH_SENTENCE = "Thiso text is included tolod makelio sure Unicodeel is handled properly:" +SENTENCE_WITH_ELLIPSIS = "This is a sentence..." +SENTENCE_WITH_LINEBREAK_1 = "and another one\n\n\nwithout space" +SENTENCE_WITH_LINEBREAK_2 = """This is a sentence. + With linebreak""" +SENTENCE_WITH_LINEBREAKS = """Sentence + with + multiple newlines - """, - "and another one\n\n\nwithout space", - "This is a sentence with tab", - "This is a sentence with multiple tabs", -] - - -def test_basic_loading(caplog): - caplog.set_level(logging.CRITICAL) - # slow tokenizers - tokenizer = Tokenizer.load(pretrained_model_name_or_path="bert-base-cased", do_lower_case=True, use_fast=False) - assert type(tokenizer) == BertTokenizer - assert tokenizer.basic_tokenizer.do_lower_case == True - - tokenizer = Tokenizer.load(pretrained_model_name_or_path="xlnet-base-cased", do_lower_case=True, use_fast=False) - assert type(tokenizer) == XLNetTokenizer - assert tokenizer.do_lower_case == True - - tokenizer = Tokenizer.load(pretrained_model_name_or_path="roberta-base", use_fast=False) - assert type(tokenizer) == RobertaTokenizer - - # fast tokenizers - tokenizer = Tokenizer.load(pretrained_model_name_or_path="bert-base-cased", do_lower_case=True) - assert type(tokenizer) == BertTokenizerFast - assert tokenizer.do_lower_case == True - - tokenizer = Tokenizer.load(pretrained_model_name_or_path="xlnet-base-cased", do_lower_case=True) - assert type(tokenizer) == XLNetTokenizerFast - assert tokenizer.do_lower_case == True - - tokenizer = Tokenizer.load(pretrained_model_name_or_path="roberta-base") - assert type(tokenizer) == RobertaTokenizerFast - - -def test_bert_tokenizer_all_meta(caplog): - caplog.set_level(logging.CRITICAL) - - lang_model = "bert-base-cased" - - tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model, do_lower_case=False) - - basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" - - tokenized = tokenizer.tokenize(basic_text) - assert tokenized == [ - "Some", - "Text", - "with", - "never", - "##see", - "##nto", - "##ken", - "##s", - "plus", - "!", - "215", - "?", - "#", - ".", - "and", - "a", - "combined", - "-", - "token", - "_", - "with", - "/", - "ch", - "##ars", - ] + """ +SENTENCE_WITH_EXCESS_WHITESPACE = "This is a sentence with multiple spaces" +SENTENCE_WITH_TABS = "This is a sentence with multiple tabs" +SENTENCE_WITH_CUSTOM_TOKEN = "Let's see all on this text and. !23# neverseenwordspossible" - encoded_batch = tokenizer.encode_plus(basic_text) - encoded = encoded_batch.encodings[0] - words = np.array(encoded.words) - words[words == None] = -1 - start_of_word_single = [False] + list(np.ediff1d(words) > 0) - assert encoded.tokens == [ - "[CLS]", - "Some", - "Text", - "with", - "never", - "##see", - "##nto", - "##ken", - "##s", - "plus", - "!", - "215", - "?", - "#", - ".", - "and", - "a", - "combined", - "-", - "token", - "_", - "with", - "/", - "ch", - "##ars", - "[SEP]", - ] - assert [x[0] for x in encoded.offsets] == [ - 0, - 0, - 5, - 10, - 15, - 20, - 23, - 26, - 29, - 31, - 36, - 37, - 40, - 41, - 42, - 44, - 48, - 50, - 58, - 59, - 64, - 65, - 69, - 70, - 72, - 0, - ] - assert start_of_word_single == [ - False, - True, - True, - True, - True, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - False, - ] +class AutoTokenizer: + mocker: MagicMock = MagicMock() -def test_save_load(tmp_path, caplog): - caplog.set_level(logging.CRITICAL) - - lang_names = ["bert-base-cased", "roberta-base", "xlnet-base-cased"] - tokenizers = [] - for lang_name in lang_names: - if "xlnet" in lang_name.lower(): - t = Tokenizer.load(lang_name, lower_case=False, use_fast=True, from_slow=True) - else: - t = Tokenizer.load(lang_name, lower_case=False) - t.add_tokens(new_tokens=["neverseentokens"]) - tokenizers.append(t) - - basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" - - for tokenizer in tokenizers: - tokenizer_type = tokenizer.__class__.__name__ - save_dir = f"{tmp_path}/testsave/{tokenizer_type}" - tokenizer.save_pretrained(save_dir) - tokenizer_loaded = Tokenizer.load(save_dir, tokenizer_class=tokenizer_type) - encoded_before = tokenizer.encode_plus(basic_text).encodings[0] - encoded_after = tokenizer_loaded.encode_plus(basic_text).encodings[0] - data_before = { - "tokens": encoded_before.tokens, - "offsets": encoded_before.offsets, - "words": encoded_before.words, - } - data_after = {"tokens": encoded_after.tokens, "offsets": encoded_after.offsets, "words": encoded_after.words} - assert data_before == data_after - - -@pytest.mark.parametrize("model_name", ["bert-base-german-cased", "google/electra-small-discriminator"]) -def test_fast_tokenizer_with_examples(caplog, model_name): - fast_tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=True) - tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=False) - - for text in TEXTS: - # plain tokenize function - tokenized = tokenizer.tokenize(text) - fast_tokenized = fast_tokenizer.tokenize(text) - - assert tokenized == fast_tokenized - - -def test_all_tokenizer_on_special_cases(caplog): - caplog.set_level(logging.CRITICAL) - - lang_names = ["bert-base-cased", "roberta-base", "xlnet-base-cased"] - - tokenizers = [] - for lang_name in lang_names: - if "roberta" in lang_name: - add_prefix_space = True - else: - add_prefix_space = False - t = Tokenizer.load(lang_name, lower_case=False, add_prefix_space=add_prefix_space) - tokenizers.append(t) - - texts = [ - "This is a sentence", - "Der entscheidende Pass", - "力加勝北区ᴵᴺᵀᵃছজটডণত", - "Thiso text is included tolod makelio sure Unicodeel is handled properly:", - "This is a sentence...", - "Let's see all on this text and. !23# neverseenwordspossible" "This is a sentence with multiple spaces", - """This is a sentence. - With linebreak""", - """Sentence with multiple - newlines - """, - "and another one\n\n\nwithout space", - "This is a sentence with multiple tabs", - ] + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.mocker.from_pretrained(*args, **kwargs) + return cls() - expected_to_fail = {(2, 1), (2, 5)} - - for i_tok, tokenizer in enumerate(tokenizers): - for i_text, text in enumerate(texts): - # Important: we don't assume to preserve whitespaces after tokenization. - # This means: \t, \n " " etc will all resolve to a single " ". - # This doesn't make a difference for BERT + XLNet but it does for roBERTa - - test_passed = True - - # 1. original tokenize function from transformer repo on full sentence - standardized_whitespace_text = " ".join(text.split()) # remove multiple whitespaces - tokenized = tokenizer.tokenize(standardized_whitespace_text) - - # 2. Our tokenization method using a pretokenizer which can normalize multiple white spaces - # This approach is used in NER - pre_tokenizer = WhitespaceSplit() - words_and_spans = pre_tokenizer.pre_tokenize_str(text) - words = [x[0] for x in words_and_spans] - word_spans = [x[1] for x in words_and_spans] - - encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] - - # verify that tokenization on full sequence is the same as the one on "whitespace tokenized words" - if encoded.tokens != tokenized: - test_passed = False - - # token offsets are originally relative to the beginning of the word - # These lines convert them so they are relative to the beginning of the sentence - token_offsets = [] - for ((start, end), w_index) in zip(encoded.offsets, encoded.words): - word_start_ch = word_spans[w_index][0] - token_offsets.append((start + word_start_ch, end + word_start_ch)) - - # verify that offsets align back to original text - if text == "力加勝北区ᴵᴺᵀᵃছজটডণত": - # contains [UNK] that are impossible to match back to original text space - continue - for tok, (start, end) in zip(encoded.tokens, token_offsets): - # subword-tokens have special chars depending on model type. In order to align with original text we need to get rid of them - tok = re.sub(r"^(##|Ġ|▁)", "", tok) - # tok = tokenizer.decode(tokenizer.convert_tokens_to_ids(tok)) - original_tok = text[start:end] - if tok != original_tok: - test_passed = False - if (i_tok, i_text) in expected_to_fail: - assert not test_passed, f"Behaviour of {tokenizer.__class__.__name__} has changed on text {text}'" - else: - assert test_passed, f"Behaviour of {tokenizer.__class__.__name__} has changed on text {text}'" - - -def test_bert_custom_vocab(caplog): - caplog.set_level(logging.CRITICAL) - - lang_model = "bert-base-cased" - - tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model, do_lower_case=False) - - # deprecated: tokenizer.add_custom_vocab("samples/tokenizer/custom_vocab.txt") - tokenizer.add_tokens(new_tokens=["neverseentokens"]) - basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" - - # original tokenizer from transformer repo - tokenized = tokenizer.tokenize(basic_text) - assert tokenized == [ - "Some", - "Text", - "with", - "neverseentokens", - "plus", - "!", - "215", - "?", - "#", - ".", - "and", - "a", - "combined", - "-", - "token", - "_", - "with", - "/", - "ch", - "##ars", - ] +@pytest.fixture(autouse=True) +def mock_autotokenizer(request, monkeypatch): + # Do not patch integration tests + if "integration" in request.keywords: + return + monkeypatch.setattr(haystack.modeling.model.tokenization, "AutoTokenizer", AutoTokenizer) + + +def convert_offset_from_word_reference_to_text_reference(offsets, words, word_spans): + """ + Token offsets are originally relative to the beginning of the word + We make them relative to the beginning of the sentence. + + Not a fixture, just a utility. + """ + token_offsets = [] + for ((start, end), word_index) in zip(offsets, words): + word_start = word_spans[word_index][0] + token_offsets.append((start + word_start, end + word_start)) + return token_offsets + + +# +# Unit tests +# + + +def test_get_tokenizer_str(): + tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name") + tokenizer.mocker.from_pretrained.assert_called_with( + pretrained_model_name_or_path="test-model-name", revision=None, use_fast=True, use_auth_token=None + ) - # ours with metadata - encoded = tokenizer.encode_plus(basic_text, add_special_tokens=False).encodings[0] - offsets = [x[0] for x in encoded.offsets] - start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0) - assert encoded.tokens == tokenized - assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72] - assert start_of_word_single == [ - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - ] +def test_get_tokenizer_path(tmp_path): + tokenizer = get_tokenizer(pretrained_model_name_or_path=tmp_path / "test-path") + tokenizer.mocker.from_pretrained.assert_called_with( + pretrained_model_name_or_path=str(tmp_path / "test-path"), revision=None, use_fast=True, use_auth_token=None + ) -def test_fast_bert_custom_vocab(caplog): - caplog.set_level(logging.CRITICAL) - lang_model = "bert-base-cased" +def test_get_tokenizer_keep_accents(): + tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name-albert") + tokenizer.mocker.from_pretrained.assert_called_with( + pretrained_model_name_or_path="test-model-name-albert", + revision=None, + use_fast=True, + use_auth_token=None, + keep_accents=True, + ) - tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model, do_lower_case=False, use_fast=True) - # deprecated: tokenizer.add_custom_vocab("samples/tokenizer/custom_vocab.txt") +def test_get_tokenizer_mlm_warning(caplog): + tokenizer = get_tokenizer(pretrained_model_name_or_path="test-model-name-mlm") + tokenizer.mocker.from_pretrained.assert_called_with( + pretrained_model_name_or_path="test-model-name-mlm", revision=None, use_fast=True, use_auth_token=None + ) + assert "MLM part of codebert is currently not supported in Haystack".lower() in caplog.text.lower() + + +# +# Integration tests +# + + +@pytest.mark.integration +@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST) +def test_save_load(tmp_path, model_name: str): + tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False) + text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" + tokenizer.add_tokens(new_tokens=["neverseentokens"]) + original_encoding = tokenizer.encode_plus(text) - basic_text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" - - # original tokenizer from transformer repo - tokenized = tokenizer.tokenize(basic_text) - assert tokenized == [ - "Some", - "Text", - "with", - "neverseentokens", - "plus", - "!", - "215", - "?", - "#", - ".", - "and", - "a", - "combined", - "-", - "token", - "_", - "with", - "/", - "ch", - "##ars", - ] + save_dir = tmp_path / "saved_tokenizer" + tokenizer.save_pretrained(save_dir) - # ours with metadata - encoded = tokenizer.encode_plus(basic_text, add_special_tokens=False).encodings[0] - offsets = [x[0] for x in encoded.offsets] - start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0) - assert encoded.tokens == tokenized - assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72] - assert start_of_word_single == [ - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - ] + tokenizer_loaded = get_tokenizer(pretrained_model_name_or_path=save_dir) + new_encoding = tokenizer_loaded.encode_plus(text) + assert original_encoding == new_encoding + + +@pytest.mark.integration +def test_tokenize_custom_vocab_bert(): + tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) + tokenizer.add_tokens(new_tokens=["neverseentokens"]) + text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" + tokenized = tokenizer.tokenize(text) + assert ( + tokenized == f"Some Text with neverseentokens plus ! 215 ? # . and a combined - token _ with / ch ##ars".split() + ) + + +@pytest.mark.integration @pytest.mark.parametrize( - "model_name, tokenizer_type", - [("bert-base-german-cased", BertTokenizerFast), ("google/electra-small-discriminator", ElectraTokenizerFast)], + "edge_case", + [ + REGULAR_SENTENCE, + OTHER_ALPHABETS, + GIBBERISH_SENTENCE, + SENTENCE_WITH_ELLIPSIS, + SENTENCE_WITH_LINEBREAK_1, + SENTENCE_WITH_LINEBREAK_2, + SENTENCE_WITH_LINEBREAKS, + SENTENCE_WITH_EXCESS_WHITESPACE, + SENTENCE_WITH_TABS, + ], ) -def test_fast_tokenizer_type(caplog, model_name, tokenizer_type): - caplog.set_level(logging.CRITICAL) +@pytest.mark.parametrize("model_name", TOKENIZERS_TO_TEST) +def test_tokenization_on_edge_cases_full_sequence_tokenization(model_name: str, edge_case: str): + """ + Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words" + """ + tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True) - tokenizer = Tokenizer.load(model_name, use_fast=True) - assert type(tokenizer) is tokenizer_type + pre_tokenizer = WhitespaceSplit() + words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case) + words = [x[0] for x in words_and_spans] + encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces -# See discussion in https://github.com/deepset-ai/FARM/pull/624 for reason to remove the test -# def test_fast_bert_tokenizer_strip_accents(caplog): -# caplog.set_level(logging.CRITICAL) -# -# tokenizer = Tokenizer.load("dbmdz/bert-base-german-uncased", -# use_fast=True, -# strip_accents=False) -# assert type(tokenizer) is BertTokenizerFast -# assert tokenizer.do_lower_case -# assert tokenizer._tokenizer._parameters['strip_accents'] is False + assert encoded.tokens == expected_tokenization + + +@pytest.mark.integration +@pytest.mark.parametrize("edge_case", [SENTENCE_WITH_CUSTOM_TOKEN, GERMAN_SENTENCE]) +@pytest.mark.parametrize("model_name", [t for t in TOKENIZERS_TO_TEST if t != ROBERTA]) +def test_tokenization_on_edge_cases_full_sequence_tokenization_roberta_exceptions(model_name: str, edge_case: str): + """ + Verify that tokenization on full sequence is the same as the one on "whitespace tokenized words". + These test cases work for all tokenizers under test except for RoBERTa. + """ + tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True) + pre_tokenizer = WhitespaceSplit() + words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case) + words = [x[0] for x in words_and_spans] -def test_fast_electra_tokenizer(caplog): - caplog.set_level(logging.CRITICAL) + encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + expected_tokenization = tokenizer.tokenize(" ".join(edge_case.split())) # remove multiple whitespaces - tokenizer = Tokenizer.load("dbmdz/electra-base-german-europeana-cased-discriminator", use_fast=True) - assert type(tokenizer) is ElectraTokenizerFast + assert encoded.tokens == expected_tokenization -@pytest.mark.parametrize("model_name", ["bert-base-cased", "distilbert-base-uncased", "deepset/electra-base-squad2"]) -def test_detokenization_in_fast_tokenizers(model_name): - tokenizer = Tokenizer.load(pretrained_model_name_or_path=model_name, use_fast=True) - for text in TEXTS: - encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0] +@pytest.mark.integration +@pytest.mark.parametrize( + "edge_case", + [ + REGULAR_SENTENCE, + # OTHER_ALPHABETS, # contains [UNK] that are impossible to match back to original text space + GIBBERISH_SENTENCE, + SENTENCE_WITH_ELLIPSIS, + SENTENCE_WITH_LINEBREAK_1, + SENTENCE_WITH_LINEBREAK_2, + SENTENCE_WITH_LINEBREAKS, + SENTENCE_WITH_EXCESS_WHITESPACE, + SENTENCE_WITH_TABS, + ], +) +@pytest.mark.parametrize("model_name,marker", TOKENIZERS_TO_TEST_WITH_TOKEN_MARKER) +def test_tokenization_on_edge_cases_full_sequence_verify_spans(model_name: str, marker: str, edge_case: str): + tokenizer = get_tokenizer(pretrained_model_name_or_path=model_name, do_lower_case=False, add_prefix_space=True) + + pre_tokenizer = WhitespaceSplit() + words_and_spans = pre_tokenizer.pre_tokenize_str(edge_case) + words = [x[0] for x in words_and_spans] + word_spans = [x[1] for x in words_and_spans] + + encoded = tokenizer.encode_plus(words, is_split_into_words=True, add_special_tokens=False).encodings[0] + + # subword-tokens have special chars depending on model type. To align with original text we get rid of them + tokens = [token.replace(marker, "") for token in encoded.tokens] + token_offsets = convert_offset_from_word_reference_to_text_reference(encoded.offsets, encoded.words, word_spans) + + for token, (start, end) in zip(tokens, token_offsets): + assert token == edge_case[start:end] + + +@pytest.mark.integration +@pytest.mark.parametrize( + "edge_case", + [ + REGULAR_SENTENCE, + GERMAN_SENTENCE, + SENTENCE_WITH_EXCESS_WHITESPACE, + OTHER_ALPHABETS, + GIBBERISH_SENTENCE, + SENTENCE_WITH_ELLIPSIS, + SENTENCE_WITH_CUSTOM_TOKEN, + SENTENCE_WITH_LINEBREAK_1, + SENTENCE_WITH_LINEBREAK_2, + SENTENCE_WITH_LINEBREAKS, + SENTENCE_WITH_TABS, + ], +) +def test_detokenization_for_bert(edge_case): + tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) - detokenized = " ".join(encoded.tokens) - detokenized = re.sub(r"(^|\s+)(##)", "", detokenized) + encoded = tokenizer.encode_plus(edge_case, add_special_tokens=False).encodings[0] - detokenized_ids = tokenizer(detokenized, add_special_tokens=False)["input_ids"] - detokenized_tokens = [tokenizer.decode([tok_id]).strip() for tok_id in detokenized_ids] + detokenized = " ".join(encoded.tokens) + detokenized = re.sub(r"(^|\s+)(##)", "", detokenized) - assert encoded.tokens == detokenized_tokens + detokenized_ids = tokenizer(detokenized, add_special_tokens=False)["input_ids"] + detokenized_tokens = [tokenizer.decode([tok_id]).strip() for tok_id in detokenized_ids] + assert encoded.tokens == detokenized_tokens -if __name__ == "__main__": - test_all_tokenizer_on_special_cases() + +@pytest.mark.integration +def test_encode_plus_for_bert(): + tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) + text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" + + encoded_batch = tokenizer.encode_plus(text) + encoded = encoded_batch.encodings[0] + + words = np.array(encoded.words) + words[0] = -1 + words[-1] = -1 + + print(words.tolist()) + + tokens = encoded.tokens + offsets = [x[0] for x in encoded.offsets] + start_of_word = [False] + list(np.ediff1d(words) > 0) + + assert list(zip(tokens, offsets, start_of_word)) == [ + ("[CLS]", 0, False), + ("Some", 0, True), + ("Text", 5, True), + ("with", 10, True), + ("never", 15, True), + ("##see", 20, False), + ("##nto", 23, False), + ("##ken", 26, False), + ("##s", 29, False), + ("plus", 31, True), + ("!", 36, True), + ("215", 37, True), + ("?", 40, True), + ("#", 41, True), + (".", 42, True), + ("and", 44, True), + ("a", 48, True), + ("combined", 50, True), + ("-", 58, True), + ("token", 59, True), + ("_", 64, True), + ("with", 65, True), + ("/", 69, True), + ("ch", 70, True), + ("##ars", 72, False), + ("[SEP]", 0, False), + ] + + +@pytest.mark.integration +def test_tokenize_custom_vocab_bert(): + tokenizer = get_tokenizer(pretrained_model_name_or_path=BERT, do_lower_case=False) + + tokenizer.add_tokens(new_tokens=["neverseentokens"]) + text = "Some Text with neverseentokens plus !215?#. and a combined-token_with/chars" + + tokenized = tokenizer.tokenize(text) + + encoded = tokenizer.encode_plus(text, add_special_tokens=False).encodings[0] + offsets = [x[0] for x in encoded.offsets] + start_of_word_single = [True] + list(np.ediff1d(encoded.words) > 0) + + assert encoded.tokens == tokenized + assert offsets == [0, 5, 10, 15, 31, 36, 37, 40, 41, 42, 44, 48, 50, 58, 59, 64, 65, 69, 70, 72] + assert start_of_word_single == [True] * 19 + [False] diff --git a/test/nodes/test_question_generator.py b/test/nodes/test_question_generator.py index 52a6712c64..1813c5be1c 100644 --- a/test/nodes/test_question_generator.py +++ b/test/nodes/test_question_generator.py @@ -1,10 +1,12 @@ +import pytest + from haystack.pipelines import ( QuestionAnswerGenerationPipeline, QuestionGenerationPipeline, RetrieverQuestionGenerationPipeline, ) +from haystack.nodes.question_generator import QuestionGenerator from haystack.schema import Document -import pytest text = 'The Living End are an Australian punk rockabilly band from Melbourne, formed in 1994. Since 2002, the line-up consists of Chris Cheney (vocals, guitar), Scott Owen (double bass, vocals), and Andy Strachan (drums). The band rose to fame in 1997 after the release of their EP Second Solution / Prisoner of Society, which peaked at No. 4 on the Australian ARIA Singles Chart. They have released eight studio albums, two of which reached the No. 1 spot on the ARIA Albums Chart: The Living End (October 1998) and State of Emergency (February 2006). They have also achieved chart success in the U.S. and the United Kingdom. The Band was nominated 27 times and won five awards at the Australian ARIA Music Awards ceremonies: "Highest Selling Single" for Second Solution / Prisoner of Society (1998), "Breakthrough Artist – Album" and "Best Group" for The Living End (1999), as well as "Best Rock Album" for White Noise (2008) and The Ending Is Just the Beginning Repeating (2011). In October 2010, their debut album was listed in the book "100 Best Australian Albums". Australian musicologist Ian McFarlane described the group as "one of Australia’s premier rock acts. By blending a range of styles (punk, rockabilly and flat out rock) with great success, The Living End has managed to produce anthemic choruses and memorable songs in abundance".' diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index cc3c5c4edb..c5081b1e5e 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -11,6 +11,7 @@ from elasticsearch import Elasticsearch from haystack.document_stores import WeaviateDocumentStore +from haystack.nodes.retriever.base import BaseRetriever from haystack.schema import Document from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore @@ -49,7 +50,7 @@ ], indirect=True, ) -def test_retrieval(retriever_with_docs, document_store_with_docs): +def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore): if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)): document_store_with_docs.update_embeddings(retriever_with_docs) @@ -344,9 +345,9 @@ def sum_params(model): def test_table_text_retriever_training(document_store): retriever = TableTextRetriever( document_store=document_store, - query_embedding_model="prajjwal1/bert-tiny", - passage_embedding_model="prajjwal1/bert-tiny", - table_embedding_model="prajjwal1/bert-tiny", + query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder", + passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder", + table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder", use_gpu=False, ) diff --git a/test/samples/squad/tiny_augmented.json b/test/samples/squad/tiny_augmented.json index 2c29add194..c906c383e8 100644 --- a/test/samples/squad/tiny_augmented.json +++ b/test/samples/squad/tiny_augmented.json @@ -1 +1 @@ -{"data": [{"title": "test1", "paragraphs": [{"context": "my name is carla \u2014 me danced together with abdul - berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my grandmother is baba and i met together with you ka jakarta", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my sister is carla & i live upstairs with friends boom berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "the name is harry and i worked together with friends in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "whose aunt is carla and i sang together paula abdul in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}]}, {"title": "test2", "paragraphs": [{"context": "suppose is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "what is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "where is the test for", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "suppose defines for test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "these constitutes a social that", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}]}], "topics": [{"title": "test1", "paragraphs": [{"context": "my name is carla \u2014 me danced together with abdul - berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my grandmother is baba and i met together with you ka jakarta", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my sister is carla & i live upstairs with friends boom berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "the name is harry and i worked together with friends in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "whose aunt is carla and i sang together paula abdul in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}]}, {"title": "test2", "paragraphs": [{"context": "suppose is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "what is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "where is the test for", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "suppose defines for test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "these constitutes a social that", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}]}]} \ No newline at end of file +{"data": [{"title": "test1", "paragraphs": [{"context": "maiden father is carla and i lives together with friends in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my dad is carla and i lived comfortably at abdul rahman manhattan", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my mum ... carla and maria perform exclusively with myself karim berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "last wife , carla because i live now beside abdul in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my name is carla and i live together with abdul hamid berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}]}, {"title": "test2", "paragraphs": [{"context": "this is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "thus is another test .", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "this is another mathematical context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "this is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "there is dynamic test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}]}], "topics": [{"title": "test1", "paragraphs": [{"context": "maiden father is carla and i lives together with friends in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my dad is carla and i lived comfortably at abdul rahman manhattan", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my mum ... carla and maria perform exclusively with myself karim berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "last wife , carla because i live now beside abdul in berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}, {"context": "my name is carla and i live together with abdul hamid berlin", "qas": [{"answers": [], "id": 7211011040021040393, "question": "Who lives in Berlin?", "is_impossible": false}]}]}, {"title": "test2", "paragraphs": [{"context": "this is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "thus is another test .", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "this is another mathematical context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "this is another test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}, {"context": "there is dynamic test context", "qas": [{"answers": [], "id": -5782547119306399562, "question": "The model can't answer this", "is_impossible": false}]}]}]} \ No newline at end of file