Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 8, 2022
1 parent c918ebc commit 1d82838
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@
# limitations under the License.
#

from contextlib import contextmanager
from functools import partial
from typing import Any, Union, Callable

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import PrecisionPlugin

from bigdl.nano.utils.log4Error import invalidInputError
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.optim._optimizer_utils import IPEX_FUSED_OPTIMIZER_LIST

from .ipex_accelerator import IPEXAccelerator

Expand All @@ -44,6 +53,9 @@ def __init__(
"""
self.enable_bf16 = enable_bf16

if enable_bf16 and isinstance(precision_plugin, PrecisionPlugin):
precision_plugin = IPEXBF16Precision()

super().__init__(accelerator=accelerator, precision_plugin=precision_plugin)

def setup(self, trainer: pl.Trainer) -> None:
Expand All @@ -63,3 +75,29 @@ def setup(self, trainer: pl.Trainer) -> None:
ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")


class IPEXBF16Precision(PrecisionPlugin):
"""Create Precision Plugin for IPEX BFloat16."""

@contextmanager
def forward_context(self):
"""PyTorch AMP for managing model forward/training_step/evaluation_step/predict_step."""
with torch.cpu.amp.autocast():
yield

def optimizer_step(self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any) -> Any:
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)

# Automatically call closure for optimizer not supported by IPEX
if type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST:
closure()

return optimizer.step(closure, **kwargs)
41 changes: 41 additions & 0 deletions python/nano/test/pytorch/tests/test_trainer_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,47 @@ def test_trainer_save_checkpoint(self):
pl_model = Trainer.compile(self.model, self.loss, self.optimizer, self.scheduler_dict)
trainer.fit(pl_model, self.train_loader)

def test_trainer_ipex_bf16(self):
trainer = Trainer(max_epochs=max_epochs, use_ipex=True, enable_bf16=True)

# use_ipex=True will perform inplace optimization
model = ResNet18(10, pretrained=False, include_top=False, freeze=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()
scheduler_dict = {
"scheduler": OneCycleLR(
optimizer,
0.1,
epochs=max_epochs,
steps_per_epoch=len(self.train_loader),
),
"interval": "step",
}

pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict)
trainer.fit(pl_model, self.train_loader)
trainer.test(pl_model, self.train_loader)

def test_trainer_ipex_bf16_unspport_optim(self):
trainer = Trainer(max_epochs=max_epochs, use_ipex=True, enable_bf16=True)

model = ResNet18(10, pretrained=False, include_top=False, freeze=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=5e-4)
loss = nn.CrossEntropyLoss()
scheduler_dict = {
"scheduler": OneCycleLR(
optimizer,
0.1,
epochs=max_epochs,
steps_per_epoch=len(self.train_loader),
),
"interval": "step",
}

pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict)
trainer.fit(pl_model, self.train_loader)
trainer.test(pl_model, self.train_loader)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 1d82838

Please sign in to comment.