diff --git a/.github/workflows/ci_dockers.yml b/.github/workflows/ci_dockers.yml index 8e8d56b04d501..1cbc99d1b68b8 100644 --- a/.github/workflows/ci_dockers.yml +++ b/.github/workflows/ci_dockers.yml @@ -108,8 +108,11 @@ jobs: pytorch_version: 1.6 - python_version: 3.6 pytorch_version: 1.4 - #- python_version: 3.7 - # pytorch_version: 1.8 # todo + - python_version: 3.7 + pytorch_version: 1.7 + # TODO + # - python_version: 3.7 + # pytorch_version: 1.8 steps: - name: Checkout uses: actions/checkout@v2 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 0ba6f701f65d6..673e65d5ccfae 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: python_version: [3.6, 3.7, 3.8] - pytorch_version: [1.3, 1.4, 1.5, 1.6] + pytorch_version: [1.3, 1.4, 1.5, 1.6, 1.7] exclude: # excludes PT 1.3 as it is missing on pypi - python_version: 3.8 diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd296d3b0e6e..95417dd8aa9ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,10 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) +- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) + +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) + ### Changed - W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) +- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + +- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + ### Deprecated - Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) @@ -31,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313)) + +- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ## [1.0.4] - 2020-10-27 @@ -74,7 +85,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) - ## [1.0.3] - 2020-10-20 ### Added diff --git a/README.md b/README.md index fea674bc396b0..33ee544802914 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ trainer = pl.Trainer() trainer.fit(autoencoder, DataLoader(train), DataLoader(val)) ``` -#### And without changing a single line of code, you could run on GPU/TPUss +#### And without changing a single line of code, you could run on GPUs/TPUs ```python # 8 GPUs trainer = Trainer(max_epochs=1, gpus=8) diff --git a/dockers/base-conda/Dockerfile b/dockers/base-conda/Dockerfile index 6a7f03970cf75..d11e61d92edbd 100644 --- a/dockers/base-conda/Dockerfile +++ b/dockers/base-conda/Dockerfile @@ -74,7 +74,7 @@ ENV CONDA_ENV=lightning COPY environment.yml environment.yml # conda init -RUN conda create -y --name $CONDA_ENV && \ +RUN conda create -y --name $CONDA_ENV cudatoolkit=${CUDA_VERSION} && \ conda init bash && \ # NOTE: this requires that the channel is presented in the yaml before packages # replace channel to nigtly if needed, fix PT version and remove Horovod as it will be installe later diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index de3cd01c33e9b..4fadfaa507168 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -150,6 +150,19 @@ Example implementation: def compute(self): return self.correct.float() / self.total +Metrics support backpropagation, if all computations involved in the metric calculation +are differentiable. However, note that the cached state is detached from the computational +graph and cannot be backpropagated. Not doing this would mean storing the computational +graph for each update call, which can lead to out-of-memory errors. +In practise this means that: + +.. code-block:: python + + metric = MyMetric() + val = metric(pred, target) # this value can be backpropagated + val = metric.compute() # this value cannot be backpropagated + + ********** Metric API ********** @@ -453,4 +466,3 @@ embedding_similarity [func] .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: - diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 2b8025959c9ca..1e7baadb64480 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -48,6 +48,10 @@ to manually manage the optimization process. To do so, do the following: opt_d.step() opt_d.zero_grad() + # log losses + self.log('loss_a', loss_a) + self.log('loss_b', loss_b) + .. note:: This is only recommended for experts who need ultimate flexibility Manual optimization does not yet support accumulated gradients but will be live in 1.1.0 @@ -108,7 +112,7 @@ Every optimizer you use can be paired with any `LearningRateScheduler =3.6 - pip>20.1 - numpy>=1.16.4 - - pytorch>=1.3 + - pytorch>=1.3,<1.8 - future>=0.17.1 - PyYAML>=5.1 - tqdm>=4.41.0 @@ -41,7 +41,7 @@ dependencies: - torchtext>=0.3.1 # Examples - - torchvision>=0.4.1 + - torchvision>=0.4.1,<0.9.0 - pip: - test-tube>=0.7.5 diff --git a/notebooks/05-trainer-flags-overview.ipynb b/notebooks/05-trainer-flags-overview.ipynb index a070ce03629ba..4589887ecb986 100644 --- a/notebooks/05-trainer-flags-overview.ipynb +++ b/notebooks/05-trainer-flags-overview.ipynb @@ -2223,7 +2223,7 @@ "source": [ "from pytorch_lightning.callbacks import ModelCheckpoint\n", "\n", - "trainer = pl.Trainer(checkpoint_callback=ModelCheckpoint(monitor='val_loss'))\n", + "trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_loss')])\n", "\n", "trainer.fit(model, train_loader, val_loader)" ], @@ -2265,7 +2265,7 @@ " prefix='',\n", ")\n", "\n", - "trainer = Trainer(checkpoint_callback=checkpoint_callback)\n", + "trainer = Trainer(callbacks=[checkpoint_callback])\n", "\n", "trainer.fit(model, train_loader, val_loader)" ], @@ -2471,7 +2471,7 @@ "# **NOTE: this saves weights to some/path NOT my/path\n", "checkpoint = ModelCheckpoint(filepath='some/path')\n", "trainer = pl.Trainer(\n", - " checkpoint_callback=checkpoint,\n", + " callbacks=[checkpoint],\n", " weights_save_path='my/path'\n", ")\n", "trainer.fit(model, train_loader, val_loader)" diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4785371a3d24f..e69addf234a36 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -132,11 +132,6 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def clip_gradients(self, optimizer, clip_val=None): - - if self.trainer.amp_backend == AMPType.NATIVE: - self.trainer.scaler.unscale_(optimizer) - - # apply clip gradients # TODO: separate TPU case from here self._clip_gradients(optimizer, clip_val) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1a81ca0980339..f3eabf5611cf0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback): ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - By default, filename is ``None`` and will be set to ``'{epoch}'``. + By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. Example:: @@ -111,7 +111,7 @@ class ModelCheckpoint(Callback): # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') - >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) + >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt @@ -123,7 +123,7 @@ class ModelCheckpoint(Callback): # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') - trainer = Trainer(checkpoint_callback=checkpoint_callback) + trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path @@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module): monitor_candidates = self._monitor_candidates(trainer) # ie: path/val_loss=0.5.ckpt - filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step) # callback supports multiple simultaneous modes # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: - self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath) + self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) + self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -360,16 +360,17 @@ def _format_checkpoint_name( cls, filename: Optional[str], epoch: int, + step: int, metrics: Dict[str, Any], prefix: str = "", ) -> str: if not filename: # filename is not set, use default name - filename = "{epoch}" + filename = "{epoch}-{step}" # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) if len(groups) >= 0: - metrics["epoch"] = epoch + metrics.update({"epoch": epoch, 'step': step}) for group in groups: name = group[1:] filename = filename.replace(group, name + "={" + name) @@ -379,7 +380,7 @@ def _format_checkpoint_name( return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt]) def format_checkpoint_name( - self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None + self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -387,24 +388,24 @@ def format_checkpoint_name( >>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') - >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') - >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' - >>> ckpt = ModelCheckpoint(filename='{epoch}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) - 'epoch=0.ckpt' + >>> ckpt = ModelCheckpoint(filename='{step}') + >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) + 'step=0.ckpt' """ filename = self._format_checkpoint_name( - self.filename, epoch, metrics, prefix=self.prefix + self.filename, epoch, step, metrics, prefix=self.prefix ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics): - filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) + def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): - filepath = self.format_checkpoint_name( - epoch, ckpt_name_metrics, ver=version_cnt - ) + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 return filepath @@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics) ckpt_name_metrics.update(trainer.logger_connector.callback_metrics) ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) + ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch}) return ckpt_name_metrics - def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath): should_save_last = self.monitor is None or self.save_last if not should_save_last: return @@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: last_filepath = self._format_checkpoint_name( - self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + ckpt_name_metrics, + prefix=self.prefix ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") @@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi if self.monitor is None: self.best_model_path = self.last_model_path - def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): + def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): current = metrics.get(self.monitor) + epoch = metrics.get("epoch") + step = metrics.get("step") if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) if self.check_monitor_top_k(current): - self._update_best_and_save(filepath, current, epoch, trainer, pl_module) + self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) elif self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" + f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}" ) def _is_valid_monitor_key(self, metrics): @@ -544,11 +550,11 @@ def _update_best_and_save( filepath: str, current: torch.Tensor, epoch: int, + step: int, trainer, pl_module, ): - - k = epoch + 1 if self.save_top_k == -1 else self.save_top_k + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k del_list = [] if len(self.best_k_models) == k and k > 0: @@ -575,9 +581,8 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} reached" - f" {current:0.5f} (best {self.best_model_score:0.5f})," - f" saving model to {filepath} as top {k}" + f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" + f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 010b4efa028e9..f185626646803 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import tempfile import collections import copy import inspect -import os import re -import tempfile from abc import ABC from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping @@ -28,16 +27,17 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( AttributeDict, collect_init_args, get_init_args, ) +from pytorch_lightning.callbacks import Callback from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs): self._datamodule = None self._results: Optional[Result] = None self._current_fx_name = '' + self._current_hook_fx_name = None + self._current_dataloader_idx = None def optimizers(self): opts = self.trainer.optimizers @@ -244,6 +246,18 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + if self._current_hook_fx_name is not None: + self.trainer.logger_connector.check_logging_in_callbacks( + self._current_hook_fx_name, + on_step=on_step, + on_epoch=on_epoch + ) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"Logged key: {name} should not contain information about dataloader_idx.") + self._results.log( name, value, @@ -257,7 +271,8 @@ def log( enable_graph, sync_dist, sync_dist_op, - sync_dist_group + sync_dist_group, + self._current_dataloader_idx, ) def log_dict( @@ -1101,7 +1116,6 @@ def backward(self, loss, optimizer, optimizer_idx): """ loss.backward(*args, **kwargs) - self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer) def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): """ @@ -1284,11 +1298,11 @@ def tbptt_split_batch(self, batch, split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): - split_x = x[:, t : t + split_size] + split_x = x[:, t: t + split_size] elif isinstance(x, collections.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t : t + split_size] + split_x[batch_idx] = x[batch_idx][t: t + split_size] batch_split.append(split_x) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 650c1876d0cd0..059c724aa75a9 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -124,6 +124,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -144,6 +145,7 @@ def log( # set step version step_name = f'{name}_step' + self.__set_meta( step_name, value, @@ -154,12 +156,15 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) + self.__setitem__(step_name, value) # set epoch version epoch_name = f'{name}_epoch' + self.__set_meta( epoch_name, value, @@ -170,7 +175,8 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) self.__setitem__(epoch_name, value) @@ -185,7 +191,8 @@ def log( reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=was_forked + forked=was_forked, + dataloader_idx=dataloader_idx, ) # set the value @@ -202,7 +209,8 @@ def __set_meta( reduce_fx: Callable, tbptt_pad_token: int, tbptt_reduce_fx: Callable, - forked: bool + forked: bool, + dataloader_idx: Union[int, None] ): # set the meta for the item meta_value = value @@ -215,7 +223,8 @@ def __set_meta( value=meta_value, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=forked + forked=forked, + dataloader_idx=dataloader_idx, ) self['meta'][name] = meta @@ -225,13 +234,22 @@ def __set_meta( _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) def track_batch_size(self, batch): + batch_size = Result.extract_batch_size(batch) + Result.attach_batch_size(batch_size, self) + + @staticmethod + def extract_batch_size(batch): try: batch_size = Result.unpack_batch_size(batch) except RecursionError as re: batch_size = 1 + return batch_size - meta = self['meta'] - meta['_internal']['batch_sizes'].append(batch_size) + @staticmethod + def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None: + if batch_size is not None: + meta = result['meta'] + meta['_internal']['batch_sizes'].append(batch_size) def get_batch_sizes(self): meta = self['meta'] @@ -242,7 +260,12 @@ def get_callback_metrics(self) -> dict: return result - def get_batch_log_metrics(self, include_forked_originals=True) -> dict: + def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: + if dataloader_idx is not None and add_dataloader_idx: + return f"{k}/dataloader_idx_{dataloader_idx}" + return k + + def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of the batch step @@ -257,15 +280,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict: if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache.detach() else: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_epoch_log_metrics(self) -> dict: + def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -279,11 +304,13 @@ def get_epoch_log_metrics(self) -> dict: if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute().detach() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -291,7 +318,7 @@ def get_epoch_log_metrics(self) -> dict: return result - def get_epoch_pbar_metrics(self): + def get_epoch_pbar_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -305,11 +332,13 @@ def get_epoch_pbar_metrics(self): if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute().detach() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -317,7 +346,7 @@ def get_epoch_pbar_metrics(self): return result - def get_forked_metrics(self): + def get_forked_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -328,12 +357,14 @@ def get_forked_metrics(self): if k == '_internal': continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['forked']: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_batch_pbar_metrics(self, include_forked_originals=True): + def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): """ Gets the metrics to log at the end of the batch step """ @@ -347,11 +378,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True): if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache else: - result[k] = self[k] + result[dl_key] = self[k] return result @@ -473,6 +506,8 @@ def reduce_on_epoch_end(cls, outputs): if option['on_epoch']: fx = option['reduce_fx'] if fx == torch.mean: + if isinstance(result[k], list): + result[k] = torch.tensor(result[k]).float() try: reduced_val = weighted_mean(result[k], batch_sizes) except Exception as e: diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index cf0b22d7d446f..2246d02bc9bcb 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -188,7 +188,7 @@ def _sanitize_callable(val): return val.__name__ return _val except Exception: - return val.__name__ + return getattr(val, "__name__", None) return val return {key: _sanitize_callable(val) for key, val in params.items()} diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 3a853be0ebdd5..f003e0d3da72a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -150,7 +150,8 @@ def forward(self, *args, **kwargs): Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. """ # add current step - self.update(*args, **kwargs) + with torch.no_grad(): + self.update(*args, **kwargs) self._forward_cache = None if self.compute_on_step: diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 6506540bde6e1..b016b6c5d24fb 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -38,6 +38,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # once backward has been applied, release graph closure_loss = closure_loss.detach() + + # unscale gradient to allow analyze within `on_after_backward` + if not self.trainer.train_loop.should_accumulate(): + self.trainer.scaler.unscale_(optimizer) + return closure_loss def training_step(self, fx, args): diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 2d0c3e46f2fe0..c0a3744f162fd 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -484,10 +484,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): | -Add a list of :class:`~pytorch_lightning.callbacks.Callback`. These callbacks DO NOT replace the explicit callbacks -(loggers or :class:`~pytorch_lightning.callbacks.ModelCheckpoint`). - -.. note:: Only user defined callbacks (ie: Not :class:`~pytorch_lightning.callbacks.ModelCheckpoint`) +Add a list of :class:`~pytorch_lightning.callbacks.Callback`. .. code-block:: python @@ -537,34 +534,27 @@ def on_train_end(self, trainer, pl_module): | -Pass in a callback for checkpointing. Checkpoints capture the exact value of all parameters used by a model. By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, -but you can override the default behavior by Initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback, -and passing it to :class:`~pytorch_lightning.trainer.Trainer` `checkpoint_callback` flag. +Checkpoints capture the exact value of all parameters used by a model. +To disable automatic checkpointing, set this to `False`. .. code-block:: python - from pytorch_lightning.callbacks import ModelCheckpoint + # default used by Trainer + trainer = Trainer(checkpoint_callback=True) - # default used by the Trainer - checkpoint_callback = ModelCheckpoint( - dirpath=os.getcwd(), - save_top_k=True, - verbose=True, - monitor='checkpoint_on', - mode='min', - prefix='' - ) + # turn off automatic checkpointing + trainer = Trainer(checkpoint_callback=False) - trainer = Trainer(checkpoint_callback=checkpoint_callback) -To disable automatic checkpointing, set this to `False`. +You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` +callback, and adding it to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks` list. +See :ref:`Saving and Loading Weights ` for how to customize checkpointing. -.. code-block:: python - trainer = Trainer(checkpoint_callback=False) +.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since + v1.1.0 and will be unsupported from v1.3.0. -See also :ref:`Saving and Loading Weights `. default_root_dir ^^^^^^^^^^^^^^^^ @@ -1529,7 +1519,7 @@ def tbptt_split_batch(self, batch, split_size): # **NOTE: this saves weights to some/path NOT my/path checkpoint = ModelCheckpoint(dirpath='some/path') trainer = Trainer( - checkpoint_callback=checkpoint, + callbacks=[checkpoint], weights_save_path='my/path' ) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index b8a4276a2d747..c9ef4ae32be77 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -56,10 +56,10 @@ def on_trainer_init( def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): if isinstance(checkpoint_callback, ModelCheckpoint): - # TODO: deprecated, remove this block in v1.4.0 + # TODO: deprecated, remove this block in v1.3.0 rank_zero_warn( "Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)" - " is deprecated since v1.1 and will no longer be supported in v1.4.", + " is deprecated since v1.1 and will no longer be supported in v1.3.", DeprecationWarning ) self.trainer.callbacks.append(checkpoint_callback) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py new file mode 100644 index 0000000000000..4034840a09b97 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py new file mode 100644 index 0000000000000..3ce4b523545c3 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class CallbackHookNameValidator: + + @staticmethod + def check_logging_in_callbacks(current_hook_fx_name: str = None, on_step: bool = None, + on_epoch: bool = None) -> None: + if current_hook_fx_name is None: + return + + internal_func = getattr(CallbackHookNameValidator, f"_{current_hook_fx_name}_log", None) + + if internal_func is None: + return + + current_callback_hook_auth_args = internal_func() + + if current_callback_hook_auth_args is not None: + m = "{} function supports only {} in {}. Provided {}" + if on_step not in current_callback_hook_auth_args["on_step"]: + msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step) + raise MisconfigurationException(msg) + + if on_epoch not in current_callback_hook_auth_args["on_epoch"]: + msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch) + raise MisconfigurationException(msg) + else: + raise MisconfigurationException( + f"{current_hook_fx_name} function doesn't support logging using self.log() yet." + ) + + @staticmethod + def _setup_log(): + """Called when fit or test begins""" + return None + + @staticmethod + def _teardown_log(): + """Called at the end of fit and test""" + return None + + @staticmethod + def _on_init_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_init_end_log(): + """Called when the trainer initialization ends, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_end_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_sanity_check_start_log(): + """Called when the validation sanity check starts.""" + return None + + @staticmethod + def _on_sanity_check_end_log(): + """Called when the validation sanity check ends.""" + return None + + @staticmethod + def _on_train_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_train_start_log(): + """Called when the train begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_pretrain_routine_start_log(): + """Called when the train begins.""" + return None + + @staticmethod + def _on_pretrain_routine_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_start_log(): + """Called when the validation batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_end_log(): + """Called when the validation batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_start_log(): + """Called when the test batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_end_log(): + """Called when the test batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_start_log(): + """Called when the validation loop begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_end_log(): + """Called when the validation loop ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_start_log(): + """Called when the test begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_end_log(): + """Called when the test ends.""" + return None + + @staticmethod + def _on_keyboard_interrupt_log(): + """Called when the training is interrupted by KeyboardInterrupt.""" + return None + + @staticmethod + def _on_save_checkpoint_log(): + """Called when saving a model checkpoint.""" + return None + + @staticmethod + def _on_load_checkpoint_log(): + """Called when loading a model checkpoint.""" + return None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py new file mode 100644 index 0000000000000..2a9d68807e694 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -0,0 +1,528 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from copy import deepcopy +from enum import Enum +from typing import Union, Tuple, Any, Mapping + +from pytorch_lightning.core.step_result import Result + + +# used to map boolean to right LoggerStage values +class FrozenDict(dict): + def __init__(self, *args, **kwargs): + self._hash = None + super(FrozenDict, self).__init__(*args, **kwargs) + + def __hash__(self): + if self._hash is None: + self._hash = hash(tuple(sorted(self.items()))) # iteritems() on py2 + return self._hash + + def _immutable(self, *args, **kws): + raise TypeError('cannot change object - object is immutable') + + __setitem__ = _immutable + __delitem__ = _immutable + pop = _immutable + popitem = _immutable + clear = _immutable + update = _immutable + setdefault = _immutable + + +LOOKUP_TABLE = FrozenDict({"1": "test", "0": "validation", "True": "test", "False": "validation"}) + + +class LoggerStages(Enum): + TRAIN = "train" + VAL = "validation" + TEST = "test" + + +class ResultStoreType(Enum): + INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" + OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" + + +class HookResultStore: + """ + This class is defined for internal usage. + It holds all metrics logged using the self.log function + in the scope of ModelHooks or Callback functions. + + We need to differiante 3 different scenarios: + - (1): We are outside of a batch loop + * It means no dataloader_idx, no optimizer idx, etc.. + - (2): We are inside the training batch loop + * We have an optimizer idx and split idx to track + - (3): We are inside the evaluation loop + * We have a dataloader_idx to track + + The data store `Result` objects for those 3 scenarios in `self._internals`. + + (1): self._internals = {"dataloader_idx": [Result(), ..., Result()]} + * dataloader_idx not being defined, it is set to 0 b default + (2): self._internals = {"dataloader_idx": + {"optimizer_idx": + {"batch_idx": + [Result(), Result()] + } + } + } + (3): Same as (1) for simplicity + + Those data structures enables us to reduce properly Result object when batch loop is finished. + """ + def __init__(self, fx_name): + self._fx_name = fx_name + self._internals = {} + self._internals_reduced = {} + self._internal_type = None + self.has_reduced = False + + def get_reduced_metrics(self): + return self._internals_reduced + + def add_dataloader_idx(self): + return len(self._internals) > 1 + + @property + def num_dataloaders(self): + return len(self._internals) + + def get_latest_from_dict(self, dl_idx): + num_opt_idx = len(self._internals[dl_idx]) - 1 + assert num_opt_idx >= 0 + num_opt_idx = str(num_opt_idx) + num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 + batch_indexes = [*self._internals[dl_idx][num_opt_idx].keys()] + # sort them by increasing order + batch_indexes.sort(key=float) + assert num_batch_idx >= 0 + return self._internals[dl_idx][num_opt_idx][batch_indexes[-1]][-1] + + def check_dataloader_idx(self, result: Result) -> bool: + add_dataloader_idx = False + try: + if len(result.keys()) > 1: + random_key = [*result.keys()][-1] + add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None + return add_dataloader_idx + return add_dataloader_idx + except Exception: + return add_dataloader_idx + + def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): + results = {} + if latest: + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + if self._internal_type == ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP: + latest_result = self._internals[dl_idx][-1] + else: + latest_result = self.get_latest_from_dict(dl_idx) + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + return results + raise NotImplementedError + + def get_batch_pbar_metrics(self, latest=True, *args, **kwargs): + return self.get_lastest_from_func_name("get_batch_pbar_metrics", *args, latest=latest, **kwargs) + + def get_batch_log_metrics(self, latest=True, *args, **kwargs): + return self.get_lastest_from_func_name("get_batch_log_metrics", *args, latest=latest, **kwargs) + + def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: + if isinstance(opt_metric, Result): + func = getattr(opt_metric, func_name) + metrics_to_log = func( + *args, + add_dataloader_idx=self.add_dataloader_idx, + **kwargs) + results.update(metrics_to_log) + else: + raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + + def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: + results = {} + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + opt_metrics = self._internals_reduced[dl_idx] + if isinstance(opt_metrics, defaultdict): + for opt_metric in opt_metrics.values(): + self.run_epoch_func(results, opt_metric, func_name, *args, **kwargs) + else: + self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) + return results + + def get_epoch_pbar_metrics(self, *args, **kwargs) -> Mapping: + return self.get_epoch_from_func_name("get_epoch_pbar_metrics") + + def get_epoch_log_metrics(self, *args, **kwargs) -> Mapping: + return self.get_epoch_from_func_name("get_epoch_log_metrics") + + def get_forked_metrics(self, *args, **kwargs) -> Mapping: + return self.get_epoch_from_func_name("get_forked_metrics") + + @staticmethod + def _append_to_structure(primary_dict, opt_idx, batch_idx, result) -> None: + if opt_idx not in primary_dict: + primary_dict[opt_idx] = {} + + if batch_idx not in primary_dict[opt_idx]: + primary_dict[opt_idx][batch_idx] = [] + + primary_dict[opt_idx][batch_idx].append(result) + + def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: + + assert isinstance(result, Result) + + if dataloader_idx is None: + dataloader_idx = 0 + + primary_key = f"{dataloader_idx}" + + # [dataloader_idx][optimizer_idx][training_step_idx] is a list + if len(extra_info) > 0: + self._internal_type = ResultStoreType.INSIDE_BATCH_TRAIN_LOOP + # initialize dictionary + if primary_key not in self._internals: + self._internals[primary_key] = {} + self._internals_reduced[primary_key] = defaultdict(dict) + + # extract infos + opt_idx = str(extra_info["opt_idx"]) + batch_idx = str(extra_info["batch_idx"]) + + self._append_to_structure(self._internals[primary_key], opt_idx, batch_idx, result) + + # [dataloader_idx] is a list + else: + self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP + if primary_key not in self._internals: + self._internals[primary_key] = [] + self._internals[primary_key].append(result) + + def auto_reduce_results_on_epoch_end(self) -> None: + """ + This function is called to reduce `self._internals` Result object. + The reduced Result object will be saved into `self._internals_reduced` + The `self._internals` stored Result objects will be deleted to save memory. + """ + if not self.has_reduced: + epoch_log_metrics = {} + epoch_progress_bar_metrics = {} + + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + epoch_metrics = self._internals[dl_idx] + + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + + num_opt_idx = len(self._internals[dl_idx]) - 1 + + # Make sure we didn't create key + assert num_opt_idx >= 0 + + for opt_idx in range(num_opt_idx + 1): + opt_idx = str(opt_idx) + # TODO: Figure out to reduce memory + # TODO: How to start training in middle of epoch + opt_outputs = epoch_metrics[opt_idx] + + num_batch_idx = len(self._internals[dl_idx][str(num_opt_idx)]) - 1 + assert num_batch_idx >= 0 + batch_indexes = self._internals[dl_idx][str(num_opt_idx)].keys() + + # reduce across time first + time_reduced_outputs = [] + for batch_idx in batch_indexes: + batch_idx = str(batch_idx) + tbptt_outs = opt_outputs[str(batch_idx)] + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + if len(tbptt_outs) > 1: + time_reduced_outputs.append(tbptt_outs) + + if len(time_reduced_outputs) == 0: + continue + + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + + # with manual opt need 1 + metrics because meta is always there + if opt_outputs.minimize is not None: + opt_outputs.minimize = opt_outputs.minimize.mean() + + self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs + + # free memory + del self._internals[dl_idx] + else: + # no need to reduce as called only once + if len(epoch_metrics) == 1: + reduced_epoch_metrics = epoch_metrics[0] + else: + reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics) + + self._internals_reduced[dl_idx] = reduced_epoch_metrics + + # free memory + del self._internals[dl_idx] + + self.has_reduced = True + + def __getitem__(self, key: str) -> Any: + try: + if key in self._internals: + return self._internals[key] + return self[key] + except KeyError: + return None + + def __repr__(self): + return self._internals.__repr__() + + +class EpochResultStore: + """ + This class is defined for internal usage. + + It holds all metrics logged using the self.log function using `HookResultStore` object. + + The internal datastructure is as follow: + + self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} + + Pseudo Code Example: + ``` + model._current_fx_name = 'something' + model._results = Result() + model.log('a', ...) + epoch_result_store.cache_result() + ``` + + """ + def __init__(self, trainer, stage): + self.trainer = trainer + self._stage = stage + self.reset() + + def __getitem__(self, key: str) -> Any: + try: + if key in self._internals: + return self._internals[key] + return None + except KeyError: + return None + + @property + def has_split_and_opt_idx(self): + """ + This function informs if we are running within training batch loop + """ + if self._split_idx is not None and self._opt_idx is not None: + return True + return False + + @property + def extra_info(self): + """ + This function provides necessary parameters to properly configure HookResultStore obj + """ + return {"batch_idx": self.trainer.batch_idx, + "split_idx": self._split_idx, + "opt_idx": self._opt_idx} + + def reset_model(self): + """ + This function is used to reset model state at the end of the capture + """ + model_ref = self.trainer.get_model() + model_ref._results = Result() + model_ref._current_hook_fx_name = None + model_ref._current_fx_name = '' + + def current_model_info(self): + """ + This function is used to extract + information related to current function scoping `self.log` call. + """ + model_ref = self.trainer.get_model() + # extract hook information + fx_name = model_ref._current_hook_fx_name + if fx_name == '': + fx_name = model_ref._current_fx_name + dataloader_idx = model_ref._current_dataloader_idx + return fx_name, dataloader_idx + + def cache_result(self) -> None: + """ + This function is called after every hook + and store the result object + """ + model_ref = self.trainer.get_model() + + # extract hook results + hook_result = model_ref._results + + # extract model information + fx_name, dataloader_idx = self.current_model_info() + + # add only if anything as been logged + # default len is 1 due to _internals + if len(hook_result) > 1: + + if fx_name not in self._internals: + self._internals[fx_name] = HookResultStore(fx_name) + + extra_info = {} + if self.has_split_and_opt_idx: + extra_info = self.extra_info + + # attach capture batch_size + Result.attach_batch_size(self._batch_size, hook_result) + + self._internals[fx_name].append( + deepcopy(hook_result), + dataloader_idx=dataloader_idx, + extra_info=extra_info) + + # update logged_metrics, progress_bar_metrics, callback_metrics + self.update_logger_connector(fx_name) + + # reset _results, fx_name + self.reset_model() + + def update_logger_connector(self, fx_name: str = None) -> None: + """ + This function is called every time we capture a hook + It automatically updates the logger_connector followings: + - progress_bar_metrics with pbar_metrics + - logged_metrics with log_metrics + - callback_metrics with progress_bar_metrics + logged_metrics + """ + + logger_connector = self.trainer.logger_connector + + callback_metrics = {} + + if not self._has_batch_loop_finished: + # get pbar + batch_pbar_metrics = self.get_latest_batch_pbar_metrics() + logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + + if self._stage in LoggerStages.TRAIN.value: + # Only log and add to callback epoch step during evaluation, test. + batch_log_metrics = self.get_latest_batch_log_metrics() + logger_connector.logged_metrics.update(batch_log_metrics) + + callback_metrics.update(batch_pbar_metrics) + callback_metrics.update(batch_log_metrics) + else: + epoch_dict = {"epoch": self.trainer.current_epoch} + + # get pbar + epoch_pbar_metrics = self.get_epoch_pbar_metrics() + logger_connector.add_progress_bar_metrics(epoch_pbar_metrics) + + # get logged_metrics + epoch_log_metrics = self.get_epoch_log_metrics() + logger_connector.logged_metrics.update(epoch_log_metrics) + logger_connector.logged_metrics.update(epoch_dict) + + # get forked_metrics + forked_metrics = self.get_forked_metrics() + + callback_metrics.update(epoch_pbar_metrics) + callback_metrics.update(epoch_log_metrics) + callback_metrics.update(forked_metrics) + + # update callback_metrics + logger_connector.callback_metrics.update(callback_metrics) + logger_connector.callback_metrics.pop("epoch", None) + + def run_batch_from_func_name(self, func_name) -> Mapping: + results = {} + for fx_name, hook_result in self._internals.items(): + func = getattr(hook_result, func_name) + results.update(func(latest=True, include_forked_originals=False)) + return results + + def get_latest_batch_log_metrics(self) -> Mapping: + return self.run_batch_from_func_name("get_batch_log_metrics") + + def get_latest_batch_pbar_metrics(self) -> Mapping: + return self.run_batch_from_func_name("get_batch_pbar_metrics") + + @property + def has_reduced(self) -> bool: + hook_results = self._internals.values() + return len(hook_results) == sum([h.has_reduced for h in hook_results]) + + def auto_reduce_results_on_epoch_end(self) -> None: + if not self.has_reduced: + for fx_name, hook_result in self._internals.items(): + hook_result.auto_reduce_results_on_epoch_end() + + @property + def has_batch_loop_finished(self) -> bool: + return self._has_batch_loop_finished + + @has_batch_loop_finished.setter + def has_batch_loop_finished(self, has_batch_loop_finished): + if has_batch_loop_finished: + # If batch loop has finished, reduce metrics + self.auto_reduce_results_on_epoch_end() + + # batch_size should be none as we finished batch loop + self._batch_size = None + + self._has_batch_loop_finished = has_batch_loop_finished + self.update_logger_connector() + + def run_epoch_by_func_name(self, func_name) -> Mapping: + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + results = {} + for fx_name, hook_result in self._internals.items(): + func = getattr(hook_result, func_name) + results.update(func()) + return results + + def get_epoch_pbar_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_epoch_pbar_metrics") + + def get_epoch_log_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_epoch_log_metrics") + + def get_forked_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_forked_metrics") + + def get_reduced_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_reduced_metrics") + + def reset(self): + self._internals = {} + self._dataloader_idx: Union[int, None] = None + self._split_idx: Union[int, None] = None + self._opt_idx: Union[int, None] = None + self._batch_size: Union[int, None] = None + self._has_batch_loop_finished = False + + def __repr__(self): + return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py similarity index 84% rename from pytorch_lightning/trainer/connectors/logger_connector.py rename to pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 893eab5a16a3d..5c699ecffa464 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from pprint import pprint +from typing import Iterable, Union, cast +from copy import deepcopy +from collections import ChainMap import torch from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection @@ -19,10 +23,12 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pprint import pprint -from typing import Iterable -from copy import deepcopy -from collections import ChainMap +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import ( + EpochResultStore, + LoggerStages, + LOOKUP_TABLE +) class LoggerConnector: @@ -33,6 +39,70 @@ def __init__(self, trainer): self.logged_metrics = {} self.progress_bar_metrics = {} self.eval_loop_results = [] + self._stages = sorted([s.value for s in LoggerStages]) + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in self._stages} + self._callback_hook_validator = CallbackHookNameValidator() + self._current_stage = None + + def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]: + """ Function to access cached_results using str or bool. Bool is used only for testing""" + stage_or_testing = str(stage_or_testing) + stages = self._stages + if stage_or_testing in self._stages: + return self._cached_results[stage_or_testing] + if stage_or_testing in LOOKUP_TABLE: + # Acces using trainer.testing + stage = LOOKUP_TABLE[stage_or_testing] + return self._cached_results[stage] + raise MisconfigurationException( + f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}" + f" or {LOOKUP_TABLE.keys()}" + ) + + def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: + self._current_stage = self._determine_stage(stage_or_testing) + if reset: + self.cached_results(stage_or_testing).reset() + + def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: + self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name, + on_step=on_step, + on_epoch=on_epoch) + + def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): + # reset the result of the PL module + model = self.trainer.get_model() + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.cached_results(testing)._batch_size = Result.extract_batch_size(batch) + + def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None: + self._cached_results["train"]._split_idx = split_idx + self._cached_results["train"]._opt_idx = opt_idx + self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch) + + def on_train_batch_end(self) -> None: + self._cached_results["train"]._split_idx = None + self._cached_results["train"]._opt_idx = None + self._cached_results["train"]._batch_size = None + + def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: + stage_or_testing = str(stage_or_testing) + stages = self._stages + if stage_or_testing in stages: + return stage_or_testing + if stage_or_testing in LOOKUP_TABLE: + # Acces using trainer.testing + return LOOKUP_TABLE[stage_or_testing] + raise MisconfigurationException( + f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {stages}" + f" or {LOOKUP_TABLE.keys()}" + ) + + def cache_logged_metrics(self) -> Union[EpochResultStore, None]: + if self._current_stage is not None: + self._cached_results[self._current_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): # logging @@ -179,12 +249,15 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): continue reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) - # make the keys 'k/dl' - reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders) # track the metrics logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics() pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() + + # make the keys 'k/dl' + logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders) + pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders) + self.logged_metrics.update(logger_metrics) self.add_progress_bar_metrics(pbar_metrics) @@ -229,6 +302,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): else: self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics) else: + flat = {} if isinstance(eval_results, list): for eval_result in eval_results: # with a scalar return, auto set it to "val_loss" for callbacks @@ -451,8 +525,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output): for opt_outputs in epoch_output: # reduce across time first time_reduced_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) if len(tbptt_outs) > 1: time_reduced_outputs.append(tbptt_outs) @@ -482,8 +555,7 @@ def __prepare_epoch_end_inputs(self, epoch_output): for opt_outputs in epoch_output: # gather across time first time_gathered_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: result = [] for x in tbptt_outs: out = x.extra @@ -511,8 +583,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): for opt_outputs in epoch_output: # gather across time first time_gathered_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) time_gathered_outputs.append(tbptt_outs) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9dab036583dd8..89a242dbfd886 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -29,6 +29,7 @@ def __init__(self, trainer): self.predictions = None self.max_batches = None self.warning_cache = WarningCache() + self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_val_batches = [] @@ -108,6 +109,9 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: @@ -133,6 +137,7 @@ def setup(self, model, max_batches, dataloaders): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: @@ -250,9 +255,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' - m = f'The {step} should not return anything as of 9.1.' \ - f'to log, use self.log(...) or self.write(...) directly in the LightningModule' - self.warning_cache.warn(m) + self.warning_cache.warn( + f'The {step} should not return anything as of 9.1.' + ' To log, use self.log(...) or self.write(...) directly in the LightningModule' + ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) @@ -292,16 +298,20 @@ def __auto_reduce_result_objs(self, outputs): return eval_results - def on_evaluation_batch_start(self, *args, **kwargs): + def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # reset the result of the PL module model = self.trainer.get_model() model._results = Result() model._current_fx_name = 'evaluation_step' + # set dataloader_idx and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start( + self.testing, batch, dataloader_idx, self.num_dataloaders) + if self.testing: - self.trainer.call_hook('on_test_batch_start', *args, **kwargs) + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: - self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index b585647fb5a0e..ae4d280d54649 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,7 +14,7 @@ from abc import ABC import inspect -from typing import Union, Iterable +from typing import Union, Iterable, Mapping import torch @@ -92,7 +92,7 @@ def process_dict_result(self, output, train=False): # --------------- # all keys not progress_bar or log are candidates for callbacks callback_metrics = {} - if output: + if isinstance(output, Mapping): for k, v in output.items(): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v @@ -156,7 +156,7 @@ def process_dict_result(self, output, train=False): # --------------- # EXTRACT HIDDEN # --------------- - hiddens = output.get('hiddens') if output else None + hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None # use every metric passed in as a candidate for callback callback_metrics.update(progress_bar_metrics) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 8d509d41d52bf..42b5f7a36641c 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,7 +15,7 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import List, Optional, Union, Type, TypeVar +from typing import List, Optional, Union, Type, TypeVar, cast from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule @@ -154,6 +154,7 @@ def progress_bar_callback(self): def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module + ref_model = cast(LightningModule, ref_model) return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) @property diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 008633273a0d1..d3cc2f2e7278f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,8 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import EvalResult +from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -174,7 +175,7 @@ def __init__( :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since - v1.1.0 and will be unsupported from v1.4.0. + v1.1.0 and will be unsupported from v1.3.0. check_val_every_n_epoch: Check val every n train epochs. @@ -465,6 +466,9 @@ def fit( def train(self): self.run_sanity_check(self.get_model()) + # set stage for logging + self.logger_connector.set_stage("train") + self.checkpoint_connector.has_trained = False # enable train mode @@ -528,16 +532,25 @@ def train(self): self.train_loop.on_train_end() def run_evaluation(self, test_mode: bool = False, max_batches=None): + + # used to know if we are logging for val, test + reset cached results + self.logger_connector.set_stage(test_mode, reset=True) + # bookkeeping self.evaluation_loop.testing = test_mode + + # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) + + # check if we want to skip this evaluation if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] - # enable eval mode + no grads + # ref model model = self.get_model() - self.evaluation_loop.on_evaluation_model_eval() + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() model.zero_grad() torch.set_grad_enabled(False) @@ -701,6 +714,8 @@ def test( # -------------------- self.verbose_test = verbose + self.logger_connector.set_stage("test") + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d1dfb3eec3733..10bab2843b7eb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -652,31 +652,18 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._num_training_batches_reached() - should_accumulate = not (accumulation_done or is_final_batch) - # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): - self.trainer.split_idx = split_idx - - # in manual optimization we loop over all optimizers at once - optimizers = self.get_optimizers_iterable() - if not self.automatic_optimization: - optimizers = [optimizers[0]] - - # loop over optimizers - for opt_idx, optimizer in optimizers: - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.get_model() - model.toggle_optimizer(optimizer, opt_idx) - - if should_accumulate: + + # create an iterable for optimizers and loop over them + for opt_idx, optimizer in self.prepare_optimizers(): + + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): # For gradient accumulation # ------------------- @@ -729,6 +716,7 @@ def train_step_and_backward_closure(): opt_idx=opt_idx, ) + # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None @@ -738,14 +726,8 @@ def train_step_and_backward_closure(): # clear gradients self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() + # update running loss + reset accumulated loss + self.update_running_loss() # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} @@ -767,7 +749,7 @@ def train_step_and_backward_closure(): @contextmanager def block_ddp_sync_behaviour(self): if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): - yield from self.trainer.model.no_sync() + yield self.trainer.model.no_sync() else: yield @@ -817,8 +799,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) - # hook - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: @@ -837,6 +821,10 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): result.closure_loss, optimizer, opt_idx, *args, **kwargs ) + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() @@ -863,6 +851,12 @@ def _accumulated_batches_reached(self): def _num_training_batches_reached(self): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 @@ -934,3 +928,33 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs + + def prepare_optimizers(self): + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + return optimizers + + def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # set split_idx to trainer for tracking + self.trainer.split_idx = split_idx + + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if self.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.get_model() + model.toggle_optimizer(optimizer, opt_idx) + + # use to track metrics internally + self.trainer.logger_connector.on_batch_start(split_idx, opt_idx, split_batch) + + def update_running_loss(self): + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 5687992981ae6..14a59fd105c5a 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import importlib +import queue as q from multiprocessing import Process, Queue import torch @@ -24,10 +25,10 @@ xm = None -def inner_f(queue, func, **kwargs): # pragma: no cover +def inner_f(queue, func, *args, **kwargs): # pragma: no cover try: - queue.put(func(**kwargs)) - except Exception as _e: + queue.put(func(*args, **kwargs)) + except Exception: import traceback traceback.print_exc() @@ -38,10 +39,13 @@ def pl_multi_process(func): @functools.wraps(func) def wrapper(*args, **kwargs): queue = Queue() - proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs) + proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() - proc.join() - return queue.get() + proc.join(10) + try: + return queue.get_nowait() + except q.Empty: + return False return wrapper diff --git a/requirements.txt b/requirements.txt index 0f8423e0860f0..d270e2bc5d854 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # the default package dependencies numpy>=1.16.4 -torch>=1.3 +torch>=1.3,<1.8 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 diff --git a/requirements/examples.txt b/requirements/examples.txt index c87d10a39346f..e930579b8b369 100644 --- a/requirements/examples.txt +++ b/requirements/examples.txt @@ -1,2 +1,2 @@ -torchvision>=0.4.1 +torchvision>=0.4.1,<0.9.0 gym>=0.17.0 \ No newline at end of file diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 19705a6ebc9a2..d3d5f67bcfeaa 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -100,7 +100,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k): path_yaml = os.path.join(tmpdir, 'best_k_models.yaml') checkpoint.to_yaml(path_yaml) d = yaml.full_load(open(path_yaml, 'r')) - best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + best_k = {k: v for k, v in checkpoint.best_k_models.items()} assert d == best_k @@ -185,67 +185,72 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: - ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {}) - assert ckpt_name == 'epoch=3' + ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) + assert ckpt_name == 'epoch=3-step=2' - ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test') - assert ckpt_name == 'test-epoch=3' + ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') + assert ckpt_name == 'test-epoch=3-step=2' # no groups case: - ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03}) + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, {}) - assert ckpt_name == 'epoch=3.ckpt' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, {}) - assert ckpt_name == 'epoch=5.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) + assert ckpt_name == 'epoch=3-step=2.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) + assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, {}) - assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt') + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with ver ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test' - ).format_checkpoint_name(3, {}, ver=3) + ).format_checkpoint_name(3, 2, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' # using slashes ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}' - ).format_checkpoint_name(4, {'val/loss': 0.03}) + ).format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' # TODO: Checks with filepath. To be removed in v1.2 # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {}) - assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt') + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt') # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) - assert ckpt_name == filepath / 'test-epoch=3.ckpt' + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) + assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" + seed_everything() model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' @@ -257,10 +262,15 @@ def test_model_checkpoint_save_last(tmpdir): logger=False, ) trainer.fit(model) - last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) + last_filename = model_checkpoint._format_checkpoint_name( + ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} + ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename]) + assert set(os.listdir(tmpdir)) == set( + [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] + ) + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' @@ -295,6 +305,7 @@ def test_none_monitor_save_last(tmpdir): def test_model_checkpoint_none_monitor(tmpdir): """ Test that it is possible to save all checkpoints when monitor=None. """ + seed_everything() model = EvalModelTemplate() model.validation_step = model.validation_step_no_monitor model.validation_epoch_end = model.validation_epoch_end_no_monitor @@ -311,13 +322,13 @@ def test_model_checkpoint_none_monitor(tmpdir): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt' + assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' assert checkpoint_callback.best_model_score == 0 assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs)] + expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] assert set(os.listdir(tmpdir)) == set(expected) @@ -325,13 +336,14 @@ def test_model_checkpoint_none_monitor(tmpdir): def test_model_checkpoint_period(tmpdir, period): model = EvalModelTemplate() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, max_epochs=epochs, limit_train_batches=0.1, limit_val_batches=0.1, + val_check_interval=1.0, logger=False, ) trainer.fit(model) @@ -372,12 +384,19 @@ def validation_epoch_end(self, outputs): return {'epoch': self.current_epoch} model = CustomModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", mode='max', save_top_k=-1) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename="{epoch}", + monitor="epoch", + mode='max', + save_top_k=-1, + ) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, max_epochs=epochs, logger=False, + val_check_interval=1.0, ) trainer.fit(model) @@ -439,7 +458,7 @@ def test_default_checkpoint_behavior(tmpdir): # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 - assert ckpts[0] == 'epoch=2.ckpt' + assert ckpts[0] == 'epoch=2-step=14.ckpt' def test_ckpt_metric_names_results(tmpdir): @@ -497,7 +516,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint( - monitor='early_stop_on', dirpath=tmpdir, save_top_k=num_epochs, save_last=True + monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, @@ -509,6 +528,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) @@ -779,15 +799,37 @@ def test_configure_model_checkpoint(tmpdir): assert trainer.checkpoint_callback == callback1 assert trainer.checkpoint_callbacks == [callback1, callback2] - with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'): + with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.3'): trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs) assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] assert trainer.checkpoint_callback == callback1 - with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"): + with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.3"): trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs) assert trainer.checkpoint_callback == callback2 assert trainer.checkpoint_callbacks == [callback2, callback1] with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) + + +def test_val_check_interval_checkpoint_files(tmpdir): + """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ + model = EvalModelTemplate() + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=-1, + monitor="val_acc", + mode="max", + verbose=True + ) + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=0.2, + max_epochs=1, + limit_train_batches=10, + callbacks=[model_checkpoint] + ) + trainer.fit(model) + files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")]) + assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]] diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 87af510e49219..fc61829645b6e 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -159,7 +159,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} @patch('pytorch_lightning.loggers.comet.comet_ml') diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index db2c353dc4e2c..a200fbf549e6a 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -115,7 +115,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): ) trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'} def test_mlflow_logger_dirs_creation(tmpdir): @@ -143,7 +143,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index cfb6533bd913b..468ca819f91b1 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel @mock.patch('pytorch_lightning.loggers.wandb.wandb') @@ -116,7 +116,7 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_wandb_sanitize_callable_params(tmpdir): @@ -135,6 +135,8 @@ def return_something(): def wrapper_something(): return return_something + + params.wrapper_something_wo_name = lambda: lambda: '1' params.wrapper_something = wrapper_something assert isinstance(params.gpus, types.FunctionType) @@ -144,3 +146,4 @@ def wrapper_something(): assert params["gpus"] == '_gpus_arg_default' assert params["something"] == "something" assert params["wrapper_something"] == "wrapper_something" + assert params["wrapper_something_wo_name"] == "" diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index c0d5747b5fc7e..6fd000b61d97f 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -84,3 +84,65 @@ def on_fit_start(self, trainer, pl_module): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1 + ) + trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + trainer.fit(model) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 67f38568e2103..de3fb63fe9664 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -15,8 +15,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -def test_tbd_remove_in_v1_4_0(tmpdir): - with pytest.deprecated_call(match='will no longer be supported in v1.4'): +def test_tbd_remove_in_v1_3_0(tmpdir): + with pytest.deprecated_call(match='will no longer be supported in v1.3'): callback = ModelCheckpoint() Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py new file mode 100644 index 0000000000000..0f27f2ca4fef4 --- /dev/null +++ b/tests/trainer/logging/test_logger_connector.py @@ -0,0 +1,249 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import os +import torch +import pytest + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from tests.base.boring_model import BoringModel, RandomDataset + + +class Helper: + def decorator_with_arguments(fx_name='', hook_fx_name=''): + def decorator(func): + def wrapper(self, *args, **kwargs): + # Set information + self._current_fx_name = fx_name + self._current_hook_fx_name = hook_fx_name + self._results = Result() + + result = func(self, *args, **kwargs) + + # cache metrics + self.trainer.logger_connector.cache_logged_metrics() + return result + return wrapper + + return decorator + + +def test__logger_connector__epoch_result_store__train(tmpdir): + """ + Tests that LoggerConnector will properly capture logged information + and reduce them + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + + train_losses = [] + + @Helper.decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + + self.train_losses.append(loss) + + self.log("train_loss", loss, on_step=True, on_epoch=True) + return {"loss": loss} + + def val_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64))] + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']) == 2 + assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0'][0]["train_loss"] == model.train_losses[0] + assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['1'][0]["train_loss"] == model.train_losses[1] + + # assert reduction didn't happen yet + assert trainer.logger_connector.cached_results("train").has_reduced is False + + # Launch reduction + trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + + # assert reduction did happen + assert trainer.logger_connector.cached_results("train").has_reduced is True + + assert trainer.logger_connector.cached_results("train")["training_step"]\ + ._internals_reduced["0"]["0"]['train_loss_epoch'].item() == torch.stack(model.train_losses).mean().item() + + +def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): + """ + Tests that LoggerConnector will properly capture logged information with ttbt + and reduce them + """ + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class TestModel(BoringModel): + + train_losses = [] + + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.Linear(2, 2) + + @Helper.decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx, hiddens): + try: + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + except Exception as e: + print(e) + + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss = torch.nn.functional.mse_loss( + pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + self.train_losses.append(loss) + + self.log('a', loss, on_epoch=True) + + return {'loss': loss, 'hiddens': self.test_hidden} + + def on_train_epoch_start(self) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = TestModel() + model.training_epoch_end = None + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=0, + truncated_bptt_steps=truncated_bptt_steps, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0']) == len(model.train_losses) + + # assert reduction didn't happen yet + assert trainer.logger_connector.cached_results("train").has_reduced is False + + # Launch reduction + trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + + # assert reduction did happen + assert trainer.logger_connector.cached_results("train").has_reduced is True + + assert trainer.logger_connector.cached_results("train")['training_step']\ + ._internals_reduced['0']['0']["a_epoch"].item() == torch.stack(model.train_losses).mean().item() + + +@pytest.mark.parametrize('num_dataloaders', [1, 2]) +def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): + """ + Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + + test_losses = {} + + @Helper.decorator_with_arguments(fx_name="test_step") + def test_step(self, batch, batch_idx, dataloader_idx=0): + output = self.layer(batch) + loss = self.loss(batch, output) + + primary_key = str(dataloader_idx) + if primary_key not in self.test_losses: + self.test_losses[primary_key] = [] + + self.test_losses[primary_key].append(loss) + + self.log("test_loss", loss, on_step=True, on_epoch=True) + return {"test_loss": loss} + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] + + model = TestModel() + model.val_dataloader = None + model.test_epoch_end = None + + limit_test_batches = 4 + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=limit_test_batches, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.test(model) + + assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals) == num_dataloaders + for dl_idx in range(num_dataloaders): + assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals[str(dl_idx)]) == limit_test_batches + trainer.logger_connector.cached_results("test").has_batch_loop_finished = True + for dl_idx in range(num_dataloaders): + expected = torch.stack(model.test_losses[str(dl_idx)]).mean() + generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] + assert abs(expected.item() - generated.item()) < 1e-6 diff --git a/tests/trainer/logging/__init__.py b/tests/trainer/logging_tests/__init__.py similarity index 100% rename from tests/trainer/logging/__init__.py rename to tests/trainer/logging_tests/__init__.py diff --git a/tests/trainer/logging/test_distributed_logging.py b/tests/trainer/logging_tests/test_distributed_logging.py similarity index 100% rename from tests/trainer/logging/test_distributed_logging.py rename to tests/trainer/logging_tests/test_distributed_logging.py diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py similarity index 100% rename from tests/trainer/logging/test_eval_loop_logging_1_0.py rename to tests/trainer/logging_tests/test_eval_loop_logging_1_0.py diff --git a/tests/trainer/logging/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py similarity index 100% rename from tests/trainer/logging/test_train_loop_logging_1_0.py rename to tests/trainer/logging_tests/test_train_loop_logging_1_0.py diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 35257e28704ba..6fceae4b5e59d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -430,7 +430,7 @@ def mock_save_function(filepath, *args): losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, + dirpath=tmpdir, filename='{epoch}', monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1 ) checkpoint_callback.save_function = mock_save_function diff --git a/tests/trainer/warnings/__init__.py b/tests/trainer/warnings_tests/__init__.py similarity index 100% rename from tests/trainer/warnings/__init__.py rename to tests/trainer/warnings_tests/__init__.py diff --git a/tests/trainer/warnings/test_flow_warnings.py b/tests/trainer/warnings_tests/test_flow_warnings.py similarity index 100% rename from tests/trainer/warnings/test_flow_warnings.py rename to tests/trainer/warnings_tests/test_flow_warnings.py diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index b0a3497a0f3be..10de63db049e7 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time + import pytest -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu +import pytorch_lightning.utilities.xla_device_utils as xla_utils from tests.base.develop_utils import pl_multi_process_test try: import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True except ImportError as e: XLA_AVAILABLE = False @@ -26,13 +29,13 @@ @pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" - assert xdu.tpu_device_exists() is None + assert xla_utils.XLADeviceUtils.tpu_device_exists() is None @pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xdu.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() is True @pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") @@ -42,3 +45,14 @@ def test_xla_device_is_a_tpu(): device = xm.xla_device() device_type = xm.xla_device_hw(device) return device_type == "TPU" + + +def test_result_returns_within_10_seconds(): + """Check that pl_multi_process returns within 10 seconds""" + + start = time.time() + result = xla_utils.pl_multi_process(time.sleep)(25) + end = time.time() + elapsed_time = int(end - start) + assert elapsed_time <= 10 + assert result is False