diff --git a/README.md b/README.md index 8abe997..a85b4b4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,95 @@ # SoCo -[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning +[NeurIPS 2021 Spotlight] [Aligning Pretraining for Detection via Object-Level Contrastive Learning](https://arxiv.org/abs/2106.02637) + +By [Fangyun Wei](https://scholar.google.com/citations?user=-ncz2s8AAAAJ&hl=en)\*, [Yue Gao](https://yuegao.me)\*, [Zhirong Wu](https://scholar.google.com/citations?user=lH4zgcIAAAAJ&hl=en), [Han Hu](https://ancientmooner.github.io), [Stephen Lin](https://www.microsoft.com/en-us/research/people/stevelin/). +> \* Equal contribution. + + +## Introduction +Image-level contrastive representation learning has proven to be highly effective as a generic model for transfer learning. +Such generality for transfer learning, however, sacrifices specificity if we are interested in a certain downstream task. +We argue that this could be sub-optimal and thus advocate a design principle which encourages alignment between the self-supervised pretext task and the downstream task. +In this paper, we follow this principle with a pretraining method specifically designed for the task of object detection. +We attain alignment in the following three aspects: +1) object-level representations are introduced via selective search bounding boxes as object proposals; +2) the pretraining network architecture incorporates the same dedicated modules used in the detection pipeline (e.g. FPN); +3) the pretraining is equipped with object detection properties such as object-level translation invariance and scale invariance. +Our method, called Selective Object COntrastive learning (SoCo), achieves state-of-the-art results for transfer performance on COCO detection using a Mask R-CNN framework. + + +### Architecture +![](figures/overview.png) + + +## Main results + +### SoCo pre-trained models +| Model | Arch | Epochs | Scripts | Download | +|:-----:|:------------:|:------:|:---------------------------------------------------:|:--------:| +| SoCo | ResNet50-C4 | 100 | [SoCo_C4_100ep](tools/SoCo_C4_100ep.sh) | | +| SoCo | ResNet50-C4 | 400 | [SoCo_C4_400ep](tools/SoCo_C4_400ep.sh) | | +| SoCo | ResNet50-FPN | 100 | [SoCo_FPN_100ep](tools/SoCo_FPN_100ep.sh) | | +| SoCo | ResNet50-FPN | 400 | [SoCo_FPN_400ep](tools/SoCo_FPN_400ep.sh) | | +| SoCo* | ResNet50-FPN | 400 | [SoCo_FPN_Star_400ep](tools/SoCo_FPN_Star_400ep.sh) | | + + +### Results on COCO with MaskRCNN **R50-FPN** +| Methods | Epoch | APbb | APbb50 | APbb75 | APmk | APmk50 | APmk75 | Detectron2 trained | +|------------|-------|-----------------|------------------------------|-----------------------------------|--------------------|-----------------------------------|-----------------------------------|--------------------| +| Scratch | - | 31.0 | 49.5 | 33.2 | 28.5 | 46.8 | 30.4 | | +| Supervised | 90 | 38.9 | 59.6 | 42.7 | 35.4 | 56.5 | 38.1 | | +| SoCo | 100 | 42.3 | 62.5 | 46.5 | 37.6 | 59.1 | 40.5 | | +| SoCo | 400 | 43.0 | 63.3 | 47.1 | 38.2 | 60.2 | 41.0 | | +| SoCo* | 400 | 43.2 | 63.5 | 47.4 | 38.4 | 60.2 | 41.4 | | + + +### Results on COCO with MaskRCNN **R50-C4** +| Methods | Epoch | APbb | APbb50 | APbb75 | APmk | APmk50 | APmk75 | Detectron2 trained | +|------------|-------|-----------------|------------------------------|-----------------------------------|--------------------|-----------------------------------|-----------------------------------|--------------------| +| Scratch | - | 26.4 | 44.0 | 27.8 | 29.3 | 46.9 | 30.8 | | +| Supervised | 90 | 38.2 | 58.2 | 41.2 | 33.3 | 54.7 | 35.2 | | +| SoCo | 100 | 40.4 | 60.4 | 43.7 | 34.9 | 56.8 | 37.0 | | +| SoCo | 400 | 40.9 | 60.9 | 44.3 | 35.3 | 57.5 | 37.3 | | + + +## Get started +### Requirements +The [Dockerfile](docker/Dockerfile) is included, please refer to it. + + +### Prepare data with Selective Search +1. Generate Selective Search proposals + ```python + python selective_search/generate_imagenet_ss_proposals.py + ``` +2. Filter out not valid proposals with filter strategy + ```python + python selective_search/filter_ss_proposals_json.py + ``` +3. Post preprocessing for no proposals images + ```python + python selective_search/filter_ss_proposals_json_post_no_prop.py + ``` + + +### Pretrain with SoCo +> Use SoCo FPN 100 epoch as example. +```shell +bash ./tools/SoCo_FPN_100ep.sh +``` + + +### Finetune detector +1. Copy the folder `detectron2_configs` to the root folder of `Detectron2` +2. Train the detectors with `Detectron2` + + +## Citation +```bib +@article{wei2021aligning, + title={Aligning Pretraining for Detection via Object-Level Contrastive Learning}, + author={Wei, Fangyun and Gao, Yue and Wu, Zhirong and Hu, Han and Lin, Stephen}, + journal={arXiv preprint arXiv:2106.02637}, + year={2021} +} +``` \ No newline at end of file diff --git a/contrast/__init__.py b/contrast/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contrast/data/__init__.py b/contrast/data/__init__.py new file mode 100644 index 0000000..ab515b5 --- /dev/null +++ b/contrast/data/__init__.py @@ -0,0 +1,138 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import os + +import numpy as np +import torch.distributed as dist +from torch.utils.data import DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler + +from .dataset import (ImageFolder, ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1, + ImageFolderImageAsymBboxAwareMultiJitter1, + ImageFolderImageAsymBboxAwareMultiJitter1Cutout, + ImageFolderImageAsymBboxCutout) +from .sampler import SubsetSlidingWindowSampler +from .transform import get_transform + + +def get_loader(aug_type, args, two_crop=False, prefix='train', return_coord=False): + transform = get_transform(args, aug_type, args.crop, args.image_size, crop1=args.crop1, + cutout_prob=args.cutout_prob, cutout_ratio=args.cutout_ratio, + image3_size=args.image3_size, image4_size=args.image4_size) + + # dataset + if args.zip: + if args.dataset == 'ImageNet': + train_ann_file = prefix + f"_{args.split_map}.txt" + train_prefix = prefix + ".zip@/" + if args.ss_props: + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.rpn_props: + train_props_file = f"rpn_props_nms_{args.nms_threshold}.json" + elif args.dataset == 'COCO': # NOTE: for coco, we use same scheme as ImageNet + prefix = 'trainunlabeled2017' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.dataset == 'Object365': + prefix = 'train' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.dataset == 'ImageNetObject365': + prefix = 'train' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.dataset == 'OpenImage': + prefix = 'train' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.dataset == 'ImageNetOpenImage': + prefix = 'train' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + elif args.dataset == 'ImageNetObject365OpenImage': + prefix = 'train' + train_ann_file = prefix + "_map.txt" + train_prefix = prefix + ".zip@/" + train_props_file = prefix + f"_{args.filter_strategy}.json" + else: + raise NotImplementedError('Dataset {} is not supported. We only support ImageNet'.format(args.dataset)) + + if aug_type == 'ImageAsymBboxCutout': + train_dataset = ImageFolderImageAsymBboxCutout(args.data_dir, train_ann_file, train_prefix, + train_props_file, image_size=args.image_size, select_strategy=args.select_strategy, + select_k=args.select_k, weight_strategy=args.weight_strategy, + jitter_ratio=args.jitter_ratio, padding_k=args.padding_k, + aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end, + max_tries=args.max_tries, + transform=transform, cache_mode=args.cache_mode, + dataset=args.dataset) + + elif aug_type == 'ImageAsymBboxAwareMultiJitter1': + train_dataset = ImageFolderImageAsymBboxAwareMultiJitter1(args.data_dir, train_ann_file, train_prefix, + train_props_file, image_size=args.image_size, select_strategy=args.select_strategy, + select_k=args.select_k, weight_strategy=args.weight_strategy, + jitter_ratio=args.jitter_ratio, padding_k=args.padding_k, + aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end, + max_tries=args.max_tries, + transform=transform, cache_mode=args.cache_mode, + dataset=args.dataset) + + elif aug_type == 'ImageAsymBboxAwareMultiJitter1Cutout': + train_dataset = ImageFolderImageAsymBboxAwareMultiJitter1Cutout(args.data_dir, train_ann_file, train_prefix, + train_props_file, image_size=args.image_size, select_strategy=args.select_strategy, + select_k=args.select_k, weight_strategy=args.weight_strategy, + jitter_ratio=args.jitter_ratio, padding_k=args.padding_k, + aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end, + max_tries=args.max_tries, + transform=transform, cache_mode=args.cache_mode, + dataset=args.dataset) + + elif aug_type == 'ImageAsymBboxAwareMulti3ResizeExtraJitter1': + train_dataset = ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1(args.data_dir, train_ann_file, train_prefix, + train_props_file, image_size=args.image_size, image3_size=args.image3_size, + image4_size=args.image4_size, + select_strategy=args.select_strategy, + select_k=args.select_k, weight_strategy=args.weight_strategy, + jitter_ratio=args.jitter_ratio, padding_k=args.padding_k, + aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end, + max_tries=args.max_tries, + transform=transform, cache_mode=args.cache_mode, + dataset=args.dataset) + elif aug_type == 'NULL': + train_dataset = ImageFolder(args.data_dir, train_ann_file, train_prefix, + transform, two_crop=two_crop, cache_mode=args.cache_mode, + dataset=args.dataset, return_coord=return_coord) + else: + raise NotImplementedError + + else: + train_folder = os.path.join(args.data_dir, prefix) + train_dataset = ImageFolder(train_folder, transform=transform, two_crop=two_crop, return_coord=return_coord) + raise NotImplementedError + + # sampler + indices = np.arange(dist.get_rank(), len(train_dataset), dist.get_world_size()) + if args.use_sliding_window_sampler: + sampler = SubsetSlidingWindowSampler(indices, + window_stride=args.window_stride // dist.get_world_size(), + window_size=args.window_size // dist.get_world_size(), + shuffle_per_epoch=args.shuffle_per_epoch) + elif args.zip and args.cache_mode == 'part': + sampler = SubsetRandomSampler(indices) + else: + sampler = DistributedSampler(train_dataset) + + # # dataloader + return DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, pin_memory=True, sampler=sampler, drop_last=True) diff --git a/contrast/data/bboxs_utils.py b/contrast/data/bboxs_utils.py new file mode 100644 index 0000000..61f4b6e --- /dev/null +++ b/contrast/data/bboxs_utils.py @@ -0,0 +1,483 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import math +import random + +import numpy as np +import torch +import torchvision.transforms as T +from mmdet.core import anchor_inside_flags + + +def overlap_image_bbox_w_id(image_bbox, bbox): + ix1, iy1, ix2, iy2 = image_bbox + bx1, by1, bx2, by2, bid = bbox + + if ix1 < bx2 and bx1 < ix2 and iy1 < by2 and by1 < iy2: + nx1, ny1, nx2, ny2 = max(ix1, bx1), max(iy1, by1), min(ix2, bx2), min(iy2, by2) + # minus the crop image x, y + nx1, ny1 = nx1 - ix1, ny1 - iy1 + nx2, ny2 = nx2 - ix1, ny2 - iy1 + return np.array([nx1, ny1, nx2, ny2, bid]) + else: + return None + + +def cal_overlap_params(params1, params2): + y11, x11, h1, w1, _, _ = params1 + y21, x21, h2, w2, _, _ = params2 + y12, x12 = y11 + h1 - 1, x11 + w1 - 1 + y22, x22 = y21 + h2 - 1, x21 + w2 - 1 + + if x11 < x22 and x21 < x12 and y11 < y22 and y21 < y12: + nx1, ny1, nx2, ny2 = max(x11, x21), max(y11, y21), min(x12, x22), min(y12, y22) + return np.array([nx1, ny1, nx2, ny2]) + else: + return None + + +def is_overlap(params_overlap, bbox): + px1, py1, px2, py2 = params_overlap + bx1, by1, bx2, by2, _ = bbox + + if px1 < bx2 and bx1 < px2 and py1 < by2 and by1 < py2: + return True + else: + return False + + +def clip_bboxs(bboxs, top, left, height, width): + clipped_bboxs_w_id_list = [] + clip_image_bbox = np.array([left, top, left + width - 1, top + height - 1]) + for cur_bbox in bboxs: + overlap_bbox = overlap_image_bbox_w_id(clip_image_bbox, cur_bbox) + # assert overlap_bbox is not None, print("clip_image_bbox", clip_image_bbox, "cur_bbox", cur_bbox) + if overlap_bbox is not None: + clipped_bboxs_w_id_list.append(overlap_bbox.reshape(1, overlap_bbox.shape[0])) + + if len(clipped_bboxs_w_id_list) > 0: + clipped_bboxs_w_id = np.concatenate(clipped_bboxs_w_id_list, axis=0) + else: + clipped_bboxs_w_id = np.array([[0, 0, width - 1, height - 1, 1]]) # just whole view + return clipped_bboxs_w_id + + +def clip_bboxs_in_jitter(bboxs, top, left, height, width): + clipped_bboxs_w_id_list = [] + clip_image_bbox = np.array([left, top, left + width - 1, top + height - 1]) + for cur_bbox in bboxs: + overlap_bbox = overlap_image_bbox_w_id(clip_image_bbox, cur_bbox) + if overlap_bbox is not None: + clipped_bboxs_w_id_list.append(overlap_bbox.reshape(1, overlap_bbox.shape[0])) + + if len(clipped_bboxs_w_id_list) > 0: + clipped_bboxs_w_id = np.concatenate(clipped_bboxs_w_id_list, axis=0) + return clipped_bboxs_w_id + else: + return None + + +def get_overlap_props(proposals, overlap_region): + if overlap_region is None: + return np.array([]) + common_props = [] + for prop in proposals: + if is_overlap(overlap_region, prop): + common_props.append(prop.reshape(1, prop.shape[0])) + if len(common_props) > 0: + common_props = np.concatenate(common_props, axis=0) + return common_props + else: + return np.array([]) + + +def resize_bboxs(clipped_bboxs_w_id, height, width, size): + bboxs_w_id = np.copy(clipped_bboxs_w_id).astype(float) # !!! + bboxs_w_id[:, 0] = bboxs_w_id[:, 0] / width * size[0] # x1 + bboxs_w_id[:, 1] = bboxs_w_id[:, 1] / height * size[1] # y1 + bboxs_w_id[:, 2] = bboxs_w_id[:, 2] / width * size[0] # x2 + bboxs_w_id[:, 3] = bboxs_w_id[:, 3] / height * size[1] # y2 + return bboxs_w_id + + +def resize_bboxs_vis(clipped_bboxs_w_id, size): + bboxs_w_id = np.copy(clipped_bboxs_w_id).astype(float) # !!! + bboxs_w_id[:, 0] = bboxs_w_id[:, 0] * size[0] # x1 + bboxs_w_id[:, 1] = bboxs_w_id[:, 1] * size[1] # y1 + bboxs_w_id[:, 2] = bboxs_w_id[:, 2] * size[0] # x2 + bboxs_w_id[:, 3] = bboxs_w_id[:, 3] * size[1] # y2 + return bboxs_w_id + + +def resize_bboxs_and_assign_labels(cropped_bboxs, height, width, size, bbox_size_range): + resized_bboxs_w_labels = torch.empty((cropped_bboxs.shape[0], cropped_bboxs.shape[1]+2), requires_grad=False) + resized_bboxs_vis = np.zeros_like(cropped_bboxs) + # 2 is used for p5, p4 one hot, determinted by bbox_size_range + min_size = bbox_size_range[0] # (32, 112, 224) + mid_size = bbox_size_range[1] + for i in range(cropped_bboxs.shape[0]): + cur_bbox = cropped_bboxs[i] + if cur_bbox[2] <= 0 or cur_bbox[3] <= 0: + # valid coordinates but we do not use it + resized_bboxs_w_labels[i] = torch.Tensor([0.0, 0.0, 1.0, 1.0, 0.0, 0.0]) + continue + x, y, w, h = cur_bbox + nx = x / width * size[0] + ny = y / height * size[1] + nw = w / width * size[0] + nh = h / height * size[1] + resized_bboxs_vis[i] = np.array([nx, ny, nw, nh]) + # NOTE: we swap x, y and w, h to align with image, and turn into 0~1 + short_side = min(nw, nh) + long_side = max(nw, nh) + if short_side >= min_size: + if long_side < mid_size: # use p4 + resized_bboxs_w_labels[i] = torch.Tensor([ny / size[1], nx / size[0], (ny + nh) / size[1], (nx + nw) / size[0], 1.0, 0.0]) + else: # use p5 + resized_bboxs_w_labels[i] = torch.Tensor([ny / size[1], nx / size[0], (ny + nh) / size[1], (nx + nw) / size[0], 0.0, 1.0]) + else: + resized_bboxs_w_labels[i] = torch.Tensor([0.0, 0.0, 1.0, 1.0, 0.0, 0.0]) + # print("resized bboxs size tensor", resized_bboxs_w_labels) + # print("resized bboxs size vis numpy", resized_bboxs_vis) + return resized_bboxs_w_labels, resized_bboxs_vis + + +def resized_crop_bboxs_assign_labels_with_ids(bboxs_w_id, top, left, height, width, size, bbox_size_range): + # bboxs_w_id, x, y, w, h, are in image size range, not in [0, 1] + cropped_bboxs_w_id = clip_bboxs(bboxs_w_id, top, left, height, width) + resized_cropped_bboxs, resized_bboxs_w_id_vis = resize_bboxs_and_assign_labels(cropped_bboxs_w_id, height, width, size, bbox_size_range) + # resize_cropped_bboxs_w_id is same shape with bboxs_w_id, we pad zeros to do data batching + return resized_cropped_bboxs, resized_bboxs_w_id_vis + + +def clip_and_resize_bboxs_w_ids(bboxs_w_id, top, left, height, width, size): + clipped_bboxs_w_id = clip_bboxs(bboxs_w_id, top, left, height, width) + resized_clipped_bboxs_w_id = resize_bboxs(clipped_bboxs_w_id, height, width, size) + return resized_clipped_bboxs_w_id + + +def get_common_bboxs_ids(bboxs1, bboxs2): + w = np.where(np.in1d(bboxs1[:, 4], bboxs2[:, 4]))[0] # intersect of ids, here bboxs ids are unique + common_bboxs_ids = bboxs1[w][:, 4] + return common_bboxs_ids + + +def jitter_bboxs(bboxs, common_bboxs_ids, jitter_ratio, pad_num, height, width): + common_indices = np.isin(bboxs[:, 4], common_bboxs_ids) + common_bboxs = bboxs[common_indices] + clipped_jittered_bboxs_list = [] + remaining_pad = pad_num + while remaining_pad > 0: + selected_bboxs = common_bboxs[np.random.choice(common_bboxs.shape[0], remaining_pad)] + jitters = np.random.uniform(low=-jitter_ratio, high=jitter_ratio, size=(remaining_pad, 4)) + + jittered_bboxs = np.copy(selected_bboxs).astype(float) + selected_bboxs_w = selected_bboxs[:, 2] - selected_bboxs[:, 0] + 1 # w = x2 - x1 + 1 + selected_bboxs_h = selected_bboxs[:, 3] - selected_bboxs[:, 1] + 1 # h = y2 - y1 + 1 + jittered_w = selected_bboxs_w + jitters[:, 2] * selected_bboxs_w # nw = w + j * w + jittered_h = selected_bboxs_h + jitters[:, 3] * selected_bboxs_h # nh = h + j * h + jittered_bboxs[:, 0] = selected_bboxs[:, 0] + jitters[:, 0] * selected_bboxs_w # nx1 = x1 + j * w + jittered_bboxs[:, 1] = selected_bboxs[:, 1] + jitters[:, 1] * selected_bboxs_h # ny1 = y1 + j * h + jittered_bboxs[:, 2] = jittered_bboxs[:, 0] + jittered_w - 1 # nx2 = nx1 + nw - 1 + jittered_bboxs[:, 3] = jittered_bboxs[:, 1] + jittered_h - 1 # ny2 = ny1 + nh - 1 + + clipped_jittered_bboxs = clip_bboxs_in_jitter(jittered_bboxs, 0, 0, height, width) + if clipped_jittered_bboxs is not None and clipped_jittered_bboxs.shape[0] > 0: + clipped_jittered_bboxs_list.append(clipped_jittered_bboxs) + remaining_pad -= clipped_jittered_bboxs.shape[0] + + padded_clipped_jittered_bboxs = np.concatenate(clipped_jittered_bboxs_list, axis=0) + + return padded_clipped_jittered_bboxs + + +def jitter_props(selected_image_props, jitter_prob, jitter_ratio): + jittered_props = [] + for prop in selected_image_props: + jitter_r = random.random() + jittered_prop = np.copy(prop).astype(float) + if jitter_r < jitter_prob: + jitter = np.random.uniform(low=-jitter_ratio, high=jitter_ratio, size=(4, )) + w = prop[2] - prop[0] + 1 + h = prop[3] - prop[1] + 1 + + jittered_w = w + jitter[2] * w # nw = w + j * w + jittered_h = h + jitter[3] * h # nh = h + j * h + + jittered_prop[0] = prop[0] + jitter[0] * w # nx1 = x1 + j * w + jittered_prop[1] = prop[1] + jitter[1] * h # ny1 = y1 + j * h + + jittered_prop[2] = jittered_prop[0] + jittered_w - 1 # nx2 = nx1 + nw - 1 + jittered_prop[3] = jittered_prop[1] + jittered_h - 1 # ny2 = ny1 + nh - 1 + + jittered_prop = jittered_prop.reshape(1, jittered_prop.shape[0]) + jittered_props.append(jittered_prop) + + if len(jittered_props) > 0: + jittered_props_np = np.concatenate(jittered_props, axis=0) + return jittered_props_np + else: + return np.array([]) + + +def random_generate_props(image_size, r=3.0, min_ratio=0.3, max_ratio=0.8, max_props=32): + props = [] + for _ in range(max_props): + sqrt_area = math.sqrt(image_size[0] * image_size[1]) + target_sqrt_area = random.uniform(min_ratio, max_ratio) * sqrt_area + target_area = target_sqrt_area * target_sqrt_area + aspect_ratio = random.uniform(1/r, r) + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + if w < 32 or h < 32 or w > image_size[0] or h > image_size[1]: + continue + center_x = random.randint(w // 2, image_size[0]- w // 2) + center_y = random.randint(h // 2, image_size[1]- h // 2) + x1 = max(center_x, 0) + x2 = min(center_x + w // 2, image_size[0]) + y1 = max(center_y, 0) + y2 = min(center_y + h // 2, image_size[1]) + prop = np.array([x1, y1, x2, y2]).reshape((1, 4)) + props.append(prop) + if len(props) > 0: + proposals = np.concatenate(props) + else: + proposals = np.array([]) + return proposals + + +def pad_bboxs_with_common(bboxs, common_bboxs_ids, jitter_ratio, pad_num, height, width): + common_indices = np.isin(bboxs[:, 4], common_bboxs_ids) + common_bboxs = bboxs[common_indices] + selected_bboxs = common_bboxs[np.random.choice(common_bboxs.shape[0], pad_num)] + return selected_bboxs + + +def get_correspondence_matrix(bboxs1, bboxs2): + # intersect of ids, here bboxs ids can be duplicate + assert bboxs1.shape[0] == bboxs2.shape[0] + L = bboxs1.shape[0] + bboxs1_ids = bboxs1[:, 4] + bboxs2_ids = bboxs2[:, 4] + bboxs1_ids = np.reshape(bboxs1_ids, (L, 1)) + bboxs2_ids = np.reshape(bboxs2_ids, (1, L)) + bboxs1_m = np.tile(bboxs1_ids, (1, L)) + bboxs2_m = np.tile(bboxs2_ids, (L, 1)) + correspondence_matrix = bboxs1_m == bboxs2_m + correspondence_matrix = correspondence_matrix.astype(float) + correspondence_matrix = torch.Tensor(correspondence_matrix) + return correspondence_matrix + + +def calculate_centerness_targets_from_bbox(bbox): + x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] + w = x2 - x1 + 1 + h = y2 - y1 + 1 + left = np.tile(np.array([i for i in range(w)]).reshape(w, 1), (1, h)) + right = np.tile(np.array([w - i - 1 for i in range(w)]).reshape(w, 1), (1, h)) + top = np.tile(np.array([i for i in range(h)]).reshape(1, h), (w, 1)) + bottom = np.tile(np.array([h - i - 1 for i in range(h)]).reshape(1, h), (w, 1)) + + left_right_min = np.minimum(left, right) + left_right_max = np.maximum(left, right) + 1e-6 + top_bottom_min = np.minimum(top, bottom) + top_bottom_max = np.maximum(top, bottom) + 1e-6 + + centerness_targets = (left_right_min / left_right_max) * (top_bottom_min / top_bottom_max) + centerness_targets = np.sqrt(centerness_targets) + return centerness_targets + + +def calculate_weight_map_bboxs(bboxs, size, weight_strategy): + if weight_strategy == 'no_weights': + weight_map = np.ones(size).astype(float) + return weight_map + + weight_map = np.zeros(size).astype(float) + for bbox in bboxs: + # print("calculate_weight_map_bboxs bbox", bbox) + x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] + w = x2 - x1 + 1 + h = y2 - y1 + 1 + if w < 1 or h < 1: + continue + if weight_strategy == 'bbox': + weight_map[x1:x2+1, y1:y2+1] += 1.0 + elif weight_strategy == 'center': + center_x = (x1 + x2) // 2 # ?? + center_y = (y1 + y2) // 2 # ?? + weight_map[center_x, center_y] += 1.0 + elif weight_strategy == 'gaussian': + centerness_targets = calculate_centerness_targets_from_bbox(bbox) + weight_map[x1:x2+1, y1:y2+1] += centerness_targets + else: + raise NotImplementedError + + return weight_map + + +def proposals_to_tensor(props): + props = props.astype(float) + props_tensor = torch.Tensor(props) + return props_tensor + + +def bboxs_to_tensor(bboxs, params): + """ x1y1x2y2, -> 0, 1 + """ + _, _, height, width, _, _ = params + bboxs = bboxs.astype(float) + bboxs_new = np.copy(bboxs) # keep ids + + bboxs_new[:, 0] = bboxs_new[:, 0] / width # x1 + bboxs_new[:, 1] = bboxs_new[:, 1] / height # y1 + bboxs_new[:, 2] = bboxs_new[:, 2] / width # x2 + bboxs_new[:, 3] = bboxs_new[:, 3] / height # y2 + + bboxs_tensor = torch.Tensor(bboxs_new) + + return bboxs_tensor + + +def bboxs_to_tensor_dynamic(bboxs, params, dynamic_params, image_size): + """ x1y1x2y2, -> 0, 1 + """ + _, _, height, width, _, _ = params + dynamic_resize = dynamic_params[3] + bboxs = bboxs.astype(float) + bboxs_new = np.copy(bboxs) # keep ids + + # raw_size -> raw_ratio -> dynamic_size -> padded_ratio + + bboxs_new[:, 0] = bboxs_new[:, 0] / width * dynamic_resize[0] / image_size[0] # x1 + bboxs_new[:, 1] = bboxs_new[:, 1] / height * dynamic_resize[1] / image_size[1] # y1 + bboxs_new[:, 2] = bboxs_new[:, 2] / width * dynamic_resize[0] / image_size[0] # x2 + bboxs_new[:, 3] = bboxs_new[:, 3] / height * dynamic_resize[1] / image_size[1] # y2 + + bboxs_tensor = torch.Tensor(bboxs_new) + + return bboxs_tensor + + +def weight_to_tensor(weight): + """ w, h -> h, w + """ + weight_tensor = torch.Tensor(weight) + + weight_tensor = weight_tensor.unsqueeze(0) # channel 1 + weight_tensor = torch.transpose(weight_tensor, 1, 2) + return weight_tensor + + +def assign_bboxs_to_feature_map(resized_bboxs, aware_range, aware_start, aware_end, not_used_value=-1): + """aware_range + """ + L = resized_bboxs.shape[0] + P = aware_end - aware_start + assert P > 0 + + bboxs_id_assigned = np.ones((P, L)) * not_used_value # the not used value is use to be different from 0 + for i, bbox in enumerate(resized_bboxs): + x1, y1, x2, y2, bid = bbox + w = x2 - x1 + 1 + h = y2 - y1 + 1 + size = math.sqrt(h * w) + for j in range(aware_start, aware_end): + if size <= aware_range[j]: + bboxs_id_assigned[j - aware_start, i] = bid + # print(f"{bid} assigned to feature {j - aware_start}") + break # assign bbox to only one feature map + + bboxs_id_assigned = bboxs_id_assigned.reshape((P * L, )) + return bboxs_id_assigned + + +def get_aware_correspondence_matrix(bboxs_id_assigned1, bboxs_id_assigned2): + PL = bboxs_id_assigned1.shape[0] + bboxs1_ids = np.reshape(bboxs_id_assigned1, (PL, 1)) + bboxs2_ids = np.reshape(bboxs_id_assigned2, (1, PL)) + bboxs1_m = np.tile(bboxs1_ids, (1, PL)) + bboxs2_m = np.tile(bboxs2_ids, (PL, 1)) + correspondence_matrix = bboxs1_m == bboxs2_m + correspondence_matrix = correspondence_matrix.astype(float) + correspondence_matrix = torch.Tensor(correspondence_matrix) + return correspondence_matrix + + +def get_aware_correspondence_matrix_torch(bboxs_id_assigned1, bboxs_id_assigned2): + PL = bboxs_id_assigned1.size(0) + bboxs1_ids = torch.reshape(bboxs_id_assigned1, (PL, 1)) + bboxs2_ids = torch.reshape(bboxs_id_assigned2, (1, PL)) + bboxs1_m = bboxs1_ids.repeat(1, PL) + bboxs2_m = bboxs2_ids.repeat(PL, 1) + correspondence_matrix = bboxs1_m == bboxs2_m + correspondence_matrix = correspondence_matrix.float() + + return correspondence_matrix + + +def visualize_image_tensor(image_tensor, visualize_name): + image_pil = T.functional.to_pil_image(image_tensor) + path = f'self_det/visualization/{visualize_name}.jpg' + image_pil.save(path) + + +def assign_gt_bboxs_to_feature_map_with_anchors(anchor_generator, assigner, gt_bboxs, view_size, img_meta, levels=4, not_used_value=-1): + # assign gt bbox based on number of anchors in the feature map + # SINGLE image version + + P = levels + L = gt_bboxs.size(0) # each image number of gts + + featmap_sizes = [] + for l in range(levels): + feat_size_0 = int(math.ceil(view_size[0]/2**(l+2))) # p2 - p5 + feat_size_1 = int(math.ceil(view_size[1]/2**(l+2))) # p2 - p5 + feat_size_tensor = torch.tensor((feat_size_0, feat_size_1)) + featmap_sizes.append(feat_size_tensor) + + # we compute multi level anchors and valid flags + multi_level_anchors = anchor_generator.grid_anchors(featmap_sizes, device='cpu') + multi_level_flags = anchor_generator.valid_flags(featmap_sizes, img_meta['pad_shape'], device='cpu') + + num_level_anchors = [anchors.size(0) for anchors in multi_level_anchors] + num_level_anchors_agg = [0] + for level_num in num_level_anchors: + num_level_anchors_agg.append(num_level_anchors_agg[-1] + level_num) + + # concat all level anchors to a single tensor + flat_anchors = torch.cat(multi_level_anchors) + flat_valid_flags = torch.cat(multi_level_flags) + + inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags, + img_meta['img_shape'][:2], + allowed_border=-1) # -1 is come from the default value of train_cfg.rpn.allowed_border + + anchors = flat_anchors[inside_flags, :] + + + gt_bboxs_coord = gt_bboxs[:, :4].clone() # must clone + cur_img_assign_result = assigner.assign(anchors, gt_bboxs_coord, None, None) + assigned_gts = cur_img_assign_result.gt_inds + + bboxs_id_assigned = torch.ones((P, L)) * not_used_value + for gt_idx in range(L): + cur_gt_assign = assigned_gts == (gt_idx+1) # the assign results index are 1-based + cur_gt_level_assign_sum = [None] * P + for level_i in range(P): + level_start = num_level_anchors_agg[level_i] + level_end = num_level_anchors_agg[level_i+1] + cur_gt_level_assign_sum[level_i] = torch.sum(cur_gt_assign[level_start:level_end]).item() + cur_gt_assign_level = cur_gt_level_assign_sum.index(max(cur_gt_level_assign_sum)) + + bboxs_id_assigned[cur_gt_assign_level, gt_idx] = gt_bboxs[gt_idx, 4] # assign bbox id + + bboxs_id_assigned = bboxs_id_assigned.reshape((P * L, )) + + return bboxs_id_assigned diff --git a/contrast/data/dataset.py b/contrast/data/dataset.py new file mode 100644 index 0000000..abebd24 --- /dev/null +++ b/contrast/data/dataset.py @@ -0,0 +1,943 @@ +import io +import json +import logging +import os +import time +from matplotlib.pyplot import winter + +import numpy as np +import torch +import torch.distributed as dist +import torch.utils.data as data +from PIL import Image + +from .bboxs_utils import (cal_overlap_params, clip_bboxs, + get_common_bboxs_ids, get_overlap_props, get_correspondence_matrix, + pad_bboxs_with_common, bboxs_to_tensor, resize_bboxs, + assign_bboxs_to_feature_map, get_aware_correspondence_matrix, jitter_props) +from .props_utils import select_props, convert_props +from .selective_search_utils import append_prop_id +from .zipreader import ZipReader, is_zip_path + + +def has_file_allowed_extension(filename, extensions): + """Checks if a file is an allowed extension. + Args: + filename (string): path to a file + Returns: + bool: True if the filename ends with a known image extension + """ + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def find_classes(dir): + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + +def make_dataset(dir, class_to_idx, extensions): + images = [] + dir = os.path.expanduser(dir) + for target in sorted(os.listdir(dir)): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if has_file_allowed_extension(fname, extensions): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + + return images + + +def make_dataset_with_ann(ann_file, img_prefix, extensions, dataset='ImageNet'): + images = [] + + with open(ann_file, "r") as f: + contents = f.readlines() + for line_str in contents: + path_contents = [c for c in line_str.split()] + im_file_name = path_contents[0] + class_index = int(path_contents[1]) + + assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions + item = (os.path.join(img_prefix, im_file_name), class_index) + + images.append(item) + + return images + + +def make_props_dataset_with_ann(ann_file, props_file, select_strategy, select_k, dataset='ImageNet', rpn_props=False, rpn_score_thres=0.5): + with open(props_file, "r") as f: + props_dict = json.load(f) + # make ImageNet or VOC dataset + with open(ann_file, "r") as f: + contents = f.readlines() + images_props = [None] * len(contents) + for i, line_str in enumerate(contents): + path_contents = [c for c in line_str.split('\t')] + im_file_name = path_contents[0] + basename = os.path.basename(im_file_name).split('.')[0] + all_props = props_dict[basename] + converted_props = convert_props(all_props) + images_props[i] = converted_props # keep all propos + + del contents + del props_dict + return images_props + + +class DatasetFolder(data.Dataset): + """A generic data loader where the samples are arranged in this way: :: + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/xxz.ext + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/asd932_.ext + Args: + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (list[string]): A list of allowed extensions. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + Attributes: + samples (list): List of (sample path, class_index) tuples + """ + + def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, + cache_mode="no", dataset='ImageNet'): + # image folder mode + if ann_file == '': + _, class_to_idx = find_classes(root) + samples = make_dataset(root, class_to_idx, extensions) + # zip mode + else: + samples = make_dataset_with_ann(os.path.join(root, ann_file), + os.path.join(root, img_prefix), + extensions, + dataset) + + if len(samples) == 0: + raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" + "Supported extensions are: " + ",".join(extensions))) + + self.root = root + self.loader = loader + self.extensions = extensions + + self.samples = samples + self.labels = [y_1k for _, y_1k in samples] + self.classes = list(set(self.labels)) + + self.transform = transform + self.target_transform = target_transform + + self.cache_mode = cache_mode + if self.cache_mode != "no": + self.init_cache() + + def init_cache(self): + assert self.cache_mode in ["part", "full"] + n_sample = len(self.samples) + global_rank = dist.get_rank() + world_size = dist.get_world_size() + + samples_bytes = [None for _ in range(n_sample)] + start_time = time.time() + for index in range(n_sample): + if index % (n_sample//10) == 0: + t = time.time() - start_time + logger = logging.getLogger(__name__) + logger.info(f'cached {index}/{n_sample} takes {t:.2f}s per block') + start_time = time.time() + path, target = self.samples[index] + if self.cache_mode == "full" or index % world_size == global_rank: + samples_bytes[index] = (ZipReader.read(path), target) + else: + samples_bytes[index] = (path, target) + self.samples = samples_bytes + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +class DatasetFolderProps(data.Dataset): + def __init__(self, root, loader, extensions, ann_file='', img_prefix='', train_props_file='', + select_strategy='', select_k=0, + transform=None, target_transform=None, + cache_mode="no", dataset='ImageNet', rpn_props=False, rpn_score_thres=0.5): + # image folder mode + if ann_file == '': + _, class_to_idx = find_classes(root) + samples = make_dataset(root, class_to_idx, extensions) + # zip mode + else: + samples = make_dataset_with_ann(os.path.join(root, ann_file), + os.path.join(root, img_prefix), + extensions, + dataset) + samples_props = make_props_dataset_with_ann(os.path.join(root, ann_file), + os.path.join(root, train_props_file), + select_strategy, select_k, + dataset=dataset, rpn_props=rpn_props, rpn_score_thres=rpn_score_thres) + + if len(samples) == 0: + raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" + "Supported extensions are: " + ",".join(extensions))) + if len(samples_props) == 0: + raise(RuntimeError("Not found the proposal files")) + + self.root = root + self.loader = loader + self.extensions = extensions + + self.samples = samples + self.samples_props = samples_props + self.labels = [y_1k for _, y_1k in samples] + self.classes = list(set(self.labels)) + + self.transform = transform + self.target_transform = target_transform + + self.cache_mode = cache_mode + if self.cache_mode != "no": + self.init_cache() + + def init_cache(self): + assert self.cache_mode in ["part", "full"] + n_sample = len(self.samples) + global_rank = dist.get_rank() + world_size = dist.get_world_size() + + samples_bytes = [None for _ in range(n_sample)] + start_time = time.time() + for index in range(n_sample): + if index % (n_sample//10) == 0: + t = time.time() - start_time + logger = logging.getLogger(__name__) + logger.info(f'cached {index}/{n_sample} takes {t:.2f}s per block') + start_time = time.time() + path, target = self.samples[index] + if self.cache_mode == "full" or index % world_size == global_rank: + samples_bytes[index] = (ZipReader.read(path), target) + else: + samples_bytes[index] = (path, target) + self.samples = samples_bytes + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + if isinstance(path, bytes): + img = Image.open(io.BytesIO(path)) + elif is_zip_path(path): + data = ZipReader.read(path) + img = Image.open(io.BytesIO(data)) + else: + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +def accimage_loader(path): + import accimage # type: ignore + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_img_loader(path): + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) + + + +class ImageFolder(DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + root/dog/xxx.png + root/dog/xxy.png + root/dog/xxz.png + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + Attributes: + imgs (list): List of (image path, class_index) tuples + """ + + def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no", dataset='ImageNet', + two_crop=False, return_coord=False): + super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode, dataset=dataset) + self.imgs = self.samples + self.two_crop = two_crop + self.return_coord = return_coord + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + if self.transform is not None: + if isinstance(self.transform, tuple) and len(self.transform) == 2: + img = self.transform[0](image) + else: + img = self.transform(image) + else: + img = image + if self.target_transform is not None: + target = self.target_transform(target) + + if self.two_crop: + if isinstance(self.transform, tuple) and len(self.transform) == 2: + img2 = self.transform[1](image) + else: + img2 = self.transform(image) + + if self.return_coord: + assert isinstance(img, tuple) + img, coord = img + + if self.two_crop: + img2, coord2 = img2 + return img, img2, coord, coord2, index, target + else: + return img, coord, index, target + else: + if isinstance(img, tuple): + img, coord = img + + if self.two_crop: + if isinstance(img2, tuple): + img2, coord2 = img2 + return img, img2, index, target + else: + return img, index, target + + +class ImageFolderImageAsymBboxCutout(DatasetFolderProps): + def __init__(self, root, ann_file='', img_prefix='', train_props_file='', + image_size=0, select_strategy='', select_k=0, weight_strategy='', + jitter_ratio=0.0, padding_k='', aware_range=[], aware_start=0, aware_end=4, + max_tries=0, + transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no", dataset='ImageNet'): + super(ImageFolderImageAsymBboxCutout, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + train_props_file=train_props_file, + select_strategy=select_strategy, select_k=select_k, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode, dataset=dataset) + self.imgs = self.samples + self.props = self.samples_props + self.select_strategy = select_strategy + self.select_k = select_k + self.weight_strategy = weight_strategy + self.jitter_ratio = jitter_ratio + self.padding_k = padding_k + self.view_size = (image_size, image_size) + self.debug = False + self.max_tries = max_tries + self.least_common = max(self.padding_k // 2, 1) + self.aware_range = aware_range + assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range' + self.aware_start = aware_start # starting from 0 means use p2 + self.aware_end = aware_end # end, if use P6 might be 5 + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + image_size = image.size + image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1 + if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image + image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]]) + + image_proposals_w_id = append_prop_id(image_proposals) # start from 1 + + assert len(self.transform) == 6 + # transform = (transform_whole_img, transform_img, transform_flip, transform_post_1, transform_post_2, transform_cutout) + + tries = 0 + least_common = self.least_common + + while tries < self.max_tries: + img, params = self.transform[0](image) # whole image resize + img2, params2 = self.transform[1](image) # random crop resize + + params_overlap = cal_overlap_params(params, params2) + overlap_props = get_overlap_props(image_proposals_w_id, params_overlap) + selected_image_props = select_props(overlap_props, self.select_strategy, self.select_k) # check paras are + + # TODO: ensure clipped bboxs width and height are greater than 32 + if selected_image_props.shape[0] >= least_common: # ok + break + least_common = max(least_common // 2, 1) + tries += 1 + + bboxs = clip_bboxs(selected_image_props, params[0], params[1], params[2], params[3]) + bboxs2 = clip_bboxs(selected_image_props, params2[0], params2[1], params2[2], params2[3]) + common_bboxs_ids = get_common_bboxs_ids(bboxs, bboxs2) + + + pad1 = self.padding_k - bboxs.shape[0] + if pad1 > 0: + # pad_bboxs = jitter_bboxs(bboxs, common_bboxs_ids, self.jitter_ratio, pad1, params[2], params[3]) + pad_bboxs = pad_bboxs_with_common(bboxs, common_bboxs_ids, self.jitter_ratio, pad1, params[2], params[3]) + bboxs = np.concatenate([bboxs, pad_bboxs], axis=0) + pad2 = self.padding_k - bboxs2.shape[0] + if pad2 > 0: + # pad_bboxs2 = jitter_bboxs(bboxs2, common_bboxs_ids, self.jitter_ratio, pad2, params2[2], params2[3]) + pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids, self.jitter_ratio, pad2, params2[2], params2[3]) + bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0) + correspondence = get_correspondence_matrix(bboxs, bboxs2) + + resized_bboxs = resize_bboxs(bboxs, params[2], params[3], self.view_size) + resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size) + resized_bboxs = resized_bboxs.astype(int) + resized_bboxs2 = resized_bboxs2.astype(int) + + bboxs = bboxs_to_tensor(bboxs, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1) + + + img, bboxs, params = self.transform[2](img, bboxs, params) # flip + img2, bboxs2, params2 = self.transform[2](img2, bboxs2, params2) + + img1_1x = self.transform[3](img) # color + img2_1x = self.transform[4](img2) # color + + img2_1x_cut = self.transform[5](img2_1x, resized_bboxs2) # cutout + + return img1_1x, img2_1x_cut, bboxs, bboxs2, correspondence, index, target + + +class ImageFolderImageAsymBboxAwareMultiJitter1(DatasetFolderProps): + def __init__(self, root, ann_file='', img_prefix='', train_props_file='', + image_size=0, select_strategy='', select_k=0, weight_strategy='', + jitter_prob=0.0, jitter_ratio=0.0, + padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0, + transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no", dataset='ImageNet'): + super(ImageFolderImageAsymBboxAwareMultiJitter1, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + train_props_file=train_props_file, + select_strategy=select_strategy, select_k=select_k, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode, dataset=dataset) + self.imgs = self.samples + self.props = self.samples_props + self.select_strategy = select_strategy + self.select_k = select_k + self.weight_strategy = weight_strategy + self.jitter_prob = jitter_prob + self.jitter_ratio = jitter_ratio + self.padding_k = padding_k + self.view_size = (image_size, image_size) + self.view_size_3 = (image_size//2, image_size//2) + self.debug = False + self.max_tries = max_tries + self.least_common = max(self.padding_k // 2, 1) + self.aware_range = aware_range + assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range' + self.aware_start = aware_start # starting from 0 means use p2 + self.aware_end = aware_end # end, if use P6 might be 5 + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + image_size = image.size + image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1 + if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image + image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]]) + + image_proposals_w_id = append_prop_id(image_proposals) # start from 1 + + assert len(self.transform) == 7 + # transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2) + + tries = 0 + least_common = self.least_common + + while tries < self.max_tries: + img, params = self.transform[0](image) # whole image resize + img2, params2 = self.transform[1](image) # random crop resize + img3, params3 = self.transform[2](image) # small random crop resize + + params_overlap12 = cal_overlap_params(params, params2) + overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12) + selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are + + params_overlap13 = cal_overlap_params(params, params3) + overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13) + selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are + + # TODO: ensure clipped bboxs width and height are greater than 32 + if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok + break + least_common = max(least_common // 2, 1) + tries += 1 + + + jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio) + jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio) + + bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3]) + bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3]) + bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3]) + bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3]) + common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2) + common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3) + + + pad1_12 = self.padding_k - bboxs1_12.shape[0] + if pad1_12 > 0: + pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3]) + bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0) + + pad1_13 = self.padding_k - bboxs1_13.shape[0] + if pad1_13 > 0: + pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3]) + bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0) + + pad2 = self.padding_k - bboxs2.shape[0] + if pad2 > 0: + pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3]) + bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0) + + pad3 = self.padding_k - bboxs3.shape[0] + if pad3 > 0: + pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3]) + bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0) + + + resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size) + resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size) + resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size) + resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3) + resized_bboxs1_12 = resized_bboxs1_12.astype(int) + resized_bboxs1_13 = resized_bboxs1_13.astype(int) + resized_bboxs2 = resized_bboxs2.astype(int) + resized_bboxs3 = resized_bboxs3.astype(int) + + bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2) + bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3) + + + aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign) + aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign) + + bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1) + + img, bboxs1_12, bboxs1_13, params = self.transform[3](img, bboxs1_12, bboxs1_13, params) # flip + img2, bboxs2, params2 = self.transform[4](img2, bboxs2, params2) # flip + img3, bboxs3, params3 = self.transform[4](img3, bboxs3, params3) # flip + + img1 = self.transform[5](img) # color + img2 = self.transform[6](img2) # color + img3 = self.transform[6](img3) # color + + return img1, img2, img3, bboxs1_12, bboxs1_13, bboxs2, bboxs3, aware_corres_12, aware_corres_13, index, target + + +class ImageFolderImageAsymBboxAwareMultiJitter1Cutout(DatasetFolderProps): + def __init__(self, root, ann_file='', img_prefix='', train_props_file='', + image_size=0, select_strategy='', select_k=0, weight_strategy='', + jitter_prob=0.0, jitter_ratio=0.0, + padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0, + transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no", dataset='ImageNet'): + super(ImageFolderImageAsymBboxAwareMultiJitter1Cutout, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + train_props_file=train_props_file, + select_strategy=select_strategy, select_k=select_k, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode, dataset=dataset) + self.imgs = self.samples + self.props = self.samples_props + self.select_strategy = select_strategy + self.select_k = select_k + self.weight_strategy = weight_strategy + self.jitter_prob = jitter_prob + self.jitter_ratio = jitter_ratio + self.padding_k = padding_k + self.view_size = (image_size, image_size) + self.view_size_3 = (image_size//2, image_size//2) + self.debug = False + self.max_tries = max_tries + self.least_common = max(self.padding_k // 2, 1) + self.aware_range = aware_range + assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range' + self.aware_start = aware_start # starting from 0 means use p2 + self.aware_end = aware_end # end, if use P6 might be 5 + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + image_size = image.size + image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1 + if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image + image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]]) + + image_proposals_w_id = append_prop_id(image_proposals) # start from 1 + + assert len(self.transform) == 8 + # transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2, transform_cutout) + + tries = 0 + least_common = self.least_common + + while tries < self.max_tries: + img, params = self.transform[0](image) # whole image resize + img2, params2 = self.transform[1](image) # random crop resize + img3, params3 = self.transform[2](image) # small random crop resize + + params_overlap12 = cal_overlap_params(params, params2) + overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12) + selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are + + params_overlap13 = cal_overlap_params(params, params3) + overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13) + selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are + + # TODO: ensure clipped bboxs width and height are greater than 32 + if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok + break + least_common = max(least_common // 2, 1) + tries += 1 + + + jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio) + jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio) + + bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3]) + bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3]) + bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3]) + bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3]) + common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2) + common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3) + + + pad1_12 = self.padding_k - bboxs1_12.shape[0] + if pad1_12 > 0: + pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3]) + bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0) + + pad1_13 = self.padding_k - bboxs1_13.shape[0] + if pad1_13 > 0: + pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3]) + bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0) + + pad2 = self.padding_k - bboxs2.shape[0] + if pad2 > 0: + pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3]) + bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0) + + pad3 = self.padding_k - bboxs3.shape[0] + if pad3 > 0: + pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3]) + bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0) + + + resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size) + resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size) + resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size) + resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3) + resized_bboxs1_12 = resized_bboxs1_12.astype(int) + resized_bboxs1_13 = resized_bboxs1_13.astype(int) + resized_bboxs2 = resized_bboxs2.astype(int) + resized_bboxs3 = resized_bboxs3.astype(int) + + bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2) + bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3) + + + aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign) + aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign) + + bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1) + + img, bboxs1_12, bboxs1_13, params = self.transform[3](img, bboxs1_12, bboxs1_13, params) # flip + img2, bboxs2, params2 = self.transform[4](img2, bboxs2, params2) # flip + img3, bboxs3, params3 = self.transform[4](img3, bboxs3, params3) # flip + + img1 = self.transform[5](img) # color + img2 = self.transform[6](img2) # color + img3 = self.transform[6](img3) # color + + img2_cutout = self.transform[7](img2, resized_bboxs2) + img3_cutout = self.transform[7](img3, resized_bboxs3) + + return img1, img2_cutout, img3_cutout, bboxs1_12, bboxs1_13, bboxs2, bboxs3, aware_corres_12, aware_corres_13, index, target + + +class ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1(DatasetFolderProps): + def __init__(self, root, ann_file='', img_prefix='', train_props_file='', + image_size=0, image3_size=0, image4_size=0, select_strategy='', select_k=0, weight_strategy='', + jitter_prob=0.0, jitter_ratio=0.0, + padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0, + transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no", dataset='ImageNet'): + super(ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + train_props_file=train_props_file, + select_strategy=select_strategy, select_k=select_k, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode, dataset=dataset) + self.imgs = self.samples + self.props = self.samples_props + self.select_strategy = select_strategy + self.select_k = select_k + self.weight_strategy = weight_strategy + self.jitter_prob = jitter_prob + self.jitter_ratio = jitter_ratio + self.padding_k = padding_k + self.view_size = (image_size, image_size) + self.view_size_3 = (image3_size, image3_size) + self.view_size_4 = (image4_size, image4_size) + assert image3_size > 0 + assert image4_size > 0 + self.debug = False + self.max_tries = max_tries + self.least_common = max(self.padding_k // 2, 1) + self.aware_range = aware_range + assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range' + self.aware_start = aware_start # starting from 0 means use p2 + self.aware_end = aware_end # end, if use P6 might be 5 + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + image_size = image.size + image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1 + if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image + image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]]) + + image_proposals_w_id = append_prop_id(image_proposals) # start from 1 + + assert len(self.transform) == 8 + # transform = (transform_whole_img, transform_img, transform_img_small, transform_img_resize, transform_flip_flip, transform_flip, transform_post_1, transform_post_2) + + tries = 0 + least_common = self.least_common + + while tries < self.max_tries: + img, params = self.transform[0](image) # whole image resize + img2, params2 = self.transform[1](image) # random crop resize + img3, params3 = self.transform[2](image) # small random crop resize + + params_overlap12 = cal_overlap_params(params, params2) + overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12) + selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are + + params_overlap13 = cal_overlap_params(params, params3) + overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13) + selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are + + + # TODO: ensure clipped bboxs width and height are greater than 32 + if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok + break + least_common = max(least_common // 2, 1) + tries += 1 + + img4 = self.transform[3](img2) # image4 are resized from image 2 + + jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio) + jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio) + + bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3]) + bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3]) + bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3]) + bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3]) + common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2) + common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3) + + + pad1_12 = self.padding_k - bboxs1_12.shape[0] + if pad1_12 > 0: + pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3]) + bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0) + + pad1_13 = self.padding_k - bboxs1_13.shape[0] + if pad1_13 > 0: + pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3]) + bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0) + + pad2 = self.padding_k - bboxs2.shape[0] + if pad2 > 0: + pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3]) + bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0) + + pad3 = self.padding_k - bboxs3.shape[0] + if pad3 > 0: + pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3]) + bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0) + + bboxs1_14 = np.copy(bboxs1_12) + + bboxs4 = np.copy(bboxs2) + params4 = np.copy(params2) + + resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size) + resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size) + resized_bboxs1_14 = resize_bboxs(bboxs1_14, params[2], params[3], self.view_size) + resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size) + resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3) + resized_bboxs4 = resize_bboxs(bboxs4, params4[2], params4[3], self.view_size_4) + resized_bboxs1_12 = resized_bboxs1_12.astype(int) + resized_bboxs1_13 = resized_bboxs1_13.astype(int) + resized_bboxs1_14 = resized_bboxs1_14.astype(int) + resized_bboxs2 = resized_bboxs2.astype(int) + resized_bboxs3 = resized_bboxs3.astype(int) + resized_bboxs4 = resized_bboxs4.astype(int) + + bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs1_14_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_14, self.aware_range, self.aware_start, self.aware_end, -1) + bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2) + bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3) + bboxs4_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs4, self.aware_range, self.aware_start, self.aware_end, -4) + + + aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign) + aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign) + aware_corres_14 = get_aware_correspondence_matrix(bboxs1_14_with_feature_assign, bboxs4_with_feature_assign) + + bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs1_14 = bboxs_to_tensor(bboxs1_14, params) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1) + bboxs4 = bboxs_to_tensor(bboxs4, params4) # x1y1x2y2 -> x1y1x2y2 (0, 1) + + img, bboxs1_12, bboxs1_13, bboxs1_14, params = self.transform[4](img, bboxs1_12, bboxs1_13, bboxs1_14, params) # flip + img2, bboxs2, params2 = self.transform[5](img2, bboxs2, params2) # flip + img3, bboxs3, params3 = self.transform[5](img3, bboxs3, params3) # flip + img4, bboxs4, params4 = self.transform[5](img4, bboxs4, params4) # flip + + img1 = self.transform[6](img) # color + img2 = self.transform[7](img2) # color + img3 = self.transform[7](img3) # color + img4 = self.transform[7](img4) # color + + return img1, img2, img3, img4, bboxs1_12, bboxs1_13, bboxs1_14, bboxs2, bboxs3, bboxs4, aware_corres_12, aware_corres_13, aware_corres_14, index, target diff --git a/contrast/data/props_utils.py b/contrast/data/props_utils.py new file mode 100644 index 0000000..6b58044 --- /dev/null +++ b/contrast/data/props_utils.py @@ -0,0 +1,43 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np + + +def select_props(all_props, select_strategy, select_k): + # all_props: numpy array + if select_k > all_props.shape[0]: # if we do not have k proposals, we just select all + select_strategy = 'none' + selected_proposals = None + if select_strategy == 'none': + selected_proposals = all_props + elif select_strategy == 'random': + selected_idx = np.random.choice(all_props.shape[0], size=select_k, replace=False) + selected_proposals = all_props[selected_idx] + elif select_strategy == 'top': + selected_proposals = all_props[:select_k] + elif select_strategy == 'area': + areas = np.zeros((all_props.shape[0], )) + for i in range(all_props.shape[0]): + area = all_props[i][2] * all_props[i][3] + areas[i] = area + areas_index = areas.argsort() + selected_proposals = all_props[areas_index[::-1]][:select_k] + else: + raise NotImplementedError + return selected_proposals + + +def convert_props(all_props): + all_props_np = np.array(all_props) + if all_props_np.shape[0] == 0: + return all_props_np + all_props_np = all_props_np[:, :4] + all_props_np[:, 2] = all_props_np[:, 0] + all_props_np[:, 2] - 1 # x2 = x1 + w - 1 + all_props_np[:, 3] = all_props_np[:, 1] + all_props_np[:, 3] - 1 # y2 = y1 + h - 1 + return all_props_np diff --git a/contrast/data/rand_augment.py b/contrast/data/rand_augment.py new file mode 100644 index 0000000..104809e --- /dev/null +++ b/contrast/data/rand_augment.py @@ -0,0 +1,448 @@ +""" AutoAugment and RandAugment +Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py +Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 +Hacked together by Ross Wightman +""" +import random +import math +import re +from PIL import Image, ImageOps, ImageEnhance +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def identity(img, **__): + return img + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return level, + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return (level / _MAX_LEVEL) * 1.8 + 0.1, + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return level, + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return level, + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return level, + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return 4 - int((level / _MAX_LEVEL) * 4), + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return int((level / _MAX_LEVEL) * 4), + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Identity': None, + 'Rotate': _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Identity': identity, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + # 'Cutout' # FIXME I implement this as random erasing separately +] + +_RAND_TRANSFORMS_CMC = [ + 'AutoContrast', + 'Identity', + 'Rotate', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + # 'Cutout' # FIXME I implement this as random erasing separately +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + """rand augment ops for RGB images""" + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None): + """rand augment ops for CMC images (removing color ops)""" + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS_CMC + return [AutoAugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams, use_cmc=False): + """ + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :param use_cmc: Flag indicates removing augmentation for coloring ops. + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + if use_cmc: + ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams) + else: + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/contrast/data/sampler.py b/contrast/data/sampler.py new file mode 100644 index 0000000..5290225 --- /dev/null +++ b/contrast/data/sampler.py @@ -0,0 +1,53 @@ +import numpy as np +from torch.utils.data import Sampler + + +class SubsetSlidingWindowSampler(Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (sequence): a sequence of indices + """ + + def __init__(self, indices, window_stride, window_size, shuffle_per_epoch=False): + self.window_stride = window_stride + self.window_size = window_size + self.shuffle_per_epoch = shuffle_per_epoch + self.indices = indices + np.random.shuffle(self.indices) + self.start_index = 0 + + def __iter__(self): + # optionally shuffle all indices per epoch + if self.shuffle_per_epoch and self.start_index + self.window_size > len(self): + np.random.shuffle(self.indices) + + # get indices of sampler in the current window + indices = np.mod(np.arange(self.window_size, dtype=np.int) + self.start_index, len(self)) + window_indices = self.indices[indices] + + # shuffle the current window + np.random.shuffle(window_indices) + + # move start index to next window + self.start_index = (self.start_index + self.window_stride) % len(self) + + return iter(window_indices.tolist()) + + def __len__(self): + return len(self.indices) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {"start_index": self.start_index} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) diff --git a/contrast/data/selective_search_utils.py b/contrast/data/selective_search_utils.py new file mode 100644 index 0000000..898f7c1 --- /dev/null +++ b/contrast/data/selective_search_utils.py @@ -0,0 +1,17 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np + + +def append_prop_id(image_proposals): + prop_ids = np.arange(1, image_proposals.shape[0]+1) # we start from 1, refer to clip_bboxs, which set + prop_ids = np.reshape(prop_ids, (image_proposals.shape[0], 1)) + image_proposals_with_prop_id = np.concatenate((image_proposals, prop_ids), axis=1) + + return image_proposals_with_prop_id diff --git a/contrast/data/transform.py b/contrast/data/transform.py new file mode 100644 index 0000000..4f28477 --- /dev/null +++ b/contrast/data/transform.py @@ -0,0 +1,146 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np +from PIL import ImageFilter, ImageOps +from torchvision import transforms + +from . import transform_ops + + +class GaussianBlur(object): + """Gaussian Blur version 2""" + + def __call__(self, x): + sigma = np.random.uniform(0.1, 2.0) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +def get_transform(args, aug_type, crop, image_size=224, crop1=0.9, cutout_prob=0.5, cutout_ratio=(0.1, 0.2), + image3_size=224, image4_size=224): + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + if aug_type == 'ImageAsymBboxCutout': + transform_whole_img = transform_ops.WholeImageResizedParams(image_size) + transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.)) + transform_flip = transform_ops.RandomHorizontalFlipImageBbox() + + transform_post_1 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=1.0), + transforms.ToTensor(), + normalize, + ]) + transform_post_2 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=0.1), + transforms.RandomApply([ImageOps.solarize], p=0.2), + transforms.ToTensor(), + normalize, + ]) + transform_cutout = transform_ops.RandomCutoutInBbox(image_size, cutout_prob=cutout_prob, cutout_ratio=cutout_ratio) + transform = (transform_whole_img, transform_img, transform_flip, transform_post_1, transform_post_2, transform_cutout) + + + elif aug_type == 'ImageAsymBboxAwareMultiJitter1': + transform_whole_img = transform_ops.WholeImageResizedParams(image_size) + transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.)) + transform_img_small = transform_ops.RandomResizedCropParams(image_size//2, scale=(crop, 1.)) + transform_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBbox() + transform_flip = transform_ops.RandomHorizontalFlipImageBbox() + transform_post_1 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=1.0), + transforms.ToTensor(), + normalize, + ]) + transform_post_2 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=0.1), + transforms.RandomApply([ImageOps.solarize], p=0.2), + transforms.ToTensor(), + normalize, + ]) + transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2) + + + elif aug_type == 'ImageAsymBboxAwareMultiJitter1Cutout': + transform_whole_img = transform_ops.WholeImageResizedParams(image_size) + transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.)) + transform_img_small = transform_ops.RandomResizedCropParams(image_size//2, scale=(crop, 1.)) + transform_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBbox() + transform_flip = transform_ops.RandomHorizontalFlipImageBbox() + transform_post_1 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=1.0), + transforms.ToTensor(), + normalize, + ]) + transform_post_2 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=0.1), + transforms.RandomApply([ImageOps.solarize], p=0.2), + transforms.ToTensor(), + normalize, + ]) + transform_cutout = transform_ops.RandomCutoutInBbox(image_size, cutout_prob=cutout_prob, cutout_ratio=cutout_ratio) + transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2, transform_cutout) + + + elif aug_type == 'ImageAsymBboxAwareMulti3ResizeExtraJitter1': + transform_whole_img = transform_ops.WholeImageResizedParams(image_size) + transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.)) + transform_img_small = transform_ops.RandomResizedCropParams(image3_size, scale=(crop, 1.)) + transform_img_resize = transforms.Resize(image4_size) + transform_flip_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBboxBbox() + transform_flip = transform_ops.RandomHorizontalFlipImageBbox() + transform_post_1 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=1.0), + transforms.ToTensor(), + normalize, + ]) + transform_post_2 = transform_ops.ComposeImage([ + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur()], p=0.1), + transforms.RandomApply([ImageOps.solarize], p=0.2), + transforms.ToTensor(), + normalize, + ]) + transform = (transform_whole_img, transform_img, transform_img_small, transform_img_resize, transform_flip_flip_flip, transform_flip, transform_post_1, transform_post_2) + + + elif aug_type == 'NULL': # used in linear evaluation + transform = transform_ops.Compose([ + transform_ops.RandomResizedCropCoord(image_size, scale=(crop, 1.)), + transform_ops.RandomHorizontalFlipCoord(), + transforms.ToTensor(), + normalize, + ]) + + elif aug_type == 'val': # used in validate + transform = transforms.Compose([ + transforms.Resize(image_size + 32), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + normalize + ]) + else: + supported = '[ImageAsymBboxCutout, ImageAsymBboxAwareMultiJitter1, ImageAsymBboxAwareMultiJitter1Cutout, ImageAsymBboxAwareMulti3ResizeExtraJitter1, NULL]' + raise NotImplementedError(f'aug_type "{aug_type}" not supported. Should in {supported}') + + return transform diff --git a/contrast/data/transform_ops.py b/contrast/data/transform_ops.py new file mode 100644 index 0000000..ed06716 --- /dev/null +++ b/contrast/data/transform_ops.py @@ -0,0 +1,566 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import math +import random +import warnings + +import numpy as np +import torch +import torch.nn.functional as TF +from PIL import Image +from torchvision.transforms import functional as F + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +def _get_image_size(img): + if F._is_pil_image(img): + return img.size + elif isinstance(img, torch.Tensor) and img.dim() > 2: + return img.shape[-2:][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + +def crop_tensor(img_tensor, top, left, height, width): + return img_tensor[:, :, left:left+width, top:top+height] + + +def resize_tensor(img_tensor, size, interpolation='bilinear'): + return TF.interpolate(img_tensor, size, mode=interpolation, align_corners=False) + + +def resized_crop_tensor(img_tensor, top, left, height, width, size, interpolation='bilinear'): + """ + tensor version of F.resized_crop + """ + assert isinstance(img_tensor, torch.Tensor) + img_tensor = crop_tensor(img_tensor, top, left, height, width) + img_tensor = resize_tensor(img_tensor, size, interpolation) + return img_tensor + + + +class ComposeImage(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + coord = None + for t in self.transforms: + if 'RandomResizedCropCoord' in t.__class__.__name__: + img, coord = t(img) + elif 'FlipCoord' in t.__class__.__name__: + img, coord = t(img, coord) + else: + img = t(img) + return img, coord + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class WholeImageResizedParams(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + + self.interpolation = interpolation + + @staticmethod + def get_params(img): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + width, height = _get_image_size(img) + + return 0, 0, height, width, height, width + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + params = self.get_params(img) + params_np = np.array(params) + i, j, h, w, _, _ = params + + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), params_np + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + + + + +class RandomResizedCropParams(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + width, height = _get_image_size(img) + area = height * width + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w, height, width + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(ratio)): + w = width + h = int(round(w / min(ratio))) + elif (in_ratio > max(ratio)): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w, height, width + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + params = self.get_params(img, self.scale, self.ratio) + params_np = np.array(params) + i, j, h, w, _, _ = params + + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), params_np + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + + + + +class RandomHorizontalFlipImageBbox(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxs, params): + """ + Args: + img (PIL Image): Image to be flipped. + bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1] + weight (torch tensor): (1, h, w), in [0, 1] + params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image) + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + bboxs_new = bboxs.clone() + bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2 + bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1 + # change x, keep y, w, h + params_new = np.copy(params) + params_new[1] = params[5] - params[3] - params[1] + return F.hflip(img), bboxs_new, params_new + return img, bboxs, params + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + + + +class RandomCutoutInBbox(object): + """RandomCutout in bboxs + """ + def __init__(self, size, cutout_prob, cutout_ratio=(0.1, 0.2)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + assert isinstance(cutout_ratio, tuple) + self.cutout_prob = cutout_prob + self.cutout_ratio = cutout_ratio + self.width = self.size[0] + self.height = self.size[1] + + def __call__(self, img, resized_bboxs): + """ img is tensor + """ + new_img = img.clone() + for bbox in resized_bboxs: + cutout_r = random.random() + if cutout_r < self.cutout_prob: + x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3] + bbox_w = x2 - x1 + 1 + bbox_h = y2 - y1 + 1 + bbox_area = bbox_w * bbox_h + + target_area = random.uniform(*self.cutout_ratio) * bbox_area + + w = int(round(math.sqrt(target_area))) + h = w + center_cut_x = random.randint(x1, x2) + center_cut_y = random.randint(y1, y2) + cut_x1 = max(center_cut_x - w // 2, 0) + cut_x2 = min(center_cut_x + w // 2, self.width) + cut_y1 = max(center_cut_y - h // 2, 0) + cut_y2 = min(center_cut_y + h // 2, self.height) + + # img is tensor 3, H, W + new_img[:, cut_y1:cut_y2+1, cut_x1:cut_x2+1] = 0.0 + return new_img + + +class RandomHorizontalFlipImageBboxBbox(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxs, bboxs_p, params): + """ + Args: + img (PIL Image): Image to be flipped. + bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1] + bboxs_p (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img + weight (torch tensor): (1, h, w), in [0, 1] + params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image) + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + bboxs_new = bboxs.clone() + bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2 + bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1 + + bboxs_p_new = bboxs_p.clone() + bboxs_p_new[:, 0] = 1.0 - bboxs_p[:, 2] # x1 = x2 + bboxs_p_new[:, 2] = 1.0 - bboxs_p[:, 0] # x2 = x1 + # change x, keep y, w, h + params_new = np.copy(params) + params_new[1] = params[5] - params[3] - params[1] + return F.hflip(img), bboxs_new, bboxs_p_new, params_new + return img, bboxs, bboxs_p, params + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomHorizontalFlipImageBboxBboxBbox(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, bboxs, bboxs_p, bboxs_q, params): + """ + Args: + img (PIL Image): Image to be flipped. + bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1] + bboxs_p (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img + bboxs_q (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img + weight (torch tensor): (1, h, w), in [0, 1] + params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image) + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + bboxs_new = bboxs.clone() + bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2 + bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1 + + bboxs_p_new = bboxs_p.clone() + bboxs_p_new[:, 0] = 1.0 - bboxs_p[:, 2] # x1 = x2 + bboxs_p_new[:, 2] = 1.0 - bboxs_p[:, 0] # x2 = x1 + + bboxs_q_new = bboxs_q.clone() + bboxs_q_new[:, 0] = 1.0 - bboxs_q[:, 2] # x1 = x2 + bboxs_q_new[:, 2] = 1.0 - bboxs_q[:, 0] # x2 = x1 + # change x, keep y, w, h + params_new = np.copy(params) + params_new[1] = params[5] - params[3] - params[1] + return F.hflip(img), bboxs_new, bboxs_p_new, bboxs_q_new, params_new + return img, bboxs, bboxs_p, bboxs_q, params + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomResizedCropCoord(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + width, height = _get_image_size(img) + area = height * width + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w, height, width + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(ratio)): + w = width + h = int(round(w / min(ratio))) + elif (in_ratio > max(ratio)): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w, height, width + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w, height, width = self.get_params(img, self.scale, self.ratio) + coord = torch.Tensor([float(j) / (width - 1), float(i) / (height - 1), + float(j + w - 1) / (width - 1), float(i + h - 1) / (height - 1)]) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), coord + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class RandomHorizontalFlipCoord(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, coord): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + coord_new = coord.clone() + coord_new[0] = coord[2] + coord_new[2] = coord[0] + return F.hflip(img), coord_new + return img, coord + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) diff --git a/contrast/data/zipreader.py b/contrast/data/zipreader.py new file mode 100644 index 0000000..ce0db9f --- /dev/null +++ b/contrast/data/zipreader.py @@ -0,0 +1,85 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import os +import zipfile + + +def is_zip_path(img_or_path): + """judge if this is a zip path""" + return '.zip@' in img_or_path + + +class ZipReader(object): + """A class to read zipped files""" + zip_bank = dict() + + def __init__(self): + super(ZipReader, self).__init__() + + @staticmethod + def get_zipfile(path): + zip_bank = ZipReader.zip_bank + if path not in zip_bank: + zfile = zipfile.ZipFile(path, 'r') + zip_bank[path] = zfile + return zip_bank[path] + + @staticmethod + def split_zip_style_path(path): + pos_at = path.index('@') + assert pos_at != -1, "character '@' is not found from the given path '%s'" % path + + zip_path = path[0: pos_at] + folder_path = path[pos_at + 1:] + folder_path = str.strip(folder_path, '/') + return zip_path, folder_path + + @staticmethod + def list_folder(path): + zip_path, folder_path = ZipReader.split_zip_style_path(path) + + zfile = ZipReader.get_zipfile(zip_path) + folder_list = [] + for file_folder_name in zfile.namelist(): + file_folder_name = str.strip(file_folder_name, '/') + if file_folder_name.startswith(folder_path) and \ + len(os.path.splitext(file_folder_name)[-1]) == 0 and \ + file_folder_name != folder_path: + if len(folder_path) == 0: + folder_list.append(file_folder_name) + else: + folder_list.append(file_folder_name[len(folder_path)+1:]) + + return folder_list + + @staticmethod + def list_files(path, extension=None): + if extension is None: + extension = ['.*'] + zip_path, folder_path = ZipReader.split_zip_style_path(path) + + zfile = ZipReader.get_zipfile(zip_path) + file_lists = [] + for file_folder_name in zfile.namelist(): + file_folder_name = str.strip(file_folder_name, '/') + if file_folder_name.startswith(folder_path) and \ + str.lower(os.path.splitext(file_folder_name)[-1]) in extension: + if len(folder_path) == 0: + file_lists.append(file_folder_name) + else: + file_lists.append(file_folder_name[len(folder_path)+1:]) + + return file_lists + + @staticmethod + def read(path): + zip_path, path_img = ZipReader.split_zip_style_path(path) + zfile = ZipReader.get_zipfile(zip_path) + data = zfile.read(path_img) + return data diff --git a/contrast/lars.py b/contrast/lars.py new file mode 100644 index 0000000..2267146 --- /dev/null +++ b/contrast/lars.py @@ -0,0 +1,161 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import torch +from torch.optim.optimizer import Optimizer + +__all__ = ['LARS'] + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + """Splits param group into weight_decay / non-weight decay. + Tweaked from https://bit.ly/3dzyqod + :param model: the torch.nn model + :param weight_decay: weight decay term + :param skip_list: extra modules (besides BN/bias) to skip + :returns: split param group into weight_decay/not-weight decay + :rtype: list(dict) + """ + # if weight_decay == 0: + # return model.parameters() + + decay, no_decay = [], [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1 or name in skip_list: + # print(name) + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0, 'ignore': True}, + {'params': decay, 'weight_decay': weight_decay, 'ignore': False}] + + +class LARS(Optimizer): + """Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a + :class:`~torch.optim.Optimizer` wrapper. + + __ : https://arxiv.org/abs/1708.03888 + + Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If + you want to the same performance obtained with small-batch training when + you use large-batch training, LARS will be helpful:: + + Args: + optimizer (Optimizer): + optimizer to wrap + eps (float, optional): + epsilon to help with numerical stability while calculating the + adaptive learning rate + trust_coef (float, optional): + trust coefficient for calculating the adaptive learning rate + + Example:: + base_optimizer = optim.SGD(model.parameters(), lr=0.1) + optimizer = LARS(optimizer=base_optimizer) + + output = model(input) + loss = loss_fn(output, target) + loss.backward() + + optimizer.step() + + """ + + def __init__(self, optimizer, eps=1e-8, trust_coef=0.001): + if eps < 0.0: + raise ValueError('invalid epsilon value: , %f' % eps) + + if trust_coef < 0.0: + raise ValueError("invalid trust coefficient: %f" % trust_coef) + + self.optim = optimizer + self.eps = eps + self.trust_coef = trust_coef + + def __getstate__(self): + lars_dict = {} + lars_dict['eps'] = self.eps + lars_dict['trust_coef'] = self.trust_coef + return (self.optim, lars_dict) + + def __setstate__(self, state): + self.optim, lars_dict = state + self.eps = lars_dict['eps'] + self.trust_coef = lars_dict['trust_coef'] + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.optim) + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def state(self): + return self.optim.state + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def apply_adaptive_lrs(self): + with torch.no_grad(): + for group in self.optim.param_groups: + weight_decay = group['weight_decay'] + ignore = group.get('ignore', None) # NOTE: this is set by add_weight_decay + + for p in group['params']: + if p.grad is None: + continue + + # Add weight decay before computing adaptive LR + # Seems to be pretty important in SIMclr style models. + if weight_decay > 0: + p.grad = p.grad.add(p, alpha=weight_decay) + + # Ignore bias / bn terms for LARS update + if ignore is not None and not ignore: + # compute the parameter and gradient norms + param_norm = p.norm() + grad_norm = p.grad.norm() + + # compute our adaptive learning rate + adaptive_lr = 1.0 + if param_norm > 0 and grad_norm > 0: + adaptive_lr = self.trust_coef * param_norm / (grad_norm + self.eps) + + # print("applying {} lr scaling to param of shape {}".format(adaptive_lr, p.shape)) + p.grad = p.grad.mul(adaptive_lr) + + def step(self, *args, **kwargs): + self.apply_adaptive_lrs() + + # Zero out weight decay as we do it in LARS + weight_decay_orig = [group['weight_decay'] for group in self.optim.param_groups] + for group in self.optim.param_groups: + group['weight_decay'] = 0 + + loss = self.optim.step(*args, **kwargs) # Normal optimizer + + # Restore weight decay + for group, wo in zip(self.optim.param_groups, weight_decay_orig): + group['weight_decay'] = wo + + return loss diff --git a/contrast/logger.py b/contrast/logger.py new file mode 100644 index 0000000..b4c853a --- /dev/null +++ b/contrast/logger.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import functools +import logging +import os +import sys +from termcolor import colored + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +# so that calling setup_logger multiple times won't add many handlers +@functools.lru_cache() +def setup_logger( + output=None, distributed_rank=0, *, color=True, name="contrast", abbrev_name=None +): + """ + Initialize the detectron2 logger and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + if abbrev_name is None: + abbrev_name = name + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + f".rank{distributed_rank}" + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return open(filename, "a") diff --git a/contrast/lr_scheduler.py b/contrast/lr_scheduler.py new file mode 100644 index 0000000..0b8b67d --- /dev/null +++ b/contrast/lr_scheduler.py @@ -0,0 +1,93 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +from torch.optim.lr_scheduler import (CosineAnnealingLR, MultiStepLR, + _LRScheduler) + + +# noinspection PyAttributeOutsideInit +class GradualWarmupScheduler(_LRScheduler): + """ Gradually warm-up(increasing) learning rate in optimizer. + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: init learning rate = base lr / multiplier + warmup_epoch: target learning rate is reached at warmup_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + """ + + def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1): + self.multiplier = multiplier + if self.multiplier <= 1.: + raise ValueError('multiplier should be greater than 1.') + self.warmup_epoch = warmup_epoch + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch > self.warmup_epoch: + return self.after_scheduler.get_lr() + else: + return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) + for base_lr in self.base_lrs] + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch + if epoch > self.warmup_epoch: + self.after_scheduler.step(epoch - self.warmup_epoch) + else: + super(GradualWarmupScheduler, self).step(epoch) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + + state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'} + state['after_scheduler'] = self.after_scheduler.state_dict() + return state + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + + after_scheduler_state = state_dict.pop('after_scheduler') + self.__dict__.update(state_dict) + self.after_scheduler.load_state_dict(after_scheduler_state) + + +def get_scheduler(optimizer, n_iter_per_epoch, args): + if "cosine" in args.lr_scheduler: + scheduler = CosineAnnealingLR( + optimizer=optimizer, + eta_min=0.000001, + T_max=(args.epochs - args.warmup_epoch) * n_iter_per_epoch) + elif "step" in args.lr_scheduler: + scheduler = MultiStepLR( + optimizer=optimizer, + gamma=args.lr_decay_rate, + milestones=[(m - args.warmup_epoch)*n_iter_per_epoch for m in args.lr_decay_epochs]) + else: + raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported") + + if args.warmup_epoch > 0: + scheduler = GradualWarmupScheduler( + optimizer, + multiplier=args.warmup_multiplier, + after_scheduler=scheduler, + warmup_epoch=args.warmup_epoch * n_iter_per_epoch) + return scheduler diff --git a/contrast/models/SoCo_C4.py b/contrast/models/SoCo_C4.py new file mode 100644 index 0000000..d4924c2 --- /dev/null +++ b/contrast/models/SoCo_C4.py @@ -0,0 +1,139 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.ops as tvops +from torch.distributed import get_world_size + +from .base import BaseModel +from .mlps import Pred_Head, Proj_Head + + +class SoCo_C4(BaseModel): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05self.output_size22 + """ + + def __init__(self, base_encoder, args): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + """ + super(SoCo_C4, self).__init__(base_encoder, args) + + self.contrast_num_negative = args.contrast_num_negative + self.contrast_momentum = args.contrast_momentum + self.contrast_temperature = args.contrast_temperature + self.output_size = args.output_size + self.aligned = args.aligned + + # create the encoder + self.encoder = base_encoder(low_dim=args.feature_dim, head_type='pass', use_roi_align_on_c4=True) + self.projector = Proj_Head() + self.predictor = Pred_Head() + + # create the encoder_k + self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='pass', use_roi_align_on_c4=True) + self.projector_k = Proj_Head() + + self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data.copy_(param_q.data) + param_k.requires_grad = False + + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder) + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) + + # create the queue + self.register_buffer("queue", torch.randn(args.feature_dim, self.contrast_num_negative)) + self.queue = F.normalize(self.queue, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs) + self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1)) + # print('Initial', get_rank(), self.k, self.K) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2. + self.k = self.k + 1 + # print('Update', get_rank(), self.k, _contrast_momentum) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + def regression_loss_bboxs(self, vectors_q, vectors_k, correspondence): + M, C = vectors_q.shape + N, L, P = correspondence.shape + assert L == P + assert N * L == M + vectors_q = vectors_q.view(N, L, C) + vectors_k = vectors_k.view(N, L, C) + # vectors_q: N, L, C + # vectors_k: N, L, C + vectors_k = torch.transpose(vectors_k, 1, 2) + sim = torch.bmm(vectors_q, vectors_k) # N, L, L + loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6) + return -2 * loss.mean() + + def forward(self, im_1, im_2, bboxs1, bboxs2, corres): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + # compute query features + feat_1 = self.encoder(im_1, bboxs=bboxs1) # queries: NxC + proj_1 = self.projector(feat_1) + pred_1 = self.predictor(proj_1) + pred_1 = F.normalize(pred_1, dim=1) + + feat_2 = self.encoder(im_2, bboxs=bboxs2) + proj_2 = self.projector(feat_2) + pred_2 = self.predictor(proj_2) + pred_2 = F.normalize(pred_2, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + feat_1_ng = self.encoder_k(im_1, bboxs=bboxs1) # keys: NxC + proj_1_ng = self.projector_k(feat_1_ng) + proj_1_ng = F.normalize(proj_1_ng, dim=1) + + feat_2_ng = self.encoder_k(im_2, bboxs=bboxs2) + proj_2_ng = self.projector_k(feat_2_ng) + proj_2_ng = F.normalize(proj_2_ng, dim=1) + + # compute loss + corres_2_1 = corres.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1 + loss = self.regression_loss_bboxs(pred_1, proj_2_ng, corres) + self.regression_loss_bboxs(pred_2, proj_1_ng, corres_2_1) + return loss diff --git a/contrast/models/SoCo_FPN.py b/contrast/models/SoCo_FPN.py new file mode 100644 index 0000000..fe34a0c --- /dev/null +++ b/contrast/models/SoCo_FPN.py @@ -0,0 +1,286 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.ops as tvops +from torch.distributed import get_world_size + +from .base import BaseModel +from .box_util import append_batch_index_to_bboxs_and_scale +from .fast_rcnn_conv_fc_head import FastRCNNConvFCHead +from .fpn import FPN +from .mlps import Pred_Head, Proj_Head + + +class SoCo_FPN(BaseModel): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05self.output_size22 + """ + + def __init__(self, base_encoder, args): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + Just cleaned unused forward tensors + """ + super(SoCo_FPN, self).__init__(base_encoder, args) + + self.contrast_num_negative = args.contrast_num_negative + self.contrast_momentum = args.contrast_momentum + self.contrast_temperature = args.contrast_temperature + self.output_size = args.output_size + self.aligned = args.aligned + + norm_cfg = dict(type='BN', requires_grad=True) + + # align to detectron2 ! + # create the encoder + self.encoder = base_encoder(low_dim=args.feature_dim, head_type='multi_layer') + self.neck = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level, + end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs, + relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg) + self.head = FastRCNNConvFCHead() + self.projector = Proj_Head(in_dim=1024) # head channel + self.predictor = Pred_Head() + + # create the encoder_k + self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='multi_layer') + self.neck_k = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level, + end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs, + relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg) + self.head_k = FastRCNNConvFCHead() + self.projector_k = Proj_Head(in_dim=1024) # head channel + + self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data.copy_(param_q.data) + param_k.requires_grad = False + + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder) + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.neck) + nn.SyncBatchNorm.convert_sync_batchnorm(self.neck_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.head) + nn.SyncBatchNorm.convert_sync_batchnorm(self.head_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) + + self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs) + self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1)) + # print('Initial', get_rank(), self.k, self.K) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2. + self.k = self.k + 1 + # print('Update', get_rank(), self.k, _contrast_momentum) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + + def regression_loss_bboxs_aware(self, vectors_q, vectors_k, correspondence): + N, M, C = vectors_q.shape # N, P * L, C + N, M1, M2 = correspondence.shape # N, P * L, P * L + assert M == M1 == M2 + # vectors_q: N, L, C + # vectors_k: N, L, C + vectors_k = torch.transpose(vectors_k, 1, 2) # N, C, P * L + sim = torch.bmm(vectors_q, vectors_k) # N, P * L, P * L + loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6) + return -2 * loss.mean() + + + def roi_align_feature_map(self, feature_map, bboxs): + feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp + N, C, H, W = feature_map.shape + N, L, _ = bboxs.shape + + output_size = (self.output_size, self.output_size) + + bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W) + aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned) + # N*L, C, output_size, output_size + return aligned_features + + def forward(self, im_1, im_2, im_3, bboxs1_12, bboxs1_13, bboxs2, bboxs3, corres_12, corres_13): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + # compute query features + N, L, _ = bboxs1_12.shape + feats_1 = self.encoder(im_1) + fpn_feats_1 = self.neck(feats_1) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_1) == 4 + + preds_1_12 = [None] * len(fpn_feats_1) + for i, feat_1_12 in enumerate(fpn_feats_1): + feat_roi_1_12 = self.roi_align_feature_map(feat_1_12, bboxs1_12) + feat_vec_1_12 = self.head(feat_roi_1_12) + proj_1_12 = self.projector(feat_vec_1_12) + pred_1_12 = self.predictor(proj_1_12) + pred_1_12 = F.normalize(pred_1_12, dim=1) # N * L, C + pred_1_12 = pred_1_12.reshape((N, L, -1)) # N, L, C + preds_1_12[i] = pred_1_12 + + preds_1_12 = torch.cat(preds_1_12, dim=1) # N, P * L, C + + + preds_1_13 = [None] * len(fpn_feats_1) + for i, feat_1_13 in enumerate(fpn_feats_1): + feat_roi_1_13 = self.roi_align_feature_map(feat_1_13, bboxs1_13) + feat_vec_1_13 = self.head(feat_roi_1_13) + proj_1_13 = self.projector(feat_vec_1_13) + pred_1_13 = self.predictor(proj_1_13) + pred_1_13 = F.normalize(pred_1_13, dim=1) # N * L, C + pred_1_13 = pred_1_13.reshape((N, L, -1)) # N, L, C + preds_1_13[i] = pred_1_13 + + preds_1_13 = torch.cat(preds_1_13, dim=1) # N, P * L, C + + + feats_2 = self.encoder(im_2) + fpn_feats_2 = self.neck(feats_2) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_2) == 4 + + preds_2 = [None] * len(fpn_feats_2) + for i, feat_2 in enumerate(fpn_feats_2): + feat_roi_2 = self.roi_align_feature_map(feat_2, bboxs2) + feat_vec_2 = self.head(feat_roi_2) + proj_2 = self.projector(feat_vec_2) + pred_2 = self.predictor(proj_2) + pred_2 = F.normalize(pred_2, dim=1) + pred_2 = pred_2.reshape((N, L, -1)) # N, L, C + preds_2[i] = pred_2 + + preds_2 = torch.cat(preds_2, dim=1) # N, P * L, C + + + feats_3 = self.encoder(im_3) + fpn_feats_3 = self.neck(feats_3) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_3) == 4 + + preds_3 = [None] * len(fpn_feats_3) + for i, feat_3 in enumerate(fpn_feats_3): + feat_roi_3 = self.roi_align_feature_map(feat_3, bboxs3) + feat_vec_3 = self.head(feat_roi_3) + proj_3 = self.projector(feat_vec_3) + pred_3 = self.predictor(proj_3) + pred_3 = F.normalize(pred_3, dim=1) + pred_3 = pred_3.reshape((N, L, -1)) # N, L, C + preds_3[i] = pred_3 + + preds_3 = torch.cat(preds_3, dim=1) # N, P * L, C + + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + feats_1_ng = self.encoder_k(im_1) + fpn_feats_1_ng = self.neck_k(feats_1_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_1_ng) == 4 + projs_1_12_ng = [None] * len(fpn_feats_1_ng) + for i, feat_1_12_ng in enumerate(fpn_feats_1_ng): + feat_roi_1_12_ng = self.roi_align_feature_map(feat_1_12_ng, bboxs1_12) + feat_vec_1_12_ng = self.head_k(feat_roi_1_12_ng) + proj_1_12_ng = self.projector_k(feat_vec_1_12_ng) + proj_1_12_ng = F.normalize(proj_1_12_ng, dim=1) + proj_1_12_ng = proj_1_12_ng.reshape((N, L, -1)) + projs_1_12_ng[i] = proj_1_12_ng + + projs_1_12_ng = torch.cat(projs_1_12_ng, dim=1) # N, P * L, C + + + projs_1_13_ng = [None] * len(fpn_feats_1_ng) + for i, feat_1_13_ng in enumerate(fpn_feats_1_ng): + feat_roi_1_13_ng = self.roi_align_feature_map(feat_1_13_ng, bboxs1_13) + feat_vec_1_13_ng = self.head_k(feat_roi_1_13_ng) + proj_1_13_ng = self.projector_k(feat_vec_1_13_ng) + proj_1_13_ng = F.normalize(proj_1_13_ng, dim=1) + proj_1_13_ng = proj_1_13_ng.reshape((N, L, -1)) + projs_1_13_ng[i] = proj_1_13_ng + + projs_1_13_ng = torch.cat(projs_1_13_ng, dim=1) # N, P * L, C + + + feats_2_ng = self.encoder_k(im_2) + fpn_feats_2_ng = self.neck_k(feats_2_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_2_ng) == 4 + projs_2_ng = [None] * len(fpn_feats_2_ng) + for i, feat_2_ng in enumerate(fpn_feats_2_ng): + feat_roi_2_ng = self.roi_align_feature_map(feat_2_ng, bboxs2) + feat_vec_2_ng = self.head_k(feat_roi_2_ng) + proj_2_ng = self.projector_k(feat_vec_2_ng) + proj_2_ng = F.normalize(proj_2_ng, dim=1) + proj_2_ng = proj_2_ng.reshape((N, L, -1)) + projs_2_ng[i] = proj_2_ng + + projs_2_ng = torch.cat(projs_2_ng, dim=1) # N, P * L, C + + + feats_3_ng = self.encoder_k(im_3) + fpn_feats_3_ng = self.neck_k(feats_3_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_3_ng) == 4 + projs_3_ng = [None] * len(fpn_feats_3_ng) + for i, feat_3_ng in enumerate(fpn_feats_3_ng): + feat_roi_3_ng = self.roi_align_feature_map(feat_3_ng, bboxs3) + feat_vec_3_ng = self.head_k(feat_roi_3_ng) + proj_3_ng = self.projector_k(feat_vec_3_ng) + proj_3_ng = F.normalize(proj_3_ng, dim=1) + proj_3_ng = proj_3_ng.reshape((N, L, -1)) + projs_3_ng[i] = proj_3_ng + + projs_3_ng = torch.cat(projs_3_ng, dim=1) # N, P * L, C + + + # compute loss + corres_12_2to1 = corres_12.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1 + corres_13_3to1 = corres_13.transpose(1, 2) # transpose dim 1 dim 2, map bboxs3 to bboxs1 + loss_bbox_aware_12 = self.regression_loss_bboxs_aware(preds_1_12, projs_2_ng, corres_12) + self.regression_loss_bboxs_aware(preds_2, projs_1_12_ng, corres_12_2to1) + loss_bbox_aware_13 = self.regression_loss_bboxs_aware(preds_1_13, projs_3_ng, corres_13) + self.regression_loss_bboxs_aware(preds_3, projs_1_13_ng, corres_13_3to1) + + loss = loss_bbox_aware_12 + loss_bbox_aware_13 + + return loss diff --git a/contrast/models/SoCo_FPN_Star.py b/contrast/models/SoCo_FPN_Star.py new file mode 100644 index 0000000..3fa9081 --- /dev/null +++ b/contrast/models/SoCo_FPN_Star.py @@ -0,0 +1,343 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.ops as tvops +from torch.distributed import get_world_size + +from .base import BaseModel +from .box_util import append_batch_index_to_bboxs_and_scale +from .fast_rcnn_conv_fc_head import FastRCNNConvFCHead +from .fpn import FPN +from .mlps import Pred_Head, Proj_Head + + +class SoCo_FPN_Star(BaseModel): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05self.output_size22 + """ + + def __init__(self, base_encoder, args): + """ + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 65536) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + Just cleaned unused forward tensors + """ + super(SoCo_FPN_Star, self).__init__(base_encoder, args) + + self.contrast_num_negative = args.contrast_num_negative + self.contrast_momentum = args.contrast_momentum + self.contrast_temperature = args.contrast_temperature + self.output_size = args.output_size + self.aligned = args.aligned + + norm_cfg = dict(type='BN', requires_grad=True) + + # align to detectron2 ! + # create the encoder + self.encoder = base_encoder(low_dim=args.feature_dim, head_type='multi_layer') + self.neck = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level, + end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs, + relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg) + self.head = FastRCNNConvFCHead() + self.projector = Proj_Head(in_dim=1024) # head channel + self.predictor = Pred_Head() + + # create the encoder_k + self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='multi_layer') + self.neck_k = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level, + end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs, + relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg) + self.head_k = FastRCNNConvFCHead() + self.projector_k = Proj_Head(in_dim=1024) # head channel + + self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data.copy_(param_q.data) + param_k.requires_grad = False + + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder) + nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.neck) + nn.SyncBatchNorm.convert_sync_batchnorm(self.neck_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.head) + nn.SyncBatchNorm.convert_sync_batchnorm(self.head_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector) + nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k) + nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) + + self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs) + self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1)) + # print('Initial', get_rank(), self.k, self.K) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2. + self.k = self.k + 1 + # print('Update', get_rank(), self.k, _contrast_momentum) + + for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()): + param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum) + + + def regression_loss_bboxs_aware(self, vectors_q, vectors_k, correspondence): + N, M, C = vectors_q.shape # N, P * L, C + N, M1, M2 = correspondence.shape # N, P * L, P * L + assert M == M1 == M2 + # vectors_q: N, L, C + # vectors_k: N, L, C + vectors_k = torch.transpose(vectors_k, 1, 2) # N, C, P * L + sim = torch.bmm(vectors_q, vectors_k) # N, P * L, P * L + loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6) + return -2 * loss.mean() + + def roi_align_feature_map(self, feature_map, bboxs): + feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp + N, C, H, W = feature_map.shape + N, L, _ = bboxs.shape + + output_size = (self.output_size, self.output_size) + + bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W) + aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned) + # N*L, C, output_size, output_size + return aligned_features + + def forward(self, im_1, im_2, im_3, im_4, bboxs1_12, bboxs1_13, bboxs1_14, bboxs2, bboxs3, bboxs4, corres_12, corres_13, corres_14): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + # compute query features + N, L, _ = bboxs1_12.shape + feats_1 = self.encoder(im_1) + fpn_feats_1 = self.neck(feats_1) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_1) == 4 + + preds_1_12 = [None] * len(fpn_feats_1) + for i, feat_1_12 in enumerate(fpn_feats_1): + feat_roi_1_12 = self.roi_align_feature_map(feat_1_12, bboxs1_12) + feat_vec_1_12 = self.head(feat_roi_1_12) + proj_1_12 = self.projector(feat_vec_1_12) + pred_1_12 = self.predictor(proj_1_12) + pred_1_12 = F.normalize(pred_1_12, dim=1) # N * L, C + pred_1_12 = pred_1_12.reshape((N, L, -1)) # N, L, C + preds_1_12[i] = pred_1_12 + + preds_1_12 = torch.cat(preds_1_12, dim=1) # N, P * L, C + + + preds_1_13 = [None] * len(fpn_feats_1) + for i, feat_1_13 in enumerate(fpn_feats_1): + feat_roi_1_13 = self.roi_align_feature_map(feat_1_13, bboxs1_13) + feat_vec_1_13 = self.head(feat_roi_1_13) + proj_1_13 = self.projector(feat_vec_1_13) + pred_1_13 = self.predictor(proj_1_13) + pred_1_13 = F.normalize(pred_1_13, dim=1) # N * L, C + pred_1_13 = pred_1_13.reshape((N, L, -1)) # N, L, C + preds_1_13[i] = pred_1_13 + + preds_1_13 = torch.cat(preds_1_13, dim=1) # N, P * L, C + + + preds_1_14 = [None] * len(fpn_feats_1) + for i, feat_1_14 in enumerate(fpn_feats_1): + feat_roi_1_14 = self.roi_align_feature_map(feat_1_14, bboxs1_14) + feat_vec_1_14 = self.head(feat_roi_1_14) + proj_1_14 = self.projector(feat_vec_1_14) + pred_1_14 = self.predictor(proj_1_14) + pred_1_14 = F.normalize(pred_1_14, dim=1) # N * L, C + pred_1_14 = pred_1_14.reshape((N, L, -1)) # N, L, C + preds_1_14[i] = pred_1_14 + + preds_1_14 = torch.cat(preds_1_14, dim=1) # N, P * L, C + + + feats_2 = self.encoder(im_2) + fpn_feats_2 = self.neck(feats_2) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_2) == 4 + + preds_2 = [None] * len(fpn_feats_2) + for i, feat_2 in enumerate(fpn_feats_2): + feat_roi_2 = self.roi_align_feature_map(feat_2, bboxs2) + feat_vec_2 = self.head(feat_roi_2) + proj_2 = self.projector(feat_vec_2) + pred_2 = self.predictor(proj_2) + pred_2 = F.normalize(pred_2, dim=1) + pred_2 = pred_2.reshape((N, L, -1)) # N, L, C + preds_2[i] = pred_2 + + preds_2 = torch.cat(preds_2, dim=1) # N, P * L, C + + + feats_3 = self.encoder(im_3) + fpn_feats_3 = self.neck(feats_3) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_3) == 4 + + preds_3 = [None] * len(fpn_feats_3) + for i, feat_3 in enumerate(fpn_feats_3): + feat_roi_3 = self.roi_align_feature_map(feat_3, bboxs3) + feat_vec_3 = self.head(feat_roi_3) + proj_3 = self.projector(feat_vec_3) + pred_3 = self.predictor(proj_3) + pred_3 = F.normalize(pred_3, dim=1) + pred_3 = pred_3.reshape((N, L, -1)) # N, L, C + preds_3[i] = pred_3 + + preds_3 = torch.cat(preds_3, dim=1) # N, P * L, C + + + feats_4 = self.encoder(im_4) + fpn_feats_4 = self.neck(feats_4) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_4) == 4 + + preds_4 = [None] * len(fpn_feats_4) + for i, feat_4 in enumerate(fpn_feats_4): + feat_roi_4 = self.roi_align_feature_map(feat_4, bboxs4) + feat_vec_4 = self.head(feat_roi_4) + proj_4 = self.projector(feat_vec_4) + pred_4 = self.predictor(proj_4) + pred_4 = F.normalize(pred_4, dim=1) + pred_4 = pred_4.reshape((N, L, -1)) # N, L, C + preds_4[i] = pred_4 + + preds_4 = torch.cat(preds_4, dim=1) # N, P * L, C + + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + feats_1_ng = self.encoder_k(im_1) + fpn_feats_1_ng = self.neck_k(feats_1_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_1_ng) == 4 + projs_1_12_ng = [None] * len(fpn_feats_1_ng) + for i, feat_1_12_ng in enumerate(fpn_feats_1_ng): + feat_roi_1_12_ng = self.roi_align_feature_map(feat_1_12_ng, bboxs1_12) + feat_vec_1_12_ng = self.head_k(feat_roi_1_12_ng) + proj_1_12_ng = self.projector_k(feat_vec_1_12_ng) + proj_1_12_ng = F.normalize(proj_1_12_ng, dim=1) + proj_1_12_ng = proj_1_12_ng.reshape((N, L, -1)) + projs_1_12_ng[i] = proj_1_12_ng + + projs_1_12_ng = torch.cat(projs_1_12_ng, dim=1) # N, P * L, C + + + projs_1_13_ng = [None] * len(fpn_feats_1_ng) + for i, feat_1_13_ng in enumerate(fpn_feats_1_ng): + feat_roi_1_13_ng = self.roi_align_feature_map(feat_1_13_ng, bboxs1_13) + feat_vec_1_13_ng = self.head_k(feat_roi_1_13_ng) + proj_1_13_ng = self.projector_k(feat_vec_1_13_ng) + proj_1_13_ng = F.normalize(proj_1_13_ng, dim=1) + proj_1_13_ng = proj_1_13_ng.reshape((N, L, -1)) + projs_1_13_ng[i] = proj_1_13_ng + + projs_1_13_ng = torch.cat(projs_1_13_ng, dim=1) # N, P * L, C + + + projs_1_14_ng = [None] * len(fpn_feats_1_ng) + for i, feat_1_14_ng in enumerate(fpn_feats_1_ng): + feat_roi_1_14_ng = self.roi_align_feature_map(feat_1_14_ng, bboxs1_14) + feat_vec_1_14_ng = self.head_k(feat_roi_1_14_ng) + proj_1_14_ng = self.projector_k(feat_vec_1_14_ng) + proj_1_14_ng = F.normalize(proj_1_14_ng, dim=1) + proj_1_14_ng = proj_1_14_ng.reshape((N, L, -1)) + projs_1_14_ng[i] = proj_1_14_ng + + projs_1_14_ng = torch.cat(projs_1_14_ng, dim=1) # N, P * L, C + + + feats_2_ng = self.encoder_k(im_2) + fpn_feats_2_ng = self.neck_k(feats_2_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_2_ng) == 4 + projs_2_ng = [None] * len(fpn_feats_2_ng) + for i, feat_2_ng in enumerate(fpn_feats_2_ng): + feat_roi_2_ng = self.roi_align_feature_map(feat_2_ng, bboxs2) + feat_vec_2_ng = self.head_k(feat_roi_2_ng) + proj_2_ng = self.projector_k(feat_vec_2_ng) + proj_2_ng = F.normalize(proj_2_ng, dim=1) + proj_2_ng = proj_2_ng.reshape((N, L, -1)) + projs_2_ng[i] = proj_2_ng + + projs_2_ng = torch.cat(projs_2_ng, dim=1) # N, P * L, C + + + feats_3_ng = self.encoder_k(im_3) + fpn_feats_3_ng = self.neck_k(feats_3_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_3_ng) == 4 + projs_3_ng = [None] * len(fpn_feats_3_ng) + for i, feat_3_ng in enumerate(fpn_feats_3_ng): + feat_roi_3_ng = self.roi_align_feature_map(feat_3_ng, bboxs3) + feat_vec_3_ng = self.head_k(feat_roi_3_ng) + proj_3_ng = self.projector_k(feat_vec_3_ng) + proj_3_ng = F.normalize(proj_3_ng, dim=1) + proj_3_ng = proj_3_ng.reshape((N, L, -1)) + projs_3_ng[i] = proj_3_ng + + projs_3_ng = torch.cat(projs_3_ng, dim=1) # N, P * L, C + + + feats_4_ng = self.encoder_k(im_4) + fpn_feats_4_ng = self.neck_k(feats_4_ng) # p2, p3, p4, p5, num_outs = 4 + assert len(fpn_feats_4_ng) == 4 + projs_4_ng = [None] * len(fpn_feats_4_ng) + for i, feat_4_ng in enumerate(fpn_feats_4_ng): + feat_roi_4_ng = self.roi_align_feature_map(feat_4_ng, bboxs4) + feat_vec_4_ng = self.head_k(feat_roi_4_ng) + proj_4_ng = self.projector_k(feat_vec_4_ng) + proj_4_ng = F.normalize(proj_4_ng, dim=1) + proj_4_ng = proj_4_ng.reshape((N, L, -1)) + projs_4_ng[i] = proj_4_ng + + projs_4_ng = torch.cat(projs_4_ng, dim=1) # N, P * L, C + + # compute loss + corres_12_2to1 = corres_12.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1 + corres_13_3to1 = corres_13.transpose(1, 2) # transpose dim 1 dim 2, map bboxs3 to bboxs1 + corres_14_4to1 = corres_14.transpose(1, 2) # transpose dim 1 dim 2, map bboxs4 to bboxs1 + loss_bbox_aware_12 = self.regression_loss_bboxs_aware(preds_1_12, projs_2_ng, corres_12) + self.regression_loss_bboxs_aware(preds_2, projs_1_12_ng, corres_12_2to1) + loss_bbox_aware_13 = self.regression_loss_bboxs_aware(preds_1_13, projs_3_ng, corres_13) + self.regression_loss_bboxs_aware(preds_3, projs_1_13_ng, corres_13_3to1) + loss_bbox_aware_14 = self.regression_loss_bboxs_aware(preds_1_14, projs_4_ng, corres_14) + self.regression_loss_bboxs_aware(preds_4, projs_1_14_ng, corres_14_4to1) + + loss = loss_bbox_aware_12 + loss_bbox_aware_13 + loss_bbox_aware_14 + + return loss diff --git a/contrast/models/__init__.py b/contrast/models/__init__.py new file mode 100644 index 0000000..ec53b68 --- /dev/null +++ b/contrast/models/__init__.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +from .SoCo_C4 import SoCo_C4 +from .SoCo_FPN import SoCo_FPN +from .SoCo_FPN_Star import SoCo_FPN_Star + +__all__ = [] diff --git a/contrast/models/base.py b/contrast/models/base.py new file mode 100644 index 0000000..ade293d --- /dev/null +++ b/contrast/models/base.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import torch.nn as nn +import torch.nn.functional as F + + +class BaseModel(nn.Module): + """ + Base model with: a encoder + """ + + def __init__(self, base_encoder, args): + super(BaseModel, self).__init__() + + # create the encoders + self.encoder = base_encoder(low_dim=args.feature_dim, head_type=args.head_type) + + def forward(self, x1, x2): + """ + Input: x1, x2 or x, y_idx + Output: logits, labels + """ + raise NotImplementedError diff --git a/contrast/models/box_util.py b/contrast/models/box_util.py new file mode 100644 index 0000000..0a418e3 --- /dev/null +++ b/contrast/models/box_util.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + +import torch + + +def append_batch_index_to_bboxs_and_scale(bboxs, H, W): + N, L, _ = bboxs.size() + batch_index = torch.arange(N, device=bboxs.device) + batch_index = batch_index.reshape(N, 1) + batch_index = batch_index.repeat(1, L) + batch_index = batch_index.reshape(N, L, 1) + + bboxs_coord = bboxs[:, :, :4].clone() # must clone + bboxs_coord[:, :, 0] = bboxs_coord[:, :, 0] * W # x1 + bboxs_coord[:, :, 2] = bboxs_coord[:, :, 2] * W # x2 + bboxs_coord[:, :, 1] = bboxs_coord[:, :, 1] * H # y1 + bboxs_coord[:, :, 3] = bboxs_coord[:, :, 3] * H # y2 + bboxs_with_index = torch.cat([batch_index, bboxs_coord], dim=2) + bboxs_with_index = bboxs_with_index.reshape(N*L, 5) + return bboxs_with_index diff --git a/contrast/models/fast_rcnn_conv_fc_head.py b/contrast/models/fast_rcnn_conv_fc_head.py new file mode 100644 index 0000000..b04bb98 --- /dev/null +++ b/contrast/models/fast_rcnn_conv_fc_head.py @@ -0,0 +1,56 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import torch.nn as nn + + +class FastRCNNConvFCHead(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.ac1 = nn.ReLU() + + self.conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.ac2 = nn.ReLU() + + self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.ac3 = nn.ReLU() + + self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.ac4 = nn.ReLU() + + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(in_features=12544, out_features=1024, bias=True) + self.fc_relu1 = nn.ReLU() + + def forward(self, roi_feature_map): + conv1_out = self.conv1(roi_feature_map) + bn1_out = self.bn1(conv1_out) + ac1_out = self.ac1(bn1_out) + + conv2_out = self.conv2(ac1_out) + bn2_out = self.bn2(conv2_out) + ac2_out = self.ac2(bn2_out) + + conv3_out = self.conv3(ac2_out) + bn3_out = self.bn3(conv3_out) + ac3_out = self.ac3(bn3_out) + + conv4_out = self.conv4(ac3_out) + bn4_out = self.bn4(conv4_out) + ac4_out = self.ac4(bn4_out) + + flat = self.flatten(ac4_out) + fc1_out = self.fc1(flat) + fc_relu1_out = self.fc_relu1(fc1_out) + + return fc_relu1_out diff --git a/contrast/models/fpn.py b/contrast/models/fpn.py new file mode 100644 index 0000000..28d31ef --- /dev/null +++ b/contrast/models/fpn.py @@ -0,0 +1,172 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import warnings + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, xavier_init + + +class FPN(nn.Module): + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=True, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + + super(FPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # TODO: deprecate `extra_convs_on_inputs` + warnings.simplefilter('once') + warnings.warn( + '"extra_convs_on_inputs" will be deprecated in v2.9.0,' + 'Please use "add_extra_convs"', DeprecationWarning) + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + """Initialize the weights of FPN module.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/contrast/models/mlps.py b/contrast/models/mlps.py new file mode 100644 index 0000000..d78aede --- /dev/null +++ b/contrast/models/mlps.py @@ -0,0 +1,100 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import torch.nn as nn + + +class MLP(nn.Module): + def __init__(self, in_dim, inner_dim=4096, out_dim=256): + super(MLP, self).__init__() + + self.linear1 = nn.Linear(in_dim, inner_dim) + self.bn1 = nn.BatchNorm1d(inner_dim) + self.relu1 = nn.ReLU(inplace=True) + + self.linear2 = nn.Linear(inner_dim, out_dim) + + def forward(self, x): + x = self.linear1(x) + x = x.unsqueeze(-1) + x = self.bn1(x) + x = x.squeeze(-1) + x = self.relu1(x) + + x = self.linear2(x) + + return x + + +def conv1x1(in_planes, out_planes): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=True) + + +class MLP2d(nn.Module): + def __init__(self, in_dim, inner_dim=4096, out_dim=256): + super(MLP2d, self).__init__() + + self.linear1 = conv1x1(in_dim, inner_dim) + self.bn1 = nn.BatchNorm2d(inner_dim) + self.relu1 = nn.ReLU(inplace=True) + + self.linear2 = conv1x1(inner_dim, out_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.bn1(x) + x = self.relu1(x) + + x = self.linear2(x) + + return x + + +class MLP2d_3Layer(nn.Module): + def __init__(self, in_dim, inner_dim=4096, out_dim=256): + super(MLP2d_3Layer, self).__init__() + + self.linear1 = conv1x1(in_dim, inner_dim) + self.bn1 = nn.BatchNorm2d(inner_dim) + self.relu1 = nn.ReLU(inplace=True) + + self.linear2 = conv1x1(inner_dim, inner_dim) + self.bn2 = nn.BatchNorm2d(inner_dim) + self.relu2 = nn.ReLU(inplace=True) + + self.linear3 = conv1x1(inner_dim, out_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.bn1(x) + x = self.relu1(x) + + x = self.linear2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.linear3(x) + + return x + + +def Proj_Head(in_dim=2048, inner_dim=4096, out_dim=256): + return MLP(in_dim, inner_dim, out_dim) + + +def Pred_Head(in_dim=256, inner_dim=4096, out_dim=256): + return MLP(in_dim, inner_dim, out_dim) + + +def Proj_Head2d(in_dim=2048, inner_dim=4096, out_dim=256): + return MLP2d(in_dim, inner_dim, out_dim) + + +def Pred_Head2d(in_dim=256, inner_dim=4096, out_dim=256): + return MLP2d(in_dim, inner_dim, out_dim) diff --git a/contrast/option.py b/contrast/option.py new file mode 100644 index 0000000..1eaeb69 --- /dev/null +++ b/contrast/option.py @@ -0,0 +1,161 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import argparse +import os + +from contrast import resnet +from contrast.util import MyHelpFormatter + +model_names = sorted(name for name in resnet.__all__ + if name.islower() and callable(resnet.__dict__[name])) + + +def parse_option(stage='pre_train'): + parser = argparse.ArgumentParser(f'contrast {stage} stage', formatter_class=MyHelpFormatter) + # develop + parser.add_argument('--debug', action='store_true', help='enable debug mode') + + # dataset + parser.add_argument('--data_dir', type=str, default='./data', help='dataset director') + parser.add_argument('--crop', type=float, default=0.2 if stage == 'pre_train' else 0.08, help='minimum crop') + parser.add_argument('--crop1', type=float, default=1.0, help='minimum crop for view1 when asym asym crop') + parser.add_argument('--aug', type=str, default='NULL', choices=['NULL', 'ImageAsymBboxCutout', 'ImageAsymBboxAwareMultiJitter1', + 'ImageAsymBboxAwareMultiJitter1Cutout', 'ImageAsymBboxAwareMulti3ResizeExtraJitter1'], + help='which augmentation to use.') + parser.add_argument('--zip', action='store_true', help='use zipped dataset') + parser.add_argument('--split_map', type=str, default='map') + parser.add_argument('--cache_mode', type=str, default='part', choices=['no', 'full', 'part'], + help='cache mode: no for no cache, full for cache all data, part for only cache part of data') + parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'VOC', 'COCO'], help='dataset type') + parser.add_argument('--ann_file', type=str, default='', help='annotation file') + parser.add_argument('--image_size', type=int, default=224, help='image crop size') + parser.add_argument('--image3_size', type=int, default=112, help='image crop size') + parser.add_argument('--image4_size', type=int, default=112, help='image crop size') + + parser.add_argument('--num_workers', type=int, default=4, help='num of workers per GPU to use') + # sliding window sampler + parser.add_argument('--swin_', type=int, default=131072, help='window size in sliding window sampler') + parser.add_argument('--window_stride', type=int, default=16384, help='window stride in sliding window sampler') + parser.add_argument('--use_sliding_window_sampler', action='store_true', + help='whether to use sliding window sampler') + parser.add_argument('--shuffle_per_epoch', action='store_true', + help='shuffle indices in sliding window sampler per epoch') + if stage == 'linear': + parser.add_argument('--total_batch_size', type=int, default=256, help='total train batch size for all GPU') + else: + parser.add_argument('--batch_size', type=int, default=64, help='batch_size for single gpu') + + # model + parser.add_argument('--arch', type=str, default='resnet50', choices=model_names, + help="backbone architecture") + if stage == 'pre_train': + parser.add_argument('--model', type=str, required=True, help='which model to use') + parser.add_argument('--contrast_temperature', type=float, default=0.07, help='temperature in instance cls loss') + parser.add_argument('--contrast_momentum', type=float, default=0.999, + help='momentum parameter used in MoCo and InstDisc') + parser.add_argument('--contrast_num_negative', type=int, default=65536, + help='number of negative samples used in MoCo and InstDisc') + parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension') + parser.add_argument('--head_type', type=str, default='mlp_head', help='choose head type') + parser.add_argument('--lambda_img', type=float, default=0., help='loss weight of image_to_image loss') + parser.add_argument('--lambda_cross', type=float, default=1., help='loss weight of image_to_point loss') + + # FPN default args + parser.add_argument('--in_channels', type=list, default=[256, 512, 1024, 2048], help='FPN feature map input channels') + parser.add_argument('--out_channels', type=int, default=256, help='FPN feature map output channels') + parser.add_argument('--start_level', type=int, default=1, help='FPN start level') + parser.add_argument('--end_level', type=int, default=-1, help='FPN end level') + parser.add_argument('--add_extra_convs', type=int, default=1, help='FPN add extra convs') + parser.add_argument('--extra_convs_on_inputs', type=int, default=0, help='FPN extra_convs_on_inputs') + parser.add_argument('--relu_before_extra_convs', type=int, default=1, help='FPN relu_before_extra_convs') + parser.add_argument('--no_norm_on_lateral', type=int, default=0, help='FPN no_norm_on_lateral') + parser.add_argument('--num_outs', type=int, default=3, help='FPN num_outs, use p3~p5') + + # Head default args + parser.add_argument('--head_in_channels', type=int, default=256, help='Head feature map input channels') + parser.add_argument('--head_feat_channels', type=int, default=256, help='Head feature map feat channels') + parser.add_argument('--head_stacked_convs', type=int, default=4, help='Head stacked convs') + + # optimization + if stage == 'pre_train': + parser.add_argument('--base_learning_rate', '--base_lr', type=float, default=0.03, + help='base learning when batch size = 256. final lr is determined by linear scale') + else: + parser.add_argument('--learning_rate', type=float, default=30, help='learning rate') + parser.add_argument('--optimizer', type=str, choices=['sgd', 'lars'], default='sgd', + help='for optimizer choice.') + parser.add_argument('--lr_scheduler', type=str, default='cosine', + choices=["step", "cosine"], help="learning rate scheduler") + parser.add_argument('--warmup_epoch', type=int, default=5, help='warmup epoch') + parser.add_argument('--warmup_multiplier', type=int, default=100, help='warmup multiplier') + parser.add_argument('--lr_decay_epochs', type=int, default=[120, 160, 200], nargs='+', + help='for step scheduler. where to decay lr, can be a list') + parser.add_argument('--lr_decay_rate', type=float, default=0.1, + help='for step scheduler. decay rate for learning rate') + parser.add_argument('--weight_decay', type=float, default=1e-4 if stage == 'pre_train' else 0, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') + parser.add_argument('--amp_opt_level', type=str, default='O1', choices=['O0', 'O1', 'O2'], + help='mixed precision opt level, if O0, no amp is used') + parser.add_argument('--start_epoch', type=int, default=1, help='used for resume') + parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') + + # misc + parser.add_argument('--output_dir', type=str, default='./output', help='output director') + parser.add_argument('--auto_resume', action='store_true', help='auto resume from current.pth') + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint') + parser.add_argument('--print_freq', type=int, default=100, help='print frequency') + parser.add_argument('--save_freq', type=int, default=10, help='save frequency') + parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + if stage == 'linear': + parser.add_argument('--pretrained_model', type=str, required=True, help="pretrained model path") + parser.add_argument('-e', '--eval', action='store_true', help='only evaluate') + else: + parser.add_argument('--pretrained_model', type=str, default="", help="pretrained model path") + + # selective search + parser.add_argument('--ss_props', action='store_true', help='use selective search propos to calculate weight map') + parser.add_argument('--filter_strategy', type=str, default='none', help='filter strategy') + parser.add_argument('--select_strategy', type=str, default='none', help='select strategy') + parser.add_argument('--select_k', type=int, default=0, help='select strategy k, required when select strategy is not none') + parser.add_argument('--weight_strategy', type=str, default='bbox', help='weight map strategy') + # we use same select_strategy for weight map and bbox to ins + parser.add_argument('--bbox_size_range', type=tuple, default=(32, 112, 224)) + parser.add_argument('--iou_thres', type=float, default=0.5) + parser.add_argument('--output_size', type=int, default=7) + parser.add_argument('--aligned', action='store_true') + parser.add_argument('--jitter_prob', type=float, default=0.5) + parser.add_argument('--jitter_ratio', type=float, default=0.2) + parser.add_argument('--padding_k', type=int, default=32) + parser.add_argument('--max_tries', type=int, default=5) + parser.add_argument('--aware_range', type=list, default=[48, 96, 192, 224, 0]) + parser.add_argument('--aware_start', type=int, default=0, help="starting from using P?") + parser.add_argument('--aware_end', type=int, default=4, help="ending from using P?, not included") + + parser.add_argument('--cutout_prob', type=float, default=0.5) + parser.add_argument('--cutout_ratio', type=tuple) + parser.add_argument('--cutout_ratio_min', type=float, default=0.1) + parser.add_argument('--cutout_ratio_max', type=float, default=0.2) + + parser.add_argument('--max_props', type=int, default=32) + parser.add_argument('--aspect_ratio', type=float, default=3) + parser.add_argument('--min_size_ratio', type=float, default=0.3) + parser.add_argument('--max_size_ratio', type=float, default=0.8) + + args = parser.parse_args() + + if stage == 'pre_train': + # Due to base command line can not directly pass bool values, we use int, 0 -> False, 1 -> True + args.add_extra_convs = bool(args.add_extra_convs) + args.extra_convs_on_inputs = bool(args.extra_convs_on_inputs) + args.relu_before_extra_convs = bool(args.relu_before_extra_convs) + args.no_norm_on_lateral = bool(args.no_norm_on_lateral) + args.cutout_ratio = (args.cutout_ratio_min, args.cutout_ratio_max) + + return args diff --git a/contrast/resnet.py b/contrast/resnet.py new file mode 100644 index 0000000..ab1140f --- /dev/null +++ b/contrast/resnet.py @@ -0,0 +1,341 @@ +import math +import os + +import torch +import torch.nn as nn +import torchvision.ops as tvops + +from .models.box_util import append_batch_index_to_bboxs_and_scale + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', + 'resnet18_d', 'resnet34_d', 'resnet50_d', 'resnet101_d', 'resnet152_d', + 'resnet50_16s', 'resnet50_w2x', 'resnext101_32x8d', 'resnext152_32x8d'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + nn.BatchNorm2d(out_planes), + nn.ReLU() + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=-1): + super(BasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(Bottleneck, self).__init__() + width = int(planes * (base_width / 64.)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, dilation=dilation, + padding=dilation, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, in_channel=3, width=1, + groups=1, width_per_group=64, + mid_dim=1024, low_dim=128, + avg_down=False, deep_stem=False, + head_type='mlp_head', layer4_dilation=1, use_roi_align_on_c4=False, + pretrained=None): + super(ResNet, self).__init__() + self.avg_down = avg_down + self.inplanes = 64 * width + self.base = int(64 * width) + self.groups = groups + self.base_width = width_per_group + self.use_roi_align_on_c4 = use_roi_align_on_c4 + + mid_dim = self.base * 8 * block.expansion + + if deep_stem: + self.conv1 = nn.Sequential( + conv3x3_bn_relu(in_channel, 32, stride=2), + conv3x3_bn_relu(32, 32, stride=1), + conv3x3(32, 64, stride=1) + ) + else: + self.conv1 = nn.Conv2d(in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, self.base, layers[0]) + self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) + self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) + if layer4_dilation == 1: + self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) + elif layer4_dilation == 2: + self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=1, dilation=2) + else: + raise NotImplementedError + self.avgpool = nn.AvgPool2d(7, stride=1) + self.small_avgpool = nn.AvgPool2d(4, stride=1) + + if self.use_roi_align_on_c4: + self.output_size = 14 + self.aligned = True + + self.head_type = head_type + if head_type == 'mlp_head': + self.fc1 = nn.Linear(mid_dim, mid_dim) + self.relu2 = nn.ReLU(inplace=True) + self.fc2 = nn.Linear(mid_dim, low_dim) + elif head_type == 'reduce': + self.fc = nn.Linear(mid_dim, low_dim) + elif head_type == 'conv_head': + self.fc1 = nn.Conv2d(mid_dim, mid_dim, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(2048) + self.relu2 = nn.ReLU(inplace=True) + self.fc2 = nn.Linear(mid_dim, low_dim) + elif head_type in ['pass', 'early_return', 'multi_layer']: + pass + else: + raise NotImplementedError + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + # zero gamma for batch norm: reference bag of tricks + if block is Bottleneck: + gamma_name = "bn3.weight" + elif block is BasicBlock: + gamma_name = "bn2.weight" + else: + raise RuntimeError(f"block {block} not supported") + for name, value in self.named_parameters(): + if name.endswith(gamma_name): + value.data.zero_() + + if pretrained: + self.pretrained = pretrained + self._init_weights_pretrained() + + def _init_weights_pretrained(self): + print(f" => loading pretrained resnet: {self.pretrained}") + assert os.path.isfile(self.pretrained) + ckpt = torch.load(self.pretrained, map_location='cpu') + missing_keys, unexpected_keys = self.load_state_dict(ckpt['state_dict'], strict=False) + print(f"missing_keys: {missing_keys}") + print(f"unexpected_keys: {unexpected_keys}") + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + if self.avg_down: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=stride, stride=stride), + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)] + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation)) + + return nn.Sequential(*layers) + + def roi_align_feature_map(self, feature_map, bboxs, small=False): + feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp + N, C, H, W = feature_map.shape + N, L, _ = bboxs.shape + + if small: + output_size = (self.output_size // 2, self.output_size // 2) + else: + output_size = (self.output_size, self.output_size) + + bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W) + aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned) + # N*L, C, output_size, output_size + return aligned_features + + def forward(self, x, bboxs=None, small=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + c2 = self.layer1(x) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + if self.use_roi_align_on_c4 and bboxs is not None: + c4align = self.roi_align_feature_map(c4, bboxs, small=small) + c5 = self.layer4(c4align) + else: + c5 = self.layer4(c4) + + if self.head_type == 'multi_layer': + return c2, c3, c4, c5 + + if self.head_type == 'early_return': + return c5 + + if self.head_type != 'conv_head': + if small: + c5 = self.small_avgpool(c5) + else: + c5 = self.avgpool(c5) + c5 = c5.view(c5.size(0), -1) + + if self.head_type == 'mlp_head': + out = self.fc1(c5) + out = self.relu2(out) + out = self.fc2(out) + elif self.head_type == 'reduce': + out = self.fc(c5) + elif self.head_type == 'conv_head': + out = self.fc1(c5) + out = self.bn2(out) + out = self.relu2(out) + out = self.avgpool(out) + out = out.view(out.size(0), -1) + out = self.fc2(out) + elif self.head_type == 'pass': + return c5 + else: + raise NotImplementedError + + return out + + +def resnet18(**kwargs): + return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet18_d(**kwargs): + return ResNet(BasicBlock, [2, 2, 2, 2], deep_stem=True, avg_down=True, **kwargs) + + +def resnet34(**kwargs): + return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet34_d(**kwargs): + return ResNet(BasicBlock, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) + + +def resnet50(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet50_w2x(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], width=2, **kwargs) + + +def resnet50_16s(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], layer4_dilation=2, **kwargs) + + +def resnet50_d(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) + + +def resnet101(**kwargs): + return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet101_d(**kwargs): + return ResNet(Bottleneck, [3, 4, 23, 3], deep_stem=True, avg_down=True, **kwargs) + + +def resnext101_32x8d(**kwargs): + return ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) + + +def resnet152(**kwargs): + return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnet152_d(**kwargs): + return ResNet(Bottleneck, [3, 8, 36, 3], deep_stem=True, avg_down=True, **kwargs) + + +def resnext152_32x8d(**kwargs): + return ResNet(Bottleneck, [3, 8, 36, 3], groups=32, width_per_group=8, **kwargs) + diff --git a/contrast/util.py b/contrast/util.py new file mode 100644 index 0000000..4b39198 --- /dev/null +++ b/contrast/util.py @@ -0,0 +1,146 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import argparse + +import torch +import torch.distributed as dist + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def dist_collect(x): + """ collect all tensor from all GPUs + args: + x: shape (mini_batch, ...) + returns: + shape (mini_batch * num_gpu, ...) + """ + x = x.contiguous() + out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype) + for _ in range(dist.get_world_size())] + dist.all_gather(out_list, x) + return torch.cat(out_list, dim=0) + + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= dist.get_world_size() + return rt + + +class MyHelpFormatter(argparse.MetavarTypeHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): + pass + +class DistributedShuffle: + + @staticmethod + def forward_shuffle(x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = dist_collect(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + dist.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = dist.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @staticmethod + def backward_shuffle(x, idx_unshuffle, return_local=True): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = dist_collect(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + if return_local: + # restored index for this gpu + gpu_idx = dist.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + return x_gather[idx_unshuffle], x_gather[idx_this] + else: + return x_gather[idx_unshuffle] + + @staticmethod + def get_local_id(ids): + return ids.chunk(dist.get_world_size())[dist.get_rank()] + + @staticmethod + def get_shuffle_ids(bsz, epoch): + """generate shuffle ids for ShuffleBN""" + torch.manual_seed(epoch) + # global forward shuffle id for all process + forward_inds = torch.randperm(bsz).long().cuda() + + # global backward shuffle id + backward_inds = torch.zeros(forward_inds.shape[0]).long().cuda() + value = torch.arange(bsz).long().cuda() + backward_inds.index_copy_(0, forward_inds, value) + + return forward_inds, backward_inds diff --git a/converter_detectron2/convert_detectron2_C4.py b/converter_detectron2/convert_detectron2_C4.py new file mode 100644 index 0000000..619ec4d --- /dev/null +++ b/converter_detectron2/convert_detectron2_C4.py @@ -0,0 +1,63 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import pickle as pkl +import torch +import argparse + + +def convert_detectron2_C4(input_file_name, output_file_name, ema=False): + ckpt = torch.load(input_file_name, map_location="cpu") + if ema: + state_dict = ckpt["model_ema"] + prefix = "encoder." + else: + state_dict = ckpt["model"] + prefix = "module.encoder." + + new_state_dict = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + old_k = k + k = k.replace(prefix, "") + if "layer" not in k: + k = "stem." + k + # k = "backbone." + k + k = k.replace("layer1", "res2") + k = k.replace("layer2", "res3") + k = k.replace("layer3", "res4") + k = k.replace("layer4", "res5") + k = k.replace("bn1", "conv1.norm") + k = k.replace("bn2", "conv2.norm") + k = k.replace("bn3", "conv3.norm") + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + print(old_k, "->", k) + new_state_dict[k] = v.numpy() + + res = {"model": new_state_dict, + "__author__": "PixPro", + "matching_heuristics": True} + + with open(output_file_name, "wb") as f: + pkl.dump(res, f) + print(f"Saved converted detectron2 C4 checkpoint to {output_file_name}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Convert Models') + parser.add_argument('input', metavar='I', + help='input model path') + parser.add_argument('output', metavar='O', + help='output path') + parser.add_argument('--ema', action='store_true', + help='using ema model') + args = parser.parse_args() + convert_detectron2_C4(args.input, args.output, args.ema) diff --git a/converter_detectron2/convert_detectron2_Head.py b/converter_detectron2/convert_detectron2_Head.py new file mode 100644 index 0000000..6501d6f --- /dev/null +++ b/converter_detectron2/convert_detectron2_Head.py @@ -0,0 +1,97 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import pickle as pkl +import torch +import argparse + + +def convert_detectron2_Head(input_file_name, output_file_name, start, num_outs, ema=False): + ckpt = torch.load(input_file_name, map_location="cpu") + if ema: + state_dict = ckpt["model_ema"] + backbone_prefix = "encoder." + fpn_prefix = "neck." + head_prefix = "head." + else: + state_dict = ckpt["model"] + backbone_prefix = "module.encoder." + fpn_prefix = "module.neck." + head_prefix = "module.head." + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith(backbone_prefix): + old_k = k + k = k.replace(backbone_prefix, "") + if "layer" not in k: + k = "stem." + k + k = k.replace("layer1", "res2") + k = k.replace("layer2", "res3") + k = k.replace("layer3", "res4") + k = k.replace("layer4", "res5") + k = k.replace("bn1", "conv1.norm") + k = k.replace("bn2", "conv2.norm") + k = k.replace("bn3", "conv3.norm") + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + print(old_k, "->", k) + new_state_dict[k] = v.numpy() + + elif k.startswith(fpn_prefix): + old_k = k + k = k.replace(fpn_prefix, "") + + for i in range(num_outs): + k = k.replace(f"lateral_convs.{i}.conv", f"fpn_lateral{start+i}") + k = k.replace(f"lateral_convs.{i}.bn", f"fpn_lateral{start+i}.norm") + + k = k.replace(f"fpn_convs.{i}.conv", f"fpn_output{start+i}") + k = k.replace(f"fpn_convs.{i}.bn", f"fpn_output{start+i}.norm") + + print(old_k, "->", k) + new_state_dict[k] = v.numpy() + + elif k.startswith(head_prefix): + old_k = k + k = k.replace(head_prefix, "box_head.") + k = k.replace("bn1", "conv1.norm") + k = k.replace("ac1", "conv1.activation") + k = k.replace("bn2", "conv2.norm") + k = k.replace("ac2", "conv2.activation") + k = k.replace("bn3", "conv3.norm") + k = k.replace("ac3", "conv3.activation") + k = k.replace("bn4", "conv4.norm") + k = k.replace("ac4", "conv4.activation") + print(old_k, "->", k) + new_state_dict[k] = v.numpy() + + res = {"model": new_state_dict, + "__author__": "Yue", + "matching_heuristics": True} + + with open(output_file_name, "wb") as f: + pkl.dump(res, f) + print(f"Saved converted detectron2 Head checkpoint to {output_file_name}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Convert Models') + parser.add_argument('input', metavar='I', + help='input model path') + parser.add_argument('output', metavar='O', + help='output path') + parser.add_argument('start', metavar='S', type=int, + help='FPN start') + parser.add_argument('num_outs', metavar='N', type=int, + help='FPN number of outputs') + parser.add_argument('--ema', action='store_true', + help='using ema model') + args = parser.parse_args() + convert_detectron2_Head(args.input, args.output, args.start, args.num_outs, args.ema) diff --git a/converter_mmdetection/convert_mmdetection_Backbone.py b/converter_mmdetection/convert_mmdetection_Backbone.py new file mode 100644 index 0000000..f331562 --- /dev/null +++ b/converter_mmdetection/convert_mmdetection_Backbone.py @@ -0,0 +1,40 @@ +# disclaimer: inspired by MoCo official repo. +import torch +import argparse + + +def convert_mmdetection_Backbone(input_file_name, output_file_name, ema=False): + ckpt = torch.load(input_file_name, map_location="cpu") + if ema: + state_dict = ckpt["model_ema"] + backbone_prefix = "encoder." + else: + state_dict = ckpt["model"] + backbone_prefix = "module.encoder." + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith(backbone_prefix): + old_k = k + k = k.replace(backbone_prefix, "backbone.") + print(old_k, "->", k) + new_state_dict[k] = v + + res = {"state_dict": new_state_dict, + "meta": {}} + + torch.save(res, output_file_name) + print(f"Saved converted mmdetection load checkpoint to {output_file_name}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Convert Models') + parser.add_argument('input', metavar='I', + help='input model path') + parser.add_argument('output', metavar='O', + help='output path') + parser.add_argument('--ema', action='store_true', + help='using ema model') + args = parser.parse_args() + convert_mmdetection_Backbone(args.input, args.output, args.ema) diff --git a/converter_mmdetection/convert_mmdetection_Head.py b/converter_mmdetection/convert_mmdetection_Head.py new file mode 100644 index 0000000..a83c595 --- /dev/null +++ b/converter_mmdetection/convert_mmdetection_Head.py @@ -0,0 +1,73 @@ +# disclaimer: inspired by MoCo official repo. +import torch +import argparse + + +def convert_mmdetection_Head(input_file_name, output_file_name, ema=False): + ckpt = torch.load(input_file_name, map_location="cpu") + if ema: + state_dict = ckpt["model_ema"] + backbone_prefix = "encoder." + fpn_prefix = "neck." + head_prefix = "head." + else: + state_dict = ckpt["model"] + backbone_prefix = "module.encoder." + fpn_prefix = "module.neck." + head_prefix = "module.head." + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith(backbone_prefix): + old_k = k + k = k.replace(backbone_prefix, "backbone.") + print(old_k, "->", k) + new_state_dict[k] = v + elif k.startswith(fpn_prefix): + old_k = k + k = k.replace(fpn_prefix, "neck.") + + print(old_k, "->", k) + new_state_dict[k] = v + elif k.startswith(head_prefix): + old_k = k + k = k.replace(head_prefix, "roi_head.bbox_head.") + k = k.replace("conv1", "shared_convs.0.conv") + k = k.replace("bn1", "shared_convs.0.bn") + k = k.replace("ac1", "shared_convs.0.activate") + + k = k.replace("conv2", "shared_convs.1.conv") + k = k.replace("bn2", "shared_convs.1.bn") + k = k.replace("ac2", "shared_convs.1.activate") + + k = k.replace("conv3", "shared_convs.2.conv") + k = k.replace("bn3", "shared_convs.2.bn") + k = k.replace("ac3", "shared_convs.2.activate") + + k = k.replace("conv4", "shared_convs.3.conv") + k = k.replace("bn4", "shared_convs.3.bn") + k = k.replace("ac4", "shared_convs.3.activate") + + k = k.replace("fc1", "shared_fcs.0") + + print(old_k, "->", k) + new_state_dict[k] = v + + res = {"state_dict": new_state_dict, + "meta": {}} + + torch.save(res, output_file_name) + print(f"Saved converted mmdetection load checkpoint to {output_file_name}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Convert Models') + parser.add_argument('input', metavar='I', + help='input model path') + parser.add_argument('output', metavar='O', + help='output path') + parser.add_argument('--ema', action='store_true', + help='using ema model') + args = parser.parse_args() + convert_mmdetection_Head(args.input, args.output, args.ema) diff --git a/detectron2_configs/R_50_C4_1x.yaml b/detectron2_configs/R_50_C4_1x.yaml new file mode 100644 index 0000000..aaf5930 --- /dev/null +++ b/detectron2_configs/R_50_C4_1x.yaml @@ -0,0 +1,13 @@ +_BASE_: "Base-RCNN-C4-BN.yaml" +MODEL: + MASK_ON: True + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + STEPS: (60000, 80000) + MAX_ITER: 90000 diff --git a/detectron2_configs/R_50_FPN_1x.yaml b/detectron2_configs/R_50_FPN_1x.yaml new file mode 100644 index 0000000..1c87ae8 --- /dev/null +++ b/detectron2_configs/R_50_FPN_1x.yaml @@ -0,0 +1,24 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: True + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + BACKBONE: + FREEZE_AT: 0 + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" +TEST: + PRECISE_BN: + ENABLED: True +SOLVER: + STEPS: (60000, 80000) + MAX_ITER: 90000 diff --git a/detectron2_configs/SoCo_C4_100ep.yaml b/detectron2_configs/SoCo_C4_100ep.yaml new file mode 100644 index 0000000..550f46a --- /dev/null +++ b/detectron2_configs/SoCo_C4_100ep.yaml @@ -0,0 +1,11 @@ +_BASE_: "R_50_C4_1x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "./SoCo_output/SoCo_C4_100ep/current_detectron2_C4.pkl" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" +SOLVER: + WEIGHT_DECAY: 0.000025 diff --git a/detectron2_configs/SoCo_C4_400ep.yaml b/detectron2_configs/SoCo_C4_400ep.yaml new file mode 100644 index 0000000..2c4bd1b --- /dev/null +++ b/detectron2_configs/SoCo_C4_400ep.yaml @@ -0,0 +1,11 @@ +_BASE_: "R_50_C4_1x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "./SoCo_output/SoCo_C4_400ep/current_detectron2_C4.pkl" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" +SOLVER: + WEIGHT_DECAY: 0.000025 diff --git a/detectron2_configs/SoCo_FPN_100ep.yaml b/detectron2_configs/SoCo_FPN_100ep.yaml new file mode 100644 index 0000000..bfcdb0e --- /dev/null +++ b/detectron2_configs/SoCo_FPN_100ep.yaml @@ -0,0 +1,11 @@ +_BASE_: "R_50_FPN_1x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "./SoCo_output/SoCo_FPN_100ep/current_detectron2_Head.pkl" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" +SOLVER: + WEIGHT_DECAY: 0.000025 diff --git a/detectron2_configs/SoCo_FPN_400ep.yaml b/detectron2_configs/SoCo_FPN_400ep.yaml new file mode 100644 index 0000000..7b7bf32 --- /dev/null +++ b/detectron2_configs/SoCo_FPN_400ep.yaml @@ -0,0 +1,11 @@ +_BASE_: "R_50_FPN_1x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "./SoCo_output/SoCo_FPN_400ep/current_detectron2_Head.pkl" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" +SOLVER: + WEIGHT_DECAY: 0.000025 diff --git a/detectron2_configs/SoCo_FPN_Star_400ep.yaml b/detectron2_configs/SoCo_FPN_Star_400ep.yaml new file mode 100644 index 0000000..e376271 --- /dev/null +++ b/detectron2_configs/SoCo_FPN_Star_400ep.yaml @@ -0,0 +1,11 @@ +_BASE_: "R_50_FPN_1x.yaml" +MODEL: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: "./SoCo_output/SoCo_FPN_Star_400ep/current_detectron2_Head.pkl" + RESNETS: + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" +SOLVER: + WEIGHT_DECAY: 0.000025 diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..de022d4 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,52 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 + +# ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" +ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" +ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" + +RUN apt-get update && apt-get install -y vim curl wget ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev \ + && apt-get install -y libblacs-mpi-dev \ + && apt-get install software-properties-common \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get install python3.7 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install python 3.7 to python and install pip +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.7 10 +RUN wget https://bootstrap.pypa.io/get-pip.py +RUN python get-pip.py + +# Torch +RUN pip install torch==1.6.0 torchvision==0.7.0 + +# accimage +RUN pip install --prefix=/opt/intel/ipp ipp-devel +RUN pip install git+https://github.com/pytorch/accimage + +# apex +RUN git clone https://github.com/NVIDIA/apex +RUN cd apex/ +# P100, P40, V100 and 2080Ti +ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5" +RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + +# detectron2 +RUN python -m pip install 'git+https://github.com/hologerry/detectron2.git' --upgrade --force-reinstall + +# Install MMCV and MMdetection +RUN pip uninstall pycocotools +RUN pip install mmcv-full==1.2.4 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html --upgrade --force-reinstall + +RUN git clone https://github.com/hologerry/mmdetection.git +RUN cd mmdetection +RUN pip install -r requirements/build.txt --upgrade --force-reinstall +RUN pip install -v -e . --upgrade --force-reinstall diff --git a/figures/overview.png b/figures/overview.png new file mode 100644 index 0000000..15e4d21 Binary files /dev/null and b/figures/overview.png differ diff --git a/main_linear.py b/main_linear.py new file mode 100644 index 0000000..1ce8018 --- /dev/null +++ b/main_linear.py @@ -0,0 +1,303 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import json +import os +import time + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from contrast import resnet +from contrast.data import get_loader +from contrast.logger import setup_logger +from contrast.lr_scheduler import get_scheduler +from contrast.option import parse_option +from contrast.util import AverageMeter, accuracy, reduce_tensor + +try: + from apex import amp # type: ignore +except ImportError: + amp = None + + +def build_model(args, num_class): + # create model + model = resnet.__dict__[args.arch](low_dim=num_class, head_type='reduce').cuda() + + # set requires_grad of parameters except last fc layer to False + for name, p in model.named_parameters(): + if 'fc' not in name: + p.requires_grad = False + + optimizer = torch.optim.SGD(model.fc.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay) + + if args.amp_opt_level != "O0": + model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) + + model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) + + return model, optimizer + + +def load_pretrained(model, pretrained_model): + ckpt = torch.load(pretrained_model, map_location='cpu') + model_dict = model.state_dict() + + base_fix = False + for key in ckpt['model'].keys(): + if key.startswith('module.base.'): + base_fix = True + break + + if base_fix: + state_dict = {k.replace("module.base.", "module."): v + for k, v in ckpt['model'].items() + if k.startswith('module.base.')} + logger.info(f"==> load checkpoint from Module.Base") + else: + state_dict = {k.replace("module.encoder.", "module."): v + for k, v in ckpt['model'].items() + if k.startswith('module.encoder.')} + logger.info(f"==> load checkpoint from Module.Encoder") + + state_dict = {k: v for k, v in state_dict.items() + if k in model_dict and v.size() == model_dict[k].size()} + + model_dict.update(state_dict) + model.load_state_dict(model_dict) + logger.info(f"==> loaded checkpoint '{pretrained_model}' (epoch {ckpt['epoch']})") + + +def load_checkpoint(args, model, optimizer, scheduler): + logger.info("=> loading checkpoint '{args.resume'") + + checkpoint = torch.load(args.resume, map_location='cpu') + + global best_acc1 + best_acc1 = checkpoint['best_acc1'] + args.start_epoch = checkpoint['epoch'] + 1 + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + if args.amp_opt_level != "O0" and checkpoint['args'].amp_opt_level != "O0": + amp.load_state_dict(checkpoint['amp']) + + logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") + + +def save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler): + state = { + 'args': args, + 'epoch': epoch, + 'model': model.state_dict(), + 'best_acc1': test_acc, + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + } + if args.amp_opt_level != "O0": + state['amp'] = amp.state_dict() + torch.save(state, os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')) + torch.save(state, os.path.join(args.output_dir, f'current.pth')) + + +def main(args): + global best_acc1 + + args.batch_size = args.total_batch_size // dist.get_world_size() + train_loader = get_loader(args.aug, args, prefix='train') + val_loader = get_loader('val', args, prefix='val') + logger.info(f"length of training dataset: {len(train_loader.dataset)}") + + model, optimizer = build_model(args, num_class=len(train_loader.dataset.classes)) + scheduler = get_scheduler(optimizer, len(train_loader), args) + + # load pre-trained model + load_pretrained(model, args.pretrained_model) + + # optionally resume from a checkpoint + if args.auto_resume: + resume_file = os.path.join(args.output_dir, "current.pth") + if os.path.exists(resume_file): + logger.info(f'auto resume from {resume_file}') + args.resume = resume_file + else: + logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume') + if args.resume: + assert os.path.isfile(args.resume), f"no checkpoint found at '{args.resume}'" + load_checkpoint(args, model, optimizer, scheduler) + + if args.eval: + logger.info("==> testing...") + validate(val_loader, model, args) + return + + # tensorboard + if dist.get_rank() == 0: + summary_writer = SummaryWriter(log_dir=args.output_dir) + else: + summary_writer = None + + # routine + for epoch in range(args.start_epoch, args.epochs + 1): + if isinstance(train_loader.sampler, DistributedSampler): + train_loader.sampler.set_epoch(epoch) + + tic = time.time() + train(epoch, train_loader, model, optimizer, scheduler, args) + logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}') + + logger.info("==> testing...") + test_acc, test_acc5, test_loss = validate(val_loader, model, args) + if summary_writer is not None: + summary_writer.add_scalar('test_acc', test_acc, epoch) + summary_writer.add_scalar('test_acc5', test_acc5, epoch) + summary_writer.add_scalar('test_loss', test_loss, epoch) + + # save model + if dist.get_rank() == 0 and epoch % args.save_freq == 0: + logger.info('==> Saving...') + save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler) + + +def train(epoch, train_loader, model, optimizer, scheduler, args): + """ + one epoch training + """ + + model.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for idx, (x, _, y) in enumerate(train_loader): + x = x.cuda(non_blocking=True) + y = y.cuda(non_blocking=True) + + # measure data loading time + data_time.update(time.time() - end) + + # forward + output = model(x) + loss = F.cross_entropy(output, y) + + # backward + optimizer.zero_grad() + if args.amp_opt_level != "O0": + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + optimizer.step() + scheduler.step() + + # update meters + acc1, acc5 = accuracy(output, y, topk=(1, 5)) + loss_meter.update(loss.item(), x.size(0)) + acc1_meter.update(acc1[0], x.size(0)) + acc5_meter.update(acc5[0], x.size(0)) + batch_time.update(time.time() - end) + end = time.time() + + # print info + if idx % args.print_freq == 0: + logger.info( + f'Epoch: [{epoch}][{idx}/{len(train_loader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + f'Lr {optimizer.param_groups[0]["lr"]:.3f} \t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' + f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})') + + return acc1_meter.avg, acc5_meter.avg, loss_meter.avg + + +def validate(val_loader, model, args): + batch_time = AverageMeter() + loss_meter = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for idx, (x, _, y) in enumerate(val_loader): + x = x.cuda(non_blocking=True) + y = y.cuda(non_blocking=True) + + # compute output + output = model(x) + loss = F.cross_entropy(output, y) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, y, topk=(1, 5)) + + acc1 = reduce_tensor(acc1) + acc5 = reduce_tensor(acc5) + loss = reduce_tensor(loss) + + loss_meter.update(loss.item(), x.size(0)) + acc1_meter.update(acc1[0], x.size(0)) + acc5_meter.update(acc5[0], x.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if idx % args.print_freq == 0: + logger.info( + f'Test: [{idx}/{len(val_loader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' + f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})') + + logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + + return acc1_meter.avg, acc5_meter.avg, loss_meter.avg + + +if __name__ == '__main__': + opt = parse_option(stage='linear') + + if opt.amp_opt_level != "O0": + assert amp is not None, "amp not installed!" + + torch.cuda.set_device(opt.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + cudnn.benchmark = True + best_acc1 = 0 + + os.makedirs(opt.output_dir, exist_ok=True) + logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="contrast") + if dist.get_rank() == 0: + path = os.path.join(opt.output_dir, "config.json") + with open(path, "w") as f: + json.dump(vars(opt), f, indent=2) + logger.info("Full config saved to {}".format(path)) + + # print args + # TODO: check format + logger.info(vars(opt)) + + main(opt) diff --git a/main_pretrain.py b/main_pretrain.py new file mode 100644 index 0000000..3743856 --- /dev/null +++ b/main_pretrain.py @@ -0,0 +1,279 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import json +import math +import os +import time +from shutil import copyfile + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.backends import cudnn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from contrast import models, resnet +from contrast.data import get_loader +from contrast.lars import LARS, add_weight_decay +from contrast.logger import setup_logger +from contrast.lr_scheduler import get_scheduler +from contrast.option import parse_option +from contrast.util import AverageMeter + +from converter_detectron2.convert_detectron2_C4 import convert_detectron2_C4 +from converter_detectron2.convert_detectron2_Head import convert_detectron2_Head +from converter_mmdetection.convert_mmdetection_Head import convert_mmdetection_Head + +try: + from apex import amp # type: ignore +except ImportError: + amp = None + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + import random + + import numpy as np + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def build_model(args): + encoder = resnet.__dict__[args.arch] + model = models.__dict__[args.model](encoder, args).cuda() + + if args.optimizer == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), + lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay) + elif args.optimizer == 'lars': + params = add_weight_decay(model, args.weight_decay) + optimizer = torch.optim.SGD(params, + lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate, + momentum=args.momentum) + optimizer = LARS(optimizer) + else: + raise NotImplementedError + + if args.amp_opt_level != "O0": + model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) + + model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False, find_unused_parameters=True) + return model, optimizer + + +def load_pretrained(model, pretrained_model): + ckpt = torch.load(pretrained_model, map_location='cpu') + state_dict = ckpt['model'] + model_dict = model.state_dict() + + model_dict.update(state_dict) + model.load_state_dict(model_dict) + logger.info(f"==> loaded checkpoint '{pretrained_model}' (epoch {ckpt['epoch']})") + + +def load_checkpoint(args, model, optimizer, scheduler, sampler=None): + logger.info(f"=> loading checkpoint '{args.resume}'") + + checkpoint = torch.load(args.resume, map_location='cpu') + args.start_epoch = checkpoint['epoch'] + 1 + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0": + amp.load_state_dict(checkpoint['amp']) + if args.use_sliding_window_sampler: + sampler.load_state_dict(checkpoint['sampler']) + + logger.info(f"=> loaded successfully '{args.resume}' (epoch {checkpoint['epoch']})") + + del checkpoint + torch.cuda.empty_cache() + + +def save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=None): + logger.info('==> Saving...') + state = { + 'opt': args, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + 'epoch': epoch, + } + if args.amp_opt_level != "O0": + state['amp'] = amp.state_dict() + if args.use_sliding_window_sampler: + state['sampler'] = sampler.state_dict() + file_name = os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth') + torch.save(state, file_name) + copyfile(file_name, os.path.join(args.output_dir, 'current.pth')) + + +def convert_checkpoint(args): + file_name = os.path.join(args.output_dir, 'current.pth') + output_file_name_C4 = os.path.join(args.output_dir, 'current_detectron2_C4.pkl') + output_file_name_Head = os.path.join(args.output_dir, 'current_detectron2_Head.pkl') + output_file_name_mmdet_Head = os.path.join(args.output_dir, 'current_mmdetection_Head.pth') + + convert_detectron2_C4(file_name, output_file_name_C4) + convert_detectron2_Head(file_name, output_file_name_Head, start=2, num_outs=4) + convert_mmdetection_Head(file_name, output_file_name_mmdet_Head) + + +def main(args): + train_prefix = 'train2017' if args.dataset == 'COCO' else 'train' + train_loader = get_loader(args.aug, args, prefix=train_prefix, return_coord=True) + args.num_instances = len(train_loader.dataset) + logger.info(f"length of training dataset: {args.num_instances}") + + model, optimizer = build_model(args) + if dist.get_rank() == 0: + print(model) + + scheduler = get_scheduler(optimizer, len(train_loader), args) + + # optionally resume from a checkpoint + if args.pretrained_model: + assert os.path.isfile(args.pretrained_model) + load_pretrained(model, args.pretrained_model) + if args.auto_resume: + resume_file = os.path.join(args.output_dir, "current.pth") + if os.path.exists(resume_file): + logger.info(f'auto resume from {resume_file}') + args.resume = resume_file + else: + logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume') + if args.resume: + assert os.path.isfile(args.resume) + load_checkpoint(args, model, optimizer, scheduler, sampler=train_loader.sampler) + + # tensorboard + if dist.get_rank() == 0: + summary_writer = SummaryWriter(log_dir=args.output_dir) + else: + summary_writer = None + + if args.use_sliding_window_sampler: + args.epochs = math.ceil(args.epochs * len(train_loader.dataset) / args.window_size) + for epoch in range(args.start_epoch, args.epochs + 1): + if isinstance(train_loader.sampler, DistributedSampler): + train_loader.sampler.set_epoch(epoch) + + train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer) + + if dist.get_rank() == 0 and (epoch % args.save_freq == 0 or epoch == args.epochs): + save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=train_loader.sampler) + if dist.get_rank() == 0 and epoch == args.epochs: + convert_checkpoint(args) + + +def train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer): + """ + one epoch training + """ + model.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + + end = time.time() + for idx, data in enumerate(train_loader): + data = [item.cuda(non_blocking=True) for item in data] + data_time.update(time.time() - end) + + if args.model in ['SoCo_C4']: + loss = model(data[0], data[1], data[2], data[3], data[4]) + elif args.model in ['SoCo_FPN',]: + loss = model(data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8]) + elif args.model in ['SoCo_FPN_Star']: + loss = model(data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9], data[10], data[11], data[12]) + else: + logit, label = model(data[0], data[1]) + loss = F.cross_entropy(logit, label) + + # backward + optimizer.zero_grad() + if args.amp_opt_level != "O0": + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + optimizer.step() + scheduler.step() + + # update meters and print info + loss_meter.update(loss.item(), data[0].size(0)) + batch_time.update(time.time() - end) + end = time.time() + + train_len = len(train_loader) + if args.use_sliding_window_sampler: + train_len = int(args.window_size / args.batch_size / dist.get_world_size()) + if idx % args.print_freq == 0: + lr = optimizer.param_groups[0]['lr'] + logger.info( + f'Train: [{epoch}/{args.epochs}][{idx}/{train_len}] ' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + f'Data Time {data_time.val:.3f} ({data_time.avg:.3f}) ' + f'lr {lr:.3f} ' + f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})') + # tensorboard logger + if summary_writer is not None: + step = (epoch - 1) * len(train_loader) + idx + summary_writer.add_scalar('lr', lr, step) + summary_writer.add_scalar('loss', loss_meter.val, step) + + +if __name__ == '__main__': + opt = parse_option(stage='pre_train') + + if opt.amp_opt_level != "O0": + assert amp is not None, "amp not installed!" + + torch.cuda.set_device(opt.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + cudnn.benchmark = True + + # setup logger + os.makedirs(opt.output_dir, exist_ok=True) + logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="SoCo") + if dist.get_rank() == 0: + path = os.path.join(opt.output_dir, "config.json") + with open(path, 'w') as f: + json.dump(vars(opt), f, indent=2) + logger.info("Full config saved to {}".format(path)) + + # print args + logger.info( + "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(opt)).items())) + ) + + if opt.debug: + logger.info('enable debug mode, set seed to 0') + set_random_seed(0) + + main(opt) diff --git a/selective_search/filter_ss_proposals_json.py b/selective_search/filter_ss_proposals_json.py new file mode 100644 index 0000000..e0c241d --- /dev/null +++ b/selective_search/filter_ss_proposals_json.py @@ -0,0 +1,73 @@ +import os +import pickle +import multiprocessing as mp +import numpy as np +import json + +from filters import filter_none, filter_ratio, filter_size, filter_ratio_size + + +imagenet_root = './imagenet_root' +imagenet_root_proposals = './imagenet_root_proposals_mp' +filter_strategy = 'ratio3size0308' # 'ratio2', 'ratio3', 'ratio4', 'size01' +print("filter_strategy", filter_strategy) +filtered_proposals = './imagenet_filtered_proposals' + +split = 'train' + + +filtered_proposals_dict = {} + + +os.makedirs(filtered_proposals, exist_ok=True) +json_path = os.path.join(filtered_proposals, f'{split}_{filter_strategy}.json') +source_path = os.path.join(imagenet_root_proposals, split) + +class_names = sorted(os.listdir(os.path.join(imagenet_root_proposals, split))) + +no_props_images = [] + + +for ci, class_name in enumerate(class_names): + filenames = sorted(os.listdir(os.path.join(source_path, class_name))) + for fi, filename in enumerate(filenames): + base_filename = os.path.splitext(filename)[0] + cur_img_pro_path = os.path.join(source_path, class_name, filename) + + with open(cur_img_pro_path, 'rb') as f: + cur_img_proposal = pickle.load(f) + if filter_strategy == 'none': + filtered_img_rects = filter_none(cur_img_proposal['regions']) + elif 'ratio' in filter_strategy and 'size' in filter_strategy: + ratio = float(filter_strategy[5]) + min_size_ratio = float(filter_strategy[-4:-2]) / 10 + max_size_ratio = float(filter_strategy[-2:]) / 10 + filtered_img_rects = filter_ratio_size(cur_img_proposal['regions'], cur_img_proposal['label'].shape, ratio, min_size_ratio, max_size_ratio) + elif 'ratio' in filter_strategy: + ratio = float(filter_strategy[-1]) + filtered_img_rects = filter_ratio(cur_img_proposal['regions'], r=ratio) + elif 'size' in filter_strategy: + min_size_ratio = float(filter_strategy[-2:]) / 10 + filtered_img_rects = filter_size(cur_img_proposal['regions'], cur_img_proposal['label'].shape, min_size_ratio) + else: + raise NotImplementedError + + filtered_proposals_dict[base_filename] = filtered_img_rects + if len(filtered_img_rects) == 0: + no_props_images.append(base_filename) + print(f"with strategy {filter_strategy}, image {base_filename} has no proposals") + if (fi + 1) % 100 == 0: + print(f"Processed [{ci}/{len(class_names)}] classes, [{fi+1}/{len(filenames)}] images") + + +print(f"Finished filtering with strategy {filter_strategy}, there are {len(no_props_images)} images have no proposals.") + + +with open(json_path, 'w') as f: + json.dump(filtered_proposals_dict, f) + + +with open(json_path.replace('.json', 'no_props_images.txt'), 'w') as f: + for image_id in no_props_images: + f.write(image_id+'\n') + diff --git a/selective_search/filter_ss_proposals_json_post_no_prop.py b/selective_search/filter_ss_proposals_json_post_no_prop.py new file mode 100644 index 0000000..1ed1ef3 --- /dev/null +++ b/selective_search/filter_ss_proposals_json_post_no_prop.py @@ -0,0 +1,36 @@ +import json +import os +import pickle + +from filters import filter_none, filter_ratio + + +json_path = './imagenet_filtered_proposals/train_ratio3size0308.json' +json_path_post = './imagenet_filtered_proposals/train_ratio3size0308post.json' +no_props_images = open('./imagenet_filtered_proposals/train_ratio3size0308no_props_images.txt').readlines() +imagenet_root_proposals = './imagenet_root_proposals_mp' + + +with open(json_path, 'r') as f: + json_dict = json.load(f) + + +for no_props_image in no_props_images: + filename = no_props_image.strip() + class_name = filename.split('_')[0] + cur_img_pro_path = os.path.join(imagenet_root_proposals, 'train', class_name, filename+'.pkl') + + with open(cur_img_pro_path, 'rb') as f: + cur_img_proposal = pickle.load(f) + props_size_ratio = filter_ratio(cur_img_proposal['regions'], r=3) + props_none = filter_none(cur_img_proposal['regions']) + print("props_size_ratio", len(props_size_ratio)) + print("props_none", len(props_none)) + if len(props_size_ratio) > 0: + json_dict[filename] = props_size_ratio + elif len(props_none) > 0: + json_dict[filename] = filter_none(cur_img_proposal['regions']) + + +with open(json_path_post, 'w') as f: + json.dump(json_dict, f) diff --git a/selective_search/filters.py b/selective_search/filters.py new file mode 100644 index 0000000..59fe26c --- /dev/null +++ b/selective_search/filters.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +from math import sqrt + +''' +regions : array of dict + [ + { + 'rect': (left, top, width, height), + 'labels': [...], + 'size': component_size + }, + ... + ] +''' + +def filter_none(regions): + rects = [] + for region in regions: + cur_rect = region['rect'] + w, h = cur_rect[2], cur_rect[3] + if w >= 32 and h >= 32: + rects.append(cur_rect) + return rects + + +def filter_ratio(regions, r=2): + rects = [] + for region in regions: + cur_rect = region['rect'] + w, h = cur_rect[2], cur_rect[3] + if w >= 32 and h >= 32 and (1.0 / r) <= (w / h) <= r: + rects.append(cur_rect) + return rects + + +def filter_size(regions, image_size, min_ratio=0.1, max_ratio=1.0): + rects = [] + ih, iw = image_size # get from ss_props label, which is h, w + img_sqrt_size = sqrt(iw * ih) + for region in regions: + cur_rect = region['rect'] + w, h = cur_rect[2], cur_rect[3] + prop_sqrt_size = sqrt(w * h) + if w >= 32 and h >= 32 and min_ratio <= (prop_sqrt_size / img_sqrt_size) <= max_ratio: + rects.append(cur_rect) + return rects + + +def filter_ratio_size(regions, image_size, r=2, min_ratio=0.1, max_ratio=1.0): + rects = [] + ih, iw = image_size # get from ss_props label, which is h, w + img_sqrt_size = sqrt(iw * ih) + for region in regions: + cur_rect = region['rect'] + w, h = cur_rect[2], cur_rect[3] + prop_sqrt_size = sqrt(w * h) + if w >= 32 and h >= 32 and (1.0 / r) <= (w / h) <= r and min_ratio <= (prop_sqrt_size / img_sqrt_size) <= max_ratio: + rects.append(cur_rect) + return rects diff --git a/selective_search/generate_imagenet_ss_proposals.py b/selective_search/generate_imagenet_ss_proposals.py new file mode 100644 index 0000000..ef7814e --- /dev/null +++ b/selective_search/generate_imagenet_ss_proposals.py @@ -0,0 +1,71 @@ +# -------------------------------------------------------- +# SoCo +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Yue Gao +# -------------------------------------------------------- + + +import multiprocessing as mp +import os +import pickle + +import numpy as np +import PIL.Image as Image + +from .selective_search import selective_search + +imagenet_root = './imagenet_root' +imagenet_root_proposals = './imagenet_root_proposals_mp' + +split = 'train' +scale = 300 +min_size = 100 + +processes_num = 48 +class_names = sorted(os.listdir(os.path.join(imagenet_root, split))) +classes_num = len(class_names) +classes_per_process = classes_num // processes_num + 1 + +source_path = os.path.join(imagenet_root, split) +target_path = os.path.join(imagenet_root_proposals, split) + + +def process_one_class(process_id, classes_per_process, class_names, source_path, target_path): + print(f"Process id: {process_id} started") + for i in range(process_id*classes_per_process, process_id*classes_per_process + classes_per_process): + if i >= len(class_names): + break + class_name = class_names[i] + filenames = sorted(os.listdir(os.path.join(source_path, class_name))) + os.makedirs(os.path.join(target_path, class_name)) + for filename in filenames: + base_filename = os.path.splitext(filename)[0] + img_path = os.path.join(source_path, class_name, filename) + img = np.array(Image.open(img_path).convert('RGB')) + + img_with_lbl, regions, bboxs = selective_search(img, scale=scale, sigma=0.9, min_size=min_size) + + region_label = img_with_lbl[:, :, 3] + cur_img_proposal = {} + cur_img_proposal['label'] = region_label + cur_img_proposal['regions'] = regions + + cur_img_pro_path = os.path.join(target_path, class_name, base_filename+'.pkl') + + with open(cur_img_pro_path, 'wb') as f: + pickle.dump(cur_img_proposal, f) + print("Process ", process_id, "processed class:", class_name) + + +processes = [mp.Process(target=process_one_class, + args=(process_id, classes_per_process, class_names, source_path, target_path)) + for process_id in range(processes_num)] + +# Run processes +for p in processes: + p.start() + +# Exit the completed processes +for p in processes: + p.join() diff --git a/selective_search/selective_search.py b/selective_search/selective_search.py new file mode 100644 index 0000000..4c8f62e --- /dev/null +++ b/selective_search/selective_search.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +from __future__ import division + +import numpy +import skimage.color +import skimage.feature +import skimage.io +import skimage.segmentation +import skimage.transform +import skimage.util + +# "Selective Search for Object Recognition" by J.R.R. Uijlings et al. +# +# - Modified version with LBP extractor for texture vectorization + + +def _generate_segments(im_orig, scale, sigma, min_size): + """ + segment smallest regions by the algorithm of Felzenswalb and + Huttenlocher + """ + + # open the Image + im_mask = skimage.segmentation.felzenszwalb( + skimage.util.img_as_float(im_orig), scale=scale, sigma=sigma, + min_size=min_size) + + # merge mask channel to the image as a 4th channel + im_orig = numpy.append( + im_orig, numpy.zeros(im_orig.shape[:2])[:, :, numpy.newaxis], axis=2) + im_orig[:, :, 3] = im_mask + + return im_orig + + +def _sim_colour(r1, r2): + """ + calculate the sum of histogram intersection of colour + """ + return sum([min(a, b) for a, b in zip(r1["hist_c"], r2["hist_c"])]) + + +def _sim_texture(r1, r2): + """ + calculate the sum of histogram intersection of texture + """ + return sum([min(a, b) for a, b in zip(r1["hist_t"], r2["hist_t"])]) + + +def _sim_size(r1, r2, imsize): + """ + calculate the size similarity over the image + """ + return 1.0 - (r1["size"] + r2["size"]) / imsize + + +def _sim_fill(r1, r2, imsize): + """ + calculate the fill similarity over the image + """ + bbsize = ( + (max(r1["max_x"], r2["max_x"]) - min(r1["min_x"], r2["min_x"])) + * (max(r1["max_y"], r2["max_y"]) - min(r1["min_y"], r2["min_y"])) + ) + return 1.0 - (bbsize - r1["size"] - r2["size"]) / imsize + + +def _calc_sim(r1, r2, imsize): + return (_sim_colour(r1, r2) + _sim_texture(r1, r2) + + _sim_size(r1, r2, imsize) + _sim_fill(r1, r2, imsize)) + + +def _calc_colour_hist(img): + """ + calculate colour histogram for each region + + the size of output histogram will be BINS * COLOUR_CHANNELS(3) + + number of bins is 25 as same as [uijlings_ijcv2013_draft.pdf] + + extract HSV + """ + + BINS = 25 + hist = numpy.array([]) + + for colour_channel in (0, 1, 2): + + # extracting one colour channel + c = img[:, colour_channel] + + # calculate histogram for each colour and join to the result + hist = numpy.concatenate( + [hist] + [numpy.histogram(c, BINS, (0.0, 255.0))[0]]) + + # L1 normalize + hist = hist / len(img) + + return hist + + +def _calc_texture_gradient(img): + """ + calculate texture gradient for entire image + + The original SelectiveSearch algorithm proposed Gaussian derivative + for 8 orientations, but we use LBP instead. + + output will be [height(*)][width(*)] + """ + ret = numpy.zeros((img.shape[0], img.shape[1], img.shape[2])) + + for colour_channel in (0, 1, 2): + ret[:, :, colour_channel] = skimage.feature.local_binary_pattern( + img[:, :, colour_channel], 8, 1.0) + + return ret + + +def _calc_texture_hist(img): + """ + calculate texture histogram for each region + + calculate the histogram of gradient for each colours + the size of output histogram will be + BINS * ORIENTATIONS * COLOUR_CHANNELS(3) + """ + BINS = 10 + + hist = numpy.array([]) + + for colour_channel in (0, 1, 2): + + # mask by the colour channel + fd = img[:, colour_channel] + + # calculate histogram for each orientation and concatenate them all + # and join to the result + hist = numpy.concatenate( + [hist] + [numpy.histogram(fd, BINS, (0.0, 255.0))[0]]) # there is a bug https://github.com/AlpacaDB/selectivesearch/issues/30 + + # L1 Normalize + hist = hist / len(img) + + return hist + + +def _extract_regions(img): + + R = {} + + # get hsv image + hsv = skimage.color.rgb2hsv(img[:, :, :3]) + + # pass 1: count pixel positions + for y, i in enumerate(img): + + for x, (r, g, b, l) in enumerate(i): + + # initialize a new region + if l not in R: + R[l] = { + "min_x": 0xffff, "min_y": 0xffff, + "max_x": 0, "max_y": 0, "labels": [l]} + + # bounding box + if R[l]["min_x"] > x: + R[l]["min_x"] = x + if R[l]["min_y"] > y: + R[l]["min_y"] = y + if R[l]["max_x"] < x: + R[l]["max_x"] = x + if R[l]["max_y"] < y: + R[l]["max_y"] = y + + # pass 2: calculate texture gradient + tex_grad = _calc_texture_gradient(img) + + # pass 3: calculate colour, texture histogram of each region + for k, v in list(R.items()): + # k is the region label + # colour histogram + masked_pixels = hsv[:, :, :][img[:, :, 3] == k] + R[k]["size"] = len(masked_pixels / 4) # number of pixels in the region, why /4 ?, in the repo, the question is not answered + R[k]["hist_c"] = _calc_colour_hist(masked_pixels) # calculated on HSV space + + # texture histogram + R[k]["hist_t"] = _calc_texture_hist(tex_grad[:, :][img[:, :, 3] == k]) # calculated on RGB space + + return R + + +def _extract_neighbors(regions): + + def intersect(a, b): + if (a["min_x"] < b["min_x"] < a["max_x"] + and a["min_y"] < b["min_y"] < a["max_y"]) or ( + a["min_x"] < b["max_x"] < a["max_x"] + and a["min_y"] < b["max_y"] < a["max_y"]) or ( + a["min_x"] < b["min_x"] < a["max_x"] + and a["min_y"] < b["max_y"] < a["max_y"]) or ( + a["min_x"] < b["max_x"] < a["max_x"] + and a["min_y"] < b["min_y"] < a["max_y"]): + return True + return False + + R = list(regions.items()) + neighbors = [] + for cur, a in enumerate(R[:-1]): + for b in R[cur + 1:]: + if intersect(a[1], b[1]): + neighbors.append((a, b)) + + return neighbors + + +def _merge_regions(r1, r2): + new_size = r1["size"] + r2["size"] + rt = { + "min_x": min(r1["min_x"], r2["min_x"]), + "min_y": min(r1["min_y"], r2["min_y"]), + "max_x": max(r1["max_x"], r2["max_x"]), + "max_y": max(r1["max_y"], r2["max_y"]), + "size": new_size, + "hist_c": ( + r1["hist_c"] * r1["size"] + r2["hist_c"] * r2["size"]) / new_size, + "hist_t": ( + r1["hist_t"] * r1["size"] + r2["hist_t"] * r2["size"]) / new_size, + "labels": r1["labels"] + r2["labels"] + } + return rt + + +def selective_search(im_orig, scale=1.0, sigma=0.8, min_size=50): + '''Selective Search + + Parameters + ---------- + im_orig : ndarray + Input image + scale : int + Free parameter. Higher means larger clusters in felzenszwalb segmentation. + sigma : float + Width of Gaussian kernel for felzenszwalb segmentation. + min_size : int + Minimum component size for felzenszwalb segmentation. + Returns + ------- + img : ndarray + image with region label + region label is stored in the 4th value of each pixel [r,g,b,(region)] + regions : array of dict + [ + { + 'rect': (left, top, width, height), + 'labels': [...], + 'size': component_size + }, + ... + ] + ''' + assert im_orig.shape[2] == 3, "3ch image is expected" + + # load image and get smallest regions + # region label is stored in the 4th value of each pixel [r,g,b,(region)] + img = _generate_segments(im_orig, scale, sigma, min_size) # img first 3 channel values are in [0, 255] + + if img is None: + return None, {} + + imsize = img.shape[0] * img.shape[1] + R = _extract_regions(img) + + # extract neighboring information + neighbors = _extract_neighbors(R) + + # calculate initial similarities + S = {} + for (ai, ar), (bi, br) in neighbors: + S[(ai, bi)] = _calc_sim(ar, br, imsize) + + # hierarchal search + while S != {}: + + # get highest similarity + i, j = sorted(S.items(), key=lambda i: i[1])[-1][0] + + # merge corresponding regions + t = max(R.keys()) + 1.0 + R[t] = _merge_regions(R[i], R[j]) + + # mark similarities for regions to be removed + key_to_delete = [] + for k, v in list(S.items()): + if (i in k) or (j in k): + key_to_delete.append(k) + + # remove old similarities of related regions + for k in key_to_delete: + del S[k] + + # calculate similarity set with the new region + for k in [a for a in key_to_delete if a != (i, j)]: + n = k[1] if k[0] in (i, j) else k[0] + S[(t, n)] = _calc_sim(R[t], R[n], imsize) + + regions = [] + boxes = [] + for k, r in list(R.items()): + regions.append({ + 'rect': ( + r['min_x'], r['min_y'], + r['max_x'] - r['min_x'], r['max_y'] - r['min_y']), + 'size': r['size'], + 'labels': r['labels'] + }) + boxes.append([int(r['min_x']), int(r['min_y']), int(r['max_x'] - r['min_x']), int(r['max_y'] - r['min_y'])]) + + return img, regions, boxes diff --git a/tools/SoCo_C4_100ep.sh b/tools/SoCo_C4_100ep.sh new file mode 100755 index 0000000..8f72c87 --- /dev/null +++ b/tools/SoCo_C4_100ep.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-100} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_C4_100ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \ + SoCo/main_pretrain.py \ + --data_dir ${data_dir} \ + --crop 0.6 \ + --base_lr 1.0 \ + --optimizer lars \ + --weight_decay 1e-5 \ + --amp_opt_level O1 \ + --ss_props \ + --auto_resume \ + --aug ImageAsymBboxCutout \ + --zip --cache_mode no \ + --arch resnet50 \ + --model SoCo_C4 \ + --warmup_epoch ${warmup} \ + --epochs ${epochs} \ + --output_dir ${output_dir} \ + --save_freq 10 \ + --batch_size 128 \ + --contrast_momentum 0.99 \ + --filter_strategy ratio3size0308post \ + --select_strategy random \ + --select_k 4 \ + --output_size 7 \ + --aligned \ + --jitter_ratio 0.1 \ + --padding_k 4 \ + --cutout_prob 0.5 \ + --cutout_ratio_min 0.1 \ + --cutout_ratio_max 0.3 diff --git a/tools/SoCo_C4_100ep_linear.sh b/tools/SoCo_C4_100ep_linear.sh new file mode 100755 index 0000000..ede931c --- /dev/null +++ b/tools/SoCo_C4_100ep_linear.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-100} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_C4_100ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \ + SoCo/main_linear.py \ + --data_dir ${data_dir} \ + --zip --cache_mode part \ + --arch resnet50 \ + --output_dir ${output_dir}_linear_eval \ + --pretrained_model ${output_dir}/current.pth \ + --save_freq 10 \ + --auto_resume diff --git a/tools/SoCo_C4_400ep.sh b/tools/SoCo_C4_400ep.sh new file mode 100755 index 0000000..17859a9 --- /dev/null +++ b/tools/SoCo_C4_400ep.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_C4_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + +python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \ + SoCo/main_pretrain.py \ + --data_dir ${data_dir} \ + --crop 0.6 \ + --base_lr 1.0 \ + --optimizer lars \ + --weight_decay 1e-5 \ + --amp_opt_level O1 \ + --ss_props \ + --auto_resume \ + --aug ImageAsymBboxCutout \ + --zip --cache_mode no \ + --arch resnet50 \ + --model SoCo_C4 \ + --warmup_epoch ${warmup} \ + --epochs ${epochs} \ + --output_dir ${output_dir} \ + --save_freq 10 \ + --batch_size 128 \ + --contrast_momentum 0.99 \ + --filter_strategy ratio3size0308post \ + --select_strategy random \ + --select_k 4 \ + --output_size 7 \ + --aligned \ + --jitter_ratio 0.1 \ + --padding_k 4 \ + --cutout_prob 0.5 \ + --cutout_ratio_min 0.1 \ + --cutout_ratio_max 0.3 + diff --git a/tools/SoCo_C4_400ep_linear.sh b/tools/SoCo_C4_400ep_linear.sh new file mode 100755 index 0000000..9457962 --- /dev/null +++ b/tools/SoCo_C4_400ep_linear.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./self_det_output/SoCo_C4_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \ + SoCo/main_linear.py \ + --data_dir ${data_dir} \ + --zip --cache_mode part \ + --arch resnet50 \ + --output_dir ${output_dir}_linear_eval \ + --pretrained_model ${output_dir}/current.pth \ + --save_freq 10 \ + --auto_resume diff --git a/tools/SoCo_FPN_100ep.sh b/tools/SoCo_FPN_100ep.sh new file mode 100755 index 0000000..8c2f22e --- /dev/null +++ b/tools/SoCo_FPN_100ep.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-100} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_100ep" + +master_addr=${MASTER_IP} +master_port=28652 + +python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \ + SoCo/main_pretrain.py \ + --data_dir ${data_dir} \ + --crop 0.5 \ + --base_lr 1.0 \ + --optimizer lars \ + --weight_decay 1e-5 \ + --amp_opt_level O1 \ + --ss_props \ + --auto_resume \ + --aug ImageAsymBboxAwareMultiJitter1Cutout \ + --zip --cache_mode no \ + --arch resnet50 \ + --model SoCo_FPN \ + --warmup_epoch ${warmup} \ + --epochs ${epochs} \ + --output_dir ${output_dir} \ + --save_freq 1 \ + --batch_size 128 \ + --contrast_momentum 0.99 \ + --filter_strategy ratio3size0308post \ + --select_strategy random \ + --select_k 4 \ + --output_size 7 \ + --aligned \ + --jitter_prob 0.5 \ + --jitter_ratio 0.1 \ + --padding_k 4 \ + --start_level 0 \ + --num_outs 4 \ + --add_extra_convs 0 \ + --extra_convs_on_inputs 0 \ + --relu_before_extra_convs 0 \ + --aware_start 0 \ + --aware_end 4 \ + --cutout_prob 0.5 \ + --cutout_ratio_min 0.1 \ + --cutout_ratio_max 0.3 diff --git a/tools/SoCo_FPN_100ep_linear.sh b/tools/SoCo_FPN_100ep_linear.sh new file mode 100755 index 0000000..de06bf2 --- /dev/null +++ b/tools/SoCo_FPN_100ep_linear.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-100} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_100ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \ + SoCo/main_linear.py \ + --data_dir ${data_dir} \ + --zip --cache_mode part \ + --arch resnet50 \ + --output_dir ${output_dir}_linear_eval \ + --pretrained_model ${output_dir}/current.pth \ + --save_freq 10 \ + --auto_resume diff --git a/tools/SoCo_FPN_400ep.sh b/tools/SoCo_FPN_400ep.sh new file mode 100755 index 0000000..10121a6 --- /dev/null +++ b/tools/SoCo_FPN_400ep.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + +python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \ + SoCo/main_pretrain.py \ + --data_dir ${data_dir} \ + --crop 0.5 \ + --base_lr 1.0 \ + --optimizer lars \ + --weight_decay 1e-5 \ + --amp_opt_level O1 \ + --ss_props \ + --auto_resume \ + --aug ImageAsymBboxAwareMultiJitter1 \ + --zip --cache_mode no \ + --arch resnet50 \ + --model SoCo_FPN \ + --warmup_epoch ${warmup} \ + --epochs ${epochs} \ + --output_dir ${output_dir} \ + --save_freq 1 \ + --batch_size 128 \ + --contrast_momentum 0.99 \ + --filter_strategy ratio3size0308post \ + --select_strategy random \ + --select_k 4 \ + --output_size 7 \ + --aligned \ + --jitter_prob 0.5 \ + --jitter_ratio 0.1 \ + --padding_k 4 \ + --start_level 0 \ + --num_outs 4 \ + --add_extra_convs 0 \ + --extra_convs_on_inputs 0 \ + --relu_before_extra_convs 0 \ + --aware_start 0 \ + --aware_end 4 \ + --cutout_prob 0.5 \ + --cutout_ratio_min 0.1 \ + --cutout_ratio_max 0.3 + +# python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ +# main_linear.py \ +# --data_dir ${data_dir} \ +# --zip --cache_mode part \ +# --arch resnet50 \ +# --output_dir ${output_dir}/eval \ +# --pretrained_model ${output_dir}/current.pth \ diff --git a/tools/SoCo_FPN_400ep_linear.sh b/tools/SoCo_FPN_400ep_linear.sh new file mode 100755 index 0000000..12111dd --- /dev/null +++ b/tools/SoCo_FPN_400ep_linear.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \ + SoCo/main_linear.py \ + --data_dir ${data_dir} \ + --zip --cache_mode part \ + --arch resnet50 \ + --output_dir ${output_dir}_linear_eval \ + --pretrained_model ${output_dir}/current.pth \ + --save_freq 10 \ + --auto_resume diff --git a/tools/SoCo_FPN_Star_400ep.sh b/tools/SoCo_FPN_Star_400ep.sh new file mode 100755 index 0000000..fea5df5 --- /dev/null +++ b/tools/SoCo_FPN_Star_400ep.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_Star_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + +python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \ + SoCo/main_pretrain.py \ + --data_dir ${data_dir} \ + --crop 0.5 \ + --base_lr 1.0 \ + --optimizer lars \ + --weight_decay 1e-5 \ + --amp_opt_level O1 \ + --ss_props \ + --auto_resume \ + --aug ImageAsymBboxAwareMulti3ResizeExtraJitter1 \ + --zip --cache_mode no \ + --arch resnet50 \ + --model SoCo_FPN_Star \ + --warmup_epoch ${warmup} \ + --epochs ${epochs} \ + --output_dir ${output_dir} \ + --save_freq 1 \ + --batch_size 128 \ + --contrast_momentum 0.99 \ + --filter_strategy ratio3size0308post \ + --select_strategy random \ + --select_k 4 \ + --output_size 7 \ + --aligned \ + --jitter_prob 0.5 \ + --jitter_ratio 0.1 \ + --padding_k 4 \ + --start_level 0 \ + --num_outs 4 \ + --add_extra_convs 0 \ + --extra_convs_on_inputs 0 \ + --relu_before_extra_convs 0 \ + --aware_start 0 \ + --aware_end 4 \ + --cutout_prob 0.5 \ + --cutout_ratio_min 0.1 \ + --cutout_ratio_max 0.3 \ + --image3_size 192 \ + --image4_size 112 \ diff --git a/tools/SoCo_FPN_Star_400ep_linear.sh b/tools/SoCo_FPN_Star_400ep_linear.sh new file mode 100755 index 0000000..e03b80e --- /dev/null +++ b/tools/SoCo_FPN_Star_400ep_linear.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +warmup=${1:-10} +epochs=${2:-400} + +data_dir="./data/ImageNet-Zip" +output_dir="./SoCo_output/SoCo_FPN_Star_400ep" + +master_addr=${MASTER_IP} +master_port=28652 + + +python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \ + SoCo/main_linear.py \ + --data_dir ${data_dir} \ + --zip --cache_mode part \ + --arch resnet50 \ + --output_dir ${output_dir}_linear_eval \ + --pretrained_model ${output_dir}/current.pth \ + --save_freq 10 \ + --auto_resume