diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md
index 14841061fdc..d449d5f797b 100644
--- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md
+++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md
@@ -8,4 +8,15 @@ pip install -r requirements.txt
python train_without_distillation.py --epochs 200 --lr 0.1 --layers 40 --widen-factor 2 --name WideResNet-40-2 --tensorboard
# for distillation of the teacher model WideResNet40-2 to the student model MobileNetV2-0.35
python main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9
+```
+
+We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
+
+For example, bash command will look like the following, where *``* is the address of the master node, it won't be necessary for single node case, *``* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *``* is the number of nodes to use, *``* is the rank of the current node, rank starts from 0 to *``*`-1`.
+
+Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *``*, which should be integer from 0 to *``*`-1` assigned to each node.
+
+```bash
+python -m torch.distributed.launch --master_addr= --nproc_per_node= --nnodes= --node_rank= \
+ main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9
```
\ No newline at end of file
diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py
index e7f4e56888b..3778162d968 100644
--- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py
+++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py
@@ -10,6 +10,7 @@
import torchvision.datasets as datasets
import torchvision.transforms as transforms
+from accelerate import Accelerator
from wideresnet import WideResNet
# used for logging to TensorBoard
@@ -60,6 +61,7 @@
help='loss weights of distillation, should be a list of length 2, '
'and sum to 1.0, first for student targets loss weight, '
'second for teacher student loss weight.')
+parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.')
parser.set_defaults(augment=True)
def set_seed(seed):
@@ -73,10 +75,13 @@ def set_seed(seed):
def main():
global args, best_prec1
args, _ = parser.parse_known_args()
+ accelerator = Accelerator(cpu=args.no_cuda)
+
best_prec1 = 0
if args.seed is not None:
set_seed(args.seed)
- if args.tensorboard: configure("runs/%s"%(args.name))
+ with accelerator.local_main_process_first():
+ if args.tensorboard: configure("runs/%s"%(args.name))
# Data loading code
normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
@@ -111,9 +116,9 @@ def main():
student_model = mobilenet.MobileNetV2(num_classes=10, width_mult=0.35)
# get the number of model parameters
- print('Number of teacher model parameters: {}'.format(
+ accelerator.print('Number of teacher model parameters: {}'.format(
sum([p.data.nelement() for p in teacher_model.parameters()])))
- print('Number of student model parameters: {}'.format(
+ accelerator.print('Number of student model parameters: {}'.format(
sum([p.data.nelement() for p in student_model.parameters()])))
kwargs = {'num_workers': 0, 'pin_memory': True}
@@ -125,10 +130,10 @@ def main():
if args.loss_weights[1] > 0:
from tqdm import tqdm
def get_logits(teacher_model, train_dataset):
- print("***** Getting logits of teacher model *****")
- print(f" Num examples = {len(train_dataset) }")
+ accelerator.print("***** Getting logits of teacher model *****")
+ accelerator.print(f" Num examples = {len(train_dataset) }")
logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy')
- if not os.path.exists(logits_file):
+ if not os.path.exists(logits_file) and accelerator.is_local_main_process:
teacher_model.eval()
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)
train_dataloader = tqdm(train_dataloader, desc="Evaluating")
@@ -137,8 +142,8 @@ def get_logits(teacher_model, train_dataset):
outputs = teacher_model(input)
teacher_logits += [x for x in outputs.numpy()]
np.save(logits_file, np.array(teacher_logits))
- else:
- teacher_logits = np.load(logits_file)
+ accelerator.wait_for_everyone()
+ teacher_logits = np.load(logits_file)
train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \
for l, tl in zip(train_dataset.targets, teacher_logits)]
return train_dataset
@@ -153,15 +158,15 @@ def get_logits(teacher_model, train_dataset):
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
- print("=> loading checkpoint '{}'".format(args.resume))
+ accelerator.print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
student_model.load_state_dict(checkpoint['state_dict'])
- print("=> loaded checkpoint '{}' (epoch {})"
+ accelerator.print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
- print("=> no checkpoint found at '{}'".format(args.resume))
+ accelerator.print("=> no checkpoint found at '{}'".format(args.resume))
# define optimizer
optimizer = torch.optim.SGD(student_model.parameters(), args.lr,
@@ -169,13 +174,18 @@ def get_logits(teacher_model, train_dataset):
weight_decay=args.weight_decay)
# cosine learning rate
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer, len(train_loader) * args.epochs // accelerator.num_processes
+ )
+
+ student_model, teacher_model, train_loader, val_loader, optimizer = \
+ accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer)
def train_func(model):
- return train(train_loader, model, scheduler, distiller, best_prec1)
+ return train(train_loader, model, scheduler, distiller, best_prec1, accelerator)
def eval_func(model):
- return validate(val_loader, model, distiller)
+ return validate(val_loader, model, distiller, accelerator)
from neural_compressor.experimental import Distillation, common
from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss
@@ -194,11 +204,12 @@ def eval_func(model):
directory = "runs/%s/"%(args.name)
os.makedirs(directory, exist_ok=True)
+ model._model = accelerator.unwrap_model(model.model)
model.save(directory)
# change to framework model for further use
model = model.model
-def train(train_loader, model, scheduler, distiller, best_prec1):
+def train(train_loader, model, scheduler, distiller, best_prec1, accelerator):
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
@@ -222,13 +233,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)
# measure accuracy and record loss
+ output = accelerator.gather(output)
+ target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
- losses.update(loss.data.item(), input.size(0))
- top1.update(prec1.item(), input.size(0))
+ losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes)
+ top1.update(prec1.item(), input.size(0)*accelerator.num_processes)
# compute gradient and do SGD step
distiller.optimizer.zero_grad()
- loss.backward()
+ accelerator.backward(loss) # loss.backward()
distiller.optimizer.step()
scheduler.step()
@@ -237,7 +250,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
end = time.time()
if i % args.print_freq == 0:
- print('Epoch: [{0}][{1}/{2}]\t'
+ accelerator.print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
@@ -249,19 +262,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
# remember best prec@1 and save checkpoint
is_best = distiller.best_score > best_prec1
best_prec1 = max(distiller.best_score, best_prec1)
- save_checkpoint({
- 'epoch': distiller._epoch_runned + 1,
- 'state_dict': model.state_dict(),
- 'best_prec1': best_prec1,
- }, is_best)
- # log to TensorBoard
- if args.tensorboard:
- log_value('train_loss', losses.avg, epoch)
- log_value('train_acc', top1.avg, epoch)
- log_value('learning_rate', scheduler._last_lr[0], epoch)
+ if accelerator.is_local_main_process:
+ save_checkpoint({
+ 'epoch': distiller._epoch_runned + 1,
+ 'state_dict': model.state_dict(),
+ 'best_prec1': best_prec1,
+ }, is_best)
+ # log to TensorBoard
+ if args.tensorboard:
+ log_value('train_loss', losses.avg, epoch)
+ log_value('train_acc', top1.avg, epoch)
+ log_value('learning_rate', scheduler._last_lr[0], epoch)
-def validate(val_loader, model, distiller):
+def validate(val_loader, model, distiller, accelerator):
"""Perform validation on the validation set"""
batch_time = AverageMeter()
top1 = AverageMeter()
@@ -276,6 +290,8 @@ def validate(val_loader, model, distiller):
output = model(input)
# measure accuracy
+ output = accelerator.gather(output)
+ target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
top1.update(prec1.item(), input.size(0))
@@ -284,15 +300,15 @@ def validate(val_loader, model, distiller):
end = time.time()
if i % args.print_freq == 0:
- print('Test: [{0}/{1}]\t'
+ accelerator.print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time,
top1=top1))
- print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
+ accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
# log to TensorBoard
- if args.tensorboard:
+ if accelerator.is_local_main_process and args.tensorboard:
log_value('val_acc', top1.avg, distiller._epoch_runned)
return top1.avg
diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt
index 8db2f310ef5..71252629880 100644
--- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt
+++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt
@@ -2,3 +2,4 @@
torch==1.5.0+cpu
torchvision==0.6.0+cpu
tensorboard_logger
+accelerate
\ No newline at end of file