From ba50badd68405915b607aef0d3e23e40df6f022f Mon Sep 17 00:00:00 2001 From: coincheung <867153576@qq.com> Date: Tue, 2 Aug 2022 11:04:06 +0000 Subject: [PATCH] add ade20k and a bit small fixes --- .gitignore | 2 + README.md | 27 ++++- configs/bisenetv1_ade20k.py | 24 ++++ configs/bisenetv2_ade20k.py | 25 ++++ datasets/ade20k/annotations | 1 + datasets/ade20k/images | 1 + lib/data/ade20k.py | 40 +++++++ lib/data/get_dataloader.py | 1 + openvino/README.md | 5 +- tools/check_dataset_info.py | 7 +- tools/evaluate.py | 21 +++- tools/gen_coco_annos.py | 42 ------- tools/gen_dataset_annos.py | 88 ++++++++++++++ tools/train.py | 223 ------------------------------------ tools/train_amp.py | 7 +- 15 files changed, 232 insertions(+), 282 deletions(-) create mode 100644 configs/bisenetv1_ade20k.py create mode 100644 configs/bisenetv2_ade20k.py create mode 120000 datasets/ade20k/annotations create mode 120000 datasets/ade20k/images create mode 100644 lib/data/ade20k.py delete mode 100644 tools/gen_coco_annos.py create mode 100644 tools/gen_dataset_annos.py delete mode 100644 tools/train.py diff --git a/.gitignore b/.gitignore index 41115ee..7a1c8d1 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,8 @@ adj.md tensorrt/build/* datasets/coco/train.txt datasets/coco/val.txt +datasets/ade20k/train.txt +datasets/ade20k/val.txt pretrained/* run.sh openvino/build/* diff --git a/README.md b/README.md index 4971277..827a675 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ mIOUs on cocostuff val2017 set: | bisenetv1 | 31.49 | 31.42 | 32.46 | 32.55 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_coco_new.pth) | | bisenetv2 | 30.49 | 30.55 | 31.81 | 31.73 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_coco.pth) | +mIOUs on ade20k val set: +| none | ss | ssc | msf | mscf | link | +|------|:--:|:---:|:---:|:----:|:----:| +| bisenetv1 | 36.15 | 36.04 | 37.27 | 36.58 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_ade20k.pth) | +| bisenetv2 | 32.53 | 32.43 | 33.23 | 31.72 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_ade20k.pth) | + Tips: 1. **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales and crop size of multi-scales evaluation can be found in [configs](./configs/). @@ -23,7 +29,9 @@ Tips: 3. The authors of bisenetv2 used cocostuff-10k, while I used cocostuff-123k(do not know how to say, just same 118k train and 5k val images as object detection). Thus the results maybe different from paper. -4. The model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 73.1-75.1. +4. The authors did not do experiments on ade20k, thus there is no official training settings, here I simply provide a "make it work" result. Maybe the results on ade20k can be boosted with better settings. + +5. The model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 on cityscapes for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 73.1-75.1. ## deploy trained models @@ -85,7 +93,7 @@ $ unzip gtFine_trainvaltest.zip 2.cocostuff -Download `train2017.zip`, `val2017.zip` and `stuffthingmaps_trainval2017.zip` split from official [website](https://cocodataset.org/#download). Then do as following: +Download `train2017.zip`, `val2017.zip` and `stuffthingmaps_trainval2017.zip` split from official [website](https://cocodataset.org/#download). Then do as following: ``` $ unzip train2017.zip $ unzip val2017.zip @@ -97,10 +105,21 @@ $ mv train2017/ /path/to/BiSeNet/datasets/coco/labels $ mv val2017/ /path/to/BiSeNet/datasets/coco/labels $ cd /path/to/BiSeNet -$ python tools/gen_coco_annos.py +$ python tools/gen_dataset_annos.py --dataset coco ``` -3.custom dataset +3. ade20k + +Download `ADEChallengeData2016.zip` and unzip it. Then we can move the uncompressed folders to `datasets/ade20k`, and generate the txt files with the script I prepared for you: +``` +$ unzip ADEChallengeData2016.zip +$ mv ADEChallengeData2016/images /path/to/BiSeNet/datasets/ade20k/ +$ mv ADEChallengeData2016/annotations /path/to/BiSeNet/datasets/ade20k/ +$ python tools/gen_dataset_annos.py --ade20k +``` + + +4. custom dataset If you want to train on your own dataset, you should generate annotation files first with the format like this: ``` diff --git a/configs/bisenetv1_ade20k.py b/configs/bisenetv1_ade20k.py new file mode 100644 index 0000000..03619e7 --- /dev/null +++ b/configs/bisenetv1_ade20k.py @@ -0,0 +1,24 @@ + +cfg = dict( + model_type='bisenetv1', + n_cats=150, + num_aux_heads=2, + lr_start=4e-2, + weight_decay=1e-4, + warmup_iters=1000, + max_iter=40000, + dataset='ADE20k', + im_root='./datasets/ade20k', + train_im_anns='./datasets/ade20k/train.txt', + val_im_anns='./datasets/ade20k/val.txt', + scales=[0.5, 2.], + cropsize=[512, 512], + eval_crop=[512, 512], + eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + eval_start_shortside=512, + ims_per_gpu=8, + eval_ims_per_gpu=1, + use_fp16=True, + use_sync_bn=True, + respth='./res', +) diff --git a/configs/bisenetv2_ade20k.py b/configs/bisenetv2_ade20k.py new file mode 100644 index 0000000..cee5812 --- /dev/null +++ b/configs/bisenetv2_ade20k.py @@ -0,0 +1,25 @@ + +## bisenetv2 +cfg = dict( + model_type='bisenetv2', + n_cats=150, + num_aux_heads=4, + lr_start=5e-3, + weight_decay=1e-4, + warmup_iters=1000, + max_iter=160000, + dataset='ADE20k', + im_root='./datasets/ade20k', + train_im_anns='./datasets/ade20k/train.txt', + val_im_anns='./datasets/ade20k/val.txt', + scales=[0.5, 2.], + cropsize=[640, 640], + eval_crop=[640, 640], + eval_start_shortside=640, + eval_scales=[0.5, 0.75, 1, 1.25, 1.5, 1.75], + ims_per_gpu=2, + eval_ims_per_gpu=1, + use_fp16=True, + use_sync_bn=True, + respth='./res', +) diff --git a/datasets/ade20k/annotations b/datasets/ade20k/annotations new file mode 120000 index 0000000..5bb1166 --- /dev/null +++ b/datasets/ade20k/annotations @@ -0,0 +1 @@ +/data/zzy/.datasets/ADEChallengeData2016/annotations/ \ No newline at end of file diff --git a/datasets/ade20k/images b/datasets/ade20k/images new file mode 120000 index 0000000..2a282b1 --- /dev/null +++ b/datasets/ade20k/images @@ -0,0 +1 @@ +/data/zzy/.datasets/ADEChallengeData2016/images/ \ No newline at end of file diff --git a/lib/data/ade20k.py b/lib/data/ade20k.py new file mode 100644 index 0000000..f7938bd --- /dev/null +++ b/lib/data/ade20k.py @@ -0,0 +1,40 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import os +import os.path as osp +import json + +import torch +from torch.utils.data import Dataset, DataLoader +import torch.distributed as dist +import cv2 +import numpy as np + +import lib.data.transform_cv2 as T +from lib.data.base_dataset import BaseDataset + +''' +proportion of each class label pixels: + [0.1692778570779725, 0.11564757275917185, 0.0952101638485813, 0.06663867349694136, 0.05213595836428788, 0.04856869977177328, 0.04285300460652723, 0.024667459730413076, 0.021459432596108052, 0.01951911788079975, 0.019458422169334556, 0.017972951662770457, 0.017102797922112795, 0.016127154995430226, 0.012743318904507446, 0.011871312183986243, 0.01169223174996906, 0.010873715499098895, 0.01119535711707017, 0.01106824347921356, 0.010700814956159628, 0.00792769980935508, 0.007320940186670243, 0.007101978087028939, 0.006652130884336369, 0.0065129268341813954, 0.005905601374046595, 0.005655465856321791, 0.00485152244584825, 0.004812313401121428, 0.004808430157907591, 0.004852065319115992, 0.0035166264746248105, 0.0034049293812196796, 0.0031501695661207163, 0.003200865983720736, 0.0027563053654176255, 0.0026019635559833536, 0.002535207367187799, 0.0024709898687369503, 0.002511264681160722, 0.002349575022340693, 0.0022952289072600395, 0.0021756144527500325, 0.0020667410351909894, + 0.002019785482875027, 0.001971430263652598, 0.0019830032929254865, 0.0019170129596070547, 0.0019400873699042965, 0.0019177214046286212, 0.001992758707175458, 0.0019064211898405371, 0.001794991169874655, 0.0017086228805355563, 0.001816450049952539, 0.0018115561530790863, 0.0017526224833158293, 0.0016693853602227783, 0.001690968246884664, 0.001672815290479542, 0.0016435338913693607, 0.0015994805524026869, 0.001415586825791652, 0.0015309535955159497, 0.0015066783881302896, 0.0015584265652761034, 0.0014294452504793305, 0.0014381224963739522, 0.0013854752714941247, 0.001299217899155161, 0.0012526667460881378, 0.0013178209535318454, 0.0012941402888239277, 0.0010893388225083507, 0.0011300189527483507, 0.0010488809855522653, 0.0009206912461167046, 0.0009957668988478528, 0.0009413381127111981, 0.0009365154048026355, 0.0009059601825045681, 0.0008541199189880419, 0.0008971791385063005, 0.0008428502465623139, 0.0008056902958152122, 0.0008098830962054097, 0.0007822564960661871, 0.0007982742428082544, 0.0007502832355158758, 0.0007779780392762995, 0.0007712568824233966, 0.0007453305503359334, 0.0006837047894907241, 0.0007144561259049724, 0.0006892632697976981, + 0.0006652429648347085, 0.0006708271650257716, 0.0006737982709217282, 0.0006266153732017621, 0.0006591083131957701, 0.0006729084088606035, 0.0006615025588342957, 0.0005978453864296776, 0.0005662905332794616, 0.0005832571600309656, 0.000558171776296493, 0.0005270943484946844, 0.0005918616094679417, 0.0005653340750898915, 0.0005626451989934503, 0.0005906185582842337, 0.0005217418569022469, 0.0005282586325333688, 0.0005198277923139954, 0.0004861910064034809, 0.0005218504774841597, 0.0005172358250665335, 0.0005247616468645153, 0.0005357304885031275, 0.0004276964118043196, 0.0004607179872730913, 0.00041193838996318965, 0.00042133234798497776, 0.000374820234027733, 0.00041071531761801536, 0.0003664373889492048, 0.00043033958917813777, 0.00037797413481418125, 0.0004129435322190717, 0.00037504252731164754, 0.0003633328611545351, 0.00039741354470741193, 0.0003815260048785467, 0.00037395769934345317, 0.00037914990094397704, 0.000360210650939554, 0.0003641708241638368, 0.0003354311501122861, 0.0003386525655944687, 0.0003593692433029189, 0.00034422115014162057, 0.00032131529694189243, 0.00031263024322531515, 0.0003252564098949305, 0.00034751306566322646, 0.0002711341955909471, 0.00022987904222809388, 0.000242549759411221, 0.0002045743505533957] +''' + + + +class ADE20k(BaseDataset): + + def __init__(self, dataroot, annpath, trans_func=None, mode='train'): + super(ADE20k, self).__init__( + dataroot, annpath, trans_func, mode) + self.n_cats = 150 + self.lb_ignore = 255 + self.lb_map = np.arange(200) - 1 # label range from 1 to 149, 0 is ignored + self.lb_map[0] = 255 + + self.to_tensor = T.ToTensor( + mean=(0.49343230, 0.46819794, 0.43106043), # ade20k, rgb + std=(0.25680755, 0.25506608, 0.27422913), + ) + diff --git a/lib/data/get_dataloader.py b/lib/data/get_dataloader.py index 6bd1a86..654de73 100644 --- a/lib/data/get_dataloader.py +++ b/lib/data/get_dataloader.py @@ -8,6 +8,7 @@ from lib.data.cityscapes_cv2 import CityScapes from lib.data.coco import CocoStuff +from lib.data.ade20k import ADE20k from lib.data.customer_dataset import CustomerDataset diff --git a/openvino/README.md b/openvino/README.md index 331ce8d..0f1eda8 100644 --- a/openvino/README.md +++ b/openvino/README.md @@ -4,7 +4,10 @@ Openvino is used to deploy model on intel cpus or "gpu inside cpu". -My cpu is Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz. +My platform: +* Ubuntu 18.04 +* Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz +* openvino_2021.4.689 ### preparation diff --git a/tools/check_dataset_info.py b/tools/check_dataset_info.py index 183ba44..37aaee5 100644 --- a/tools/check_dataset_info.py +++ b/tools/check_dataset_info.py @@ -62,15 +62,12 @@ max_lb_val = max(max_lb_val, np.max(lb)) min_lb_val = min(min_lb_val, np.min(lb)) -min_lb_val = 0 -max_lb_val = 181 -lb_minlength = 182 ## label info lb_minlength = max_lb_val+1-min_lb_val lb_hist = np.zeros(lb_minlength) for lbpth in tqdm(lbpaths): lb = cv2.imread(lbpth, 0) - lb = lb[lb != lb_ignore] + min_lb_val + lb = lb[lb != lb_ignore] - min_lb_val lb_hist += np.bincount(lb, minlength=lb_minlength) lb_missing_vals = [ind + min_lb_val @@ -113,7 +110,7 @@ print(f'we ignore label value of {args.lb_ignore} in label images') print(f'label values are within range of [{min_lb_val}, {max_lb_val}]') print(f'label values that are missing: {lb_missing_vals}') -print('ratios of each label value: ') +print('ratios of each label value(from small to big, without ignored): ') print('\t', lb_ratios) print('\n') diff --git a/tools/evaluate.py b/tools/evaluate.py index 0fe81e1..3c8e9e3 100644 --- a/tools/evaluate.py +++ b/tools/evaluate.py @@ -31,9 +31,10 @@ def get_round_size(size, divisor=32): class SizePreprocessor(object): - def __init__(self, shape=None, shortside=None): + def __init__(self, shape=None, shortside=None, longside=None): self.shape = shape self.shortside = shortside + self.longside = longside def __call__(self, imgs): new_size = None @@ -45,6 +46,13 @@ def __call__(self, imgs): if h < w: h, w = ss, int(ss / h * w) else: h, w = int(ss / w * h), ss new_size = h, w + elif not self.longside is None: # long size limit + h, w = imgs.size()[2:] + if max(h, w) > self.longside: + ls = self.longside + if h < w: h, w = int(ls / w * h), ls + else: h, w = ls, int(ls / h * w) + new_size = h, w if not new_size is None: imgs = F.interpolate(imgs, size=new_size, @@ -125,6 +133,7 @@ def __init__(self, n_classes, scales=(0.5, ), flip=False, lb_ignore=255, size_pr self.sp = size_processor self.metric_observer = Metrics(n_classes, lb_ignore) + @torch.no_grad() def __call__(self, net, dl): ## evaluate n_classes = self.n_classes @@ -305,7 +314,9 @@ def eval_model(cfg, net): size_processor = SizePreprocessor( cfg.get('eval_start_shape'), - cfg.get('eval_start_shortside')) + cfg.get('eval_start_shortside'), + cfg.get('eval_start_longside'), + ) single_scale = MscEvalV0( n_classes=cfg.n_cats, @@ -411,7 +422,9 @@ def evaluate(cfg, weight_pth): ## evaluator iou_heads, iou_content, f1_heads, f1_content = eval_model(cfg, net) + logger.info('\neval results of f1 score metric:') logger.info('\n' + tabulate(f1_content, headers=f1_heads, tablefmt='orgtbl')) + logger.info('\neval results of miou metric:') logger.info('\n' + tabulate(iou_content, headers=iou_heads, tablefmt='orgtbl')) @@ -432,11 +445,9 @@ def main(): torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl') if not osp.exists(cfg.respth): os.makedirs(cfg.respth) - setup_logger('{}-eval'.format(cfg.model_type), cfg.respth) + setup_logger(f'{cfg.model_type}-{cfg.dataset.lower()}-eval', cfg.respth) evaluate(cfg, args.weight_pth) if __name__ == "__main__": main() - # 0.70646 | 0.719953 | 0.715522 | 0.712184 - # 0.70646 | 0.719953 | 0.715522 | 0.712184 diff --git a/tools/gen_coco_annos.py b/tools/gen_coco_annos.py deleted file mode 100644 index 5974f96..0000000 --- a/tools/gen_coco_annos.py +++ /dev/null @@ -1,42 +0,0 @@ - -import os -import os.path as osp - - -def gen_coco(): - ''' - root_path: - |- images - |- train2017 - |- val2017 - |- labels - |- train2017 - |- val2017 - ''' - root_path = './datasets/coco' - save_path = './datasets/coco/' - for mode in ('train', 'val'): - im_root = osp.join(root_path, f'images/{mode}2017') - lb_root = osp.join(root_path, f'labels/{mode}2017') - - ims = os.listdir(im_root) - lbs = os.listdir(lb_root) - - print(len(ims)) - print(len(lbs)) - - im_names = [el.replace('.jpg', '') for el in ims] - lb_names = [el.replace('.png', '') for el in lbs] - common_names = list(set(im_names) & set(lb_names)) - - lines = [ - f'images/{mode}2017/{name}.jpg,labels/{mode}2017/{name}.png' - for name in common_names - ] - - with open(f'{save_path}/{mode}.txt', 'w') as fw: - fw.write('\n'.join(lines)) - - - -gen_coco() diff --git a/tools/gen_dataset_annos.py b/tools/gen_dataset_annos.py new file mode 100644 index 0000000..bcba969 --- /dev/null +++ b/tools/gen_dataset_annos.py @@ -0,0 +1,88 @@ + +import os +import os.path as osp +import argparse + + +def gen_coco(): + ''' + root_path: + |- images + |- train2017 + |- val2017 + |- labels + |- train2017 + |- val2017 + ''' + root_path = './datasets/coco' + save_path = './datasets/coco/' + for mode in ('train', 'val'): + im_root = osp.join(root_path, f'images/{mode}2017') + lb_root = osp.join(root_path, f'labels/{mode}2017') + + ims = os.listdir(im_root) + lbs = os.listdir(lb_root) + + print(len(ims)) + print(len(lbs)) + + im_names = [el.replace('.jpg', '') for el in ims] + lb_names = [el.replace('.png', '') for el in lbs] + common_names = list(set(im_names) & set(lb_names)) + + lines = [ + f'images/{mode}2017/{name}.jpg,labels/{mode}2017/{name}.png' + for name in common_names + ] + + with open(f'{save_path}/{mode}.txt', 'w') as fw: + fw.write('\n'.join(lines)) + + +def gen_ade20k(): + ''' + root_path: + |- images + |- training + |- validation + |- annotations + |- training + |- validation + ''' + root_path = './datasets/ade20k/' + save_path = './datasets/ade20k/' + folder_map = {'train': 'training', 'val': 'validation'} + for mode in ('train', 'val'): + folder = folder_map[mode] + im_root = osp.join(root_path, f'images/{folder}') + lb_root = osp.join(root_path, f'annotations/{folder}') + + ims = os.listdir(im_root) + lbs = os.listdir(lb_root) + + print(len(ims)) + print(len(lbs)) + + im_names = [el.replace('.jpg', '') for el in ims] + lb_names = [el.replace('.png', '') for el in lbs] + common_names = list(set(im_names) & set(lb_names)) + + lines = [ + f'images/{folder}/{name}.jpg,annotations/{folder}/{name}.png' + for name in common_names + ] + + with open(f'{save_path}/{mode}.txt', 'w') as fw: + fw.write('\n'.join(lines)) + + + +if __name__ == '__main__': + parse = argparse.ArgumentParser() + parse.add_argument('--dataset', dest='dataset', type=str, default='coco') + args = parse.parse_args() + + if args.dataset == 'coco': + gen_coco() + elif args.dataset == 'ade20k': + gen_ade20k() diff --git a/tools/train.py b/tools/train.py deleted file mode 100644 index 35a2103..0000000 --- a/tools/train.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/python -# -*- encoding: utf-8 -*- - -import sys -sys.path.insert(0, '.') -import os -import os.path as osp -import random -import logging -import time -import argparse -import numpy as np -from tabulate import tabulate - -import torch -import torch.nn as nn -import torch.distributed as dist -from torch.utils.data import DataLoader - -from lib.models import model_factory -from configs import set_cfg_from_file -from lib.data import get_data_loader -from tools.evaluate import eval_model -from lib.ohem_ce_loss import OhemCELoss -from lib.lr_scheduler import WarmupPolyLrScheduler -from lib.meters import TimeMeter, AvgMeter -from lib.logger import setup_logger, print_log_msg - -# apex -has_apex = True -try: - from apex import amp, parallel -except ImportError: - has_apex = False - - -## fix all random seeds -torch.manual_seed(123) -torch.cuda.manual_seed(123) -np.random.seed(123) -random.seed(123) -torch.backends.cudnn.deterministic = True -# torch.backends.cudnn.benchmark = True -# torch.multiprocessing.set_sharing_strategy('file_system') - - - - -def parse_args(): - parse = argparse.ArgumentParser() - parse.add_argument('--local_rank', dest='local_rank', type=int, default=-1,) - parse.add_argument('--port', dest='port', type=int, default=44554,) - parse.add_argument('--config', dest='config', type=str, - default='configs/bisenetv2.py',) - parse.add_argument('--finetune-from', type=str, default=None,) - return parse.parse_args() - -args = parse_args() -cfg = set_cfg_from_file(args.config) - - -def set_model(): - net = model_factory[cfg.model_type](19) - if not args.finetune_from is None: - net.load_state_dict(torch.load(args.finetune_from, map_location='cpu')) - if cfg.use_sync_bn: net = set_syncbn(net) - net.cuda() - net.train() - criteria_pre = OhemCELoss(0.7) - criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)] - return net, criteria_pre, criteria_aux - -def set_syncbn(net): - if has_apex: - net = parallel.convert_syncbn_model(net) - else: - net = nn.SyncBatchNorm.convert_sync_batchnorm(net) - return net - - -def set_optimizer(model): - if hasattr(model, 'get_params'): - wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() - params_list = [ - {'params': wd_params, }, - {'params': nowd_params, 'weight_decay': 0}, - {'params': lr_mul_wd_params, 'lr': cfg.lr_start * 10}, - {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr': cfg.lr_start * 10}, - ] - else: - wd_params, non_wd_params = [], [] - for name, param in model.named_parameters(): - if param.dim() == 1: - non_wd_params.append(param) - elif param.dim() == 2 or param.dim() == 4: - wd_params.append(param) - params_list = [ - {'params': wd_params, }, - {'params': non_wd_params, 'weight_decay': 0}, - ] - optim = torch.optim.SGD( - params_list, - lr=cfg.lr_start, - momentum=0.9, - weight_decay=cfg.weight_decay, - ) - return optim - - -def set_model_dist(net): - if has_apex: - net = parallel.DistributedDataParallel(net, delay_allreduce=True) - else: - local_rank = dist.get_rank() - net = nn.parallel.DistributedDataParallel( - net, - device_ids=[local_rank, ], - output_device=local_rank) - return net - - -def set_meters(): - time_meter = TimeMeter(cfg.max_iter) - loss_meter = AvgMeter('loss') - loss_pre_meter = AvgMeter('loss_prem') - loss_aux_meters = [AvgMeter('loss_aux{}'.format(i)) - for i in range(cfg.num_aux_heads)] - return time_meter, loss_meter, loss_pre_meter, loss_aux_meters - - -def train(): - logger = logging.getLogger() - is_dist = dist.is_initialized() - - ## dataset - dl = get_data_loader(cfg, mode='train', distributed=is_dist) - - ## model - net, criteria_pre, criteria_aux = set_model() - - ## optimizer - optim = set_optimizer(net) - - ## fp16 - if has_apex: - opt_level = 'O1' if cfg.use_fp16 else 'O0' - net, optim = amp.initialize(net, optim, opt_level=opt_level) - - ## ddp training - net = set_model_dist(net) - - ## meters - time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters() - - ## lr scheduler - lr_schdr = WarmupPolyLrScheduler(optim, power=0.9, - max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters, - warmup_ratio=0.1, warmup='exp', last_epoch=-1,) - - ## train loop - for it, (im, lb) in enumerate(dl): - im = im.cuda() - lb = lb.cuda() - - lb = torch.squeeze(lb, 1) - - optim.zero_grad() - logits, *logits_aux = net(im) - loss_pre = criteria_pre(logits, lb) - loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)] - loss = loss_pre + sum(loss_aux) - if has_apex: - with amp.scale_loss(loss, optim) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - optim.step() - torch.cuda.synchronize() - lr_schdr.step() - - time_meter.update() - loss_meter.update(loss.item()) - loss_pre_meter.update(loss_pre.item()) - _ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)] - - ## print training log message - if (it + 1) % 100 == 0: - lr = lr_schdr.get_lr() - lr = sum(lr) / len(lr) - print_log_msg( - it, cfg.max_iter, lr, time_meter, loss_meter, - loss_pre_meter, loss_aux_meters) - - ## dump the final model and evaluate the result - save_pth = osp.join(cfg.respth, 'model_final.pth') - logger.info('\nsave models to {}'.format(save_pth)) - state = net.module.state_dict() - if dist.get_rank() == 0: - torch.save(state, save_pth, _use_new_zipfile_serialization=False) - - logger.info('\nevaluating the final model') - torch.cuda.empty_cache() - heads, mious = eval_model(cfg, net.module) - logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) - - return - - -def main(): - torch.cuda.set_device(args.local_rank) - dist.init_process_group( - backend='nccl', - init_method='tcp://127.0.0.1:{}'.format(args.port), - world_size=torch.cuda.device_count(), - rank=args.local_rank - ) - if not osp.exists(cfg.respth): os.makedirs(cfg.respth) - setup_logger('{}-train'.format(cfg.model_type), cfg.respth) - train() - - -if __name__ == "__main__": - main() diff --git a/tools/train_amp.py b/tools/train_amp.py index 95be433..c661c2a 100644 --- a/tools/train_amp.py +++ b/tools/train_amp.py @@ -185,8 +185,11 @@ def train(): logger.info('\nevaluating the final model') torch.cuda.empty_cache() - heads, mious = eval_model(cfg, net.module) - logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl')) + iou_heads, iou_content, f1_heads, f1_content = eval_model(cfg, net.module) + logger.info('\neval results of f1 score metric:') + logger.info('\n' + tabulate(f1_content, headers=f1_heads, tablefmt='orgtbl')) + logger.info('\neval results of miou metric:') + logger.info('\n' + tabulate(iou_content, headers=iou_heads, tablefmt='orgtbl')) return