diff --git a/docs/source/tasks/nlp/token_classification.rst b/docs/source/tasks/nlp/token_classification.rst index 0d60ffc7..b35ae844 100644 --- a/docs/source/tasks/nlp/token_classification.rst +++ b/docs/source/tasks/nlp/token_classification.rst @@ -35,7 +35,7 @@ Training revision="master", tokenizer=tokenizer, ) - model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.labels) + model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.num_classes) trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1) trainer.fit(model, dm) diff --git a/examples/token_classification.py b/examples/token_classification.py index 5846f760..9bd9a5e8 100644 --- a/examples/token_classification.py +++ b/examples/token_classification.py @@ -17,7 +17,7 @@ revision="master", tokenizer=tokenizer, ) - model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.labels) + model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.num_classes) trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1) trainer.fit(model, dm) diff --git a/examples/translation_wmt.py b/examples/translation_wmt.py index 62688852..89f956aa 100644 --- a/examples/translation_wmt.py +++ b/examples/translation_wmt.py @@ -4,7 +4,7 @@ from lightning_transformers.task.nlp.translation import TranslationTransformer, WMT16TranslationDataModule if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base") + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base", model_max_length=512) model = TranslationTransformer( pretrained_model_name_or_path="t5-base", n_gram=4, @@ -20,6 +20,7 @@ target_language="ro", max_source_length=128, max_target_length=128, + padding="max_length", tokenizer=tokenizer, ) trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1) diff --git a/lightning_transformers/core/data.py b/lightning_transformers/core/data.py index c38398af..f13b5a6c 100644 --- a/lightning_transformers/core/data.py +++ b/lightning_transformers/core/data.py @@ -212,12 +212,3 @@ def predict_dataloader(self) -> Optional[DataLoader]: @property def collate_fn(self) -> Optional[Callable]: return None - - @property - def model_data_kwargs(self) -> Dict: - """Override to provide the model with additional kwargs. - - This is useful to provide the number of classes/pixels to the model or any other data specific args - Returns: Dict of args - """ - return {} diff --git a/lightning_transformers/core/model.py b/lightning_transformers/core/model.py index eaedb9d4..f956b2bc 100644 --- a/lightning_transformers/core/model.py +++ b/lightning_transformers/core/model.py @@ -42,7 +42,6 @@ class TaskTransformer(pl.LightningModule): pretrained_model_name_or_path: Huggingface model to use if backbone config not passed. tokenizer: The pre-trained tokenizer. pipeline_kwargs: Arguments required for the HuggingFace inference pipeline class. - **model_data_kwargs: Arguments passed from the data module to the class. """ def __init__( diff --git a/lightning_transformers/core/seq2seq/data.py b/lightning_transformers/core/seq2seq/data.py index 2c007ff6..5d6dd38d 100644 --- a/lightning_transformers/core/seq2seq/data.py +++ b/lightning_transformers/core/seq2seq/data.py @@ -18,10 +18,9 @@ class Seq2SeqDataModule(TransformerDataModule): def __init__( self, *args, max_target_length: int = 128, max_source_length: int = 1024, padding: str = "longest", **kwargs ) -> None: - super().__init__(*args, **kwargs) + super().__init__(*args, padding=padding, **kwargs) self.max_target_length = max_target_length self.max_source_length = max_source_length - self.padding = padding def process_data(self, dataset: Dataset, stage: Optional[str] = None) -> Dataset: src_text_column_name, tgt_text_column_name = self.source_target_column_names @@ -60,14 +59,16 @@ def convert_to_features( src_text_column_name: str, tgt_text_column_name: str, ): - encoded_results = tokenizer.prepare_seq2seq_batch( - src_texts=examples[src_text_column_name], - tgt_texts=examples[tgt_text_column_name], - max_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - ) - return encoded_results + inputs = examples[src_text_column_name] + targets = examples[tgt_text_column_name] + model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True) + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + return model_inputs @property def collate_fn(self) -> Callable: diff --git a/lightning_transformers/task/nlp/multiple_choice/data.py b/lightning_transformers/task/nlp/multiple_choice/data.py index a218d4f2..c23fbd7e 100644 --- a/lightning_transformers/task/nlp/multiple_choice/data.py +++ b/lightning_transformers/task/nlp/multiple_choice/data.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict from transformers import default_data_collator @@ -35,7 +34,3 @@ def collate_fn(self) -> callable: @property def num_classes(self) -> int: raise NotImplementedError - - @property - def model_data_kwargs(self) -> Dict[str, int]: - return {"num_labels": self.num_classes} diff --git a/lightning_transformers/task/nlp/text_classification/data.py b/lightning_transformers/task/nlp/text_classification/data.py index 46a16fa8..2f244b6a 100644 --- a/lightning_transformers/task/nlp/text_classification/data.py +++ b/lightning_transformers/task/nlp/text_classification/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from datasets import ClassLabel, Dataset from pytorch_lightning.utilities import rank_zero_warn @@ -53,10 +53,6 @@ def num_classes(self) -> int: self.setup("fit") return self.labels.num_classes - @property - def model_data_kwargs(self) -> Dict[str, int]: - return {"num_labels": self.num_classes} - @staticmethod def convert_to_features( example_batch: Any, _, tokenizer: PreTrainedTokenizerBase, input_feature_fields: List[str], **tokenizer_kwargs diff --git a/lightning_transformers/task/nlp/token_classification/data.py b/lightning_transformers/task/nlp/token_classification/data.py index b918a7d0..351e8938 100644 --- a/lightning_transformers/task/nlp/token_classification/data.py +++ b/lightning_transformers/task/nlp/token_classification/data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from datasets import ClassLabel, Dataset from pytorch_lightning.utilities import rank_zero_warn @@ -90,10 +90,6 @@ def num_classes(self) -> int: self.setup("fit") return len(self.labels) - @property - def model_data_kwargs(self) -> Dict[str, Any]: - return {"labels": self.labels} - @staticmethod def convert_to_features( examples: Any, diff --git a/lightning_transformers/task/nlp/translation/datasets/wmt16.py b/lightning_transformers/task/nlp/translation/datasets/wmt16.py index 1f7157e3..59f553f3 100644 --- a/lightning_transformers/task/nlp/translation/datasets/wmt16.py +++ b/lightning_transformers/task/nlp/translation/datasets/wmt16.py @@ -32,16 +32,13 @@ def convert_to_features( src_text_column_name: str, tgt_text_column_name: str, ): - translations = examples["translation"] # Extract translations from dict + inputs = [ex[src_text_column_name] for ex in examples["translation"]] + targets = [ex[tgt_text_column_name] for ex in examples["translation"]] + model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True) - def extract_text(lang): - return [text[lang] for text in translations] + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) - encoded_results = tokenizer.prepare_seq2seq_batch( - src_texts=extract_text(src_text_column_name), - tgt_texts=extract_text(tgt_text_column_name), - max_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - ) - return encoded_results + model_inputs["labels"] = labels["input_ids"] + return model_inputs diff --git a/lightning_transformers/task/nlp/translation/model.py b/lightning_transformers/task/nlp/translation/model.py index 4a60af10..54f1a68d 100644 --- a/lightning_transformers/task/nlp/translation/model.py +++ b/lightning_transformers/task/nlp/translation/model.py @@ -53,7 +53,7 @@ def compute_generate_metrics(self, batch, prefix): tgt_lns = self.tokenize_labels(batch["labels"]) pred_lns = self.generate(batch["input_ids"], batch["attention_mask"]) # wrap targets in list as score expects a list of potential references - result = self.bleu(pred_lns, tgt_lns) + result = self.bleu(preds=pred_lns, target=tgt_lns) self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) def configure_metrics(self, stage: str): diff --git a/lightning_transformers/task/vision/image_classification/data.py b/lightning_transformers/task/vision/image_classification/data.py index f6db3193..5acf53ab 100644 --- a/lightning_transformers/task/vision/image_classification/data.py +++ b/lightning_transformers/task/vision/image_classification/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from datasets import ClassLabel, Dataset from pytorch_lightning.utilities import rank_zero_warn @@ -52,7 +52,3 @@ def num_classes(self) -> int: rank_zero_warn("Labels has not been set, calling `setup('fit')`.") self.setup("fit") return self.labels.num_classes - - @property - def model_data_kwargs(self) -> Dict[str, int]: - return {"num_labels": self.num_classes}