diff --git a/CHANGELOG.md b/CHANGELOG.md index 58569e04fd..707b37b885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598)) +- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720)) + + ### Changed - Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701)) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst deleted file mode 100644 index 0ba954a093..0000000000 --- a/docs/source/callbacks.rst +++ /dev/null @@ -1,40 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Build a Callback -================ -This module houses a collection of callbacks that can be passed into the trainer - -.. code-block:: python - - from pl_bolts.callbacks import PrintTableMetricsCallback - - trainer = pl.Trainer(callbacks=[PrintTableMetricsCallback()]) - - # loss│train_loss│val_loss│epoch - # ────────────────────────────── - # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0 - ------------------- - -What is a Callback ------------------- -A callback is a self-contained program that can be intertwined into a training pipeline without polluting the main -research logic. - ---------------- - -Create a Callback ------------------ -Creating a callback is simple: - -.. code-block:: python - - from pytorch_lightning.callbacks import Callback - - class MyCallback(Callback) - def on_epoch_end(self, trainer, pl_module): - # do something - -Please refer to `Callback docs `_ -for a full list of the 20+ hooks available. diff --git a/docs/source/info_callbacks.rst b/docs/source/callbacks/monitor_callbacks.rst similarity index 99% rename from docs/source/info_callbacks.rst rename to docs/source/callbacks/monitor_callbacks.rst index b65136675d..20ea42d074 100644 --- a/docs/source/info_callbacks.rst +++ b/docs/source/callbacks/monitor_callbacks.rst @@ -1,8 +1,8 @@ .. role:: hidden :class: hidden-section -Info Callbacks -============== +Monitoring Callbacks +==================== These callbacks give all sorts of useful information during training. diff --git a/docs/source/self_supervised_callbacks.rst b/docs/source/callbacks/self_supervised_callbacks.rst similarity index 100% rename from docs/source/self_supervised_callbacks.rst rename to docs/source/callbacks/self_supervised_callbacks.rst diff --git a/docs/source/callbacks/torch_ort.rst b/docs/source/callbacks/torch_ort.rst new file mode 100644 index 0000000000..a533eacaa8 --- /dev/null +++ b/docs/source/callbacks/torch_ort.rst @@ -0,0 +1,35 @@ +================== +Torch ORT Callback +================== + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions `here `__. + +This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as `self.model` within the ``LightningModule`` as shown below. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + from pytorch_lightning import LightningModule, Trainer + from transformers import AutoModel + + from pl_bolts.callbacks import ORTCallback + + + class MyTransformerModel(LightningModule): + + def __init__(self): + super().__init__() + self.model = AutoModel.from_pretrained('bert-base-cased') + + ... + + + model = MyTransformerModel() + trainer = Trainer(gpus=1, callbacks=ORTCallback()) + trainer.fit(model) + + +For even easier setup and integration, have a look at our Lightning Flash integration for :ref:`Text Classification `, :ref:`Translation ` and :ref:`Summarization `. diff --git a/docs/source/variational_callbacks.rst b/docs/source/callbacks/variational_callbacks.rst similarity index 88% rename from docs/source/variational_callbacks.rst rename to docs/source/callbacks/variational_callbacks.rst index 3f33e1f20d..1a514c1584 100644 --- a/docs/source/variational_callbacks.rst +++ b/docs/source/callbacks/variational_callbacks.rst @@ -13,7 +13,7 @@ Interpolates latent dims. Example output: - .. image:: _images/gans/basic_gan_interpolate.jpg + .. image:: ../_images/gans/basic_gan_interpolate.jpg :width: 400 :alt: Example latent space interpolation diff --git a/docs/source/vision_callbacks.rst b/docs/source/callbacks/vision_callbacks.rst similarity index 93% rename from docs/source/vision_callbacks.rst rename to docs/source/callbacks/vision_callbacks.rst index 23bfb1861a..5662bf13f0 100644 --- a/docs/source/vision_callbacks.rst +++ b/docs/source/callbacks/vision_callbacks.rst @@ -14,7 +14,7 @@ Shows how the input would have to change to move the prediction from one logit t Example outputs: - .. image:: _images/vision/confused_logit.png + .. image:: ../_images/vision/confused_logit.png :width: 400 :alt: Example of prediction confused between 5 and 8 diff --git a/docs/source/conf.py b/docs/source/conf.py index d7c2c3499a..ea38d1de68 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -261,6 +261,7 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None), + "lightning_flash": ("https://lightning-flash.readthedocs.io/en/latest/", None), "python": ("https://docs.python.org/3", None), "torch": ("https://pytorch.org/docs/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), @@ -385,7 +386,7 @@ def find_source(): # This value determines the text for the permalink; it defaults to "¶". Set it to None or the empty # string to disable permalinks. # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-html_add_permalinks -html_add_permalinks = "¶" +html_permalinks_icon = "¶" # True to prefix each section label with the name of the document it is in, followed by a colon. # For example, index:Introduction for a section called Introduction that appears in document index.rst. diff --git a/docs/source/index.rst b/docs/source/index.rst index 9304469781..860b82ba12 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -19,11 +19,11 @@ Lightning-Bolts documentation :name: callbacks :caption: Callbacks - callbacks - info_callbacks - self_supervised_callbacks - variational_callbacks - vision_callbacks + callbacks/monitor_callbacks + callbacks/self_supervised_callbacks + callbacks/variational_callbacks + callbacks/vision_callbacks + callbacks/torch_ort .. toctree:: :maxdepth: 2 diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index 0909c2638c..9558aed816 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -3,6 +3,7 @@ from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator +from pl_bolts.callbacks.torch_ort import ORTCallback from pl_bolts.callbacks.variational import LatentDimInterpolator from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback @@ -18,4 +19,5 @@ "LatentDimInterpolator", "ConfusedLogitCallback", "TensorboardGenerativeModelImageSampler", + "ORTCallback", ] diff --git a/pl_bolts/callbacks/torch_ort.py b/pl_bolts/callbacks/torch_ort.py new file mode 100644 index 0000000000..fec900e1e1 --- /dev/null +++ b/pl_bolts/callbacks/torch_ort.py @@ -0,0 +1,51 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.utils import _TORCH_ORT_AVAILABLE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +class ORTCallback(Callback): + """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. + + Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for + training and inference. + + Usage: + + # via Transformer Tasks + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) + + # or via the trainer + trainer = flash.Trainer(callbacks=ORTCallback()) + """ + + def __init__(self): + if not _TORCH_ORT_AVAILABLE: + raise MisconfigurationException( + "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" + ) + + def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: + if not hasattr(pl_module, "model"): + raise MisconfigurationException( + "Torch ORT requires to wrap a single model that defines a forward function " + "assigned as `model` inside the `LightningModule`." + ) + if not isinstance(pl_module.model, ORTModule): + pl_module.model = ORTModule(pl_module.model) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index d0f032fc02..8a77cd66e7 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -39,5 +39,6 @@ def _compare_version(package: str, op, version) -> bool: _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") _PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") +_TORCH_ORT_AVAILABLE = _module_available("torch_ort") __all__ = ["BatchGradientVerification"] diff --git a/tests/callbacks/test_ort.py b/tests/callbacks/test_ort.py new file mode 100644 index 0000000000..38897b1f3b --- /dev/null +++ b/tests/callbacks/test_ort.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.callbacks import ORTCallback +from pl_bolts.utils import _TORCH_ORT_AVAILABLE +from tests.helpers.boring_model import BoringModel + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_init_train_enable_ort(tmpdir): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(pl_module.model, ORTModule) + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.model = self.layer + + def forward(self, x): + return self.model(x) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[ORTCallback(), TestCallback()]) + trainer.fit(model) + trainer.test(model) + + +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_ort_callback_fails_no_model(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) + with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): + trainer.fit( + model, + ) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py new file mode 100644 index 0000000000..513c7ab227 --- /dev/null +++ b/tests/helpers/boring_model.py @@ -0,0 +1,183 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from pytorch_lightning import LightningDataModule, LightningModule +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset + + +class RandomDictDataset(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {"a": a, "b": b} + + def __len__(self): + return self.len + + +class RandomDataset(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class RandomIterableDataset(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count + + +class BoringModel(LightningModule): + def __init__(self): + """Testing PL Module. + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x["x"] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + + def train_dataloader(self): + return DataLoader(self.random_train) + + def val_dataloader(self): + return DataLoader(self.random_val) + + def test_dataloader(self): + return DataLoader(self.random_test) + + def predict_dataloader(self): + return DataLoader(self.random_predict)