Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate the trainer to Lightning #43

Merged
merged 9 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
pip install pytest-cov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install opencv-python
pip install pycocotools>=2.0.2
- name: Install PyTorch ${{ matrix.torch }} Version
run: |
pip install ${{ matrix.pip_address }}
Expand Down
19 changes: 19 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.utils.data
import torchvision

from models.transform import nested_tensor_from_tensor_list

from .coco import build as build_coco
from .voc import build as build_voc

Expand All @@ -16,6 +18,23 @@ def get_coco_api_from_dataset(dataset):
return dataset.coco


def collate_fn(batch):
batch = list(zip(*batch))
samples = nested_tensor_from_tensor_list(batch[0])

targets = []
for i, target in enumerate(batch[1]):
num_objects = len(target['labels'])
if num_objects > 0:
targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32)
targets_merged[:, 1] = target['labels']
targets_merged[:, 2:] = target['boxes']
targets.append(targets_merged)
targets = torch.cat(targets, dim=0)

return samples, targets


def build_dataset(image_set, dataset_year, args):

datasets = []
Expand Down
102 changes: 0 additions & 102 deletions engine.py

This file was deleted.

166 changes: 25 additions & 141 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,94 +9,41 @@
import torch
from torch.utils.data import DataLoader, DistributedSampler

import utils.misc as utils
import pytorch_lightning as pl

from datasets import build_dataset, get_coco_api_from_dataset
from models import yolov5s
from engine import train_one_epoch, evaluate
from datasets import build_dataset, get_coco_api_from_dataset, collate_fn
from models import YOLOLitWrapper


def get_args_parser():
parser = argparse.ArgumentParser('You only look once detector', add_help=False)

parser.add_argument('--arch', default='yolov5s',
help='model architecture')
parser.add_argument('--return-criterion', action='store_true',
help='Should be enabled in training mode')
parser.add_argument('--data-path', default='./data-bin',
parser.add_argument('--data_path', default='./data-bin',
help='dataset')
parser.add_argument('--dataset-file', default='coco',
parser.add_argument('--dataset_file', default='coco',
help='dataset')
parser.add_argument('--dataset-mode', default='instances',
parser.add_argument('--dataset_mode', default='instances',
help='dataset mode')
parser.add_argument('--dataset-year', default=['2017'], nargs='+',
parser.add_argument('--dataset_year', default=['2017'], nargs='+',
help='dataset year')
parser.add_argument('--train-set', default='train',
parser.add_argument('--train_set', default='train',
help='set of train')
parser.add_argument('--val-set', default='val',
parser.add_argument('--val_set', default='val',
help='set of val')
parser.add_argument('--model', default='ssd',
help='model')
parser.add_argument("--masks", action="store_true",
help="semantic segmentation")
parser.add_argument('--device', default='cuda',
help='device')
parser.add_argument('--score-thresh', default=0.01, type=float,
help='inference score threshold')
parser.add_argument('--image-size', default=300, type=int,
help='input size of models')
parser.add_argument('--num-classes', default=80, type=int,
help='number classes of datasets')
parser.add_argument('--batch-size', default=32, type=int,
parser.add_argument('--batch_size', default=32, type=int,
help='images per gpu, the total batch size is $NGPU x batch_size')
parser.add_argument('--epochs', default=26, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--num-workers', default=4, type=int, metavar='N',
parser.add_argument('--num_workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--lr', default=0.02, type=float,
help='initial learning rate, 0.02 is the default value for training '
'on 8 gpus and 2 images_per_gpu')
parser.add_argument('--lr-backbone', default=1e-5, type=float)
parser.add_argument('--lr-scheduler', default='cosine',
help='Scheduler for SGD, It can be chosed to multi-step or cosine')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--lr-step-size', default=8, type=int,
help='decrease lr every step-size epochs')
parser.add_argument('--lr-steps', default=[16, 70], nargs='+', type=int,
help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument('--t-max', default=200, type=int,
help='T_max value for Cosine Annealing Scheduler')
parser.add_argument('--print-freq', default=20, type=int,
parser.add_argument('--print_freq', default=20, type=int,
help='print frequency')
parser.add_argument('--output-dir', default='.',
parser.add_argument('--output_dir', default='.',
help='path where to save')
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int,
help='start epoch')
parser.add_argument('--test-only', action='store_true',
help='Only test the model')
parser.add_argument('--pretrained', action='store_true',
help='Use pre-trained models from the modelzoo')

# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
return parser


def main(args):
utils.init_distributed_mode(args)
print(args)

device = torch.device(args.device)

# Data loading code
print('Loading data')
Expand All @@ -105,12 +52,8 @@ def main(args):
base_ds = get_coco_api_from_dataset(dataset_val)

print('Creating data loaders')
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True,
Expand All @@ -119,88 +62,29 @@ def main(args):
data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
collate_fn=collate_fn,
num_workers=args.num_workers,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
collate_fn=collate_fn,
num_workers=args.num_workers,
)

print('Creating model, always set args.return_criterion be True')
args.return_criterion = True
model = yolov5s(num_classes=args.num_classes)
model.to(device)

model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu],
)
model_without_ddp = model.module

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)

if args.lr_scheduler == 'cosine':
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.t_max)
elif args.lr_scheduler == 'multi-step':
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=args.lr_steps,
gamma=args.lr_gamma,
)
else:
raise ValueError(f'scheduler {args.lr_scheduler} not supported')

output_dir = Path(args.output_dir)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

if args.test_only:
evaluate(model, data_loader_val, base_ds, device)
return

print('Start training')
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader_train, device, epoch, args.print_freq)

lr_scheduler.step()
if args.output_dir:
utils.save_on_master(
{
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args,
'epoch': epoch,
},
output_dir.joinpath(f'model_{epoch}.pth'),
)

# evaluate after every epoch
# evaluate(model, criterion, data_loader_val, device=device)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
# Load model
model = YOLOLitWrapper()
model.train()

# train
# trainer = pl.Trainer().from_argparse_args(args)
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.fit(model, data_loader_train, data_loader_val)


if __name__ == "__main__":
Expand Down
23 changes: 17 additions & 6 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
from torch import nn
from .common import Conv
from utils.activations import Hardswish

from .yolo import (yolov5_darknet_pan_s_r31 as yolov5s,
yolov5_darknet_pan_m_r31 as yolov5m,
yolov5_darknet_pan_l_r31 as yolov5l,
yolov5_darknet_pan_s_r40,
yolov5_darknet_pan_m_r40,
yolov5_darknet_pan_l_r40)
from .lightning_wrapper import YOLOLitWrapper


def yolov5s(**kwargs):
model = YOLOLitWrapper(arch="yolov5_darknet_pan_s_r31", **kwargs)
return model


def yolov5m(**kwargs):
model = YOLOLitWrapper(arch="yolov5_darknet_pan_m_r31", **kwargs)
return model


def yolov5l(**kwargs):
model = YOLOLitWrapper(arch="yolov5_darknet_pan_l_r31", **kwargs)
return model


def yolov5_onnx(pretrained=False, progress=True, num_classes=80, **kwargs):
Expand Down
Loading