From a85242f4251222731708a9a014f348e0210899b9 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 27 May 2020 15:01:40 +0800 Subject: [PATCH 01/13] add lvis dataset --- configs/_base_/datasets/lvis_instance.py | 19 + configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py | 8 + .../lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py | 19 + .../mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py | 25 + ...rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py | 30 ++ docs/install.md | 2 +- mmdet/datasets/__init__.py | 9 +- mmdet/datasets/coco.py | 22 +- mmdet/datasets/lvis.py | 429 ++++++++++++++++++ 9 files changed, 547 insertions(+), 16 deletions(-) create mode 100644 configs/_base_/datasets/lvis_instance.py create mode 100644 configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py create mode 100644 configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py create mode 100644 configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py create mode 100644 configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py create mode 100644 mmdet/datasets/lvis.py diff --git a/configs/_base_/datasets/lvis_instance.py b/configs/_base_/datasets/lvis_instance.py new file mode 100644 index 00000000000..e0f672f137e --- /dev/null +++ b/configs/_base_/datasets/lvis_instance.py @@ -0,0 +1,19 @@ +_base_ = 'coco_instance.py' +dataset_type = 'LvisDataset' +data_root = 'data/lvis/' +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_train.json', + img_prefix=data_root + 'train2017/'), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_val.json', + img_prefix=data_root + 'val2017/'), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_val.json', + img_prefix=data_root + 'val2017/')) +evaluation = dict(metric=['bbox', 'segm']) diff --git a/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py new file mode 100644 index 00000000000..78e897c2aa4 --- /dev/null +++ b/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py @@ -0,0 +1,8 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/lvis_instance.py', + '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' +] +model = dict( + roi_head=dict( + bbox_head=dict(num_classes=1230), mask_head=dict(num_classes=1230))) diff --git a/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py new file mode 100644 index 00000000000..79d7f3f16f4 --- /dev/null +++ b/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py @@ -0,0 +1,19 @@ +_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='Resize', + img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py new file mode 100644 index 00000000000..da7ddd03cb2 --- /dev/null +++ b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py @@ -0,0 +1,25 @@ +_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' +dataset_type = 'LvisDataset' +data_root = 'data/lvis/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict( + train=dict( + _delete_=True, + type='ClassBalancedDataset', + oversample_thr=1e-3, + dataset=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_train.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline))) diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py new file mode 100644 index 00000000000..51b15671fa5 --- /dev/null +++ b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py @@ -0,0 +1,30 @@ +_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' +dataset_type = 'LvisDataset' +data_root = 'data/lvis/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='Resize', + img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict( + train=dict( + _delete_=True, + type='ClassBalancedDataset', + oversample_thr=1e-3, + dataset=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_train.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline))) diff --git a/docs/install.md b/docs/install.md index 1f490a80d19..cabdab0f648 100644 --- a/docs/install.md +++ b/docs/install.md @@ -129,7 +129,7 @@ conda install -c pytorch pytorch torchvision -y git clone https://github.com/open-mmlab/mmdetection.git cd mmdetection pip install -r requirements/build.txt -pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI" +pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI" pip install -v -e . ``` diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 311b8f24e76..8dc885fec03 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -4,6 +4,7 @@ from .custom import CustomDataset from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, RepeatDataset) +from .lvis import LvisDataset from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler from .voc import VOCDataset from .wider_face import WIDERFaceDataset @@ -11,8 +12,8 @@ __all__ = [ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', - 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler', - 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', - 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES', - 'build_dataset' + 'CityscapesDataset', 'LvisDataset', 'GroupSampler', + 'DistributedGroupSampler', 'DistributedSampler', 'build_dataloader', + 'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', + 'WIDERFaceDataset', 'DATASETS', 'PIPELINES', 'build_dataset' ] diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 345de1b6a15..8ed3313e6d4 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -35,26 +35,26 @@ class CocoDataset(CustomDataset): def load_annotations(self, ann_file): self.coco = COCO(ann_file) - self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) + self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} - self.img_ids = self.coco.getImgIds() + self.img_ids = self.coco.get_img_ids() data_infos = [] for i in self.img_ids: - info = self.coco.loadImgs([i])[0] + info = self.coco.load_imgs([i])[0] info['filename'] = info['file_name'] data_infos.append(info) return data_infos def get_ann_info(self, idx): img_id = self.data_infos[idx]['id'] - ann_ids = self.coco.getAnnIds(imgIds=[img_id]) - ann_info = self.coco.loadAnns(ann_ids) + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + ann_info = self.coco.load_anns(ann_ids) return self._parse_ann_info(self.data_infos[idx], ann_info) def get_cat_ids(self, idx): img_id = self.data_infos[idx]['id'] - ann_ids = self.coco.getAnnIds(imgIds=[img_id]) - ann_info = self.coco.loadAnns(ann_ids) + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + ann_info = self.coco.load_anns(ann_ids) return [ann['category_id'] for ann in ann_info] def _filter_imgs(self, min_size=32): @@ -83,12 +83,12 @@ def get_subset_by_classes(self): ids = set() for i, class_id in enumerate(self.cat_ids): - ids |= set(self.coco.catToImgs[class_id]) + ids |= set(self.coco.cat_img_map[class_id]) self.img_ids = list(ids) data_infos = [] for i in self.img_ids: - info = self.coco.loadImgs([i])[0] + info = self.coco.load_imgs([i])[0] info['filename'] = info['file_name'] data_infos.append(info) return data_infos @@ -268,8 +268,8 @@ def results2json(self, results, outfile_prefix): def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None): gt_bboxes = [] for i in range(len(self.img_ids)): - ann_ids = self.coco.getAnnIds(imgIds=self.img_ids[i]) - ann_info = self.coco.loadAnns(ann_ids) + ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i]) + ann_info = self.coco.load_anns(ann_ids) if len(ann_info) == 0: gt_bboxes.append(np.zeros((0, 4))) continue diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py new file mode 100644 index 00000000000..2e2de618d05 --- /dev/null +++ b/mmdet/datasets/lvis.py @@ -0,0 +1,429 @@ +import itertools +import logging +import os.path as osp +import tempfile + +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable + +from .builder import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class LvisDataset(CocoDataset): + + CLASSES = ( + 'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', + 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', + 'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron', + 'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', + 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', + 'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack', + 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', + 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', + 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', + 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', + 'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop', + 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', + 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', + 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', + 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', + 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', + 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', 'binoculars', + 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', + 'birthday_cake', 'birthday_card', 'biscuit_(bread)', 'pirate_flag', + 'black_sheep', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', + 'blinker', 'blueberry', 'boar', 'gameboard', 'boat', 'bobbin', + 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', + 'book', 'book_bag', 'bookcase', 'booklet', 'bookmark', + 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', + 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', + 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin', + 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', + 'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase', + 'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie', + 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull', + 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', + 'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed', + 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife', + 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', + 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', + 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', + 'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder', + 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon', + 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', + 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', + 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', + 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', + 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', + 'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player', + 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', + 'champagne', 'chandelier', 'chap', 'checkbook', 'checkerboard', + 'cherry', 'chessboard', 'chest_of_drawers_(furniture)', + 'chicken_(animal)', 'chicken_wire', 'chickpea', 'Chihuahua', + 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', + 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', + 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', + 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', + 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', + 'clementine', 'clip', 'clipboard', 'clock', 'clock_tower', + 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', + 'coat_hanger', 'coatrack', 'cock', 'coconut', 'coffee_filter', + 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', + 'colander', 'coleslaw', 'coloring_material', 'combination_lock', + 'pacifier', 'comic_book', 'computer_keyboard', 'concrete_mixer', + 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cookie', + 'cookie_jar', 'cooking_utensil', 'cooler_(for_food)', + 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', + 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', + 'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell', + 'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon', + 'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot', + 'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship', + 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', + 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler', + 'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool', + 'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard', + 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', + 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', + 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', + 'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog', + 'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask', + 'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', + 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', + 'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper', + 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', + 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', + 'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel', + 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', + 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', + 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', + 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', + 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', + 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', + 'fireplug', 'fish', 'fish_(food)', 'fishbowl', 'fishing_boat', + 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flash', + 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', + 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', + 'food_processor', 'football_(American)', 'football_helmet', + 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', + 'freshener', 'frisbee', 'frog', 'fruit_juice', 'fruit_salad', + 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', + 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', + 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda', + 'gift_wrap', 'ginger', 'giraffe', 'cincture', + 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', + 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', + 'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater', + 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', + 'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag', + 'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush', + 'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock', + 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', + 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', + 'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil', + 'headband', 'headboard', 'headlight', 'headscarf', 'headset', + 'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater', + 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', + 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', + 'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', + 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', + 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', + 'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod', + 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean', + 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick', + 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', + 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', + 'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)', + 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', + 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', + 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', + 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', + 'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', + 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', + 'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard', + 'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion', + 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', + 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth', + 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', + 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', + 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', + 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', + 'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan', + 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', + 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', + 'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle', + 'mound_(baseball)', 'mouse_(animal_rodent)', + 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', + 'music_stool', 'musical_instrument', 'nailfile', 'nameplate', 'napkin', + 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newsstand', + 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', + 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', + 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', + 'orange_(fruit)', 'orange_juice', 'oregano', 'ostrich', 'ottoman', + 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', + 'padlock', 'paintbox', 'paintbrush', 'painting', 'pajamas', 'palette', + 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', + 'papaya', 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book', + 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', + 'parchment', 'parka', 'parking_meter', 'parrot', + 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', + 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', + 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard', + 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', + 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', + 'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood', + 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', + 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', + 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', + 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', + 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', + 'plate', 'platter', 'playing_card', 'playpen', 'pliers', + 'plow_(farm_equipment)', 'pocket_watch', 'pocketknife', + 'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt', + 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait', + 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', + 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'printer', + 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', + 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', + 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', + 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', + 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', + 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', + 'recliner', 'record_player', 'red_cabbage', 'reflector', + 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', + 'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate', + 'Rollerblade', 'rolling_pin', 'root_beer', + 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', + 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', + 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', + 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', + 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', + 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', + 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', + 'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver', + 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', + 'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker', + 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', + 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', + 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', + 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', + 'shower_curtain', 'shredder_(for_paper)', 'sieve', 'signboard', 'silo', + 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', + 'ski_pole', 'skirt', 'sled', 'sleeping_bag', 'sling_(bandage)', + 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', + 'snowmobile', 'soap', 'soccer_ball', 'sock', 'soda_fountain', + 'carbonated_water', 'sofa', 'softball', 'solar_array', 'sombrero', + 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', + 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', + 'spectacles', 'spice_rack', 'spider', 'sponge', 'spoon', 'sportswear', + 'spotlight', 'squirrel', 'stapler_(stapling_machine)', 'starfish', + 'statue_(sculpture)', 'steak_(food)', 'steak_knife', + 'steamer_(kitchen_appliance)', 'steering_wheel', 'stencil', + 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', + 'stirrup', 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', + 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', + 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', + 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', + 'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop', + 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', + 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', + 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', + 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', + 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', + 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', + 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', + 'telephone_pole', 'telephoto_lens', 'television_camera', + 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', + 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', + 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', + 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', + 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', + 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', + 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', + 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', + 'tray', 'tree_house', 'trench_coat', 'triangle_(musical_instrument)', + 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', + 'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip', + 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', + 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve', + 'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin', + 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', + 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', + 'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch', + 'water_bottle', 'water_cooler', 'water_faucet', 'water_filter', + 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', + 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', + 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', + 'whipped_cream', 'whiskey', 'whistle', 'wick', 'wig', 'wind_chime', + 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', + 'wine_bottle', 'wine_bucket', 'wineglass', 'wing_chair', + 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', + 'wrench', 'wristband', 'wristlet', 'yacht', 'yak', 'yogurt', + 'yoke_(animal_equipment)', 'zebra', 'zucchini') + + def load_annotations(self, ann_file): + try: + from lvis import LVIS + except ImportError: + raise ImportError('Please run "pip install lvis" to ' + 'install lvis first.') + self.coco = LVIS(ann_file) + assert not self.custom_classes, 'LVIS custom classes is not supported' + self.cat_ids = self.coco.get_cat_ids() + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.get_img_ids() + data_infos = [] + for i in self.img_ids: + info = self.coco.load_imgs([i])[0] + info['filename'] = info['file_name'] + data_infos.append(info) + return data_infos + + def evaluate(self, + results, + metric='bbox', + logger=None, + jsonfile_prefix=None, + classwise=False, + proposal_nums=(100, 300, 1000), + iou_thrs=np.arange(0.5, 0.96, 0.05)): + """Evaluation in LVIS protocol. + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + jsonfile_prefix (str | None): + classwise (bool): Whether to evaluating the AP for each class. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thrs (Sequence[float]): IoU threshold used for evaluating + recalls. If set to a list, the average recall of all IoUs will + also be computed. Default: 0.5. + Returns: + dict[str: float] + """ + try: + from lvis import LVISResults, LVISEval + except ImportError: + raise ImportError('Please run "pip install lvis" to ' + 'install lvis first.') + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: {} != {}'. + format(len(results), len(self))) + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError('metric {} is not supported'.format(metric)) + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = osp.join(tmp_dir.name, 'results') + else: + tmp_dir = None + result_files = self.results2json(results, jsonfile_prefix) + + eval_results = {} + # get original api + lvis_gt = self.api.api + for metric in metrics: + msg = 'Evaluating {}...'.format(metric) + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + if metric == 'proposal_fast': + ar = self.fast_eval_recall( + results, proposal_nums, iou_thrs, logger='silent') + log_msg = [] + for i, num in enumerate(proposal_nums): + eval_results['AR@{}'.format(num)] = ar[i] + log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i])) + log_msg = ''.join(log_msg) + print_log(log_msg, logger=logger) + continue + + if metric not in result_files: + raise KeyError('{} is not in results'.format(metric)) + try: + lvis_dt = LVISResults(lvis_gt, result_files[metric]) + except IndexError: + print_log( + 'The testing results of the whole dataset is empty.', + logger=logger, + level=logging.ERROR) + break + + iou_type = 'bbox' if metric == 'proposal' else metric + lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type) + lvis_eval.params.imgIds = self.img_ids + if metric == 'proposal': + lvis_eval.params.useCats = 0 + lvis_eval.params.maxDets = list(proposal_nums) + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + for k, v in lvis_eval.get_results().items(): + if k.startswith('AR'): + val = float('{:.3f}'.format(float(v))) + eval_results[k] = val + else: + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + lvis_results = lvis_eval.get_results() + if classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = lvis_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(self.cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self.coco.load_cats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (f'{nm["name"]}', f'{float(ap):0.3f}')) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + for k, v in lvis_results.items(): + if k.startswith('AP'): + key = '{}_{}'.format(metric, k) + val = float('{:.3f}'.format(float(v))) + eval_results[key] = val + ap_summary = ' '.join([ + '{}:{:.3f}'.format(k, float(v)) + for k, v in lvis_results.items() if k.startswith('AP') + ]) + eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary + lvis_eval.print_results() + if tmp_dir is None: + tmp_dir.cleanup() + return eval_results From 7a3355cfefbaab2e4febe6705640cc412781011a Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Fri, 29 May 2020 09:14:16 +0800 Subject: [PATCH 02/13] fixed eval --- mmdet/datasets/lvis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py index 2e2de618d05..c587f2cc6cd 100644 --- a/mmdet/datasets/lvis.py +++ b/mmdet/datasets/lvis.py @@ -332,7 +332,7 @@ def evaluate(self, eval_results = {} # get original api - lvis_gt = self.api.api + lvis_gt = self.coco for metric in metrics: msg = 'Evaluating {}...'.format(metric) if logger is None: From 441bb382ff1b089b6b27d698d1c037090a93899c Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Sun, 31 May 2020 09:42:42 +0800 Subject: [PATCH 03/13] fixed test cfg --- configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py index 78e897c2aa4..03d2efd59eb 100644 --- a/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py +++ b/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py @@ -6,3 +6,8 @@ model = dict( roi_head=dict( bbox_head=dict(num_classes=1230), mask_head=dict(num_classes=1230))) +test_cfg = dict( + rcnn=dict( + score_thr=0.0001, + # LVIS allows up to 300 + max_per_img=300)) From 93aa38cc0466da938752dceccf15b180cb87bbee Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Sun, 31 May 2020 17:36:46 +0800 Subject: [PATCH 04/13] add resnext config --- ...cnn_r101_fpn_sample1e-3_mstrain_2x_coco.py | 2 ++ configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py | 13 ---------- .../lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py | 19 -------------- .../mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py | 25 ------------------- ...rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py | 14 ++++++++++- ...01_32x4d_fpn_sample1e-3_mstrain_2x_coco.py | 13 ++++++++++ ...01_64x4d_fpn_sample1e-3_mstrain_2x_coco.py | 13 ++++++++++ 7 files changed, 41 insertions(+), 58 deletions(-) create mode 100644 configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py delete mode 100644 configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py delete mode 100644 configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py delete mode 100644 configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py create mode 100644 configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py create mode 100644 configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py diff --git a/configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py new file mode 100644 index 00000000000..e99ca92a08a --- /dev/null +++ b/configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py @@ -0,0 +1,2 @@ +_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py' +model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) diff --git a/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py deleted file mode 100644 index 03d2efd59eb..00000000000 --- a/configs/lvis/mask_rcnn_r50_fpn_2x_lvis.py +++ /dev/null @@ -1,13 +0,0 @@ -_base_ = [ - '../_base_/models/mask_rcnn_r50_fpn.py', - '../_base_/datasets/lvis_instance.py', - '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' -] -model = dict( - roi_head=dict( - bbox_head=dict(num_classes=1230), mask_head=dict(num_classes=1230))) -test_cfg = dict( - rcnn=dict( - score_thr=0.0001, - # LVIS allows up to 300 - max_per_img=300)) diff --git a/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py deleted file mode 100644 index 79d7f3f16f4..00000000000 --- a/configs/lvis/mask_rcnn_r50_fpn_mstrain_2x_lvis.py +++ /dev/null @@ -1,19 +0,0 @@ -_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True, with_mask=True), - dict( - type='Resize', - img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), - (1333, 768), (1333, 800)], - multiscale_mode='value', - keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), -] -data = dict(train=dict(pipeline=train_pipeline)) diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py deleted file mode 100644 index da7ddd03cb2..00000000000 --- a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_2x_lvis.py +++ /dev/null @@ -1,25 +0,0 @@ -_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' -dataset_type = 'LvisDataset' -data_root = 'data/lvis/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True, with_mask=True), - dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), -] -data = dict( - train=dict( - _delete_=True, - type='ClassBalancedDataset', - oversample_thr=1e-3, - dataset=dict( - type=dataset_type, - ann_file=data_root + 'annotations/lvis_v0.5_train.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline))) diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py index 51b15671fa5..8d84e4ed9c6 100644 --- a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py +++ b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py @@ -1,4 +1,16 @@ -_base_ = './mask_rcnn_r50_fpn_2x_lvis.py' +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/lvis_instance.py', + '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' +] +model = dict( + roi_head=dict( + bbox_head=dict(num_classes=1230), mask_head=dict(num_classes=1230))) +test_cfg = dict( + rcnn=dict( + score_thr=0.0001, + # LVIS allows up to 300 + max_per_img=300)) dataset_type = 'LvisDataset' data_root = 'data/lvis/' img_norm_cfg = dict( diff --git a/configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py new file mode 100644 index 00000000000..f79c3815a9c --- /dev/null +++ b/configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py @@ -0,0 +1,13 @@ +_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py' +model = dict( + pretrained='open-mmlab://resnext101_32x4d', + backbone=dict( + type='ResNeXt', + depth=101, + groups=32, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + style='pytorch')) diff --git a/configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py new file mode 100644 index 00000000000..cadc575cc45 --- /dev/null +++ b/configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py @@ -0,0 +1,13 @@ +_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py' +model = dict( + pretrained='open-mmlab://resnext101_64x4d', + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + style='pytorch')) From 64e69ceb4138be4650af782a4a66a6d1fdfeccbd Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Sun, 31 May 2020 21:04:46 +0800 Subject: [PATCH 05/13] update md --- configs/lvis/README.md | 24 ++++++++++++++++++++++++ docs/install.md | 5 +++-- docs/tutorials/new_dataset.md | 2 +- 3 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 configs/lvis/README.md diff --git a/configs/lvis/README.md b/configs/lvis/README.md new file mode 100644 index 00000000000..6b61da8eea0 --- /dev/null +++ b/configs/lvis/README.md @@ -0,0 +1,24 @@ +# LVIS dataset + +## Introduction +``` +@inproceedings{gupta2019lvis, + title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation}, + author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross}, + booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, + year={2019} +} +``` + +## Common Setting +* All experiments use oversample strategy [here](../../docs/tutorials/new_dataset.md#class-balanced-dataset) with oversample threshold `1e-3`. +* The size of LVIS v0.5 is half of COCO, so schedule `2x` in LVIS is roughly the same iterations as `1x` in COCO. + +## Results and models + +| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Download | +| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :-----: | :------: | +| R-50-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_20200505_003907.log.json) | +| R-101-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_2x_coco/mask_rcnn_r101_fpn_2x_coco_bbox_mAP-0.408__segm_mAP-0.366_20200505_071027-14b391c7.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_2x_coco/mask_rcnn_r101_fpn_2x_coco_20200505_071027.log.json) | +| X-101-32x4d-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco/mask_rcnn_x101_32x4d_fpn_2x_coco_bbox_mAP-0.422__segm_mAP-0.378_20200506_004702-faef898c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco/mask_rcnn_x101_32x4d_fpn_2x_coco_20200506_004702.log.json) | +| X-101-64x4d-FPN | pytorch | 2x | - | - | | | | diff --git a/docs/install.md b/docs/install.md index cabdab0f648..b4ba1a40175 100644 --- a/docs/install.md +++ b/docs/install.md @@ -53,11 +53,12 @@ cd mmdetection ``` d. Install build requirements and then install mmdetection. -(We install pycocotools via the github repo instead of pypi because the pypi version is old and not compatible with the latest numpy.) +(We install our forked version of pycocotools via the github repo instead of pypi +for better compatibility with our repo.) ```shell pip install -r requirements/build.txt -pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI" +pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI" pip install -v -e . # or "python setup.py develop" ``` diff --git a/docs/tutorials/new_dataset.md b/docs/tutorials/new_dataset.md index 4109e5de7af..87b3e853406 100644 --- a/docs/tutorials/new_dataset.md +++ b/docs/tutorials/new_dataset.md @@ -228,7 +228,7 @@ dataset_A_train = dict( ) ``` -### Repeat factor dataset +### Class balanced dataset We use `ClassBalancedDataset` as wrapper to repeat the dataset based on category frequency. The dataset to repeat needs to instantiate function `self.get_cat_ids(idx)` From 4d820a8e14f8d935c36e33013f016c8d31d9091b Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Mon, 1 Jun 2020 13:47:38 +0800 Subject: [PATCH 06/13] fixed name --- ...x_coco.py => mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis.py} | 0 ....py => mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis.py} | 0 ....py => mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename configs/lvis/{mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py => mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis.py} (100%) rename configs/lvis/{mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py => mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis.py} (100%) rename configs/lvis/{mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py => mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis.py} (100%) diff --git a/configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis.py similarity index 100% rename from configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_coco.py rename to configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis.py diff --git a/configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis.py similarity index 100% rename from configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_coco.py rename to configs/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis.py diff --git a/configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py b/configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis.py similarity index 100% rename from configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_coco.py rename to configs/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis.py From 1028fd418bf05f591b45f08ab3f08cc5dc385b1e Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 3 Jun 2020 10:00:11 +0800 Subject: [PATCH 07/13] update model urls --- configs/lvis/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/lvis/README.md b/configs/lvis/README.md index 6b61da8eea0..07fa8da720f 100644 --- a/configs/lvis/README.md +++ b/configs/lvis/README.md @@ -18,7 +18,7 @@ | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Download | | :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :-----: | :------: | -| R-50-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_20200505_003907-3e542a40.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/mask_rcnn_r50_fpn_2x_coco_20200505_003907.log.json) | -| R-101-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_2x_coco/mask_rcnn_r101_fpn_2x_coco_bbox_mAP-0.408__segm_mAP-0.366_20200505_071027-14b391c7.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r101_fpn_2x_coco/mask_rcnn_r101_fpn_2x_coco_20200505_071027.log.json) | -| X-101-32x4d-FPN | pytorch | 2x | - | - | | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco/mask_rcnn_x101_32x4d_fpn_2x_coco_bbox_mAP-0.422__segm_mAP-0.378_20200506_004702-faef898c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco/mask_rcnn_x101_32x4d_fpn_2x_coco_20200506_004702.log.json) | -| X-101-64x4d-FPN | pytorch | 2x | - | - | | | | +| R-50-FPN | pytorch | 2x | - | - | 26.1 | 25.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis-dbd06831.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis_20200531_160435.log.json) | +| R-101-FPN | pytorch | 2x | - | - | 27.1 | 27.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis-54582ee2.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis_20200601_134748.log.json) | +| X-101-32x4d-FPN | pytorch | 2x | - | - | 26.7 | 26.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis-3cf55ea2.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis_20200531_221749.log.json) | +| X-101-64x4d-FPN | pytorch | 2x | - | - | 26.4 | 26.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis-1c99a5ad.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis_20200601_194651.log.json) | From c7e55d82d82825fdb8c3704bec5bf89cb137c388 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Thu, 4 Jun 2020 01:24:29 +0800 Subject: [PATCH 08/13] minor fix --- mmdet/datasets/lvis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py index c587f2cc6cd..08a07444f18 100644 --- a/mmdet/datasets/lvis.py +++ b/mmdet/datasets/lvis.py @@ -12,7 +12,7 @@ @DATASETS.register_module() -class LvisDataset(CocoDataset): +class LVISDataset(CocoDataset): CLASSES = ( 'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', @@ -424,6 +424,6 @@ def evaluate(self, ]) eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary lvis_eval.print_results() - if tmp_dir is None: + if tmp_dir is not None: tmp_dir.cleanup() return eval_results From b8c334093eab4be63c3446b06fe4252e916b9086 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Thu, 4 Jun 2020 01:47:07 +0800 Subject: [PATCH 09/13] fixed typo --- configs/_base_/datasets/lvis_instance.py | 2 +- configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py | 2 +- mmdet/datasets/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/_base_/datasets/lvis_instance.py b/configs/_base_/datasets/lvis_instance.py index e0f672f137e..77cdd8c8373 100644 --- a/configs/_base_/datasets/lvis_instance.py +++ b/configs/_base_/datasets/lvis_instance.py @@ -1,5 +1,5 @@ _base_ = 'coco_instance.py' -dataset_type = 'LvisDataset' +dataset_type = 'LVISDataset' data_root = 'data/lvis/' data = dict( samples_per_gpu=2, diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py index 8d84e4ed9c6..180231d3071 100644 --- a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py +++ b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py @@ -11,7 +11,7 @@ score_thr=0.0001, # LVIS allows up to 300 max_per_img=300)) -dataset_type = 'LvisDataset' +dataset_type = 'LVISDataset' data_root = 'data/lvis/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 8dc885fec03..d10cf9c685a 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -4,7 +4,7 @@ from .custom import CustomDataset from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, RepeatDataset) -from .lvis import LvisDataset +from .lvis import LVISDataset from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler from .voc import VOCDataset from .wider_face import WIDERFaceDataset @@ -12,7 +12,7 @@ __all__ = [ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', - 'CityscapesDataset', 'LvisDataset', 'GroupSampler', + 'CityscapesDataset', 'LVISDataset', 'GroupSampler', 'DistributedGroupSampler', 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES', 'build_dataset' From e1ffacbf3b722004eac8bdaab7c9757e0d682802 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Thu, 4 Jun 2020 18:09:40 +0800 Subject: [PATCH 10/13] use open-mmlab lvis --- .isort.cfg | 2 +- mmdet/datasets/lvis.py | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 0fff944ee29..9a52437bc31 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = PIL,asynctest,cityscapesscripts,cv2,lvis,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py index 08a07444f18..7ef43ab067c 100644 --- a/mmdet/datasets/lvis.py +++ b/mmdet/datasets/lvis.py @@ -4,6 +4,7 @@ import tempfile import numpy as np +from lvis import LVIS, LVISEval, LVISResults from mmcv.utils import print_log from terminaltables import AsciiTable @@ -265,11 +266,6 @@ class LVISDataset(CocoDataset): 'yoke_(animal_equipment)', 'zebra', 'zucchini') def load_annotations(self, ann_file): - try: - from lvis import LVIS - except ImportError: - raise ImportError('Please run "pip install lvis" to ' - 'install lvis first.') self.coco = LVIS(ann_file) assert not self.custom_classes, 'LVIS custom classes is not supported' self.cat_ids = self.coco.get_cat_ids() @@ -307,11 +303,6 @@ def evaluate(self, Returns: dict[str: float] """ - try: - from lvis import LVISResults, LVISEval - except ImportError: - raise ImportError('Please run "pip install lvis" to ' - 'install lvis first.') assert isinstance(results, list), 'results must be a list' assert len(results) == len(self), ( 'The length of results is not equal to the dataset len: {} != {}'. From e0f5a1e5963da0e6085fcfafb3c3f42db5178944 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Thu, 4 Jun 2020 18:32:45 +0800 Subject: [PATCH 11/13] update travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 4f238b181ab..34f8995102c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,7 +36,7 @@ install: - pip install Pillow==6.2.2 # remove this line when torchvision>=0.5 - pip install torch==${TORCH} torchvision==${TORCHVISION} - pip install mmcv-nightly - - pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI" + - pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI" - pip install -r requirements.txt before_script: From c76b4369fa79c4d2ebe81486761b6dddc52b96d9 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Thu, 4 Jun 2020 19:07:38 +0800 Subject: [PATCH 12/13] fixed install --- .isort.cfg | 2 +- mmdet/datasets/lvis.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 9a52437bc31..0fff944ee29 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = PIL,asynctest,cityscapesscripts,cv2,lvis,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py index 7ef43ab067c..e8a16f2ae0d 100644 --- a/mmdet/datasets/lvis.py +++ b/mmdet/datasets/lvis.py @@ -4,7 +4,6 @@ import tempfile import numpy as np -from lvis import LVIS, LVISEval, LVISResults from mmcv.utils import print_log from terminaltables import AsciiTable @@ -266,6 +265,11 @@ class LVISDataset(CocoDataset): 'yoke_(animal_equipment)', 'zebra', 'zucchini') def load_annotations(self, ann_file): + try: + from lvis import LVIS + except ImportError: + raise ImportError('Please follow install.md to ' + 'install open-mmlab forked cocoapi first.') self.coco = LVIS(ann_file) assert not self.custom_classes, 'LVIS custom classes is not supported' self.cat_ids = self.coco.get_cat_ids() @@ -303,6 +307,11 @@ def evaluate(self, Returns: dict[str: float] """ + try: + from lvis import LVISResults, LVISEval + except ImportError: + raise ImportError('Please follow install.md to ' + 'install open-mmlab forked cocoapi first.') assert isinstance(results, list), 'results must be a list' assert len(results) == len(self), ( 'The length of results is not equal to the dataset len: {} != {}'. From d6a59db3b3f951c167aa956e895958d0c834a1d7 Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Fri, 5 Jun 2020 20:06:06 +0800 Subject: [PATCH 13/13] make class balance as default --- configs/_base_/datasets/lvis_instance.py | 9 ++++++--- .../mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py | 13 +------------ 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/configs/_base_/datasets/lvis_instance.py b/configs/_base_/datasets/lvis_instance.py index 77cdd8c8373..7c9a30a349e 100644 --- a/configs/_base_/datasets/lvis_instance.py +++ b/configs/_base_/datasets/lvis_instance.py @@ -5,9 +5,12 @@ samples_per_gpu=2, workers_per_gpu=2, train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/lvis_v0.5_train.json', - img_prefix=data_root + 'train2017/'), + type='ClassBalancedDataset', + oversample_thr=1e-3, + dataset=dict( + type=dataset_type, + ann_file=data_root + 'annotations/lvis_v0.5_train.json', + img_prefix=data_root + 'train2017/')), val=dict( type=dataset_type, ann_file=data_root + 'annotations/lvis_v0.5_val.json', diff --git a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py index 180231d3071..9a9ebf0578a 100644 --- a/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py +++ b/configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py @@ -11,8 +11,6 @@ score_thr=0.0001, # LVIS allows up to 300 max_per_img=300)) -dataset_type = 'LVISDataset' -data_root = 'data/lvis/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ @@ -30,13 +28,4 @@ dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), ] -data = dict( - train=dict( - _delete_=True, - type='ClassBalancedDataset', - oversample_thr=1e-3, - dataset=dict( - type=dataset_type, - ann_file=data_root + 'annotations/lvis_v0.5_train.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline))) +data = dict(train=dict(dataset=dict(pipeline=train_pipeline)))