Skip to content

Commit

Permalink
Update nano training ipex bf16 (#5333)
Browse files Browse the repository at this point in the history
* Update

* Fix code style

* re-run action

* Fix code style

* re-run action

* Update

* Fix code style

* support bf16 multi training

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* reduce ut time and re-run action

* track avx512

* Update lite bf16 training

* Update

* Update

* Update bf16 api

* Update

* Fix typo
  • Loading branch information
y199387 authored Aug 24, 2022
1 parent c74b6d8 commit 2966d49
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 71 deletions.
10 changes: 4 additions & 6 deletions python/nano/src/bigdl/nano/deps/ray/ray_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
num_cpus_per_worker: int = 1,
use_gpu: bool = False,
use_ipex: bool = False,
enable_bf16: bool = False,
dtype=None,
init_hook: Callable = None,
auto_lr: Union[bool, dict] = True,
**ddp_kwargs: Any):
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self,
self.num_cpus_per_worker = num_cpus_per_worker
self.use_gpu = use_gpu
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16
self.dtype = dtype
self.auto_lr = auto_lr

invalidInputError(not self.use_gpu or not self.use_ipex,
Expand Down Expand Up @@ -328,14 +328,12 @@ def _unpack_lightning_optimizer(opt):
]

if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
num_optimizers = len(self.optimizers)

if num_optimizers == 1:
optimizer = self.optimizers[0]
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype)
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype)
elif num_optimizers == 0:
ipex_optimize(self.model, inplace=True, dtype=dtype)
ipex_optimize(self.model, inplace=True, dtype=self.dtype)
else:
warnings.warn(f"IPEX currently only support single optimizers, "
f"but got {num_optimizers}. Skip IPEX")
Expand Down
21 changes: 15 additions & 6 deletions python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
num_processes: int = 1,
cpu_for_each_process: Optional[List[List[int]]] = None,
use_ipex=False,
enable_bf16=False,
dtype=None,
auto_lr=False,
**kwargs: Any
):
Expand All @@ -180,14 +180,24 @@ def __init__(
if use_ipex and TORCH_VERSION_LESS_1_10 and 'accelerator' not in kwargs:
super().__init__(accelerator=create_IPEXAccelerator(),
parallel_devices=parallel_devices,
cluster_environment=cluster_environment, **kwargs)
cluster_environment=cluster_environment,
**kwargs)
if dtype == torch.bfloat16:
import intel_pytorch_extension as ipex
# Automatically mix precision
ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16)
elif use_ipex and dtype == torch.bfloat16 and 'precision_plugin' not in kwargs:
from bigdl.nano.pytorch.strategies.ipex.ipex_strategy import IPEXBF16Precision
super().__init__(parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
precision_plugin=IPEXBF16Precision(), **kwargs)
else:
super().__init__(parallel_devices=parallel_devices,
cluster_environment=cluster_environment, **kwargs)
self.cpu_for_each_process = cpu_for_each_process
self.is_distributed = True
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16
self.dtype = dtype
self.auto_lr = auto_lr

def _configure_launcher(self):
Expand Down Expand Up @@ -249,14 +259,13 @@ def _unpack_lightning_optimizer(opt):
]

if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
num_optimizers = len(self.optimizers)

if num_optimizers == 1:
optimizer = self.optimizers[0]
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype)
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype)
elif num_optimizers == 0:
ipex_optimize(self.model, inplace=True, dtype=dtype)
ipex_optimize(self.model, inplace=True, dtype=self.dtype)
else:
warnings.warn(f"IPEX currently only support single optimizers, "
f"but got {num_optimizers}. Skip IPEX")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@
# limitations under the License.
#

from contextlib import contextmanager
from functools import partial
from logging import warning
from typing import Any, Union, Callable

import torch
from torch.nn import Module
from torch.optim import Optimizer, LBFGS
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_12

import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.utilities import AMPType

from bigdl.nano.utils.log4Error import invalidInputError
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.optim._optimizer_utils import IPEX_FUSED_OPTIMIZER_LIST

from .ipex_accelerator import IPEXAccelerator

Expand All @@ -34,16 +46,18 @@ def __init__(
self,
accelerator: Accelerator = IPEXAccelerator(),
precision_plugin: PrecisionPlugin = PrecisionPlugin(),
enable_bf16=False,
dtype=None,
) -> None:
"""
Create a IPEXStrategy.
:param accelerator: the accelerator to handle hardware
:param precision_plugin: the plugin to handle precision-specific parts
"""
self.enable_bf16 = enable_bf16
self.dtype = dtype

if self.dtype == torch.bfloat16 and isinstance(precision_plugin, PrecisionPlugin):
precision_plugin = IPEXBF16Precision()
super().__init__(accelerator=accelerator, precision_plugin=precision_plugin)

def setup(self, trainer: pl.Trainer) -> None:
Expand All @@ -56,10 +70,56 @@ def setup(self, trainer: pl.Trainer) -> None:
"""
super().setup(trainer)

dtype = torch.bfloat16 if self.enable_bf16 else None
if len(self.optimizers) == 0:
ipex.optimize(self.model, inplace=True, dtype=dtype)
ipex.optimize(self.model, inplace=True, dtype=self.dtype)
elif len(self.optimizers) == 1:
ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=dtype)
ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=self.dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")


class IPEXBF16Precision(PrecisionPlugin):
"""Create Precision Plugin for IPEX BFloat16."""

precision: Union[str, int] = 'bf16'

@contextmanager
def forward_context(self):
"""AMP for managing model forward/training_step/evaluation_step/predict_step."""
with torch.cpu.amp.autocast():
yield

def optimizer_step(self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any) -> Any:
"""Hook to run the optimizer step."""
if type(optimizer) in IPEX_FUSED_OPTIMIZER_LIST:
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)

if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)

# Only `torch.optim.LBFGS` need to reevaluate closure multiple times
# in optimizer.step(...) now.
if isinstance(optimizer, LBFGS):
invalidInputError(False,
"IPEX BFloat16 and the LBFGS optimizer are not compatible "
f"(optimizer {optimizer_idx}",
"Hint: Set 'use_ipex' to False or not set 'precision' to 'bf16'"
" if LBFGS optimizer is necessary")

# Detect custom optimzer
if type(optimizer).__name__ not in dir(torch.optim):
warning("Seems like you are using a custom optimizer,"
"please make sure that 'optimizer.step(closure)'"
" does not need to be called in training stage")

# For optimizer not in IPEX_FUSED_OPTIMIZER_LIST,
# `closure()` needs to be called to backward the loss to avoid `.grad` being None
closure_result = closure()
optimizer.step(**kwargs)

return closure_result
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __init__(
self,
accelerator: Accelerator = IPEXAccelerator(), # type: ignore
precision_plugin: PrecisionPlugin = PrecisionPlugin(),
enable_bf16=False,
dtype=None,
) -> None:
"""
Create a IPEXStrategy.
:param accelerator: the accelerator to handle hardware
:param precision_plugin: the plugin to handle precision-specific parts
"""
if enable_bf16:
if dtype == torch.bfloat16:
# Automatically mix precision
ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16)

Expand Down
63 changes: 40 additions & 23 deletions python/nano/src/bigdl/nano/pytorch/torch_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,45 +52,60 @@ class TorchNano(LightningLite):

def __init__(self, num_processes: int = 1,
use_ipex: bool = False,
enable_bf16: bool = False,
strategy: str = "subprocess",
precision: Union[str, int] = 32,
*args, **kwargs) -> None:
"""
Create a TorchNano with nano acceleration.
:param num_processes: number of processes in distributed training, defaults to 1
:param use_ipex: whether use ipex acceleration, defaults to False
:param enable_bf16: whether use bf16 acceleration, defaults to False
:param strategy: use which backend in distributed mode, defaults to "subprocess", \
now avaiable strategies are 'spawn', 'subprocess' and 'ray'
:param precision: Double precision (64), full precision (32), half precision (16)
or bfloat16 precision (bf16), defaults to 32.
Enable ipex bfloat16 weight prepack when `use_ipex=True` and `precision='bf16'`
"""
self.num_processes = num_processes
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16

if TORCH_VERSION_LESS_1_11 and use_ipex and not check_avx512():
warning("Enable ipex<=1.10 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
self.dtype = None
if self.use_ipex and precision == 'bf16':
# Enable ipex bfloat16 weight prepack and disable native AMP
self.dtype = torch.bfloat16
precision = 32

# Confirm if cpu supports AVX512
if self.use_ipex and not check_avx512():
if TORCH_VERSION_LESS_1_11:
warning("Enable ipex<=1.10 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
elif self.dtype == torch.bfloat16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Using 32-bit precision")
self.dtype = None

kwargs['precision'] = precision

if self.num_processes == 1:
if self.use_ipex:
strategy = create_IPEXStrategy(enable_bf16=self.enable_bf16)
strategy = create_IPEXStrategy(dtype=self.dtype)
else:
strategy = None # type: ignore
elif strategy == "spawn":
strategy = DDPSpawnStrategy(num_processes=self.num_processes, # type: ignore
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
elif strategy == "subprocess":
strategy = DDPSubprocessStrategy(num_processes=self.num_processes, # type: ignore
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
elif strategy == "ray":
strategy = create_RayStrategy(num_workers=self.num_processes,
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
else:
warning(f"Bigdl-nano doesn't support '{strategy}' strategy now, "
f"'{strategy}' strategy of pytorch_lightning will be used. "
Expand Down Expand Up @@ -118,23 +133,25 @@ def _setup(
# so we have to add optimizations in this method, which will be called in
# user defined `train()` method.

# add IPEX 1.11's optimization
if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
if len(optimizers) == 0:
ipex_optimize(model, inplace=True, dtype=dtype)
elif len(optimizers) == 1:
ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")

# the following codes are copied from pl's LightningLite's `setup` method,
# ipex 1.9 requires `_move_model_to_device` after `_setup_model_and_optimizers`, but
# pl's `setup` method calls `_move_model_to_device` before `_setup_model_and_optimizers`,
# so we copy the codes and swap their order.
self._validate_setup(model, optimizers)

model, optimizers = self._strategy._setup_model_and_optimizers(model, optimizers)

# IPEX bfloat16 optimization will cast model parameters to `torch.bfloat16`
# which is not supported by ddp currently,
# so add IPEX 1.11's optimization after `_setup_model`
if self.use_ipex and not TORCH_VERSION_LESS_1_10:
if len(optimizers) == 0:
ipex_optimize(model, inplace=True, dtype=self.dtype)
elif len(optimizers) == 1:
ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=self.dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")

if move_to_device:
model = self._move_model_to_device(model=model, optimizers=optimizers)
model = _TorchNanoModule(model, self._precision_plugin)
Expand Down
Loading

0 comments on commit 2966d49

Please sign in to comment.