diff --git a/python/nano/src/bigdl/nano/deps/ray/ray_distributed.py b/python/nano/src/bigdl/nano/deps/ray/ray_distributed.py index da16615cb8c..32d7cc7bcb6 100644 --- a/python/nano/src/bigdl/nano/deps/ray/ray_distributed.py +++ b/python/nano/src/bigdl/nano/deps/ray/ray_distributed.py @@ -176,7 +176,7 @@ def __init__(self, num_cpus_per_worker: int = 1, use_gpu: bool = False, use_ipex: bool = False, - enable_bf16: bool = False, + dtype=None, init_hook: Callable = None, auto_lr: Union[bool, dict] = True, **ddp_kwargs: Any): @@ -207,7 +207,7 @@ def __init__(self, self.num_cpus_per_worker = num_cpus_per_worker self.use_gpu = use_gpu self.use_ipex = use_ipex - self.enable_bf16 = enable_bf16 + self.dtype = dtype self.auto_lr = auto_lr invalidInputError(not self.use_gpu or not self.use_ipex, @@ -328,14 +328,12 @@ def _unpack_lightning_optimizer(opt): ] if self.use_ipex and not TORCH_VERSION_LESS_1_10: - dtype = torch.bfloat16 if self.enable_bf16 else None num_optimizers = len(self.optimizers) - if num_optimizers == 1: optimizer = self.optimizers[0] - ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype) + ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype) elif num_optimizers == 0: - ipex_optimize(self.model, inplace=True, dtype=dtype) + ipex_optimize(self.model, inplace=True, dtype=self.dtype) else: warnings.warn(f"IPEX currently only support single optimizers, " f"but got {num_optimizers}. Skip IPEX") diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py b/python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py index 84e6d29f1d2..558e1b58109 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py @@ -168,7 +168,7 @@ def __init__( num_processes: int = 1, cpu_for_each_process: Optional[List[List[int]]] = None, use_ipex=False, - enable_bf16=False, + dtype=None, auto_lr=False, **kwargs: Any ): @@ -180,14 +180,24 @@ def __init__( if use_ipex and TORCH_VERSION_LESS_1_10 and 'accelerator' not in kwargs: super().__init__(accelerator=create_IPEXAccelerator(), parallel_devices=parallel_devices, - cluster_environment=cluster_environment, **kwargs) + cluster_environment=cluster_environment, + **kwargs) + if dtype == torch.bfloat16: + import intel_pytorch_extension as ipex + # Automatically mix precision + ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16) + elif use_ipex and dtype == torch.bfloat16 and 'precision_plugin' not in kwargs: + from bigdl.nano.pytorch.strategies.ipex.ipex_strategy import IPEXBF16Precision + super().__init__(parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + precision_plugin=IPEXBF16Precision(), **kwargs) else: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment, **kwargs) self.cpu_for_each_process = cpu_for_each_process self.is_distributed = True self.use_ipex = use_ipex - self.enable_bf16 = enable_bf16 + self.dtype = dtype self.auto_lr = auto_lr def _configure_launcher(self): @@ -249,14 +259,13 @@ def _unpack_lightning_optimizer(opt): ] if self.use_ipex and not TORCH_VERSION_LESS_1_10: - dtype = torch.bfloat16 if self.enable_bf16 else None num_optimizers = len(self.optimizers) if num_optimizers == 1: optimizer = self.optimizers[0] - ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype) + ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype) elif num_optimizers == 0: - ipex_optimize(self.model, inplace=True, dtype=dtype) + ipex_optimize(self.model, inplace=True, dtype=self.dtype) else: warnings.warn(f"IPEX currently only support single optimizers, " f"but got {num_optimizers}. Skip IPEX") diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py index 26d23855c44..9b987f51d2e 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py @@ -14,13 +14,25 @@ # limitations under the License. # +from contextlib import contextmanager +from functools import partial +from logging import warning +from typing import Any, Union, Callable + import torch +from torch.nn import Module +from torch.optim import Optimizer, LBFGS +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_12 + import pytorch_lightning as pl from pytorch_lightning.strategies import SingleDeviceStrategy from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.utilities import AMPType + from bigdl.nano.utils.log4Error import invalidInputError import intel_extension_for_pytorch as ipex +from intel_extension_for_pytorch.optim._optimizer_utils import IPEX_FUSED_OPTIMIZER_LIST from .ipex_accelerator import IPEXAccelerator @@ -34,7 +46,7 @@ def __init__( self, accelerator: Accelerator = IPEXAccelerator(), precision_plugin: PrecisionPlugin = PrecisionPlugin(), - enable_bf16=False, + dtype=None, ) -> None: """ Create a IPEXStrategy. @@ -42,8 +54,10 @@ def __init__( :param accelerator: the accelerator to handle hardware :param precision_plugin: the plugin to handle precision-specific parts """ - self.enable_bf16 = enable_bf16 + self.dtype = dtype + if self.dtype == torch.bfloat16 and isinstance(precision_plugin, PrecisionPlugin): + precision_plugin = IPEXBF16Precision() super().__init__(accelerator=accelerator, precision_plugin=precision_plugin) def setup(self, trainer: pl.Trainer) -> None: @@ -56,10 +70,56 @@ def setup(self, trainer: pl.Trainer) -> None: """ super().setup(trainer) - dtype = torch.bfloat16 if self.enable_bf16 else None if len(self.optimizers) == 0: - ipex.optimize(self.model, inplace=True, dtype=dtype) + ipex.optimize(self.model, inplace=True, dtype=self.dtype) elif len(self.optimizers) == 1: - ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=dtype) + ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=self.dtype) else: invalidInputError(False, "Ipex does not support more than one optimizers.") + + +class IPEXBF16Precision(PrecisionPlugin): + """Create Precision Plugin for IPEX BFloat16.""" + + precision: Union[str, int] = 'bf16' + + @contextmanager + def forward_context(self): + """AMP for managing model forward/training_step/evaluation_step/predict_step.""" + with torch.cpu.amp.autocast(): + yield + + def optimizer_step(self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any) -> Any: + """Hook to run the optimizer step.""" + if type(optimizer) in IPEX_FUSED_OPTIMIZER_LIST: + return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) + + if isinstance(model, pl.LightningModule): + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + + # Only `torch.optim.LBFGS` need to reevaluate closure multiple times + # in optimizer.step(...) now. + if isinstance(optimizer, LBFGS): + invalidInputError(False, + "IPEX BFloat16 and the LBFGS optimizer are not compatible " + f"(optimizer {optimizer_idx}", + "Hint: Set 'use_ipex' to False or not set 'precision' to 'bf16'" + " if LBFGS optimizer is necessary") + + # Detect custom optimzer + if type(optimizer).__name__ not in dir(torch.optim): + warning("Seems like you are using a custom optimizer," + "please make sure that 'optimizer.step(closure)'" + " does not need to be called in training stage") + + # For optimizer not in IPEX_FUSED_OPTIMIZER_LIST, + # `closure()` needs to be called to backward the loss to avoid `.grad` being None + closure_result = closure() + optimizer.step(**kwargs) + + return closure_result diff --git a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/version_1_9/ipex_strategy_1_9.py b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/version_1_9/ipex_strategy_1_9.py index 71d3b72dd44..d7d93a32351 100644 --- a/python/nano/src/bigdl/nano/pytorch/strategies/ipex/version_1_9/ipex_strategy_1_9.py +++ b/python/nano/src/bigdl/nano/pytorch/strategies/ipex/version_1_9/ipex_strategy_1_9.py @@ -39,7 +39,7 @@ def __init__( self, accelerator: Accelerator = IPEXAccelerator(), # type: ignore precision_plugin: PrecisionPlugin = PrecisionPlugin(), - enable_bf16=False, + dtype=None, ) -> None: """ Create a IPEXStrategy. @@ -47,7 +47,7 @@ def __init__( :param accelerator: the accelerator to handle hardware :param precision_plugin: the plugin to handle precision-specific parts """ - if enable_bf16: + if dtype == torch.bfloat16: # Automatically mix precision ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16) diff --git a/python/nano/src/bigdl/nano/pytorch/torch_nano.py b/python/nano/src/bigdl/nano/pytorch/torch_nano.py index eee05e66142..4d9a461393f 100644 --- a/python/nano/src/bigdl/nano/pytorch/torch_nano.py +++ b/python/nano/src/bigdl/nano/pytorch/torch_nano.py @@ -52,45 +52,60 @@ class TorchNano(LightningLite): def __init__(self, num_processes: int = 1, use_ipex: bool = False, - enable_bf16: bool = False, strategy: str = "subprocess", + precision: Union[str, int] = 32, *args, **kwargs) -> None: """ Create a TorchNano with nano acceleration. :param num_processes: number of processes in distributed training, defaults to 1 :param use_ipex: whether use ipex acceleration, defaults to False - :param enable_bf16: whether use bf16 acceleration, defaults to False :param strategy: use which backend in distributed mode, defaults to "subprocess", \ now avaiable strategies are 'spawn', 'subprocess' and 'ray' + :param precision: Double precision (64), full precision (32), half precision (16) + or bfloat16 precision (bf16), defaults to 32. + Enable ipex bfloat16 weight prepack when `use_ipex=True` and `precision='bf16'` """ self.num_processes = num_processes self.use_ipex = use_ipex - self.enable_bf16 = enable_bf16 - - if TORCH_VERSION_LESS_1_11 and use_ipex and not check_avx512(): - warning("Enable ipex<=1.10 in a cpu instruction set" - " without avx512 will crash." - "Fall back to regular pytorch.") - self.use_ipex = False + self.dtype = None + if self.use_ipex and precision == 'bf16': + # Enable ipex bfloat16 weight prepack and disable native AMP + self.dtype = torch.bfloat16 + precision = 32 + + # Confirm if cpu supports AVX512 + if self.use_ipex and not check_avx512(): + if TORCH_VERSION_LESS_1_11: + warning("Enable ipex<=1.10 in a cpu instruction set" + " without avx512 will crash." + "Fall back to regular pytorch.") + self.use_ipex = False + elif self.dtype == torch.bfloat16: + warning("Enable IPEX bfloat16 in a cpu instruction set" + " without avx512 will crash. " + "Using 32-bit precision") + self.dtype = None + + kwargs['precision'] = precision if self.num_processes == 1: if self.use_ipex: - strategy = create_IPEXStrategy(enable_bf16=self.enable_bf16) + strategy = create_IPEXStrategy(dtype=self.dtype) else: strategy = None # type: ignore elif strategy == "spawn": strategy = DDPSpawnStrategy(num_processes=self.num_processes, # type: ignore use_ipex=self.use_ipex, - enable_bf16=self.enable_bf16) + dtype=self.dtype) elif strategy == "subprocess": strategy = DDPSubprocessStrategy(num_processes=self.num_processes, # type: ignore use_ipex=self.use_ipex, - enable_bf16=self.enable_bf16) + dtype=self.dtype) elif strategy == "ray": strategy = create_RayStrategy(num_workers=self.num_processes, use_ipex=self.use_ipex, - enable_bf16=self.enable_bf16) + dtype=self.dtype) else: warning(f"Bigdl-nano doesn't support '{strategy}' strategy now, " f"'{strategy}' strategy of pytorch_lightning will be used. " @@ -118,16 +133,6 @@ def _setup( # so we have to add optimizations in this method, which will be called in # user defined `train()` method. - # add IPEX 1.11's optimization - if self.use_ipex and not TORCH_VERSION_LESS_1_10: - dtype = torch.bfloat16 if self.enable_bf16 else None - if len(optimizers) == 0: - ipex_optimize(model, inplace=True, dtype=dtype) - elif len(optimizers) == 1: - ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=dtype) - else: - invalidInputError(False, "Ipex does not support more than one optimizers.") - # the following codes are copied from pl's LightningLite's `setup` method, # ipex 1.9 requires `_move_model_to_device` after `_setup_model_and_optimizers`, but # pl's `setup` method calls `_move_model_to_device` before `_setup_model_and_optimizers`, @@ -135,6 +140,18 @@ def _setup( self._validate_setup(model, optimizers) model, optimizers = self._strategy._setup_model_and_optimizers(model, optimizers) + + # IPEX bfloat16 optimization will cast model parameters to `torch.bfloat16` + # which is not supported by ddp currently, + # so add IPEX 1.11's optimization after `_setup_model` + if self.use_ipex and not TORCH_VERSION_LESS_1_10: + if len(optimizers) == 0: + ipex_optimize(model, inplace=True, dtype=self.dtype) + elif len(optimizers) == 1: + ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=self.dtype) + else: + invalidInputError(False, "Ipex does not support more than one optimizers.") + if move_to_device: model = self._move_model_to_device(model=model, optimizers=optimizers) model = _TorchNanoModule(model, self._precision_plugin) diff --git a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py index de7ba402026..85cbcecdc4f 100644 --- a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py +++ b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py @@ -57,12 +57,12 @@ class Trainer(pl.Trainer): def __init__(self, num_processes: int = 1, use_ipex: bool = False, - enable_bf16=False, distributed_backend="subprocess", cpu_for_each_process: Optional[List[List[int]]] = None, use_hpo=False, channels_last: bool = False, auto_lr: Union[int, bool] = True, + precision: Union[str, int] = 32, *args: Any, **kwargs: Any) -> None: """ A pytorch lightning trainer that uses bigdl-nano optimization. @@ -72,6 +72,9 @@ def __init__(self, num_processes: int = 1, :param cpu_for_each_process: A list of length `num_processes`, each containing a list of indices of cpus each process will be using. default: None, and the cpu will be automatically and evenly distributed among processes. + :param precision: Double precision (64), full precision (32), half precision (16) + or bfloat16 precision (bf16), defaults to 32. + Enable ipex bfloat16 weight prepack when `use_ipex=True` and `precision='bf16'` """ # Check keyword arguments if "accelerator" in kwargs: @@ -103,17 +106,31 @@ def __init__(self, num_processes: int = 1, else: kwargs["callbacks"] = [ChannelsLastCallback()] - if TORCH_VERSION_LESS_1_11 and use_ipex and not check_avx512(): - warning("Enable ipex<=1.10 in a cpu instruction set" - " without avx512 will crash." - "Fall back to regular pytorch.") - use_ipex = False - self.use_ipex = use_ipex + dtype = None + if self.use_ipex and precision == 'bf16': + # Enable ipex bfloat16 weight prepack and disable pytorch-lightning native AMP + dtype = torch.bfloat16 + precision = 32 + + # Confirm if cpu supports avx512 + if self.use_ipex and not check_avx512(): + if TORCH_VERSION_LESS_1_11: + warning("Enable ipex<=1.11 in a cpu instruction set" + " without avx512 will crash." + "Fall back to regular pytorch.") + self.use_ipex = False + elif dtype == torch.bfloat16: + warning("Enable IPEX bfloat16 in a cpu instruction set" + " without avx512 will crash. " + "Using 32-bit precision") + dtype = None + + kwargs['precision'] = precision if num_processes == 1: from bigdl.nano.pytorch.strategies import create_IPEXStrategy - strategy = create_IPEXStrategy(enable_bf16=enable_bf16) if self.use_ipex else None + strategy = create_IPEXStrategy(dtype=dtype) if self.use_ipex else None kwargs["strategy"] = strategy super().__init__(*args, **kwargs) else: @@ -132,20 +149,20 @@ def __init__(self, num_processes: int = 1, strategy = DDPSpawnStrategy(num_processes=num_processes, cpu_for_each_process=cpu_for_each_process, use_ipex=self.use_ipex, - enable_bf16=enable_bf16, + dtype=dtype, auto_lr=auto_lr) elif distributed_backend == "subprocess": from bigdl.nano.pytorch.strategies import DDPSubprocessStrategy strategy = DDPSubprocessStrategy(num_processes=num_processes, cpu_for_each_process=cpu_for_each_process, use_ipex=self.use_ipex, - enable_bf16=enable_bf16, + dtype=dtype, auto_lr=auto_lr) elif distributed_backend == "ray": from bigdl.nano.pytorch.strategies import create_RayStrategy strategy = create_RayStrategy(num_workers=num_processes, use_ipex=self.use_ipex, - enable_bf16=enable_bf16, + dtype=dtype, auto_lr=auto_lr) kwargs["strategy"] = strategy super().__init__(*args, **kwargs) diff --git a/python/nano/test/pytorch/tests/test_plugin_ipex.py b/python/nano/test/pytorch/tests/test_plugin_ipex.py index 609e1e449ad..7d9015aca7d 100644 --- a/python/nano/test/pytorch/tests/test_plugin_ipex.py +++ b/python/nano/test/pytorch/tests/test_plugin_ipex.py @@ -24,10 +24,12 @@ from bigdl.nano.pytorch.lightning import LightningModule from bigdl.nano.pytorch import Trainer +from bigdl.nano.common import check_avx512 +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform from test.pytorch.utils._train_torch_lightning import create_test_data_loader -from test.pytorch.utils._train_ipex_callback import CheckIPEXCallback +from test.pytorch.utils._train_ipex_callback import CheckIPEXCallback, CheckIPEXFusedStepCallback from test.pytorch.tests.test_lightning import ResNet18 num_classes = 10 @@ -64,6 +66,26 @@ def test_trainer_subprocess_plugin(self): trainer.fit(pl_model, self.data_loader, self.test_data_loader) trainer.test(pl_model, self.test_data_loader) + def test_trainer_spawn_plugin_bf16(self): + # IPEX BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq + model = ResNet18(pretrained=False, include_top=False, freeze=True) + loss = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + pl_model = LightningModule( + model, loss, optimizer, + metrics=[torchmetrics.F1(num_classes), torchmetrics.Accuracy(num_classes=10)] + ) + trainer = Trainer(num_processes=2, distributed_backend="spawn", + max_epochs=4, use_ipex=True, precision="bf16", + callbacks=[CheckIPEXCallback(), CheckIPEXFusedStepCallback()]) + trainer.fit(pl_model, self.data_loader, self.test_data_loader) + trainer.test(pl_model, self.test_data_loader) + if trainer.use_ipex and TORCH_VERSION_LESS_1_10: + import intel_pytorch_extension as ipex + # Diable IPEX AMP + # Avoid affecting other tests + ipex.enable_auto_mixed_precision(None) + if __name__ == '__main__': pytest.main([__file__]) diff --git a/python/nano/test/pytorch/tests/test_torch_nano_ipex.py b/python/nano/test/pytorch/tests/test_torch_nano_ipex.py index 3c01e1e639e..d9c25590a7b 100644 --- a/python/nano/test/pytorch/tests/test_torch_nano_ipex.py +++ b/python/nano/test/pytorch/tests/test_torch_nano_ipex.py @@ -45,10 +45,13 @@ def forward(self, x): class MyNano(TorchNano): - def train(self): + def train(self, optimizer_supported: bool = False): model = ResNet18(10, pretrained=False, include_top=False, freeze=True) loss_func = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + if optimizer_supported: + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform) model, optimizer, train_loader = self.setup(model, optimizer, train_loader) @@ -63,7 +66,6 @@ def train(self): loss = loss_func(model(X), y) self.backward(loss) optimizer.step() - total_loss += loss.sum() num += 1 print(f'avg_loss: {total_loss / num}') @@ -132,6 +134,18 @@ def test_torch_nano_spawn_correctness(self): def test_torch_nano_subprocess_correctness(self): MyNanoCorrectness(use_ipex=True, num_processes=2, strategy="subprocess").train(0.5) + def test_torch_nano_bf16_support_opt(self): + MyNano(use_ipex=True, precision='bf16').train(optimizer_supported=True) + + def test_torch_nano_bf16_unsupport_opt(self): + MyNano(use_ipex=True, precision='bf16').train() + + def test_torch_nano_bf16_spawn(self): + MyNano(use_ipex=True, precision='bf16', num_processes=2, strategy="spawn").train() + + def test_torch_nano_bf16_subprocess(self): + MyNano(use_ipex=True, precision='bf16', num_processes=2, strategy="subprocess").train() + if __name__ == '__main__': pytest.main([__file__]) diff --git a/python/nano/test/pytorch/tests/test_trainer_ipex.py b/python/nano/test/pytorch/tests/test_trainer_ipex.py index 16847bcc35a..74f4e23fd0d 100644 --- a/python/nano/test/pytorch/tests/test_trainer_ipex.py +++ b/python/nano/test/pytorch/tests/test_trainer_ipex.py @@ -22,10 +22,13 @@ import torch from torch.optim.lr_scheduler import OneCycleLR from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform +from test.pytorch.utils._train_ipex_callback import CheckIPEXFusedStepCallback from torch import nn from bigdl.nano.pytorch import Trainer from bigdl.nano.pytorch.vision.models import vision +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 +from bigdl.nano.common import check_avx512 batch_size = 256 max_epochs = 2 @@ -52,11 +55,11 @@ class TestTrainer(TestCase): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler_dict = { "scheduler": OneCycleLR( - optimizer, - 0.1, - epochs=max_epochs, - steps_per_epoch=len(train_loader), - ), + optimizer, + 0.1, + epochs=max_epochs, + steps_per_epoch=len(train_loader), + ), "interval": "step", } @@ -66,6 +69,63 @@ def test_trainer_save_checkpoint(self): pl_model = Trainer.compile(self.model, self.loss, self.optimizer, self.scheduler_dict) trainer.fit(pl_model, self.train_loader) + def test_trainer_ipex_bf16(self): + # IPEX BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq + trainer = Trainer(max_epochs=max_epochs, use_ipex=True, precision="bf16", + callbacks=[CheckIPEXFusedStepCallback()]) + + # use_ipex=True will perform inplace optimization + model = ResNet18(10, pretrained=False, include_top=False, freeze=True) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss = nn.CrossEntropyLoss() + scheduler_dict = { + "scheduler": OneCycleLR( + optimizer, + 0.1, + epochs=max_epochs, + steps_per_epoch=len(self.train_loader), + ), + "interval": "step", + } + + pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict) + trainer.fit(pl_model, self.train_loader) + trainer.test(pl_model, self.train_loader) + + if trainer.use_ipex and TORCH_VERSION_LESS_1_10: + import intel_pytorch_extension as ipex + # Diable IPEX AMP + # Avoid affecting other tests + ipex.enable_auto_mixed_precision(None) + + def test_trainer_ipex_bf16_unspport_optim(self): + # IPEX BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq + trainer = Trainer(max_epochs=max_epochs, use_ipex=True, precision="bf16", + callbacks=[CheckIPEXFusedStepCallback()]) + + model = ResNet18(10, pretrained=False, include_top=False, freeze=True) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=5e-4) + loss = nn.CrossEntropyLoss() + scheduler_dict = { + "scheduler": OneCycleLR( + optimizer, + 0.1, + epochs=max_epochs, + steps_per_epoch=len(self.train_loader), + ), + "interval": "step", + } + + pl_model = Trainer.compile(model, loss, optimizer, scheduler_dict) + trainer.fit(pl_model, self.train_loader) + trainer.test(pl_model, self.train_loader) + + if trainer.use_ipex and TORCH_VERSION_LESS_1_10: + import intel_pytorch_extension as ipex + # Diable IPEX AMP + # Avoid affecting other tests + ipex.enable_auto_mixed_precision(None) + if __name__ == '__main__': pytest.main([__file__]) diff --git a/python/nano/test/pytorch/tests/test_trainer_precision.py b/python/nano/test/pytorch/tests/test_trainer_precision.py new file mode 100644 index 00000000000..d30b2a9c3d9 --- /dev/null +++ b/python/nano/test/pytorch/tests/test_trainer_precision.py @@ -0,0 +1,85 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +from unittest import TestCase + +import pytest +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin + +from bigdl.nano.pytorch import Trainer +from bigdl.nano.pytorch.vision.models import vision +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 + +from test.pytorch.tests.test_scale_lr import ResNetBase +from test.pytorch.utils._train_torch_lightning import (create_data_loader, + create_test_data_loader, + data_transform) + +batch_size = 32 +dataset_size = 256 +num_workers = 0 +data_dir = os.path.join(os.path.dirname(__file__), "../data") + + +class ResNet18(nn.Module): + def __init__(self, num_classes, pretrained=True, include_top=False, freeze=True): + super().__init__() + backbone = vision.resnet18(pretrained=pretrained, include_top=include_top, freeze=freeze) + output_size = backbone.get_output_size() + head = nn.Linear(output_size, num_classes) + self.model = nn.Sequential(backbone, head) + + def forward(self, x): + return self.model(x) + + +class TestTrainer(TestCase): + train_loader = create_data_loader(data_dir, batch_size, num_workers, + data_transform, dataset_size) + test_loader = create_test_data_loader(data_dir, batch_size, num_workers, + data_transform, dataset_size) + + def test_trainer_precision(self): + model = ResNet18(10, pretrained=False, include_top=False, freeze=True) + loss = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + pl_model = Trainer.compile(model, loss, optimizer) + # torch must be greater or euqal to 1.10 to use native amp for bfloat16 precision + if TORCH_VERSION_LESS_1_10: + trainer = Trainer(max_epochs=2, precision=64) + trainer.fit(pl_model, self.train_loader) + assert isinstance(trainer.strategy.precision_plugin, DoublePrecisionPlugin) + opt = pl_model.optimizers() + assert opt.param_groups[0]['params'][0].dtype is torch.float64 + else: + trainer = Trainer(max_epochs=2, precision='bf16') + trainer.fit(pl_model, self.train_loader) + assert isinstance(trainer.strategy.precision_plugin, NativeMixedPrecisionPlugin) + # model is not converted to bfloat16 precision + input = TensorDataset(torch.rand(1, 3, 32, 32)) + train_loader = DataLoader(input) + y_hat = trainer.predict(pl_model, train_loader) + assert y_hat[0].dtype is torch.bfloat16 + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/python/nano/test/pytorch/utils/_train_ipex_callback.py b/python/nano/test/pytorch/utils/_train_ipex_callback.py index 2dd36b59c07..1b972108892 100644 --- a/python/nano/test/pytorch/utils/_train_ipex_callback.py +++ b/python/nano/test/pytorch/utils/_train_ipex_callback.py @@ -17,19 +17,23 @@ import torch import warnings from typing import Dict +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.training_type import SingleDevicePlugin, DDPSpawnPlugin -from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 from pytorch_lightning.accelerators.cpu import CPUAccelerator +from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 +from bigdl.nano.common import check_avx512 + class CheckIPEXCallback(Callback): def on_train_start(self, trainer, pl_module): - if trainer.use_ipex == False: - warnings.warn("CheckIPEXCallback is used, but ipex is disabled. ") - return + if not trainer.use_ipex: + warnings.warn("CheckIPEXCallback is used, but ipex is disabled. ") + return if TORCH_VERSION_LESS_1_10: from bigdl.nano.deps.ipex.version_1_9.ipex_torchfunctional import RESTORE_TYPE + def check_device(obj): if torch.is_tensor(obj): if obj.device.type == 'xpu': @@ -45,15 +49,18 @@ def check_device(obj): assert check_device(pl_module.state_dict()) else: from intel_extension_for_pytorch.nn.utils._model_convert import _LSTM - from intel_extension_for_pytorch.nn.utils._weight_prepack import _IPEXConvNd, _IPEXLinear, _IPEXConvTransposeNd + from intel_extension_for_pytorch.nn.utils._weight_prepack import (_IPEXConvNd, + _IPEXLinear, + _IPEXConvTransposeNd) IPEX_LAYERS = (_LSTM, _IPEXConvNd, _IPEXLinear, _IPEXConvTransposeNd) - IPEX_ATTR = ('master_weight', - 'weight_trail', - 'master_bias', - 'bias_trail') + IPEX_ATTR = ('master_weight', + 'weight_trail', + 'master_bias', + 'bias_trail') + def check_ipex_layers(m): if isinstance(m, IPEX_LAYERS): print("model is optimized by IPEX") @@ -68,3 +75,19 @@ def check_ipex_layers(m): return False assert check_ipex_layers(pl_module) + +class CheckIPEXFusedStepCallback(Callback): + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + if not check_avx512(): + # IPEX BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq + return + if not TORCH_VERSION_LESS_1_10: + from intel_extension_for_pytorch.optim._optimizer_utils import IPEX_FUSED_OPTIMIZER_LIST + # IPEX only support one optimizer + opt = trainer.optimizers[0] + if type(opt) in IPEX_FUSED_OPTIMIZER_LIST: + assert opt.fused # type: ignore + else: + # Check non-fused step + assert hasattr(opt, '_original_step') + assert getattr(opt, 'step') is not getattr(type(opt), 'step')