-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add nano pytorch training bf16 example (#5600)
* rollback requirement-doc * Add Nano PyTorch BF16 Example * Add comments and remove testing loop * Update
- Loading branch information
Showing
1 changed file
with
119 additions
and
0 deletions.
There are no files selected for viewing
119 changes: 119 additions & 0 deletions
119
python/nano/tutorial/training/pytorch/pytorch_train_bf16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.models import resnet18 | ||
from torchvision.datasets import OxfordIIITPet | ||
from torchmetrics import Accuracy | ||
|
||
from bigdl.nano.pytorch.torch_nano import TorchNano | ||
|
||
|
||
class MyPytorchModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.model = resnet18(pretrained=True) | ||
num_ftrs = self.model.fc.in_features | ||
# Here the size of each output sample is set to 37. | ||
self.model.fc = nn.Linear(num_ftrs, 37) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
|
||
def create_dataloaders(): | ||
train_transform = transforms.Compose([transforms.Resize(256), | ||
transforms.RandomCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ColorJitter(brightness=.5, hue=.3), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], | ||
[0.229, 0.224, 0.225])]) | ||
val_transform = transforms.Compose([transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], | ||
[0.229, 0.224, 0.225])]) | ||
|
||
# Apply data augmentation to the tarin_dataset | ||
train_dataset = OxfordIIITPet(root="/tmp/data", transform=train_transform, download=True) | ||
val_dataset = OxfordIIITPet(root="/tmp/data", transform=val_transform) | ||
|
||
# obtain training indices that will be used for validation | ||
indices = torch.randperm(len(train_dataset)) | ||
val_size = len(train_dataset) // 4 | ||
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) | ||
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) | ||
|
||
# prepare data loaders | ||
train_dataloader = DataLoader(train_dataset, batch_size=32) | ||
val_dataloader = DataLoader(val_dataset, batch_size=32) | ||
|
||
return train_dataloader, val_dataloader | ||
|
||
|
||
# subclass TorchNano and override its train() method | ||
class MyNano(TorchNano): | ||
# move the body of your existing train function into TorchNano train method | ||
def train(self): | ||
model = MyPytorchModule() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | ||
loss_fuc = torch.nn.CrossEntropyLoss() | ||
train_loader, val_loader = create_dataloaders() | ||
|
||
# call `setup` to prepare for model, optimizer(s) and dataloader(s) for accelerated training | ||
model, optimizer, (train_loader, val_loader) = self.setup(model, optimizer, | ||
train_loader, val_loader) | ||
num_epochs = 5 | ||
|
||
# EPOCH LOOP | ||
for epoch in range(num_epochs): | ||
|
||
# TRAINING LOOP | ||
model.train() | ||
train_loss, num = 0, 0 | ||
for data, target in train_loader: | ||
optimizer.zero_grad() | ||
output = model(data) | ||
# replace the loss.backward() with self.backward(loss) | ||
loss = loss_fuc(output, target) | ||
self.backward(loss) | ||
optimizer.step() | ||
|
||
train_loss += loss.sum() | ||
num += 1 | ||
print(f'Train Epoch: {epoch}, avg_loss: {train_loss / num}') | ||
|
||
|
||
if __name__ == '__main__': | ||
# BF16 Training | ||
# | ||
# BFloat16 is a custom 16-bit floating point format for machine learning | ||
# that’s comprised of one sign bit, eight exponent bits, and seven mantissa bits. | ||
# BFloat16 has a greater "dynamic range" than FP16. This means it is able to | ||
# improve numerical stability than FP16 while delivering increased performance | ||
# and reducing memory usage. | ||
# | ||
# In BigDL-Nano, you can easily use BFloat16 mixed precision to accelerates PyTorch training | ||
# through TorchNano by setting precision='bf16'. | ||
# | ||
# Note: Using BFloat16 precision with torch < 1.12 may result in extremely slow training. | ||
MyNano(precision='bf16').train() | ||
# You can also set use_ipex=True and precision='bf16' to enable ipex optimizer fusion | ||
# for bf16 to gain more acceleration from BFloat16 data type. | ||
MyNano(use_ipex=True, precision='bf16').train() |