Skip to content

Commit

Permalink
Update lite bf16 training
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 23, 2022
1 parent 070fe97 commit c5dd357
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import PrecisionPlugin, MixedPrecisionPlugin
from pytorch_lightning.utilities import AMPType

from bigdl.nano.utils.log4Error import invalidInputError
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -78,9 +79,12 @@ def setup(self, trainer: pl.Trainer) -> None:
invalidInputError(False, "Ipex does not support more than one optimizers.")


class IPEXBF16Precision(PrecisionPlugin):
class IPEXBF16Precision(MixedPrecisionPlugin):
"""Create Precision Plugin for IPEX BFloat16."""

backend: "AMPType" = AMPType.NATIVE
precision: Union[str, int] = 'bf16'

@contextmanager
def forward_context(self):
"""AMP for managing model forward/training_step/evaluation_step/predict_step."""
Expand Down
27 changes: 19 additions & 8 deletions python/nano/src/bigdl/nano/pytorch/torch_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class TorchNano(LightningLite):

def __init__(self, num_processes: int = 1,
use_ipex: bool = False,
enable_bf16: bool = False,
strategy: str = "subprocess",
*args, **kwargs) -> None:
"""
Expand All @@ -66,13 +65,25 @@ def __init__(self, num_processes: int = 1,
"""
self.num_processes = num_processes
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16

if TORCH_VERSION_LESS_1_11 and use_ipex and not check_avx512():
warning("Enable ipex<=1.10 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
self.enable_bf16 = self.use_ipex and kwargs.get('precision', None) == 'bf16'

# Set 'precision' for strategy without precision_plugin,
# Strategy > accelerator/precision/plugin
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and enable_bf16:
kwargs['precision'] = 32

if self.use_ipex and not check_avx512():
if TORCH_VERSION_LESS_1_11:
warning("Enable ipex<=1.10 in a cpu instruction set"
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
elif enable_bf16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Will use PyTorch Lightning Native AMP for BFloat16 precision")
enable_bf16 = False

if self.num_processes == 1:
if self.use_ipex:
Expand Down
2 changes: 1 addition & 1 deletion python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, num_processes: int = 1,

# Set 'precision' for strategy without precision_plugin,
# Strategy > accelerator/precision/plugin
# torch must be greater or equal to 1.10 to use natice amp for bfloat16 precision
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and enable_bf16:
kwargs['precision'] = 32

Expand Down
19 changes: 17 additions & 2 deletions python/nano/test/pytorch/tests/test_torch_nano_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ def forward(self, x):


class MyNano(TorchNano):
def train(self):
def train(self, optimizer_supported: bool = False):
model = ResNet18(10, pretrained=False, include_top=False, freeze=True)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
if optimizer_supported:
optimizer = torch.optim.SGD(model.parameters, lr=0.01)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform)

model, optimizer, train_loader = self.setup(model, optimizer, train_loader)
Expand Down Expand Up @@ -132,6 +135,18 @@ def test_torch_nano_spawn_correctness(self):
def test_torch_nano_subprocess_correctness(self):
MyNanoCorrectness(use_ipex=True, num_processes=2, strategy="subprocess").train(0.5)

def test_torch_nano_bf16_support_opt(self):
MyNano(use_ipex=True, precision='bf16').train(optimizer_supported=True)

def test_torch_nano_bf16_unsupport_opt(self):
MyNano(use_ipex=True, precision='bf16').train()

def test_torch_nano_bf16_spawn(self):
MyNano(use_ipex=True, precision='bf16', num_processes=2, strategy="spawn").train()

def test_torch_nano_bf16_subprocess(self):
MyNano(use_ipex=True, precision='bf16', num_processes=2, strategy="subprocess").train()


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit c5dd357

Please sign in to comment.