Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
y199387 committed Aug 23, 2022
1 parent c5dd357 commit ff461c7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
22 changes: 11 additions & 11 deletions python/nano/src/bigdl/nano/pytorch/torch_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, num_processes: int = 1,
:param num_processes: number of processes in distributed training, defaults to 1
:param use_ipex: whether use ipex acceleration, defaults to False
:param enable_bf16: whether use bf16 acceleration, defaults to False
:param strategy: use which backend in distributed mode, defaults to "subprocess", \
now avaiable strategies are 'spawn', 'subprocess' and 'ray'
"""
Expand All @@ -70,7 +69,7 @@ def __init__(self, num_processes: int = 1,
# Set 'precision' for strategy without precision_plugin,
# Strategy > accelerator/precision/plugin
# torch must be greater or equal to 1.10 to use native amp for bfloat16 precision
if TORCH_VERSION_LESS_1_10 and enable_bf16:
if TORCH_VERSION_LESS_1_10 and self.enable_bf16:
kwargs['precision'] = 32

if self.use_ipex and not check_avx512():
Expand All @@ -79,11 +78,11 @@ def __init__(self, num_processes: int = 1,
" without avx512 will crash."
"Fall back to regular pytorch.")
self.use_ipex = False
elif enable_bf16:
elif self.enable_bf16:
warning("Enable IPEX bfloat16 in a cpu instruction set"
" without avx512 will crash. "
"Will use PyTorch Lightning Native AMP for BFloat16 precision")
enable_bf16 = False
self.enable_bf16 = False

if self.num_processes == 1:
if self.use_ipex:
Expand Down Expand Up @@ -129,6 +128,14 @@ def _setup(
# so we have to add optimizations in this method, which will be called in
# user defined `train()` method.

# the following codes are copied from pl's LightningLite's `setup` method,
# ipex 1.9 requires `_move_model_to_device` after `_setup_model_and_optimizers`, but
# pl's `setup` method calls `_move_model_to_device` before `_setup_model_and_optimizers`,
# so we copy the codes and swap their order.
self._validate_setup(model, optimizers)

model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))

# add IPEX 1.11's optimization
if self.use_ipex and not TORCH_VERSION_LESS_1_10:
dtype = torch.bfloat16 if self.enable_bf16 else None
Expand All @@ -139,13 +146,6 @@ def _setup(
else:
invalidInputError(False, "Ipex does not support more than one optimizers.")

# the following codes are copied from pl's LightningLite's `setup` method,
# ipex 1.9 requires `_move_model_to_device` after `_setup_model_and_optimizers`, but
# pl's `setup` method calls `_move_model_to_device` before `_setup_model_and_optimizers`,
# so we copy the codes and swap their order.
self._validate_setup(model, optimizers)

model, optimizers = self._strategy._setup_model_and_optimizers(model, optimizers)
if move_to_device:
model = self._move_model_to_device(model=model, optimizers=optimizers)
model = _TorchNanoModule(model, self._precision_plugin)
Expand Down
3 changes: 1 addition & 2 deletions python/nano/test/pytorch/tests/test_torch_nano_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train(self, optimizer_supported: bool = False):
model = ResNet18(10, pretrained=False, include_top=False, freeze=True)
loss_func = nn.CrossEntropyLoss()
if optimizer_supported:
optimizer = torch.optim.SGD(model.parameters, lr=0.01)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform)
Expand All @@ -66,7 +66,6 @@ def train(self, optimizer_supported: bool = False):
loss = loss_func(model(X), y)
self.backward(loss)
optimizer.step()

total_loss += loss.sum()
num += 1
print(f'avg_loss: {total_loss / num}')
Expand Down

0 comments on commit ff461c7

Please sign in to comment.