Skip to content

Commit

Permalink
refine training API for TensorFlow and PyTorch framework (intel#1055)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng authored Jul 11, 2022
1 parent 909f22e commit f082424
Show file tree
Hide file tree
Showing 46 changed files with 853 additions and 448 deletions.
4 changes: 2 additions & 2 deletions docs/api-introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ class Pruning(object):
def on_epoch_begin(self, epoch):
...

def on_batch_begin(self, batch_id):
def on_step_begin(self, batch_id):
...

def on_batch_end(self):
def on_step_end(self):
...

def on_epoch_end(self):
Expand Down
21 changes: 11 additions & 10 deletions docs/distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class Distillation():
# The criterion used in training phase. It is optional if criterion is configured in user-define yaml.
...

def pre_epoch_begin(self):
def on_train_begin(self):
# The hook point used by distillation algorithm
...

def on_epoch_end(self):
# The hook point used by distillation algorithm
...

def on_post_forward(self, batch, teacher_output=None):
def on_after_compute_loss(self, input, student_output, student_loss, teacher_output=None):
# The hook point used by distillation algorithm
...

Expand Down Expand Up @@ -194,16 +194,16 @@ User can pass the customized training/evaluation functions to `Distillation` for
Neural Compressor defines several hooks for user pass

```
pre_epoch_begin() : Hook executed before training begins
on_post_forward(batch) : Hook executed after each batch inference of student model
on_train_begin() : Hook executed before training begins
on_after_compute_loss(input, student_output, student_loss) : Hook executed after each batch inference of student model
on_epoch_end() : Hook executed at each epoch end
```
Following section shows how to use hooks in user pass-in training function which is part of example from BlendCNN distillation:
```python
def train_func(model):
distiller.pre_epoch_begin()
distiller.on_train_begin()
for nepoch in range(epochs):
model.train()
cnt = 0
Expand All @@ -213,11 +213,12 @@ def train_func(model):
teacher_logits, input_ids, segment_ids, input_mask, target = batch
cnt += 1
output = model(input_ids, segment_ids, input_mask)
distiller.on_post_forward({'input_ids':input_ids,
'segment_ids':segment_ids,
'input_mask':input_mask}, \
teacher_logits)
loss = distiller.criterion(output, target)
loss = criterion(output, target)
loss = distiller.on_after_compute_loss(
{'input_ids':input_ids, 'segment_ids':segment_ids, 'input_mask':input_mask},
output,
loss,
teacher_logits)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Expand Down
14 changes: 7 additions & 7 deletions docs/pruning.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ In this case, the launcher code is like the following:
from neural_compressor.experimental import Pruning, common
prune = Pruning(args.config)
prune.model = model
prune.pruning_func = pruning_func
prune.train_func = pruning_func
model = prune.fit()
```

Expand All @@ -153,10 +153,10 @@ Neural Compressor defines several hooks for user use:

```
on_epoch_begin(epoch) : Hook executed at each epoch beginning
on_batch_begin(batch) : Hook executed at each batch beginning
on_batch_end() : Hook executed at each batch end
on_step_begin(batch) : Hook executed at each batch beginning
on_step_end() : Hook executed at each batch end
on_epoch_end() : Hook executed at each epoch end
on_post_grad() : Hook executed after gradients calculated and before backward
on_before_optimizer_step() : Hook executed after gradients calculated and before backward
```

Following section shows how to use hooks in user pass-in training function which is part of example from BERT training:
Expand All @@ -168,7 +168,7 @@ def pruning_func(model):
model.train()
prune.on_epoch_begin(epoch)
for step, batch in enumerate(train_dataloader):
prune.on_batch_begin(step)
prune.on_step_begin(step)
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
Expand All @@ -185,13 +185,13 @@ def pruning_func(model):
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)


if (step + 1) % args.gradient_accumulation_steps == 0:
prune.on_before_optimizer_step()
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()

prune.on_batch_end()
prune.on_step_end()
...
```

Expand Down
10 changes: 5 additions & 5 deletions docs/pruning_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class Pruning():
# This attribute needs to be set before invoking self.__call__().
...
@pruning_func.setter
def pruning_func(self, user_pruning_func)
@train_func.setter
def train_func(self, user_pruning_func)
# The training function provided by user. This function takes framework runtime model object as input parameter,
# and executes entire training process with self contained training hyper-parameters.
# It is optional if training could be configured by neural_compressor built-in dataloader/optimizer/criterion.
Expand Down Expand Up @@ -53,15 +53,15 @@ class Pruning():
# The hook point used by pruning algorithm
...
def on_batch_begin(self, batch):
def on_step_begin(self, batch):
# The hook point used by pruning algorithm
...
def on_batch_end(self):
def on_step_end(self):
# The hook point used by pruning algorithm
...
def on_post_grad(self):
def on_before_optimizer_step(self):
# The hook point used by pruning algorithm
...
Expand Down
24 changes: 13 additions & 11 deletions examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'second for teacher student loss weight.')
parser.set_defaults(augment=True)


def set_seed(seed):
import random
import numpy as np
Expand All @@ -70,22 +71,23 @@ def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def main():
global args, best_prec1
args, _ = parser.parse_known_args()
best_prec1 = 0
if args.seed is not None:
set_seed(args.seed)
if args.tensorboard: configure("runs/%s"%(args.name))
if args.tensorboard: configure("runs/%s" % (args.name))

# Data loading code
normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761])

if args.augment:
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4,4,4,4),mode='reflect').squeeze()),
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4),mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
Expand All @@ -110,7 +112,7 @@ def main():
else:
raise NotImplementedError('Unsupported teacher model type')
teacher_model.load_state_dict(torch.load(args.teacher_model)['state_dict'])

if args.student_type == 'CNN-2':
student_model = ConvNetMaker(plane_cifar100_book['2'])
elif args.student_type == 'VGG-8':
Expand Down Expand Up @@ -187,7 +189,7 @@ def eval_func(model):

from neural_compressor.experimental import Distillation, common
from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss

distiller = Distillation(args.config)
distiller.teacher_model = common.Model(teacher_model)
distiller.student_model = common.Model(student_model)
Expand All @@ -199,15 +201,15 @@ def eval_func(model):
loss_types=args.loss_types,
loss_weights=args.loss_weights)
model = distiller.fit()

directory = "runs/%s/"%(args.name)
os.makedirs(directory, exist_ok=True)
model.save(directory)
# change to framework model for further use
model = model.model

def train(train_loader, model, scheduler, distiller, best_prec1):
distiller.pre_epoch_begin()
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
batch_time = AverageMeter()
Expand All @@ -226,9 +228,9 @@ def train(train_loader, model, scheduler, distiller, best_prec1):

# compute output
output = model(input)

distiller.on_post_forward(input, teacher_logits)

loss = distiller.criterion(output, target)
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)

# measure accuracy and record loss
prec1 = accuracy(output.data, target, topk=(1,))[0]
Expand Down Expand Up @@ -348,4 +350,4 @@ def accuracy(output, target, topk=(1,)):
return res

if __name__ == '__main__':
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def eval_func(model):
model = model.model

def train(train_loader, model, scheduler, distiller, best_prec1):
distiller.pre_epoch_begin()
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
batch_time = AverageMeter()
Expand All @@ -218,9 +218,8 @@ def train(train_loader, model, scheduler, distiller, best_prec1):

# compute output
output = model(input)

distiller.on_post_forward(input, teacher_logits)
loss = distiller.criterion(output, target)
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)

# measure accuracy and record loss
prec1 = accuracy(output.data, target, topk=(1,))[0]
Expand Down Expand Up @@ -340,4 +339,4 @@ def accuracy(output, target, topk=(1,)):
return res

if __name__ == '__main__':
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def eval_func(model):
model = model.model

def train(train_loader, model, scheduler, distiller, best_prec1):
distiller.pre_epoch_begin()
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
batch_time = AverageMeter()
Expand All @@ -226,9 +226,8 @@ def train(train_loader, model, scheduler, distiller, best_prec1):

# compute output
output = model(input)

distiller.on_post_forward(input, teacher_logits)
loss = distiller.criterion(output, target)
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)

# measure accuracy and record loss
prec1 = accuracy(output.data, target, topk=(1,))[0]
Expand Down Expand Up @@ -268,7 +267,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
log_value('train_loss', losses.avg, epoch)
log_value('train_acc', top1.avg, epoch)
log_value('learning_rate', scheduler._last_lr[0], epoch)


def validate(val_loader, model, distiller):
"""Perform validation on the validation set"""
Expand Down Expand Up @@ -348,4 +347,4 @@ def accuracy(output, target, topk=(1,)):
return res

if __name__ == '__main__':
main()
main()
Loading

0 comments on commit f082424

Please sign in to comment.