Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
CI Fixes + Fix Token Classification (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Jun 23, 2022
1 parent e51ce01 commit 1d81230
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docs/source/tasks/nlp/token_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Training
revision="master",
tokenizer=tokenizer,
)
model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.labels)
model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.num_classes)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Expand Down
2 changes: 1 addition & 1 deletion examples/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
revision="master",
tokenizer=tokenizer,
)
model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.labels)
model = TokenClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", labels=dm.num_classes)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)
3 changes: 2 additions & 1 deletion examples/translation_wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning_transformers.task.nlp.translation import TranslationTransformer, WMT16TranslationDataModule

if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="t5-base", model_max_length=512)
model = TranslationTransformer(
pretrained_model_name_or_path="t5-base",
n_gram=4,
Expand All @@ -20,6 +20,7 @@
target_language="ro",
max_source_length=128,
max_target_length=128,
padding="max_length",
tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
Expand Down
9 changes: 0 additions & 9 deletions lightning_transformers/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,3 @@ def predict_dataloader(self) -> Optional[DataLoader]:
@property
def collate_fn(self) -> Optional[Callable]:
return None

@property
def model_data_kwargs(self) -> Dict:
"""Override to provide the model with additional kwargs.
This is useful to provide the number of classes/pixels to the model or any other data specific args
Returns: Dict of args
"""
return {}
1 change: 0 additions & 1 deletion lightning_transformers/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class TaskTransformer(pl.LightningModule):
pretrained_model_name_or_path: Huggingface model to use if backbone config not passed.
tokenizer: The pre-trained tokenizer.
pipeline_kwargs: Arguments required for the HuggingFace inference pipeline class.
**model_data_kwargs: Arguments passed from the data module to the class.
"""

def __init__(
Expand Down
21 changes: 11 additions & 10 deletions lightning_transformers/core/seq2seq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ class Seq2SeqDataModule(TransformerDataModule):
def __init__(
self, *args, max_target_length: int = 128, max_source_length: int = 1024, padding: str = "longest", **kwargs
) -> None:
super().__init__(*args, **kwargs)
super().__init__(*args, padding=padding, **kwargs)
self.max_target_length = max_target_length
self.max_source_length = max_source_length
self.padding = padding

def process_data(self, dataset: Dataset, stage: Optional[str] = None) -> Dataset:
src_text_column_name, tgt_text_column_name = self.source_target_column_names
Expand Down Expand Up @@ -60,14 +59,16 @@ def convert_to_features(
src_text_column_name: str,
tgt_text_column_name: str,
):
encoded_results = tokenizer.prepare_seq2seq_batch(
src_texts=examples[src_text_column_name],
tgt_texts=examples[tgt_text_column_name],
max_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
)
return encoded_results
inputs = examples[src_text_column_name]
targets = examples[tgt_text_column_name]
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

model_inputs["labels"] = labels["input_ids"]
return model_inputs

@property
def collate_fn(self) -> Callable:
Expand Down
5 changes: 0 additions & 5 deletions lightning_transformers/task/nlp/multiple_choice/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict

from transformers import default_data_collator

Expand All @@ -35,7 +34,3 @@ def collate_fn(self) -> callable:
@property
def num_classes(self) -> int:
raise NotImplementedError

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}
6 changes: 1 addition & 5 deletions lightning_transformers/task/nlp/text_classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from datasets import ClassLabel, Dataset
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -53,10 +53,6 @@ def num_classes(self) -> int:
self.setup("fit")
return self.labels.num_classes

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}

@staticmethod
def convert_to_features(
example_batch: Any, _, tokenizer: PreTrainedTokenizerBase, input_feature_fields: List[str], **tokenizer_kwargs
Expand Down
6 changes: 1 addition & 5 deletions lightning_transformers/task/nlp/token_classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional

from datasets import ClassLabel, Dataset
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -90,10 +90,6 @@ def num_classes(self) -> int:
self.setup("fit")
return len(self.labels)

@property
def model_data_kwargs(self) -> Dict[str, Any]:
return {"labels": self.labels}

@staticmethod
def convert_to_features(
examples: Any,
Expand Down
19 changes: 8 additions & 11 deletions lightning_transformers/task/nlp/translation/datasets/wmt16.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,13 @@ def convert_to_features(
src_text_column_name: str,
tgt_text_column_name: str,
):
translations = examples["translation"] # Extract translations from dict
inputs = [ex[src_text_column_name] for ex in examples["translation"]]
targets = [ex[tgt_text_column_name] for ex in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

def extract_text(lang):
return [text[lang] for text in translations]
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

encoded_results = tokenizer.prepare_seq2seq_batch(
src_texts=extract_text(src_text_column_name),
tgt_texts=extract_text(tgt_text_column_name),
max_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
)
return encoded_results
model_inputs["labels"] = labels["input_ids"]
return model_inputs
2 changes: 1 addition & 1 deletion lightning_transformers/task/nlp/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_generate_metrics(self, batch, prefix):
tgt_lns = self.tokenize_labels(batch["labels"])
pred_lns = self.generate(batch["input_ids"], batch["attention_mask"])
# wrap targets in list as score expects a list of potential references
result = self.bleu(pred_lns, tgt_lns)
result = self.bleu(preds=pred_lns, target=tgt_lns)
self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True)

def configure_metrics(self, stage: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from datasets import ClassLabel, Dataset
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -52,7 +52,3 @@ def num_classes(self) -> int:
rank_zero_warn("Labels has not been set, calling `setup('fit')`.")
self.setup("fit")
return self.labels.num_classes

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}

0 comments on commit 1d81230

Please sign in to comment.