Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Torch ORT Callback, re-write callback section #720

Merged
merged 4 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))


- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))


### Changed

- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
Expand Down
40 changes: 0 additions & 40 deletions docs/source/callbacks.rst

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.. role:: hidden
:class: hidden-section

Info Callbacks
==============
Monitoring Callbacks
====================

These callbacks give all sorts of useful information during training.

Expand Down
35 changes: 35 additions & 0 deletions docs/source/callbacks/torch_ort.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
==================
Torch ORT Callback
==================

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as `self.model` within the ``LightningModule`` as shown below.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python

from pytorch_lightning import LightningModule, Trainer
from transformers import AutoModel

from pl_bolts.callbacks import ORTCallback


class MyTransformerModel(LightningModule):

def __init__(self):
super().__init__()
self.model = AutoModel.from_pretrained('bert-base-cased')

...


model = MyTransformerModel()
trainer = Trainer(gpus=1, callbacks=ORTCallback())
trainer.fit(model)


For even easier setup and integration, have a look at our Lightning Flash integration for `Text Classification <https://lightning-flash.readthedocs.io/en/latest/reference/text_classification.html#accelerate-training-inference-with-torch-ort>`__, `Summarization <https://lightning-flash.readthedocs.io/en/latest/reference/summarization.html#accelerate-training-inference-with-torch-ort>`__ and `Translation <https://lightning-flash.readthedocs.io/en/latest/reference/translation.html#accelerate-training-inference-with-torch-ort>`__.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Interpolates latent dims.

Example output:

.. image:: _images/gans/basic_gan_interpolate.jpg
.. image:: ../_images/gans/basic_gan_interpolate.jpg
:width: 400
:alt: Example latent space interpolation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Shows how the input would have to change to move the prediction from one logit t

Example outputs:

.. image:: _images/vision/confused_logit.png
.. image:: ../_images/vision/confused_logit.png
:width: 400
:alt: Example of prediction confused between 5 and 8

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def find_source():
# This value determines the text for the permalink; it defaults to "¶". Set it to None or the empty
# string to disable permalinks.
# https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-html_add_permalinks
html_add_permalinks = "¶"
html_permalinks_icon = "¶"

# True to prefix each section label with the name of the document it is in, followed by a colon.
# For example, index:Introduction for a section called Introduction that appears in document index.rst.
Expand Down
10 changes: 5 additions & 5 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ Lightning-Bolts documentation
:name: callbacks
:caption: Callbacks

callbacks
info_callbacks
self_supervised_callbacks
variational_callbacks
vision_callbacks
callbacks/monitor_callbacks
callbacks/self_supervised_callbacks
callbacks/variational_callbacks
callbacks/vision_callbacks
callbacks/torch_ort

.. toctree::
:maxdepth: 2
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.callbacks.printing import PrintTableMetricsCallback
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.callbacks.torch_ort import ORTCallback
from pl_bolts.callbacks.variational import LatentDimInterpolator
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
Expand All @@ -18,4 +19,5 @@
"LatentDimInterpolator",
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
"ORTCallback",
]
51 changes: 51 additions & 0 deletions pl_bolts/callbacks/torch_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.utils import _TORCH_ORT_AVAILABLE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


class ORTCallback(Callback):
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.

Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
training and inference.

Usage:

# via Transformer Tasks
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)

# or via the trainer
trainer = flash.Trainer(callbacks=ORTCallback())
"""

def __init__(self):
if not _TORCH_ORT_AVAILABLE:
raise MisconfigurationException(
"Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
)

def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not hasattr(pl_module, "model"):
raise MisconfigurationException(
"Torch ORT requires to wrap a single model that defines a forward function "
"assigned as `model` inside the `LightningModule`."
)
if not isinstance(pl_module.model, ORTModule):
pl_module.model = ORTModule(pl_module.model)
1 change: 1 addition & 0 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def _compare_version(package: str, op, version) -> bool:
_MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1")
_PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

__all__ = ["BatchGradientVerification"]
55 changes: 55 additions & 0 deletions tests/callbacks/test_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.callbacks import ORTCallback
from pl_bolts.utils import _TORCH_ORT_AVAILABLE
from tests.helpers.boring_model import BoringModel

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_init_train_enable_ort(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert isinstance(pl_module.model, ORTModule)

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.model = self.layer

def forward(self, x):
return self.model(x)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[ORTCallback(), TestCallback()])
trainer.fit(model)
trainer.test(model)


@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_ort_callback_fails_no_model(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
trainer.fit(
model,
)
Empty file added tests/helpers/__init__.py
Empty file.
Loading