diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py index 26d23855c446..ec0f80dee03b 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py @@ -14,13 +14,22 @@ # limitations under the License. # +from contextlib import contextmanager +from functools import partial +from typing import Any, Union, Callable + import torch +from torch.nn import Module +from torch.optim import Optimizer + import pytorch_lightning as pl from pytorch_lightning.strategies import SingleDeviceStrategy from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import PrecisionPlugin + from bigdl.nano.utils.log4Error import invalidInputError import intel_extension_for_pytorch as ipex +from intel_extension_for_pytorch.optim._optimizer_utils import IPEX_FUSED_OPTIMIZER_LIST from .ipex_accelerator import IPEXAccelerator @@ -44,6 +53,9 @@ def __init__( """ self.enable_bf16 = enable_bf16 + if enable_bf16 and isinstance(precision_plugin, PrecisionPlugin): + precision_plugin = IPEXBF16Precision() + super().__init__(accelerator=accelerator, precision_plugin=precision_plugin) def setup(self, trainer: pl.Trainer) -> None: @@ -63,3 +75,29 @@ def setup(self, trainer: pl.Trainer) -> None: ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=dtype) else: invalidInputError(False, "Ipex does not support more than one optimizers.") + + +class IPEXBF16Precision(PrecisionPlugin): + """Create Precision Plugin for IPEX BFloat16.""" + + @contextmanager + def forward_context(self): + """PyTorch AMP for managing model forward/training_step/evaluation_step/predict_step.""" + with torch.cpu.amp.autocast(): + yield + + def optimizer_step(self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any) -> Any: + """Hook to run the optimizer step.""" + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + + # Automatically call closure for optimizer not supported by IPEX + if type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST: + closure() + + return optimizer.step(closure, **kwargs) diff --git a/python/nano/test/pytorch/tests/test_trainer_ipex.py b/python/nano/test/pytorch/tests/test_trainer_ipex.py index 16847bcc35a2..b5f24be99167 100644 --- a/python/nano/test/pytorch/tests/test_trainer_ipex.py +++ b/python/nano/test/pytorch/tests/test_trainer_ipex.py @@ -66,6 +66,47 @@ def test_trainer_save_checkpoint(self): pl_model = Trainer.compile(self.model, self.loss, self.optimizer, self.scheduler_dict) trainer.fit(pl_model, self.train_loader) + def test_trainer_ipex_bf16(self): + trainer = Trainer(max_epochs=max_epochs, use_ipex=True, enable_bf16=True) + + # use_ipex=True will perform inplace optimization + model = ResNet18(10, pretrained=False, include_top=False, freeze=True) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss = nn.CrossEntropyLoss() + scheduler_dict = { + "scheduler": OneCycleLR( + optimizer, + 0.1, + epochs=max_epochs, + steps_per_epoch=len(self.train_loader), + ), + "interval": "step", + } + + pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict) + trainer.fit(pl_model, self.train_loader) + trainer.test(pl_model, self.train_loader) + + def test_trainer_ipex_bf16_unspport_optim(self): + trainer = Trainer(max_epochs=max_epochs, use_ipex=True, enable_bf16=True) + + model = ResNet18(10, pretrained=False, include_top=False, freeze=True) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=5e-4) + loss = nn.CrossEntropyLoss() + scheduler_dict = { + "scheduler": OneCycleLR( + optimizer, + 0.1, + epochs=max_epochs, + steps_per_epoch=len(self.train_loader), + ), + "interval": "step", + } + + pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict) + trainer.fit(pl_model, self.train_loader) + trainer.test(pl_model, self.train_loader) + if __name__ == '__main__': pytest.main([__file__])