Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 23, 2022
1 parent ff461c7 commit 5bcfab0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions python/nano/src/bigdl/nano/pytorch/torch_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, num_processes: int = 1,
self.use_ipex = use_ipex
self.enable_bf16 = self.use_ipex and kwargs.get('precision', None) == 'bf16'

# Set 'precision' for strategy without precision_plugin,
# Strategy > accelerator/precision/plugin
# Strategy has a higher priority than accelerator/precision/plugin,
# set precision for strategy without precision_plugin(e.g. ddp-spawn, ddp-subprocess)
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and self.enable_bf16:
kwargs['precision'] = 32
Expand All @@ -83,6 +83,7 @@ def __init__(self, num_processes: int = 1,
" without avx512 will crash. "
"Will use PyTorch Lightning Native AMP for BFloat16 precision")
self.enable_bf16 = False
kwargs['precision'] = 32

if self.num_processes == 1:
if self.use_ipex:
Expand Down Expand Up @@ -136,7 +137,9 @@ def _setup(

model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))

# add IPEX 1.11's optimization
# IPEX bfloat16 optimization will cast model parameters to `torch.bfloat16`
# which is not supported by ddp currently,
# so add IPEX 1.11's optimization after `_setup_model`
if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
if len(optimizers) == 0:
Expand Down
9 changes: 6 additions & 3 deletions python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def __init__(self, num_processes: int = 1,
self.use_ipex = use_ipex
enable_bf16 = self.use_ipex and kwargs.get('precision', None) == 'bf16'

# Set 'precision' for strategy without precision_plugin,
# Strategy > accelerator/precision/plugin
# Strategy has a higher priority than accelerator/precision/plugin,
# set precision for strategy without precision_plugin(e.g. ddp-spawn, ddp-subprocess)
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and enable_bf16:
kwargs['precision'] = 32
Expand All @@ -120,8 +120,11 @@ def __init__(self, num_processes: int = 1,
elif enable_bf16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Will use PyTorch Lightning Native AMP for BFloat16 precision")
"Using 32-bit precision")
enable_bf16 = False
# IPEX-optimized model is incompatible with PL Native AMP,
# so fall back to 32-bit precision instead of staying at bfloat16 precision
kwargs['precision'] = 32

if num_processes == 1:
from bigdl.nano.pytorch.strategies import create_IPEXStrategy
Expand Down

0 comments on commit 5bcfab0

Please sign in to comment.