Skip to content

Commit

Permalink
Nano: Add nano pytorch example (#5570)
Browse files Browse the repository at this point in the history
* rollback requirement-doc

* Add example

* Update

* Update

* Update

* Update

* Add comments and remove testing loop

* Add comments about linear scale and warmup

* remove .keep file
  • Loading branch information
y199387 authored Sep 7, 2022
1 parent 41eb2a0 commit 543d0a4
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 0 deletions.
Empty file.
114 changes: 114 additions & 0 deletions python/nano/tutorial/training/pytorch/pytorch_cv_data_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchmetrics import Accuracy

from bigdl.nano.pytorch.torch_nano import TorchNano


class MyPytorchModule(nn.Module):
def __init__(self):
super().__init__()
self.model = resnet18(pretrained=True)
num_ftrs = self.model.fc.in_features
# Here the size of each output sample is set to 37.
self.model.fc = nn.Linear(num_ftrs, 37)

def forward(self, x):
return self.model(x)


def create_dataloaders():
# CV Data Pipelines
#
# Computer Vision task often needs a data processing pipeline that sometimes constitutes a
# non-trivial part of the whole training pipeline.
# BigDL-Nano can accelerate computer vision data pipelines.
#
# BigDL-Nano can accelerate computer vision data pipelines
# by providing a drop-in replacement of torch_vision’s datasets and transforms
#
from bigdl.nano.pytorch.vision import transforms
from bigdl.nano.pytorch.vision.datasets import OxfordIIITPet
train_transform = transforms.Compose([transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

# Apply data augmentation to the tarin_dataset
train_dataset = OxfordIIITPet(root="/tmp/data", transform=train_transform, download=True)
val_dataset = OxfordIIITPet(root="/tmp/data", transform=val_transform)

# obtain training indices that will be used for validation
indices = torch.randperm(len(train_dataset))
val_size = len(train_dataset) // 4
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])

# prepare data loaders
train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

return train_dataloader, val_dataloader


# subclass TorchNano and override its train() method
class MyNano(TorchNano):
# move the body of your existing train function into TorchNano train method
def train(self):
model = MyPytorchModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
loss_fuc = torch.nn.CrossEntropyLoss()

train_loader, val_loader = create_dataloaders()

# call `setup` to prepare for model, optimizer(s) and dataloader(s) for accelerated training
model, optimizer, (train_loader, val_loader) = self.setup(model, optimizer,
train_loader, val_loader)
num_epochs = 5

# EPOCH LOOP
for epoch in range(num_epochs):

# TRAINING LOOP
model.train()
train_loss, num = 0, 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fuc(output, target)
# replace the loss.backward() with self.backward(loss)
self.backward(loss)
optimizer.step()

train_loss += loss.sum()
num += 1
print(f'Train Epoch: {epoch}, avg_loss: {train_loss / num}')


if __name__ == '__main__':
MyNano().train()
113 changes: 113 additions & 0 deletions python/nano/tutorial/training/pytorch/pytorch_train_ipex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import OxfordIIITPet
from torchmetrics import Accuracy

from bigdl.nano.pytorch.torch_nano import TorchNano


class MyPytorchModule(nn.Module):
def __init__(self):
super().__init__()
self.model = resnet18(pretrained=True)
num_ftrs = self.model.fc.in_features
# Here the size of each output sample is set to 37.
self.model.fc = nn.Linear(num_ftrs, 37)

def forward(self, x):
return self.model(x)


def create_dataloaders():
train_transform = transforms.Compose([transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

# Apply data augmentation to the tarin_dataset
train_dataset = OxfordIIITPet(root="/tmp/data", transform=train_transform, download=True)
val_dataset = OxfordIIITPet(root="/tmp/data", transform=val_transform)

# obtain training indices that will be used for validation
indices = torch.randperm(len(train_dataset))
val_size = len(train_dataset) // 4
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])

# prepare data loaders
train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

return train_dataloader, val_dataloader


# subclass TorchNano and override its train() method
class MyNano(TorchNano):
# move the body of your existing train function into TorchNano train method
def train(self):
model = MyPytorchModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
loss_fuc = torch.nn.CrossEntropyLoss()
train_loader, val_loader = create_dataloaders()

# call `setup` to prepare for model, optimizer(s) and dataloader(s) for accelerated training
model, optimizer, (train_loader, val_loader) = self.setup(model, optimizer,
train_loader, val_loader)
num_epochs = 5

# EPOCH LOOP
for epoch in range(num_epochs):

# TRAINING LOOP
model.train()
train_loss, num = 0, 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fuc(output, target)
# replace the loss.backward() with self.backward(loss)
self.backward(loss)
optimizer.step()

train_loss += loss.sum()
num += 1
print(f'Train Epoch: {epoch}, avg_loss: {train_loss / num}')


if __name__ == '__main__':
# IPEX Accelerated Training
#
# Intel Extension for PyTorch (a.k.a. IPEX) ecapsulates
# several optimizations for PyTorch and offers an extra
# performance boost on Intel hardware.
#
# In BigDL-Nano, you can easily use IPEX to accelerate custom pytorch training loops
# through the TorchNano by setting use_ipex=True.
#
MyNano(use_ipex=True).train()
129 changes: 129 additions & 0 deletions python/nano/tutorial/training/pytorch/pytorch_train_multi_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import OxfordIIITPet
from torchmetrics import Accuracy

from pytorch_lightning import seed_everything
from bigdl.nano.pytorch.torch_nano import TorchNano


class MyPytorchModule(nn.Module):
def __init__(self):
super().__init__()
self.model = resnet18(pretrained=True)
num_ftrs = self.model.fc.in_features
# Here the size of each output sample is set to 37.
self.model.fc = nn.Linear(num_ftrs, 37)

def forward(self, x):
return self.model(x)


def create_dataloaders():
train_transform = transforms.Compose([transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

# Apply data augmentation to the tarin_dataset
train_dataset = OxfordIIITPet(root="/tmp/data", transform=train_transform, download=True)
val_dataset = OxfordIIITPet(root="/tmp/data", transform=val_transform)

# obtain training indices that will be used for validation
indices = torch.randperm(len(train_dataset))
val_size = len(train_dataset) // 4
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])

# prepare data loaders
train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

return train_dataloader, val_dataloader


# subclass TorchNano and override its train() method
class MyNano(TorchNano):
# move the body of your existing train function into TorchNano train method
def train(self):
seed_everything(42)
model = MyPytorchModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
loss_fuc = torch.nn.CrossEntropyLoss()
train_loader, val_loader = create_dataloaders()

# call `setup` to prepare for model, optimizer(s) and dataloader(s) for accelerated training
model, optimizer, (train_loader, val_loader) = self.setup(model, optimizer,
train_loader, val_loader)
num_epochs = 5

# EPOCH LOOP
for epoch in range(num_epochs):

# TRAINING LOOP
model.train()
train_loss, num = 0, 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fuc(output, target)
# replace the loss.backward() with self.backward(loss)
self.backward(loss)
optimizer.step()

train_loss += loss.sum()
num += 1
print(f'Train Epoch: {epoch}, loss: {train_loss/num}')


if __name__ == '__main__':
# Multi-instance Training
#
# It is often beneficial to use multiple instances
# for training if a server contains multiple sockets or
# many cores, so that the workload can make full use of
# all CPU cores.
#
# When using data-parallel training, the batch size is equivalent to
# becoming n times larger, where n is the number of parallel processes.
# We should to scale the learning rate to n times as well to achieve the
# same effect as single instance training.
# However, scaling the learning rate linearly may lead to poor convergence
# at the beginning of training, so we should gradually increase the
# learning rate to n times, and this is called 'learning rate warmup'.
#
# Fortunately, BigDL-Nano makes it very easy to conduct multi-instance
# training correctly. It will handle all these for you.
#
# In BigDL-Nano, you can simply set the num_processes in
# TorchNano to enable multi-instance training. In addition, it will automatically
# apply learning rate scaling and warmup for your training.
#
MyNano(num_processes=2).train()

0 comments on commit 543d0a4

Please sign in to comment.