Skip to content

Commit

Permalink
Update bf16 api
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 23, 2022
1 parent 5bcfab0 commit 3fad3ba
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 63 deletions.
10 changes: 4 additions & 6 deletions python/nano/src/bigdl/nano/deps/ray/ray_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
num_cpus_per_worker: int = 1,
use_gpu: bool = False,
use_ipex: bool = False,
enable_bf16: bool = False,
dtype=None,
init_hook: Callable = None,
auto_lr: Union[bool, dict] = True,
**ddp_kwargs: Any):
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self,
self.num_cpus_per_worker = num_cpus_per_worker
self.use_gpu = use_gpu
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16
self.dtype = dtype
self.auto_lr = auto_lr

invalidInputError(not self.use_gpu or not self.use_ipex,
Expand Down Expand Up @@ -328,14 +328,12 @@ def _unpack_lightning_optimizer(opt):
]

if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
num_optimizers = len(self.optimizers)

if num_optimizers == 1:
optimizer = self.optimizers[0]
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype)
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype)
elif num_optimizers == 0:
ipex_optimize(self.model, inplace=True, dtype=dtype)
ipex_optimize(self.model, inplace=True, dtype=self.dtype)
else:
warnings.warn(f"IPEX currently only support single optimizers, "
f"but got {num_optimizers}. Skip IPEX")
Expand Down
13 changes: 6 additions & 7 deletions python/nano/src/bigdl/nano/pytorch/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
num_processes: int = 1,
cpu_for_each_process: Optional[List[List[int]]] = None,
use_ipex=False,
enable_bf16=False,
dtype=None,
auto_lr=False,
**kwargs: Any
):
Expand All @@ -182,11 +182,11 @@ def __init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
**kwargs)
if enable_bf16:
if dtype == torch.bfloat16:
import intel_pytorch_extension as ipex
# Automatically mix precision
ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16)
elif use_ipex and enable_bf16 and 'precision_plugin' not in kwargs:
elif use_ipex and dtype == torch.bfloat16 and 'precision_plugin' not in kwargs:
from bigdl.nano.pytorch.strategies.ipex.ipex_strategy import IPEXBF16Precision
super().__init__(parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
Expand All @@ -197,7 +197,7 @@ def __init__(
self.cpu_for_each_process = cpu_for_each_process
self.is_distributed = True
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16
self.dtype = dtype
self.auto_lr = auto_lr

def _configure_launcher(self):
Expand Down Expand Up @@ -259,14 +259,13 @@ def _unpack_lightning_optimizer(opt):
]

if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
num_optimizers = len(self.optimizers)

if num_optimizers == 1:
optimizer = self.optimizers[0]
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=dtype)
ipex_optimize(self.model, optimizer=optimizer, inplace=True, dtype=self.dtype)
elif num_optimizers == 0:
ipex_optimize(self.model, inplace=True, dtype=dtype)
ipex_optimize(self.model, inplace=True, dtype=self.dtype)
else:
warnings.warn(f"IPEX currently only support single optimizers, "
f"but got {num_optimizers}. Skip IPEX")
Expand Down
23 changes: 10 additions & 13 deletions python/nano/src/bigdl/nano/pytorch/strategies/ipex/ipex_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import PrecisionPlugin, MixedPrecisionPlugin
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.utilities import AMPType

from bigdl.nano.utils.log4Error import invalidInputError
Expand All @@ -46,17 +46,17 @@ def __init__(
self,
accelerator: Accelerator = IPEXAccelerator(),
precision_plugin: PrecisionPlugin = PrecisionPlugin(),
enable_bf16=False,
dtype=None,
) -> None:
"""
Create a IPEXStrategy.
:param accelerator: the accelerator to handle hardware
:param precision_plugin: the plugin to handle precision-specific parts
"""
self.enable_bf16 = enable_bf16
self.dtype = dtype

if enable_bf16 and isinstance(precision_plugin, PrecisionPlugin):
if self.dtype == torch.bfloat16 and isinstance(precision_plugin, PrecisionPlugin):
precision_plugin = IPEXBF16Precision()
super().__init__(accelerator=accelerator, precision_plugin=precision_plugin)

Expand All @@ -70,28 +70,22 @@ def setup(self, trainer: pl.Trainer) -> None:
"""
super().setup(trainer)

dtype = torch.bfloat16 if self.enable_bf16 else None
if len(self.optimizers) == 0:
ipex.optimize(self.model, inplace=True, dtype=dtype)
ipex.optimize(self.model, inplace=True, dtype=self.dtype)
elif len(self.optimizers) == 1:
ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=dtype)
ipex.optimize(self.model, optimizer=self.optimizers[0], inplace=True, dtype=self.dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")


class IPEXBF16Precision(MixedPrecisionPlugin):
class IPEXBF16Precision(PrecisionPlugin):
"""Create Precision Plugin for IPEX BFloat16."""

backend: "AMPType" = AMPType.NATIVE
precision: Union[str, int] = 'bf16'

@contextmanager
def forward_context(self):
"""AMP for managing model forward/training_step/evaluation_step/predict_step."""
# Using IPEX bf16 and torch.autocast(...) reports a segmentation fault
# in PyTorch 1.11.
# torch.autocast("cpu", args...) is equivalent to torch.cpu.amp.autocast(args...)
# in PyTorch 1.12.
with torch.cpu.amp.autocast():
yield

Expand Down Expand Up @@ -122,6 +116,9 @@ def optimizer_step(self,
warning("Seems like you are using a custom optimizer,"
"please make sure that 'optimizer.step(closure)'"
" does not need to be called in training stage")

# For optimizer not in IPEX_FUSED_OPTIMIZER_LIST,
# `closure()` needs to be called to backward the loss to avoid `.grad` being None
closure_result = closure()
optimizer.step(**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __init__(
self,
accelerator: Accelerator = IPEXAccelerator(), # type: ignore
precision_plugin: PrecisionPlugin = PrecisionPlugin(),
enable_bf16=False,
dtype=None,
) -> None:
"""
Create a IPEXStrategy.
:param accelerator: the accelerator to handle hardware
:param precision_plugin: the plugin to handle precision-specific parts
"""
if enable_bf16:
if dtype == torch.bfloat16:
# Automatically mix precision
ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16)

Expand Down
36 changes: 18 additions & 18 deletions python/nano/src/bigdl/nano/pytorch/torch_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TorchNano(LightningLite):
def __init__(self, num_processes: int = 1,
use_ipex: bool = False,
strategy: str = "subprocess",
precision: Union[str, int] = 32,
*args, **kwargs) -> None:
"""
Create a TorchNano with nano acceleration.
Expand All @@ -64,44 +65,44 @@ def __init__(self, num_processes: int = 1,
"""
self.num_processes = num_processes
self.use_ipex = use_ipex
self.enable_bf16 = self.use_ipex and kwargs.get('precision', None) == 'bf16'

# Strategy has a higher priority than accelerator/precision/plugin,
# set precision for strategy without precision_plugin(e.g. ddp-spawn, ddp-subprocess)
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and self.enable_bf16:
kwargs['precision'] = 32
self.dtype = None
if self.use_ipex and precision == 'bf16':
# Enable ipex bfloat16 weight prepack and disable native AMP
self.dtype = torch.float16
precision = 32

# Confirm if cpu supports AVX512
if self.use_ipex and not check_avx512():
if TORCH_VERSION_LESS_1_11:
warning("Enable ipex<=1.10 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
elif self.enable_bf16:
elif self.dtype == torch.bfloat16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Will use PyTorch Lightning Native AMP for BFloat16 precision")
self.enable_bf16 = False
kwargs['precision'] = 32
"Using 32-bit precision")
self.dtype = None

kwargs['precision'] = precision

if self.num_processes == 1:
if self.use_ipex:
strategy = create_IPEXStrategy(enable_bf16=self.enable_bf16)
strategy = create_IPEXStrategy(dtype=self.dtype)
else:
strategy = None # type: ignore
elif strategy == "spawn":
strategy = DDPSpawnStrategy(num_processes=self.num_processes, # type: ignore
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
elif strategy == "subprocess":
strategy = DDPSubprocessStrategy(num_processes=self.num_processes, # type: ignore
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
elif strategy == "ray":
strategy = create_RayStrategy(num_workers=self.num_processes,
use_ipex=self.use_ipex,
enable_bf16=self.enable_bf16)
dtype=self.dtype)
else:
warning(f"Bigdl-nano doesn't support '{strategy}' strategy now, "
f"'{strategy}' strategy of pytorch_lightning will be used. "
Expand Down Expand Up @@ -141,11 +142,10 @@ def _setup(
# which is not supported by ddp currently,
# so add IPEX 1.11's optimization after `_setup_model`
if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
if len(optimizers) == 0:
ipex_optimize(model, inplace=True, dtype=dtype)
ipex_optimize(model, inplace=True, dtype=self.dtype)
elif len(optimizers) == 1:
ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=dtype)
ipex_optimize(model, optimizer=optimizers[0], inplace=True, dtype=self.dtype)
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")

Expand Down
36 changes: 19 additions & 17 deletions python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, num_processes: int = 1,
use_hpo=False,
channels_last: bool = False,
auto_lr: Union[int, bool] = True,
precision: Union[str, int] = 32,
*args: Any, **kwargs: Any) -> None:
"""
A pytorch lightning trainer that uses bigdl-nano optimization.
Expand All @@ -71,6 +72,9 @@ def __init__(self, num_processes: int = 1,
:param cpu_for_each_process: A list of length `num_processes`, each containing a list of
indices of cpus each process will be using. default: None, and the cpu will be
automatically and evenly distributed among processes.
:param precision: Double precision (64), full precision (32), half precision (16)
or bfloat16 precision (bf16). Enable ipex bfloat16 weight prepack when `use_ipex=True`
and `precision='bf16'`
"""
# Check keyword arguments
if "accelerator" in kwargs:
Expand Down Expand Up @@ -103,32 +107,30 @@ def __init__(self, num_processes: int = 1,
kwargs["callbacks"] = [ChannelsLastCallback()]

self.use_ipex = use_ipex
enable_bf16 = self.use_ipex and kwargs.get('precision', None) == 'bf16'

# Strategy has a higher priority than accelerator/precision/plugin,
# set precision for strategy without precision_plugin(e.g. ddp-spawn, ddp-subprocess)
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and enable_bf16:
kwargs['precision'] = 32
dtype = None
if self.use_ipex and precision == 'bf16':
# Enable ipex bfloat16 weight prepack and disable pytorch-lightning native AMP
dtype = torch.bfloat16
precision = 32

# Confirm if cpu supports avx512
if self.use_ipex and not check_avx512():
if TORCH_VERSION_LESS_1_11:
warning("Enable ipex<=1.10 in a cpu instruction set"
warning("Enable ipex<=1.11 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
elif enable_bf16:
elif dtype == torch.float16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Using 32-bit precision")
enable_bf16 = False
# IPEX-optimized model is incompatible with PL Native AMP,
# so fall back to 32-bit precision instead of staying at bfloat16 precision
kwargs['precision'] = 32
dtype = None

kwargs['precision'] = precision

if num_processes == 1:
from bigdl.nano.pytorch.strategies import create_IPEXStrategy
strategy = create_IPEXStrategy(enable_bf16=enable_bf16) if self.use_ipex else None
strategy = create_IPEXStrategy(dtype=dtype) if self.use_ipex else None
kwargs["strategy"] = strategy
super().__init__(*args, **kwargs)
else:
Expand All @@ -147,20 +149,20 @@ def __init__(self, num_processes: int = 1,
strategy = DDPSpawnStrategy(num_processes=num_processes,
cpu_for_each_process=cpu_for_each_process,
use_ipex=self.use_ipex,
enable_bf16=enable_bf16,
dtype=dtype,
auto_lr=auto_lr)
elif distributed_backend == "subprocess":
from bigdl.nano.pytorch.strategies import DDPSubprocessStrategy
strategy = DDPSubprocessStrategy(num_processes=num_processes,
cpu_for_each_process=cpu_for_each_process,
use_ipex=self.use_ipex,
enable_bf16=enable_bf16,
dtype=dtype,
auto_lr=auto_lr)
elif distributed_backend == "ray":
from bigdl.nano.pytorch.strategies import create_RayStrategy
strategy = create_RayStrategy(num_workers=num_processes,
use_ipex=self.use_ipex,
enable_bf16=enable_bf16,
dtype=dtype,
auto_lr=auto_lr)
kwargs["strategy"] = strategy
super().__init__(*args, **kwargs)
Expand Down

0 comments on commit 3fad3ba

Please sign in to comment.