Skip to content

Commit

Permalink
added logger, save model only at last epoch, compatible with latest t…
Browse files Browse the repository at this point in the history
…ransformers + pytorch lightning
  • Loading branch information
Shivanandroy committed Feb 15, 2022
1 parent 94ad898 commit cb4d892
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 83 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Here's a link to [Medium article](https://snrspeaks.medium.com/simplet5-train-t5

## Install
```python
# It's advisable to create a new python environment and install simplet5
pip install --upgrade simplet5
```

Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy
pandas
sentencepiece
torch>=1.7.0,!=1.8.0
transformers==4.10.0
pytorch-lightning==1.4.5
transformers==4.16.2
pytorch-lightning==1.5.10
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="simplet5",
version="0.1.3",
version="0.1.4",
license="apache-2.0",
author="Shivanand Roy",
author_email="[email protected]",
Expand Down Expand Up @@ -39,12 +39,12 @@
packages=setuptools.find_packages(),
python_requires=">=3.5",
install_requires=[
"numpy",
"pandas",
"sentencepiece",
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
"transformers==4.10.0",
"pytorch-lightning==1.4.5",
"tqdm"
# "fastt5==0.0.7",
"transformers==4.16.2",
"pytorch-lightning==1.5.10",
],
classifiers=[
"Intended Audience :: Developers",
Expand Down
157 changes: 81 additions & 76 deletions simplet5/simplet5.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from transformers import (
AdamW,
T5ForConditionalGeneration,
MT5ForConditionalGeneration,
ByT5Tokenizer,
Expand All @@ -12,14 +10,13 @@
MT5TokenizerFast as MT5Tokenizer,
)
from transformers import AutoTokenizer

# from fastT5 import export_and_get_onnx_model
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelWithLMHead, AutoTokenizer
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar

torch.cuda.empty_cache()
pl.seed_everything(42)
Expand All @@ -37,7 +34,6 @@ def __init__(
):
"""
initiates a PyTorch Dataset Module for input data
Args:
data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
Expand Down Expand Up @@ -85,8 +81,6 @@ def __getitem__(self, index: int):
] = -100 # to make sure we have correct labels for T5 text generation

return dict(
source_text=source_text,
target_text=data_row["target_text"],
source_text_input_ids=source_text_encoding["input_ids"].flatten(),
source_text_attention_mask=source_text_encoding["attention_mask"].flatten(),
labels=labels.flatten(),
Expand All @@ -105,10 +99,10 @@ def __init__(
batch_size: int = 4,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
num_workers: int = 2,
):
"""
initiates a PyTorch Lightning Data Module
Args:
train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
Expand All @@ -125,6 +119,7 @@ def __init__(
self.tokenizer = tokenizer
self.source_max_token_len = source_max_token_len
self.target_max_token_len = target_max_token_len
self.num_workers = num_workers

def setup(self, stage=None):
self.train_dataset = PyTorchDataModule(
Expand All @@ -143,38 +138,56 @@ def setup(self, stage=None):
def train_dataloader(self):
""" training dataloader """
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)

def test_dataloader(self):
""" test dataloader """
return DataLoader(
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

def val_dataloader(self):
""" validation dataloader """
return DataLoader(
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)


class LightningModel(pl.LightningModule):
""" PyTorch Lightning Model class"""

def __init__(self, tokenizer, model, outputdir: str = "outputs"):
def __init__(
self,
tokenizer,
model,
outputdir: str = "outputs",
save_only_last_epoch: bool = False,
):
"""
initiates a PyTorch Lightning Model
Args:
tokenizer : T5/MT5/ByT5 tokenizer
model : T5/MT5/ByT5 model
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
"""
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.outputdir = outputdir
self.average_training_loss = None
self.average_validation_loss = None
self.save_only_last_epoch = save_only_last_epoch

def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
""" forward step """
Expand All @@ -201,7 +214,9 @@ def training_step(self, batch, batch_size):
labels=labels,
)

self.log("train_loss", loss, prog_bar=True, logger=True)
self.log(
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
)
return loss

def validation_step(self, batch, batch_size):
Expand All @@ -218,7 +233,9 @@ def validation_step(self, batch, batch_size):
labels=labels,
)

self.log("val_loss", loss, prog_bar=True, logger=True)
self.log(
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
)
return loss

def test_step(self, batch, batch_size):
Expand All @@ -244,19 +261,25 @@ def configure_optimizers(self):

def training_epoch_end(self, training_step_outputs):
""" save tokenizer and model on epoch end """
avg_traning_loss = np.round(
self.average_training_loss = np.round(
torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
4,
)
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}"
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"
if self.save_only_last_epoch:
if self.current_epoch == self.trainer.max_epochs - 1:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
else:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)

# def validation_epoch_end(self, validation_step_outputs):
# # val_loss = torch.stack([x['loss'] for x in validation_step_outputs]).mean()
# path = f"{self.outputdir}/T5-epoch-{self.current_epoch}"
# self.tokenizer.save_pretrained(path)
# # self.model.save_pretrained(path)
def validation_epoch_end(self, validation_step_outputs):
_loss = [x.cpu() for x in validation_step_outputs]
self.average_validation_loss = np.round(
torch.mean(torch.stack(_loss)).item(),
4,
)


class SimpleT5:
Expand All @@ -269,7 +292,6 @@ def __init__(self) -> None:
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
"""
loads T5/MT5 Model model for training/finetuning
Args:
model_type (str, optional): "t5" or "mt5" . Defaults to "t5".
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
Expand Down Expand Up @@ -302,10 +324,12 @@ def train(
outputdir: str = "outputs",
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
precision=32,
logger="default",
dataloader_num_workers: int = 2,
save_only_last_epoch: bool = False,
):
"""
trains T5/MT5 model on custom dataset
Args:
train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
Expand All @@ -317,65 +341,64 @@ def train(
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
"""
self.target_max_token_len = target_max_token_len
self.data_module = LightningDataModule(
train_df,
eval_df,
self.tokenizer,
batch_size=batch_size,
source_max_token_len=source_max_token_len,
target_max_token_len=target_max_token_len,
num_workers=dataloader_num_workers,
)

self.T5Model = LightningModel(
tokenizer=self.tokenizer, model=self.model, outputdir=outputdir
tokenizer=self.tokenizer,
model=self.model,
outputdir=outputdir,
save_only_last_epoch=save_only_last_epoch,
)

# checkpoint_callback = ModelCheckpoint(
# dirpath="checkpoints",
# filename="best-checkpoint-{epoch}-{train_loss:.2f}",
# save_top_k=-1,
# verbose=True,
# monitor="train_loss",
# mode="min",
# )

# logger = TensorBoardLogger("SimpleT5", name="SimpleT5-Logger")

early_stop_callback = (
[
EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=early_stopping_patience_epochs,
verbose=True,
mode="min",
)
]
if early_stopping_patience_epochs > 0
else None
)
# add callbacks
callbacks = [TQDMProgressBar(refresh_rate=5)]

if early_stopping_patience_epochs > 0:
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=early_stopping_patience_epochs,
verbose=True,
mode="min",
)
callbacks.append(early_stop_callback)

# add gpu support
gpus = 1 if use_gpu else 0

# add logger
loggers = True if logger == "default" else logger

# prepare trainer
trainer = pl.Trainer(
# logger=logger,
callbacks=early_stop_callback,
logger=loggers,
callbacks=callbacks,
max_epochs=max_epochs,
gpus=gpus,
progress_bar_refresh_rate=5,
precision=precision,
log_every_n_steps=1,
)

# fit trainer
trainer.fit(self.T5Model, self.data_module)

def load_model(
self, model_type: str = "t5", model_dir: str = "outputs", use_gpu: bool = False
):
"""
loads a checkpoint for inferencing/prediction
Args:
model_type (str, optional): "t5" or "mt5". Defaults to "t5".
model_dir (str, optional): path to model directory. Defaults to "outputs".
Expand Down Expand Up @@ -418,7 +441,6 @@ def predict(
):
"""
generates prediction for T5/MT5 model
Args:
source_text (str): any text for generating predictions
max_length (int, optional): max token length of prediction. Defaults to 512.
Expand All @@ -432,7 +454,6 @@ def predict(
early_stopping (bool, optional): Defaults to True.
skip_special_tokens (bool, optional): Defaults to True.
clean_up_tokenization_spaces (bool, optional): Defaults to True.
Returns:
list[str]: returns predictions
"""
Expand All @@ -459,20 +480,4 @@ def predict(
)
for g in generated_ids
]
return preds

# def convert_and_load_onnx_model(self, model_dir: str):
# """ returns ONNX model """
# self.onnx_model = export_and_get_onnx_model(model_dir)
# self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)

# def onnx_predict(self, source_text: str):
# """ generates prediction from ONNX model """
# token = self.onnx_tokenizer(source_text, return_tensors="pt")
# tokens = self.onnx_model.generate(
# input_ids=token["input_ids"],
# attention_mask=token["attention_mask"],
# num_beams=2,
# )
# output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
# return output
return preds

0 comments on commit cb4d892

Please sign in to comment.