-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//cpp/ptq/training): Training recipe for VGG16 Classifier on
CIFAR10 for ptq example Gets about 90-91% accuracy, initial LR 0.01, dropout 0.15 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
- Loading branch information
1 parent
8580106
commit 676bf56
Showing
3 changed files
with
270 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,7 @@ py/tmp/ | |
py/.eggs | ||
.vscode/ | ||
.DS_Store | ||
._DS_Store | ||
._DS_Store | ||
*.pth | ||
*.pyc | ||
cpp/ptq/training/vgg16/data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
import argparse | ||
import os | ||
import random | ||
from datetime import datetime | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
|
||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from vgg16 import vgg16 | ||
|
||
PARSER = argparse.ArgumentParser(description="VGG16 example to use with TRTorch PTQ") | ||
PARSER.add_argument('--epochs', default=300, type=int, help="Number of total epochs to train") | ||
PARSER.add_argument('--batch-size', default=128, type=int, help="Batch size to use when training") | ||
PARSER.add_argument('--lr', default=0.1, type=float, help="Initial learning rate") | ||
PARSER.add_argument('--drop-ratio', default=0., type=float, help="Dropout ratio") | ||
PARSER.add_argument('--momentum', default=0.9, type=float, help="Momentum") | ||
PARSER.add_argument('--weight-decay', default=5e-4, type=float, help="Weight decay") | ||
PARSER.add_argument('--ckpt-dir', default="/tmp/vgg16_ckpts", type=str, help="Path to save checkpoints (saved every 10 epochs)") | ||
PARSER.add_argument('--start-from', default=0, type=int, help="Epoch to resume from (requires a checkpoin in the providied checkpoi") | ||
PARSER.add_argument('--seed', type=int, help='Seed value for rng') | ||
PARSER.add_argument('--tensorboard', type=str, default='/tmp/vgg16_logs', help='Location for tensorboard info') | ||
|
||
args = PARSER.parse_args() | ||
for arg in vars(args): | ||
print(' {} {}'.format(arg, getattr(args, arg))) | ||
state = {k: v for k, v in args._get_kwargs()} | ||
|
||
if args.seed is None: | ||
args.seed = random.randint(1, 10000) | ||
random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.cuda.manual_seed_all(args.seed) | ||
print("RNG seed used: ", args.seed) | ||
|
||
now = datetime.now() | ||
|
||
timestamp = datetime.timestamp(now) | ||
|
||
writer = SummaryWriter(args.tensorboard + '/test_' + str(timestamp)) | ||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
|
||
def main(): | ||
global state | ||
global classes | ||
global writer | ||
if not os.path.isdir(args.ckpt_dir): | ||
os.makedirs(args.ckpt_dir) | ||
|
||
training_dataset = datasets.CIFAR10(root='./data', train=True, | ||
download=True, transform=transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), | ||
(0.2023, 0.1994, 0.2010)), | ||
])) | ||
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=args.batch_size, | ||
shuffle=True, num_workers=2) | ||
|
||
testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), | ||
(0.2023, 0.1994, 0.2010)), | ||
])) | ||
|
||
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=args.batch_size, | ||
shuffle=False, num_workers=2) | ||
|
||
num_classes = len(classes) | ||
|
||
model = vgg16(num_classes=num_classes, init_weights=False) | ||
model = model.cuda() | ||
|
||
data = iter(training_dataloader) | ||
images, _ = data.next() | ||
|
||
writer.add_graph(model, images.cuda()) | ||
writer.close() | ||
|
||
crit = nn.CrossEntropyLoss() | ||
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) | ||
|
||
if args.start_from != 0: | ||
ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth' | ||
print('Loading from checkpoint {}'.format(ckpt_file)) | ||
assert(os.path.isfile(ckpt_file)) | ||
ckpt = torch.load(ckpt_file) | ||
model.load_state_dict(ckpt["model_state_dict"]) | ||
opt.load_state_dict(ckpt["opt_state_dict"]) | ||
state = ckpt["state"] | ||
|
||
if torch.cuda.device_count() > 1: | ||
model = nn.DataParallel(model) | ||
|
||
for epoch in range(args.start_from, args.epochs): | ||
adjust_lr(opt, epoch) | ||
writer.add_scalar('Learning Rate', state["lr"], epoch) | ||
writer.close() | ||
print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) | ||
|
||
train(model, training_dataloader, crit, opt, epoch) | ||
test_loss, test_acc = test(model, testing_dataloader, crit, epoch) | ||
|
||
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) | ||
|
||
if epoch % 10 == 9: | ||
save_checkpoint({ | ||
'epoch': epoch + 1, | ||
'model_state_dict': model.state_dict(), | ||
'acc': test_acc, | ||
'opt_state_dict' : opt.state_dict(), | ||
'state': state | ||
}, ckpt_dir=args.ckpt_dir) | ||
|
||
def train(model, dataloader, crit, opt, epoch): | ||
global writer | ||
model.train() | ||
running_loss = 0.0 | ||
for batch, (data, labels) in enumerate(dataloader): | ||
data, labels = data.cuda(), labels.cuda(async=True) | ||
opt.zero_grad() | ||
out = model(data) | ||
loss = crit(out, labels) | ||
loss.backward() | ||
opt.step() | ||
|
||
running_loss += loss.item() | ||
if batch % 50 == 49: | ||
writer.add_scalar('Training Loss', running_loss / 100, epoch * len(dataloader) + batch) | ||
writer.close() | ||
print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100)) | ||
running_loss = 0.0 | ||
|
||
def test(model, dataloader, crit, epoch): | ||
global writer | ||
global classes | ||
total = 0 | ||
correct = 0 | ||
loss = 0.0 | ||
class_probs = [] | ||
class_preds = [] | ||
model.eval() | ||
with torch.no_grad(): | ||
for data, labels in dataloader: | ||
data, labels = data.cuda(), labels.cuda(async=True) | ||
out = model(data) | ||
loss += crit(out, labels) | ||
preds = torch.max(out, 1)[1] | ||
class_probs.append([F.softmax(i, dim=0) for i in out]) | ||
class_preds.append(preds) | ||
total += labels.size(0) | ||
correct += (preds == labels).sum().item() | ||
|
||
writer.add_scalar('Testing Loss', loss / total, epoch) | ||
writer.close() | ||
|
||
writer.add_scalar('Testing Accuracy', correct / total * 100, epoch) | ||
writer.close() | ||
|
||
test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) | ||
test_preds = torch.cat(class_preds) | ||
for i in range(len(classes)): | ||
add_pr_curve_tensorboard(i, test_probs, test_preds, epoch) | ||
return loss / total, correct / total | ||
|
||
|
||
def save_checkpoint(state, ckpt_dir='checkpoint'): | ||
print("Checkpoint {} saved".format(state['epoch'])) | ||
filename = "ckpt_epoch" + str(state['epoch']) + ".pth" | ||
filepath = os.path.join(ckpt_dir, filename) | ||
torch.save(state, filepath) | ||
|
||
def adjust_lr(optimizer, epoch): | ||
global state | ||
new_lr = state["lr"] * (0.5 ** (epoch // 50)) if state["lr"] > 1e-7 else state["lr"] | ||
if new_lr != state["lr"]: | ||
state["lr"] = new_lr | ||
print("Updating learning rate: {}".format(state["lr"])) | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = state["lr"] | ||
|
||
def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0): | ||
global classes | ||
''' | ||
Takes in a "class_index" from 0 to 9 and plots the corresponding | ||
precision-recall curve | ||
''' | ||
tensorboard_preds = test_preds == class_index | ||
tensorboard_probs = test_probs[:, class_index] | ||
|
||
writer.add_pr_curve(classes[class_index], | ||
tensorboard_preds, | ||
tensorboard_probs, | ||
global_step=global_step) | ||
writer.close() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from functools import reduce | ||
|
||
class VGG(nn.Module): | ||
def __init__(self, layer_spec, num_classes=1000, init_weights=False): | ||
super(VGG, self).__init__() | ||
|
||
layers = [] | ||
in_channels = 3 | ||
for l in layer_spec: | ||
if l == 'pool': | ||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) | ||
else: | ||
layers += [ | ||
nn.Conv2d(in_channels, l, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(l), | ||
nn.ReLU() | ||
] | ||
in_channels = l | ||
|
||
self.features = nn.Sequential(*layers) | ||
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) | ||
self.classifier = nn.Sequential( | ||
nn.Linear(512 * 7 * 7, 4096), | ||
nn.ReLU(), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(), | ||
nn.Dropout(), | ||
nn.Linear(4096, num_classes) | ||
) | ||
if init_weights: | ||
self._initialize_weights() | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.normal_(m.weight, 0, 0.01) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.avgpool(x) | ||
x = torch.flatten(x,1) | ||
x = self.classifier(x) | ||
return x | ||
|
||
def vgg16(num_classes=1000, init_weights=False): | ||
vgg16_cfg = [64, 64, 'pool', 128, 128, 'pool', 256, 256, 256, 256, 'pool', 512, 512, 512, 512, 'pool', 512, 512, 512, 512, 'pool'] | ||
return VGG(vgg16_cfg, num_classes, init_weights) |