From 26d60938c303fba826a7ddb1f00260ca45b0e108 Mon Sep 17 00:00:00 2001 From: fabiocapsouza Date: Sun, 15 Nov 2020 12:30:46 -0300 Subject: [PATCH] Revert "Upgrade PyTorch Lightning to 1.0.2 (#7852)" This reverts commit 7af83fc63357fc88484904138ea55c168e7f6ba4. --- examples/lightning_base.py | 5 ++--- examples/requirements.txt | 2 +- examples/seq2seq/callbacks.py | 1 + examples/seq2seq/finetune.py | 5 +++-- examples/seq2seq/test_bash_script.py | 2 +- examples/seq2seq/test_seq2seq_examples_multi_gpu.py | 1 + examples/text-classification/run_pl_glue.py | 2 +- examples/token-classification/run_pl_ner.py | 6 +++--- 8 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 739e5dc59650dd..6ff4a08fc4ac9d 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None: def generic_train( model: BaseTransformer, args: argparse.Namespace, - early_stopping_callback=None, + early_stopping_callback=False, logger=True, # can pass WandbLogger() here extra_callbacks=[], checkpoint_callback=None, @@ -355,8 +355,6 @@ def generic_train( checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 ) - if early_stopping_callback: - extra_callbacks.append(early_stopping_callback) if logging_callback is None: logging_callback = LoggingCallback() @@ -378,6 +376,7 @@ def generic_train( callbacks=[logging_callback] + extra_callbacks, logger=logger, checkpoint_callback=checkpoint_callback, + early_stop_callback=early_stopping_callback, **train_params, ) diff --git a/examples/requirements.txt b/examples/requirements.txt index 9c270479678916..120a3ab5e06cfa 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,7 +5,7 @@ psutil sacrebleu rouge-score tensorflow_datasets -pytorch-lightning==1.0.4 +pytorch-lightning==0.9.0 matplotlib git-python==1.0.3 faiss-cpu diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 64560487496dcf..c6cd2014ded49b 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -102,6 +102,7 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa monitor=f"val_{metric}", mode="min" if "loss" in metric else "max", save_top_k=save_top_k, + period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 7e57f7ba40f1c4..9da761db73b42e 100755 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -182,6 +182,7 @@ def validation_step(self, batch, batch_idx) -> Dict: return self._generative_step(batch) def validation_epoch_end(self, outputs, prefix="val") -> Dict: + self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] @@ -251,7 +252,7 @@ def get_dataset(self, type_path) -> Seq2SeqDataset: def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: dataset = self.get_dataset(type_path) - if self.hparams.sortish_sampler and type_path != "test" and type_path != "val": + if self.hparams.sortish_sampler and type_path != "test": sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) return DataLoader( dataset, @@ -262,7 +263,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) sampler=sampler, ) - elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val": + elif self.hparams.max_tokens_per_batch is not None and type_path != "test": batch_sampler = dataset.make_dynamic_sampler( self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1 ) diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index 71861ef4dbc6a3..24ce9bfe6b49c5 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -144,7 +144,6 @@ def test_opus_mt_distill_script(self): f"--num_train_epochs={epochs}", "--warmup_steps=10", "--val_check_interval=1.0", - "--do_predict", ] ) with patch.object(sys, "argv", testargs): @@ -152,6 +151,7 @@ def test_opus_mt_distill_script(self): parser = pl.Trainer.add_argparse_args(parser) parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() + args.do_predict = False # assert args.gpus == gpus THIS BREAKS for multigpu model = distill_main(args) diff --git a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py index 03ec39037c15b4..a6b76a4c530a6f 100644 --- a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py +++ b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py @@ -176,6 +176,7 @@ def convert(k, v): print(metrics) last_step_stats = metrics["val"][-1] self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) + self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float) self.assertEqual(len(metrics["test"]), 1) desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1) diff --git a/examples/text-classification/run_pl_glue.py b/examples/text-classification/run_pl_glue.py index 500a0bd627643d..80315abc56bbb9 100644 --- a/examples/text-classification/run_pl_glue.py +++ b/examples/text-classification/run_pl_glue.py @@ -192,7 +192,7 @@ def main(): # Optionally, predict on dev set and write to output_dir if args.do_predict: - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) model = model.load_from_checkpoint(checkpoints[-1]) return trainer.test(model) diff --git a/examples/token-classification/run_pl_ner.py b/examples/token-classification/run_pl_ner.py index 1066c6fed48cc9..c82cff74d8ef4c 100644 --- a/examples/token-classification/run_pl_ner.py +++ b/examples/token-classification/run_pl_ner.py @@ -207,9 +207,9 @@ def add_model_specific_args(parser, root_dir): if args.do_predict: # See https://github.com/huggingface/transformers/issues/3159 - # pl use this default format to create a checkpoint: + # pl use this format to create a checkpoint: # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ - # /pytorch_lightning/callbacks/model_checkpoint.py#L322 - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) + # /pytorch_lightning/callbacks/model_checkpoint.py#L169 + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model)