diff --git a/README.md b/README.md
index 8abe997..a85b4b4 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,95 @@
# SoCo
-[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning
+[NeurIPS 2021 Spotlight] [Aligning Pretraining for Detection via Object-Level Contrastive Learning](https://arxiv.org/abs/2106.02637)
+
+By [Fangyun Wei](https://scholar.google.com/citations?user=-ncz2s8AAAAJ&hl=en)\*, [Yue Gao](https://yuegao.me)\*, [Zhirong Wu](https://scholar.google.com/citations?user=lH4zgcIAAAAJ&hl=en), [Han Hu](https://ancientmooner.github.io), [Stephen Lin](https://www.microsoft.com/en-us/research/people/stevelin/).
+> \* Equal contribution.
+
+
+## Introduction
+Image-level contrastive representation learning has proven to be highly effective as a generic model for transfer learning.
+Such generality for transfer learning, however, sacrifices specificity if we are interested in a certain downstream task.
+We argue that this could be sub-optimal and thus advocate a design principle which encourages alignment between the self-supervised pretext task and the downstream task.
+In this paper, we follow this principle with a pretraining method specifically designed for the task of object detection.
+We attain alignment in the following three aspects:
+1) object-level representations are introduced via selective search bounding boxes as object proposals;
+2) the pretraining network architecture incorporates the same dedicated modules used in the detection pipeline (e.g. FPN);
+3) the pretraining is equipped with object detection properties such as object-level translation invariance and scale invariance.
+Our method, called Selective Object COntrastive learning (SoCo), achieves state-of-the-art results for transfer performance on COCO detection using a Mask R-CNN framework.
+
+
+### Architecture
+![](figures/overview.png)
+
+
+## Main results
+
+### SoCo pre-trained models
+| Model | Arch | Epochs | Scripts | Download |
+|:-----:|:------------:|:------:|:---------------------------------------------------:|:--------:|
+| SoCo | ResNet50-C4 | 100 | [SoCo_C4_100ep](tools/SoCo_C4_100ep.sh) | |
+| SoCo | ResNet50-C4 | 400 | [SoCo_C4_400ep](tools/SoCo_C4_400ep.sh) | |
+| SoCo | ResNet50-FPN | 100 | [SoCo_FPN_100ep](tools/SoCo_FPN_100ep.sh) | |
+| SoCo | ResNet50-FPN | 400 | [SoCo_FPN_400ep](tools/SoCo_FPN_400ep.sh) | |
+| SoCo* | ResNet50-FPN | 400 | [SoCo_FPN_Star_400ep](tools/SoCo_FPN_Star_400ep.sh) | |
+
+
+### Results on COCO with MaskRCNN **R50-FPN**
+| Methods | Epoch | APbb | APbb50 | APbb75 | APmk | APmk50 | APmk75 | Detectron2 trained |
+|------------|-------|-----------------|------------------------------|-----------------------------------|--------------------|-----------------------------------|-----------------------------------|--------------------|
+| Scratch | - | 31.0 | 49.5 | 33.2 | 28.5 | 46.8 | 30.4 | |
+| Supervised | 90 | 38.9 | 59.6 | 42.7 | 35.4 | 56.5 | 38.1 | |
+| SoCo | 100 | 42.3 | 62.5 | 46.5 | 37.6 | 59.1 | 40.5 | |
+| SoCo | 400 | 43.0 | 63.3 | 47.1 | 38.2 | 60.2 | 41.0 | |
+| SoCo* | 400 | 43.2 | 63.5 | 47.4 | 38.4 | 60.2 | 41.4 | |
+
+
+### Results on COCO with MaskRCNN **R50-C4**
+| Methods | Epoch | APbb | APbb50 | APbb75 | APmk | APmk50 | APmk75 | Detectron2 trained |
+|------------|-------|-----------------|------------------------------|-----------------------------------|--------------------|-----------------------------------|-----------------------------------|--------------------|
+| Scratch | - | 26.4 | 44.0 | 27.8 | 29.3 | 46.9 | 30.8 | |
+| Supervised | 90 | 38.2 | 58.2 | 41.2 | 33.3 | 54.7 | 35.2 | |
+| SoCo | 100 | 40.4 | 60.4 | 43.7 | 34.9 | 56.8 | 37.0 | |
+| SoCo | 400 | 40.9 | 60.9 | 44.3 | 35.3 | 57.5 | 37.3 | |
+
+
+## Get started
+### Requirements
+The [Dockerfile](docker/Dockerfile) is included, please refer to it.
+
+
+### Prepare data with Selective Search
+1. Generate Selective Search proposals
+ ```python
+ python selective_search/generate_imagenet_ss_proposals.py
+ ```
+2. Filter out not valid proposals with filter strategy
+ ```python
+ python selective_search/filter_ss_proposals_json.py
+ ```
+3. Post preprocessing for no proposals images
+ ```python
+ python selective_search/filter_ss_proposals_json_post_no_prop.py
+ ```
+
+
+### Pretrain with SoCo
+> Use SoCo FPN 100 epoch as example.
+```shell
+bash ./tools/SoCo_FPN_100ep.sh
+```
+
+
+### Finetune detector
+1. Copy the folder `detectron2_configs` to the root folder of `Detectron2`
+2. Train the detectors with `Detectron2`
+
+
+## Citation
+```bib
+@article{wei2021aligning,
+ title={Aligning Pretraining for Detection via Object-Level Contrastive Learning},
+ author={Wei, Fangyun and Gao, Yue and Wu, Zhirong and Hu, Han and Lin, Stephen},
+ journal={arXiv preprint arXiv:2106.02637},
+ year={2021}
+}
+```
\ No newline at end of file
diff --git a/contrast/__init__.py b/contrast/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/contrast/data/__init__.py b/contrast/data/__init__.py
new file mode 100644
index 0000000..ab515b5
--- /dev/null
+++ b/contrast/data/__init__.py
@@ -0,0 +1,138 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import os
+
+import numpy as np
+import torch.distributed as dist
+from torch.utils.data import DataLoader, SubsetRandomSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from .dataset import (ImageFolder, ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1,
+ ImageFolderImageAsymBboxAwareMultiJitter1,
+ ImageFolderImageAsymBboxAwareMultiJitter1Cutout,
+ ImageFolderImageAsymBboxCutout)
+from .sampler import SubsetSlidingWindowSampler
+from .transform import get_transform
+
+
+def get_loader(aug_type, args, two_crop=False, prefix='train', return_coord=False):
+ transform = get_transform(args, aug_type, args.crop, args.image_size, crop1=args.crop1,
+ cutout_prob=args.cutout_prob, cutout_ratio=args.cutout_ratio,
+ image3_size=args.image3_size, image4_size=args.image4_size)
+
+ # dataset
+ if args.zip:
+ if args.dataset == 'ImageNet':
+ train_ann_file = prefix + f"_{args.split_map}.txt"
+ train_prefix = prefix + ".zip@/"
+ if args.ss_props:
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.rpn_props:
+ train_props_file = f"rpn_props_nms_{args.nms_threshold}.json"
+ elif args.dataset == 'COCO': # NOTE: for coco, we use same scheme as ImageNet
+ prefix = 'trainunlabeled2017'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.dataset == 'Object365':
+ prefix = 'train'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.dataset == 'ImageNetObject365':
+ prefix = 'train'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.dataset == 'OpenImage':
+ prefix = 'train'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.dataset == 'ImageNetOpenImage':
+ prefix = 'train'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ elif args.dataset == 'ImageNetObject365OpenImage':
+ prefix = 'train'
+ train_ann_file = prefix + "_map.txt"
+ train_prefix = prefix + ".zip@/"
+ train_props_file = prefix + f"_{args.filter_strategy}.json"
+ else:
+ raise NotImplementedError('Dataset {} is not supported. We only support ImageNet'.format(args.dataset))
+
+ if aug_type == 'ImageAsymBboxCutout':
+ train_dataset = ImageFolderImageAsymBboxCutout(args.data_dir, train_ann_file, train_prefix,
+ train_props_file, image_size=args.image_size, select_strategy=args.select_strategy,
+ select_k=args.select_k, weight_strategy=args.weight_strategy,
+ jitter_ratio=args.jitter_ratio, padding_k=args.padding_k,
+ aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end,
+ max_tries=args.max_tries,
+ transform=transform, cache_mode=args.cache_mode,
+ dataset=args.dataset)
+
+ elif aug_type == 'ImageAsymBboxAwareMultiJitter1':
+ train_dataset = ImageFolderImageAsymBboxAwareMultiJitter1(args.data_dir, train_ann_file, train_prefix,
+ train_props_file, image_size=args.image_size, select_strategy=args.select_strategy,
+ select_k=args.select_k, weight_strategy=args.weight_strategy,
+ jitter_ratio=args.jitter_ratio, padding_k=args.padding_k,
+ aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end,
+ max_tries=args.max_tries,
+ transform=transform, cache_mode=args.cache_mode,
+ dataset=args.dataset)
+
+ elif aug_type == 'ImageAsymBboxAwareMultiJitter1Cutout':
+ train_dataset = ImageFolderImageAsymBboxAwareMultiJitter1Cutout(args.data_dir, train_ann_file, train_prefix,
+ train_props_file, image_size=args.image_size, select_strategy=args.select_strategy,
+ select_k=args.select_k, weight_strategy=args.weight_strategy,
+ jitter_ratio=args.jitter_ratio, padding_k=args.padding_k,
+ aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end,
+ max_tries=args.max_tries,
+ transform=transform, cache_mode=args.cache_mode,
+ dataset=args.dataset)
+
+ elif aug_type == 'ImageAsymBboxAwareMulti3ResizeExtraJitter1':
+ train_dataset = ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1(args.data_dir, train_ann_file, train_prefix,
+ train_props_file, image_size=args.image_size, image3_size=args.image3_size,
+ image4_size=args.image4_size,
+ select_strategy=args.select_strategy,
+ select_k=args.select_k, weight_strategy=args.weight_strategy,
+ jitter_ratio=args.jitter_ratio, padding_k=args.padding_k,
+ aware_range=args.aware_range, aware_start=args.aware_start, aware_end=args.aware_end,
+ max_tries=args.max_tries,
+ transform=transform, cache_mode=args.cache_mode,
+ dataset=args.dataset)
+ elif aug_type == 'NULL':
+ train_dataset = ImageFolder(args.data_dir, train_ann_file, train_prefix,
+ transform, two_crop=two_crop, cache_mode=args.cache_mode,
+ dataset=args.dataset, return_coord=return_coord)
+ else:
+ raise NotImplementedError
+
+ else:
+ train_folder = os.path.join(args.data_dir, prefix)
+ train_dataset = ImageFolder(train_folder, transform=transform, two_crop=two_crop, return_coord=return_coord)
+ raise NotImplementedError
+
+ # sampler
+ indices = np.arange(dist.get_rank(), len(train_dataset), dist.get_world_size())
+ if args.use_sliding_window_sampler:
+ sampler = SubsetSlidingWindowSampler(indices,
+ window_stride=args.window_stride // dist.get_world_size(),
+ window_size=args.window_size // dist.get_world_size(),
+ shuffle_per_epoch=args.shuffle_per_epoch)
+ elif args.zip and args.cache_mode == 'part':
+ sampler = SubsetRandomSampler(indices)
+ else:
+ sampler = DistributedSampler(train_dataset)
+
+ # # dataloader
+ return DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=args.num_workers, pin_memory=True, sampler=sampler, drop_last=True)
diff --git a/contrast/data/bboxs_utils.py b/contrast/data/bboxs_utils.py
new file mode 100644
index 0000000..61f4b6e
--- /dev/null
+++ b/contrast/data/bboxs_utils.py
@@ -0,0 +1,483 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import math
+import random
+
+import numpy as np
+import torch
+import torchvision.transforms as T
+from mmdet.core import anchor_inside_flags
+
+
+def overlap_image_bbox_w_id(image_bbox, bbox):
+ ix1, iy1, ix2, iy2 = image_bbox
+ bx1, by1, bx2, by2, bid = bbox
+
+ if ix1 < bx2 and bx1 < ix2 and iy1 < by2 and by1 < iy2:
+ nx1, ny1, nx2, ny2 = max(ix1, bx1), max(iy1, by1), min(ix2, bx2), min(iy2, by2)
+ # minus the crop image x, y
+ nx1, ny1 = nx1 - ix1, ny1 - iy1
+ nx2, ny2 = nx2 - ix1, ny2 - iy1
+ return np.array([nx1, ny1, nx2, ny2, bid])
+ else:
+ return None
+
+
+def cal_overlap_params(params1, params2):
+ y11, x11, h1, w1, _, _ = params1
+ y21, x21, h2, w2, _, _ = params2
+ y12, x12 = y11 + h1 - 1, x11 + w1 - 1
+ y22, x22 = y21 + h2 - 1, x21 + w2 - 1
+
+ if x11 < x22 and x21 < x12 and y11 < y22 and y21 < y12:
+ nx1, ny1, nx2, ny2 = max(x11, x21), max(y11, y21), min(x12, x22), min(y12, y22)
+ return np.array([nx1, ny1, nx2, ny2])
+ else:
+ return None
+
+
+def is_overlap(params_overlap, bbox):
+ px1, py1, px2, py2 = params_overlap
+ bx1, by1, bx2, by2, _ = bbox
+
+ if px1 < bx2 and bx1 < px2 and py1 < by2 and by1 < py2:
+ return True
+ else:
+ return False
+
+
+def clip_bboxs(bboxs, top, left, height, width):
+ clipped_bboxs_w_id_list = []
+ clip_image_bbox = np.array([left, top, left + width - 1, top + height - 1])
+ for cur_bbox in bboxs:
+ overlap_bbox = overlap_image_bbox_w_id(clip_image_bbox, cur_bbox)
+ # assert overlap_bbox is not None, print("clip_image_bbox", clip_image_bbox, "cur_bbox", cur_bbox)
+ if overlap_bbox is not None:
+ clipped_bboxs_w_id_list.append(overlap_bbox.reshape(1, overlap_bbox.shape[0]))
+
+ if len(clipped_bboxs_w_id_list) > 0:
+ clipped_bboxs_w_id = np.concatenate(clipped_bboxs_w_id_list, axis=0)
+ else:
+ clipped_bboxs_w_id = np.array([[0, 0, width - 1, height - 1, 1]]) # just whole view
+ return clipped_bboxs_w_id
+
+
+def clip_bboxs_in_jitter(bboxs, top, left, height, width):
+ clipped_bboxs_w_id_list = []
+ clip_image_bbox = np.array([left, top, left + width - 1, top + height - 1])
+ for cur_bbox in bboxs:
+ overlap_bbox = overlap_image_bbox_w_id(clip_image_bbox, cur_bbox)
+ if overlap_bbox is not None:
+ clipped_bboxs_w_id_list.append(overlap_bbox.reshape(1, overlap_bbox.shape[0]))
+
+ if len(clipped_bboxs_w_id_list) > 0:
+ clipped_bboxs_w_id = np.concatenate(clipped_bboxs_w_id_list, axis=0)
+ return clipped_bboxs_w_id
+ else:
+ return None
+
+
+def get_overlap_props(proposals, overlap_region):
+ if overlap_region is None:
+ return np.array([])
+ common_props = []
+ for prop in proposals:
+ if is_overlap(overlap_region, prop):
+ common_props.append(prop.reshape(1, prop.shape[0]))
+ if len(common_props) > 0:
+ common_props = np.concatenate(common_props, axis=0)
+ return common_props
+ else:
+ return np.array([])
+
+
+def resize_bboxs(clipped_bboxs_w_id, height, width, size):
+ bboxs_w_id = np.copy(clipped_bboxs_w_id).astype(float) # !!!
+ bboxs_w_id[:, 0] = bboxs_w_id[:, 0] / width * size[0] # x1
+ bboxs_w_id[:, 1] = bboxs_w_id[:, 1] / height * size[1] # y1
+ bboxs_w_id[:, 2] = bboxs_w_id[:, 2] / width * size[0] # x2
+ bboxs_w_id[:, 3] = bboxs_w_id[:, 3] / height * size[1] # y2
+ return bboxs_w_id
+
+
+def resize_bboxs_vis(clipped_bboxs_w_id, size):
+ bboxs_w_id = np.copy(clipped_bboxs_w_id).astype(float) # !!!
+ bboxs_w_id[:, 0] = bboxs_w_id[:, 0] * size[0] # x1
+ bboxs_w_id[:, 1] = bboxs_w_id[:, 1] * size[1] # y1
+ bboxs_w_id[:, 2] = bboxs_w_id[:, 2] * size[0] # x2
+ bboxs_w_id[:, 3] = bboxs_w_id[:, 3] * size[1] # y2
+ return bboxs_w_id
+
+
+def resize_bboxs_and_assign_labels(cropped_bboxs, height, width, size, bbox_size_range):
+ resized_bboxs_w_labels = torch.empty((cropped_bboxs.shape[0], cropped_bboxs.shape[1]+2), requires_grad=False)
+ resized_bboxs_vis = np.zeros_like(cropped_bboxs)
+ # 2 is used for p5, p4 one hot, determinted by bbox_size_range
+ min_size = bbox_size_range[0] # (32, 112, 224)
+ mid_size = bbox_size_range[1]
+ for i in range(cropped_bboxs.shape[0]):
+ cur_bbox = cropped_bboxs[i]
+ if cur_bbox[2] <= 0 or cur_bbox[3] <= 0:
+ # valid coordinates but we do not use it
+ resized_bboxs_w_labels[i] = torch.Tensor([0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
+ continue
+ x, y, w, h = cur_bbox
+ nx = x / width * size[0]
+ ny = y / height * size[1]
+ nw = w / width * size[0]
+ nh = h / height * size[1]
+ resized_bboxs_vis[i] = np.array([nx, ny, nw, nh])
+ # NOTE: we swap x, y and w, h to align with image, and turn into 0~1
+ short_side = min(nw, nh)
+ long_side = max(nw, nh)
+ if short_side >= min_size:
+ if long_side < mid_size: # use p4
+ resized_bboxs_w_labels[i] = torch.Tensor([ny / size[1], nx / size[0], (ny + nh) / size[1], (nx + nw) / size[0], 1.0, 0.0])
+ else: # use p5
+ resized_bboxs_w_labels[i] = torch.Tensor([ny / size[1], nx / size[0], (ny + nh) / size[1], (nx + nw) / size[0], 0.0, 1.0])
+ else:
+ resized_bboxs_w_labels[i] = torch.Tensor([0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
+ # print("resized bboxs size tensor", resized_bboxs_w_labels)
+ # print("resized bboxs size vis numpy", resized_bboxs_vis)
+ return resized_bboxs_w_labels, resized_bboxs_vis
+
+
+def resized_crop_bboxs_assign_labels_with_ids(bboxs_w_id, top, left, height, width, size, bbox_size_range):
+ # bboxs_w_id, x, y, w, h, are in image size range, not in [0, 1]
+ cropped_bboxs_w_id = clip_bboxs(bboxs_w_id, top, left, height, width)
+ resized_cropped_bboxs, resized_bboxs_w_id_vis = resize_bboxs_and_assign_labels(cropped_bboxs_w_id, height, width, size, bbox_size_range)
+ # resize_cropped_bboxs_w_id is same shape with bboxs_w_id, we pad zeros to do data batching
+ return resized_cropped_bboxs, resized_bboxs_w_id_vis
+
+
+def clip_and_resize_bboxs_w_ids(bboxs_w_id, top, left, height, width, size):
+ clipped_bboxs_w_id = clip_bboxs(bboxs_w_id, top, left, height, width)
+ resized_clipped_bboxs_w_id = resize_bboxs(clipped_bboxs_w_id, height, width, size)
+ return resized_clipped_bboxs_w_id
+
+
+def get_common_bboxs_ids(bboxs1, bboxs2):
+ w = np.where(np.in1d(bboxs1[:, 4], bboxs2[:, 4]))[0] # intersect of ids, here bboxs ids are unique
+ common_bboxs_ids = bboxs1[w][:, 4]
+ return common_bboxs_ids
+
+
+def jitter_bboxs(bboxs, common_bboxs_ids, jitter_ratio, pad_num, height, width):
+ common_indices = np.isin(bboxs[:, 4], common_bboxs_ids)
+ common_bboxs = bboxs[common_indices]
+ clipped_jittered_bboxs_list = []
+ remaining_pad = pad_num
+ while remaining_pad > 0:
+ selected_bboxs = common_bboxs[np.random.choice(common_bboxs.shape[0], remaining_pad)]
+ jitters = np.random.uniform(low=-jitter_ratio, high=jitter_ratio, size=(remaining_pad, 4))
+
+ jittered_bboxs = np.copy(selected_bboxs).astype(float)
+ selected_bboxs_w = selected_bboxs[:, 2] - selected_bboxs[:, 0] + 1 # w = x2 - x1 + 1
+ selected_bboxs_h = selected_bboxs[:, 3] - selected_bboxs[:, 1] + 1 # h = y2 - y1 + 1
+ jittered_w = selected_bboxs_w + jitters[:, 2] * selected_bboxs_w # nw = w + j * w
+ jittered_h = selected_bboxs_h + jitters[:, 3] * selected_bboxs_h # nh = h + j * h
+ jittered_bboxs[:, 0] = selected_bboxs[:, 0] + jitters[:, 0] * selected_bboxs_w # nx1 = x1 + j * w
+ jittered_bboxs[:, 1] = selected_bboxs[:, 1] + jitters[:, 1] * selected_bboxs_h # ny1 = y1 + j * h
+ jittered_bboxs[:, 2] = jittered_bboxs[:, 0] + jittered_w - 1 # nx2 = nx1 + nw - 1
+ jittered_bboxs[:, 3] = jittered_bboxs[:, 1] + jittered_h - 1 # ny2 = ny1 + nh - 1
+
+ clipped_jittered_bboxs = clip_bboxs_in_jitter(jittered_bboxs, 0, 0, height, width)
+ if clipped_jittered_bboxs is not None and clipped_jittered_bboxs.shape[0] > 0:
+ clipped_jittered_bboxs_list.append(clipped_jittered_bboxs)
+ remaining_pad -= clipped_jittered_bboxs.shape[0]
+
+ padded_clipped_jittered_bboxs = np.concatenate(clipped_jittered_bboxs_list, axis=0)
+
+ return padded_clipped_jittered_bboxs
+
+
+def jitter_props(selected_image_props, jitter_prob, jitter_ratio):
+ jittered_props = []
+ for prop in selected_image_props:
+ jitter_r = random.random()
+ jittered_prop = np.copy(prop).astype(float)
+ if jitter_r < jitter_prob:
+ jitter = np.random.uniform(low=-jitter_ratio, high=jitter_ratio, size=(4, ))
+ w = prop[2] - prop[0] + 1
+ h = prop[3] - prop[1] + 1
+
+ jittered_w = w + jitter[2] * w # nw = w + j * w
+ jittered_h = h + jitter[3] * h # nh = h + j * h
+
+ jittered_prop[0] = prop[0] + jitter[0] * w # nx1 = x1 + j * w
+ jittered_prop[1] = prop[1] + jitter[1] * h # ny1 = y1 + j * h
+
+ jittered_prop[2] = jittered_prop[0] + jittered_w - 1 # nx2 = nx1 + nw - 1
+ jittered_prop[3] = jittered_prop[1] + jittered_h - 1 # ny2 = ny1 + nh - 1
+
+ jittered_prop = jittered_prop.reshape(1, jittered_prop.shape[0])
+ jittered_props.append(jittered_prop)
+
+ if len(jittered_props) > 0:
+ jittered_props_np = np.concatenate(jittered_props, axis=0)
+ return jittered_props_np
+ else:
+ return np.array([])
+
+
+def random_generate_props(image_size, r=3.0, min_ratio=0.3, max_ratio=0.8, max_props=32):
+ props = []
+ for _ in range(max_props):
+ sqrt_area = math.sqrt(image_size[0] * image_size[1])
+ target_sqrt_area = random.uniform(min_ratio, max_ratio) * sqrt_area
+ target_area = target_sqrt_area * target_sqrt_area
+ aspect_ratio = random.uniform(1/r, r)
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < 32 or h < 32 or w > image_size[0] or h > image_size[1]:
+ continue
+ center_x = random.randint(w // 2, image_size[0]- w // 2)
+ center_y = random.randint(h // 2, image_size[1]- h // 2)
+ x1 = max(center_x, 0)
+ x2 = min(center_x + w // 2, image_size[0])
+ y1 = max(center_y, 0)
+ y2 = min(center_y + h // 2, image_size[1])
+ prop = np.array([x1, y1, x2, y2]).reshape((1, 4))
+ props.append(prop)
+ if len(props) > 0:
+ proposals = np.concatenate(props)
+ else:
+ proposals = np.array([])
+ return proposals
+
+
+def pad_bboxs_with_common(bboxs, common_bboxs_ids, jitter_ratio, pad_num, height, width):
+ common_indices = np.isin(bboxs[:, 4], common_bboxs_ids)
+ common_bboxs = bboxs[common_indices]
+ selected_bboxs = common_bboxs[np.random.choice(common_bboxs.shape[0], pad_num)]
+ return selected_bboxs
+
+
+def get_correspondence_matrix(bboxs1, bboxs2):
+ # intersect of ids, here bboxs ids can be duplicate
+ assert bboxs1.shape[0] == bboxs2.shape[0]
+ L = bboxs1.shape[0]
+ bboxs1_ids = bboxs1[:, 4]
+ bboxs2_ids = bboxs2[:, 4]
+ bboxs1_ids = np.reshape(bboxs1_ids, (L, 1))
+ bboxs2_ids = np.reshape(bboxs2_ids, (1, L))
+ bboxs1_m = np.tile(bboxs1_ids, (1, L))
+ bboxs2_m = np.tile(bboxs2_ids, (L, 1))
+ correspondence_matrix = bboxs1_m == bboxs2_m
+ correspondence_matrix = correspondence_matrix.astype(float)
+ correspondence_matrix = torch.Tensor(correspondence_matrix)
+ return correspondence_matrix
+
+
+def calculate_centerness_targets_from_bbox(bbox):
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ left = np.tile(np.array([i for i in range(w)]).reshape(w, 1), (1, h))
+ right = np.tile(np.array([w - i - 1 for i in range(w)]).reshape(w, 1), (1, h))
+ top = np.tile(np.array([i for i in range(h)]).reshape(1, h), (w, 1))
+ bottom = np.tile(np.array([h - i - 1 for i in range(h)]).reshape(1, h), (w, 1))
+
+ left_right_min = np.minimum(left, right)
+ left_right_max = np.maximum(left, right) + 1e-6
+ top_bottom_min = np.minimum(top, bottom)
+ top_bottom_max = np.maximum(top, bottom) + 1e-6
+
+ centerness_targets = (left_right_min / left_right_max) * (top_bottom_min / top_bottom_max)
+ centerness_targets = np.sqrt(centerness_targets)
+ return centerness_targets
+
+
+def calculate_weight_map_bboxs(bboxs, size, weight_strategy):
+ if weight_strategy == 'no_weights':
+ weight_map = np.ones(size).astype(float)
+ return weight_map
+
+ weight_map = np.zeros(size).astype(float)
+ for bbox in bboxs:
+ # print("calculate_weight_map_bboxs bbox", bbox)
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ if w < 1 or h < 1:
+ continue
+ if weight_strategy == 'bbox':
+ weight_map[x1:x2+1, y1:y2+1] += 1.0
+ elif weight_strategy == 'center':
+ center_x = (x1 + x2) // 2 # ??
+ center_y = (y1 + y2) // 2 # ??
+ weight_map[center_x, center_y] += 1.0
+ elif weight_strategy == 'gaussian':
+ centerness_targets = calculate_centerness_targets_from_bbox(bbox)
+ weight_map[x1:x2+1, y1:y2+1] += centerness_targets
+ else:
+ raise NotImplementedError
+
+ return weight_map
+
+
+def proposals_to_tensor(props):
+ props = props.astype(float)
+ props_tensor = torch.Tensor(props)
+ return props_tensor
+
+
+def bboxs_to_tensor(bboxs, params):
+ """ x1y1x2y2, -> 0, 1
+ """
+ _, _, height, width, _, _ = params
+ bboxs = bboxs.astype(float)
+ bboxs_new = np.copy(bboxs) # keep ids
+
+ bboxs_new[:, 0] = bboxs_new[:, 0] / width # x1
+ bboxs_new[:, 1] = bboxs_new[:, 1] / height # y1
+ bboxs_new[:, 2] = bboxs_new[:, 2] / width # x2
+ bboxs_new[:, 3] = bboxs_new[:, 3] / height # y2
+
+ bboxs_tensor = torch.Tensor(bboxs_new)
+
+ return bboxs_tensor
+
+
+def bboxs_to_tensor_dynamic(bboxs, params, dynamic_params, image_size):
+ """ x1y1x2y2, -> 0, 1
+ """
+ _, _, height, width, _, _ = params
+ dynamic_resize = dynamic_params[3]
+ bboxs = bboxs.astype(float)
+ bboxs_new = np.copy(bboxs) # keep ids
+
+ # raw_size -> raw_ratio -> dynamic_size -> padded_ratio
+
+ bboxs_new[:, 0] = bboxs_new[:, 0] / width * dynamic_resize[0] / image_size[0] # x1
+ bboxs_new[:, 1] = bboxs_new[:, 1] / height * dynamic_resize[1] / image_size[1] # y1
+ bboxs_new[:, 2] = bboxs_new[:, 2] / width * dynamic_resize[0] / image_size[0] # x2
+ bboxs_new[:, 3] = bboxs_new[:, 3] / height * dynamic_resize[1] / image_size[1] # y2
+
+ bboxs_tensor = torch.Tensor(bboxs_new)
+
+ return bboxs_tensor
+
+
+def weight_to_tensor(weight):
+ """ w, h -> h, w
+ """
+ weight_tensor = torch.Tensor(weight)
+
+ weight_tensor = weight_tensor.unsqueeze(0) # channel 1
+ weight_tensor = torch.transpose(weight_tensor, 1, 2)
+ return weight_tensor
+
+
+def assign_bboxs_to_feature_map(resized_bboxs, aware_range, aware_start, aware_end, not_used_value=-1):
+ """aware_range
+ """
+ L = resized_bboxs.shape[0]
+ P = aware_end - aware_start
+ assert P > 0
+
+ bboxs_id_assigned = np.ones((P, L)) * not_used_value # the not used value is use to be different from 0
+ for i, bbox in enumerate(resized_bboxs):
+ x1, y1, x2, y2, bid = bbox
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ size = math.sqrt(h * w)
+ for j in range(aware_start, aware_end):
+ if size <= aware_range[j]:
+ bboxs_id_assigned[j - aware_start, i] = bid
+ # print(f"{bid} assigned to feature {j - aware_start}")
+ break # assign bbox to only one feature map
+
+ bboxs_id_assigned = bboxs_id_assigned.reshape((P * L, ))
+ return bboxs_id_assigned
+
+
+def get_aware_correspondence_matrix(bboxs_id_assigned1, bboxs_id_assigned2):
+ PL = bboxs_id_assigned1.shape[0]
+ bboxs1_ids = np.reshape(bboxs_id_assigned1, (PL, 1))
+ bboxs2_ids = np.reshape(bboxs_id_assigned2, (1, PL))
+ bboxs1_m = np.tile(bboxs1_ids, (1, PL))
+ bboxs2_m = np.tile(bboxs2_ids, (PL, 1))
+ correspondence_matrix = bboxs1_m == bboxs2_m
+ correspondence_matrix = correspondence_matrix.astype(float)
+ correspondence_matrix = torch.Tensor(correspondence_matrix)
+ return correspondence_matrix
+
+
+def get_aware_correspondence_matrix_torch(bboxs_id_assigned1, bboxs_id_assigned2):
+ PL = bboxs_id_assigned1.size(0)
+ bboxs1_ids = torch.reshape(bboxs_id_assigned1, (PL, 1))
+ bboxs2_ids = torch.reshape(bboxs_id_assigned2, (1, PL))
+ bboxs1_m = bboxs1_ids.repeat(1, PL)
+ bboxs2_m = bboxs2_ids.repeat(PL, 1)
+ correspondence_matrix = bboxs1_m == bboxs2_m
+ correspondence_matrix = correspondence_matrix.float()
+
+ return correspondence_matrix
+
+
+def visualize_image_tensor(image_tensor, visualize_name):
+ image_pil = T.functional.to_pil_image(image_tensor)
+ path = f'self_det/visualization/{visualize_name}.jpg'
+ image_pil.save(path)
+
+
+def assign_gt_bboxs_to_feature_map_with_anchors(anchor_generator, assigner, gt_bboxs, view_size, img_meta, levels=4, not_used_value=-1):
+ # assign gt bbox based on number of anchors in the feature map
+ # SINGLE image version
+
+ P = levels
+ L = gt_bboxs.size(0) # each image number of gts
+
+ featmap_sizes = []
+ for l in range(levels):
+ feat_size_0 = int(math.ceil(view_size[0]/2**(l+2))) # p2 - p5
+ feat_size_1 = int(math.ceil(view_size[1]/2**(l+2))) # p2 - p5
+ feat_size_tensor = torch.tensor((feat_size_0, feat_size_1))
+ featmap_sizes.append(feat_size_tensor)
+
+ # we compute multi level anchors and valid flags
+ multi_level_anchors = anchor_generator.grid_anchors(featmap_sizes, device='cpu')
+ multi_level_flags = anchor_generator.valid_flags(featmap_sizes, img_meta['pad_shape'], device='cpu')
+
+ num_level_anchors = [anchors.size(0) for anchors in multi_level_anchors]
+ num_level_anchors_agg = [0]
+ for level_num in num_level_anchors:
+ num_level_anchors_agg.append(num_level_anchors_agg[-1] + level_num)
+
+ # concat all level anchors to a single tensor
+ flat_anchors = torch.cat(multi_level_anchors)
+ flat_valid_flags = torch.cat(multi_level_flags)
+
+ inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
+ img_meta['img_shape'][:2],
+ allowed_border=-1) # -1 is come from the default value of train_cfg.rpn.allowed_border
+
+ anchors = flat_anchors[inside_flags, :]
+
+
+ gt_bboxs_coord = gt_bboxs[:, :4].clone() # must clone
+ cur_img_assign_result = assigner.assign(anchors, gt_bboxs_coord, None, None)
+ assigned_gts = cur_img_assign_result.gt_inds
+
+ bboxs_id_assigned = torch.ones((P, L)) * not_used_value
+ for gt_idx in range(L):
+ cur_gt_assign = assigned_gts == (gt_idx+1) # the assign results index are 1-based
+ cur_gt_level_assign_sum = [None] * P
+ for level_i in range(P):
+ level_start = num_level_anchors_agg[level_i]
+ level_end = num_level_anchors_agg[level_i+1]
+ cur_gt_level_assign_sum[level_i] = torch.sum(cur_gt_assign[level_start:level_end]).item()
+ cur_gt_assign_level = cur_gt_level_assign_sum.index(max(cur_gt_level_assign_sum))
+
+ bboxs_id_assigned[cur_gt_assign_level, gt_idx] = gt_bboxs[gt_idx, 4] # assign bbox id
+
+ bboxs_id_assigned = bboxs_id_assigned.reshape((P * L, ))
+
+ return bboxs_id_assigned
diff --git a/contrast/data/dataset.py b/contrast/data/dataset.py
new file mode 100644
index 0000000..abebd24
--- /dev/null
+++ b/contrast/data/dataset.py
@@ -0,0 +1,943 @@
+import io
+import json
+import logging
+import os
+import time
+from matplotlib.pyplot import winter
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.utils.data as data
+from PIL import Image
+
+from .bboxs_utils import (cal_overlap_params, clip_bboxs,
+ get_common_bboxs_ids, get_overlap_props, get_correspondence_matrix,
+ pad_bboxs_with_common, bboxs_to_tensor, resize_bboxs,
+ assign_bboxs_to_feature_map, get_aware_correspondence_matrix, jitter_props)
+from .props_utils import select_props, convert_props
+from .selective_search_utils import append_prop_id
+from .zipreader import ZipReader, is_zip_path
+
+
+def has_file_allowed_extension(filename, extensions):
+ """Checks if a file is an allowed extension.
+ Args:
+ filename (string): path to a file
+ Returns:
+ bool: True if the filename ends with a known image extension
+ """
+ filename_lower = filename.lower()
+ return any(filename_lower.endswith(ext) for ext in extensions)
+
+
+def find_classes(dir):
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+ classes.sort()
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ return classes, class_to_idx
+
+
+def make_dataset(dir, class_to_idx, extensions):
+ images = []
+ dir = os.path.expanduser(dir)
+ for target in sorted(os.listdir(dir)):
+ d = os.path.join(dir, target)
+ if not os.path.isdir(d):
+ continue
+
+ for root, _, fnames in sorted(os.walk(d)):
+ for fname in sorted(fnames):
+ if has_file_allowed_extension(fname, extensions):
+ path = os.path.join(root, fname)
+ item = (path, class_to_idx[target])
+ images.append(item)
+
+ return images
+
+
+def make_dataset_with_ann(ann_file, img_prefix, extensions, dataset='ImageNet'):
+ images = []
+
+ with open(ann_file, "r") as f:
+ contents = f.readlines()
+ for line_str in contents:
+ path_contents = [c for c in line_str.split()]
+ im_file_name = path_contents[0]
+ class_index = int(path_contents[1])
+
+ assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
+ item = (os.path.join(img_prefix, im_file_name), class_index)
+
+ images.append(item)
+
+ return images
+
+
+def make_props_dataset_with_ann(ann_file, props_file, select_strategy, select_k, dataset='ImageNet', rpn_props=False, rpn_score_thres=0.5):
+ with open(props_file, "r") as f:
+ props_dict = json.load(f)
+ # make ImageNet or VOC dataset
+ with open(ann_file, "r") as f:
+ contents = f.readlines()
+ images_props = [None] * len(contents)
+ for i, line_str in enumerate(contents):
+ path_contents = [c for c in line_str.split('\t')]
+ im_file_name = path_contents[0]
+ basename = os.path.basename(im_file_name).split('.')[0]
+ all_props = props_dict[basename]
+ converted_props = convert_props(all_props)
+ images_props[i] = converted_props # keep all propos
+
+ del contents
+ del props_dict
+ return images_props
+
+
+class DatasetFolder(data.Dataset):
+ """A generic data loader where the samples are arranged in this way: ::
+ root/class_x/xxx.ext
+ root/class_x/xxy.ext
+ root/class_x/xxz.ext
+ root/class_y/123.ext
+ root/class_y/nsdf3.ext
+ root/class_y/asd932_.ext
+ Args:
+ root (string): Root directory path.
+ loader (callable): A function to load a sample given its path.
+ extensions (list[string]): A list of allowed extensions.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ E.g, ``transforms.RandomCrop`` for images.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ Attributes:
+ samples (list): List of (sample path, class_index) tuples
+ """
+
+ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
+ cache_mode="no", dataset='ImageNet'):
+ # image folder mode
+ if ann_file == '':
+ _, class_to_idx = find_classes(root)
+ samples = make_dataset(root, class_to_idx, extensions)
+ # zip mode
+ else:
+ samples = make_dataset_with_ann(os.path.join(root, ann_file),
+ os.path.join(root, img_prefix),
+ extensions,
+ dataset)
+
+ if len(samples) == 0:
+ raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
+ "Supported extensions are: " + ",".join(extensions)))
+
+ self.root = root
+ self.loader = loader
+ self.extensions = extensions
+
+ self.samples = samples
+ self.labels = [y_1k for _, y_1k in samples]
+ self.classes = list(set(self.labels))
+
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.cache_mode = cache_mode
+ if self.cache_mode != "no":
+ self.init_cache()
+
+ def init_cache(self):
+ assert self.cache_mode in ["part", "full"]
+ n_sample = len(self.samples)
+ global_rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ samples_bytes = [None for _ in range(n_sample)]
+ start_time = time.time()
+ for index in range(n_sample):
+ if index % (n_sample//10) == 0:
+ t = time.time() - start_time
+ logger = logging.getLogger(__name__)
+ logger.info(f'cached {index}/{n_sample} takes {t:.2f}s per block')
+ start_time = time.time()
+ path, target = self.samples[index]
+ if self.cache_mode == "full" or index % world_size == global_rank:
+ samples_bytes[index] = (ZipReader.read(path), target)
+ else:
+ samples_bytes[index] = (path, target)
+ self.samples = samples_bytes
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __repr__(self):
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
+ fmt_str += ' Root Location: {}\n'.format(self.root)
+ tmp = ' Transforms (if any): '
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ tmp = ' Target Transforms (if any): '
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ return fmt_str
+
+
+class DatasetFolderProps(data.Dataset):
+ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', train_props_file='',
+ select_strategy='', select_k=0,
+ transform=None, target_transform=None,
+ cache_mode="no", dataset='ImageNet', rpn_props=False, rpn_score_thres=0.5):
+ # image folder mode
+ if ann_file == '':
+ _, class_to_idx = find_classes(root)
+ samples = make_dataset(root, class_to_idx, extensions)
+ # zip mode
+ else:
+ samples = make_dataset_with_ann(os.path.join(root, ann_file),
+ os.path.join(root, img_prefix),
+ extensions,
+ dataset)
+ samples_props = make_props_dataset_with_ann(os.path.join(root, ann_file),
+ os.path.join(root, train_props_file),
+ select_strategy, select_k,
+ dataset=dataset, rpn_props=rpn_props, rpn_score_thres=rpn_score_thres)
+
+ if len(samples) == 0:
+ raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
+ "Supported extensions are: " + ",".join(extensions)))
+ if len(samples_props) == 0:
+ raise(RuntimeError("Not found the proposal files"))
+
+ self.root = root
+ self.loader = loader
+ self.extensions = extensions
+
+ self.samples = samples
+ self.samples_props = samples_props
+ self.labels = [y_1k for _, y_1k in samples]
+ self.classes = list(set(self.labels))
+
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.cache_mode = cache_mode
+ if self.cache_mode != "no":
+ self.init_cache()
+
+ def init_cache(self):
+ assert self.cache_mode in ["part", "full"]
+ n_sample = len(self.samples)
+ global_rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ samples_bytes = [None for _ in range(n_sample)]
+ start_time = time.time()
+ for index in range(n_sample):
+ if index % (n_sample//10) == 0:
+ t = time.time() - start_time
+ logger = logging.getLogger(__name__)
+ logger.info(f'cached {index}/{n_sample} takes {t:.2f}s per block')
+ start_time = time.time()
+ path, target = self.samples[index]
+ if self.cache_mode == "full" or index % world_size == global_rank:
+ samples_bytes[index] = (ZipReader.read(path), target)
+ else:
+ samples_bytes[index] = (path, target)
+ self.samples = samples_bytes
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __repr__(self):
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
+ fmt_str += ' Root Location: {}\n'.format(self.root)
+ tmp = ' Transforms (if any): '
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ tmp = ' Target Transforms (if any): '
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ return fmt_str
+
+
+IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
+
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ if isinstance(path, bytes):
+ img = Image.open(io.BytesIO(path))
+ elif is_zip_path(path):
+ data = ZipReader.read(path)
+ img = Image.open(io.BytesIO(data))
+ else:
+ with open(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
+
+def accimage_loader(path):
+ import accimage # type: ignore
+ try:
+ return accimage.Image(path)
+ except IOError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_img_loader(path):
+ from torchvision import get_image_backend
+ if get_image_backend() == 'accimage':
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+
+
+class ImageFolder(DatasetFolder):
+ """A generic data loader where the images are arranged in this way: ::
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/xxz.png
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/asd932_.png
+ Args:
+ root (string): Root directory path.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ Attributes:
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no", dataset='ImageNet',
+ two_crop=False, return_coord=False):
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode, dataset=dataset)
+ self.imgs = self.samples
+ self.two_crop = two_crop
+ self.return_coord = return_coord
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ if self.transform is not None:
+ if isinstance(self.transform, tuple) and len(self.transform) == 2:
+ img = self.transform[0](image)
+ else:
+ img = self.transform(image)
+ else:
+ img = image
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if self.two_crop:
+ if isinstance(self.transform, tuple) and len(self.transform) == 2:
+ img2 = self.transform[1](image)
+ else:
+ img2 = self.transform(image)
+
+ if self.return_coord:
+ assert isinstance(img, tuple)
+ img, coord = img
+
+ if self.two_crop:
+ img2, coord2 = img2
+ return img, img2, coord, coord2, index, target
+ else:
+ return img, coord, index, target
+ else:
+ if isinstance(img, tuple):
+ img, coord = img
+
+ if self.two_crop:
+ if isinstance(img2, tuple):
+ img2, coord2 = img2
+ return img, img2, index, target
+ else:
+ return img, index, target
+
+
+class ImageFolderImageAsymBboxCutout(DatasetFolderProps):
+ def __init__(self, root, ann_file='', img_prefix='', train_props_file='',
+ image_size=0, select_strategy='', select_k=0, weight_strategy='',
+ jitter_ratio=0.0, padding_k='', aware_range=[], aware_start=0, aware_end=4,
+ max_tries=0,
+ transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no", dataset='ImageNet'):
+ super(ImageFolderImageAsymBboxCutout, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ train_props_file=train_props_file,
+ select_strategy=select_strategy, select_k=select_k,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode, dataset=dataset)
+ self.imgs = self.samples
+ self.props = self.samples_props
+ self.select_strategy = select_strategy
+ self.select_k = select_k
+ self.weight_strategy = weight_strategy
+ self.jitter_ratio = jitter_ratio
+ self.padding_k = padding_k
+ self.view_size = (image_size, image_size)
+ self.debug = False
+ self.max_tries = max_tries
+ self.least_common = max(self.padding_k // 2, 1)
+ self.aware_range = aware_range
+ assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range'
+ self.aware_start = aware_start # starting from 0 means use p2
+ self.aware_end = aware_end # end, if use P6 might be 5
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ image_size = image.size
+ image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1
+ if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image
+ image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]])
+
+ image_proposals_w_id = append_prop_id(image_proposals) # start from 1
+
+ assert len(self.transform) == 6
+ # transform = (transform_whole_img, transform_img, transform_flip, transform_post_1, transform_post_2, transform_cutout)
+
+ tries = 0
+ least_common = self.least_common
+
+ while tries < self.max_tries:
+ img, params = self.transform[0](image) # whole image resize
+ img2, params2 = self.transform[1](image) # random crop resize
+
+ params_overlap = cal_overlap_params(params, params2)
+ overlap_props = get_overlap_props(image_proposals_w_id, params_overlap)
+ selected_image_props = select_props(overlap_props, self.select_strategy, self.select_k) # check paras are
+
+ # TODO: ensure clipped bboxs width and height are greater than 32
+ if selected_image_props.shape[0] >= least_common: # ok
+ break
+ least_common = max(least_common // 2, 1)
+ tries += 1
+
+ bboxs = clip_bboxs(selected_image_props, params[0], params[1], params[2], params[3])
+ bboxs2 = clip_bboxs(selected_image_props, params2[0], params2[1], params2[2], params2[3])
+ common_bboxs_ids = get_common_bboxs_ids(bboxs, bboxs2)
+
+
+ pad1 = self.padding_k - bboxs.shape[0]
+ if pad1 > 0:
+ # pad_bboxs = jitter_bboxs(bboxs, common_bboxs_ids, self.jitter_ratio, pad1, params[2], params[3])
+ pad_bboxs = pad_bboxs_with_common(bboxs, common_bboxs_ids, self.jitter_ratio, pad1, params[2], params[3])
+ bboxs = np.concatenate([bboxs, pad_bboxs], axis=0)
+ pad2 = self.padding_k - bboxs2.shape[0]
+ if pad2 > 0:
+ # pad_bboxs2 = jitter_bboxs(bboxs2, common_bboxs_ids, self.jitter_ratio, pad2, params2[2], params2[3])
+ pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids, self.jitter_ratio, pad2, params2[2], params2[3])
+ bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0)
+ correspondence = get_correspondence_matrix(bboxs, bboxs2)
+
+ resized_bboxs = resize_bboxs(bboxs, params[2], params[3], self.view_size)
+ resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size)
+ resized_bboxs = resized_bboxs.astype(int)
+ resized_bboxs2 = resized_bboxs2.astype(int)
+
+ bboxs = bboxs_to_tensor(bboxs, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+
+
+ img, bboxs, params = self.transform[2](img, bboxs, params) # flip
+ img2, bboxs2, params2 = self.transform[2](img2, bboxs2, params2)
+
+ img1_1x = self.transform[3](img) # color
+ img2_1x = self.transform[4](img2) # color
+
+ img2_1x_cut = self.transform[5](img2_1x, resized_bboxs2) # cutout
+
+ return img1_1x, img2_1x_cut, bboxs, bboxs2, correspondence, index, target
+
+
+class ImageFolderImageAsymBboxAwareMultiJitter1(DatasetFolderProps):
+ def __init__(self, root, ann_file='', img_prefix='', train_props_file='',
+ image_size=0, select_strategy='', select_k=0, weight_strategy='',
+ jitter_prob=0.0, jitter_ratio=0.0,
+ padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0,
+ transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no", dataset='ImageNet'):
+ super(ImageFolderImageAsymBboxAwareMultiJitter1, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ train_props_file=train_props_file,
+ select_strategy=select_strategy, select_k=select_k,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode, dataset=dataset)
+ self.imgs = self.samples
+ self.props = self.samples_props
+ self.select_strategy = select_strategy
+ self.select_k = select_k
+ self.weight_strategy = weight_strategy
+ self.jitter_prob = jitter_prob
+ self.jitter_ratio = jitter_ratio
+ self.padding_k = padding_k
+ self.view_size = (image_size, image_size)
+ self.view_size_3 = (image_size//2, image_size//2)
+ self.debug = False
+ self.max_tries = max_tries
+ self.least_common = max(self.padding_k // 2, 1)
+ self.aware_range = aware_range
+ assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range'
+ self.aware_start = aware_start # starting from 0 means use p2
+ self.aware_end = aware_end # end, if use P6 might be 5
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ image_size = image.size
+ image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1
+ if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image
+ image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]])
+
+ image_proposals_w_id = append_prop_id(image_proposals) # start from 1
+
+ assert len(self.transform) == 7
+ # transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2)
+
+ tries = 0
+ least_common = self.least_common
+
+ while tries < self.max_tries:
+ img, params = self.transform[0](image) # whole image resize
+ img2, params2 = self.transform[1](image) # random crop resize
+ img3, params3 = self.transform[2](image) # small random crop resize
+
+ params_overlap12 = cal_overlap_params(params, params2)
+ overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12)
+ selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are
+
+ params_overlap13 = cal_overlap_params(params, params3)
+ overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13)
+ selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are
+
+ # TODO: ensure clipped bboxs width and height are greater than 32
+ if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok
+ break
+ least_common = max(least_common // 2, 1)
+ tries += 1
+
+
+ jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio)
+ jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio)
+
+ bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3])
+ bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3])
+ bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3])
+ bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3])
+ common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2)
+ common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3)
+
+
+ pad1_12 = self.padding_k - bboxs1_12.shape[0]
+ if pad1_12 > 0:
+ pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3])
+ bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0)
+
+ pad1_13 = self.padding_k - bboxs1_13.shape[0]
+ if pad1_13 > 0:
+ pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3])
+ bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0)
+
+ pad2 = self.padding_k - bboxs2.shape[0]
+ if pad2 > 0:
+ pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3])
+ bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0)
+
+ pad3 = self.padding_k - bboxs3.shape[0]
+ if pad3 > 0:
+ pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3])
+ bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0)
+
+
+ resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size)
+ resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size)
+ resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size)
+ resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3)
+ resized_bboxs1_12 = resized_bboxs1_12.astype(int)
+ resized_bboxs1_13 = resized_bboxs1_13.astype(int)
+ resized_bboxs2 = resized_bboxs2.astype(int)
+ resized_bboxs3 = resized_bboxs3.astype(int)
+
+ bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2)
+ bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3)
+
+
+ aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign)
+ aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign)
+
+ bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+
+ img, bboxs1_12, bboxs1_13, params = self.transform[3](img, bboxs1_12, bboxs1_13, params) # flip
+ img2, bboxs2, params2 = self.transform[4](img2, bboxs2, params2) # flip
+ img3, bboxs3, params3 = self.transform[4](img3, bboxs3, params3) # flip
+
+ img1 = self.transform[5](img) # color
+ img2 = self.transform[6](img2) # color
+ img3 = self.transform[6](img3) # color
+
+ return img1, img2, img3, bboxs1_12, bboxs1_13, bboxs2, bboxs3, aware_corres_12, aware_corres_13, index, target
+
+
+class ImageFolderImageAsymBboxAwareMultiJitter1Cutout(DatasetFolderProps):
+ def __init__(self, root, ann_file='', img_prefix='', train_props_file='',
+ image_size=0, select_strategy='', select_k=0, weight_strategy='',
+ jitter_prob=0.0, jitter_ratio=0.0,
+ padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0,
+ transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no", dataset='ImageNet'):
+ super(ImageFolderImageAsymBboxAwareMultiJitter1Cutout, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ train_props_file=train_props_file,
+ select_strategy=select_strategy, select_k=select_k,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode, dataset=dataset)
+ self.imgs = self.samples
+ self.props = self.samples_props
+ self.select_strategy = select_strategy
+ self.select_k = select_k
+ self.weight_strategy = weight_strategy
+ self.jitter_prob = jitter_prob
+ self.jitter_ratio = jitter_ratio
+ self.padding_k = padding_k
+ self.view_size = (image_size, image_size)
+ self.view_size_3 = (image_size//2, image_size//2)
+ self.debug = False
+ self.max_tries = max_tries
+ self.least_common = max(self.padding_k // 2, 1)
+ self.aware_range = aware_range
+ assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range'
+ self.aware_start = aware_start # starting from 0 means use p2
+ self.aware_end = aware_end # end, if use P6 might be 5
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ image_size = image.size
+ image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1
+ if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image
+ image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]])
+
+ image_proposals_w_id = append_prop_id(image_proposals) # start from 1
+
+ assert len(self.transform) == 8
+ # transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2, transform_cutout)
+
+ tries = 0
+ least_common = self.least_common
+
+ while tries < self.max_tries:
+ img, params = self.transform[0](image) # whole image resize
+ img2, params2 = self.transform[1](image) # random crop resize
+ img3, params3 = self.transform[2](image) # small random crop resize
+
+ params_overlap12 = cal_overlap_params(params, params2)
+ overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12)
+ selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are
+
+ params_overlap13 = cal_overlap_params(params, params3)
+ overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13)
+ selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are
+
+ # TODO: ensure clipped bboxs width and height are greater than 32
+ if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok
+ break
+ least_common = max(least_common // 2, 1)
+ tries += 1
+
+
+ jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio)
+ jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio)
+
+ bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3])
+ bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3])
+ bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3])
+ bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3])
+ common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2)
+ common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3)
+
+
+ pad1_12 = self.padding_k - bboxs1_12.shape[0]
+ if pad1_12 > 0:
+ pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3])
+ bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0)
+
+ pad1_13 = self.padding_k - bboxs1_13.shape[0]
+ if pad1_13 > 0:
+ pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3])
+ bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0)
+
+ pad2 = self.padding_k - bboxs2.shape[0]
+ if pad2 > 0:
+ pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3])
+ bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0)
+
+ pad3 = self.padding_k - bboxs3.shape[0]
+ if pad3 > 0:
+ pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3])
+ bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0)
+
+
+ resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size)
+ resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size)
+ resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size)
+ resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3)
+ resized_bboxs1_12 = resized_bboxs1_12.astype(int)
+ resized_bboxs1_13 = resized_bboxs1_13.astype(int)
+ resized_bboxs2 = resized_bboxs2.astype(int)
+ resized_bboxs3 = resized_bboxs3.astype(int)
+
+ bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2)
+ bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3)
+
+
+ aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign)
+ aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign)
+
+ bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+
+ img, bboxs1_12, bboxs1_13, params = self.transform[3](img, bboxs1_12, bboxs1_13, params) # flip
+ img2, bboxs2, params2 = self.transform[4](img2, bboxs2, params2) # flip
+ img3, bboxs3, params3 = self.transform[4](img3, bboxs3, params3) # flip
+
+ img1 = self.transform[5](img) # color
+ img2 = self.transform[6](img2) # color
+ img3 = self.transform[6](img3) # color
+
+ img2_cutout = self.transform[7](img2, resized_bboxs2)
+ img3_cutout = self.transform[7](img3, resized_bboxs3)
+
+ return img1, img2_cutout, img3_cutout, bboxs1_12, bboxs1_13, bboxs2, bboxs3, aware_corres_12, aware_corres_13, index, target
+
+
+class ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1(DatasetFolderProps):
+ def __init__(self, root, ann_file='', img_prefix='', train_props_file='',
+ image_size=0, image3_size=0, image4_size=0, select_strategy='', select_k=0, weight_strategy='',
+ jitter_prob=0.0, jitter_ratio=0.0,
+ padding_k='', aware_range=[], aware_start=0, aware_end=4, max_tries=0,
+ transform=None, target_transform=None,
+ loader=default_img_loader, cache_mode="no", dataset='ImageNet'):
+ super(ImageFolderImageAsymBboxAwareMulti3ResizeExtraJitter1, self).__init__(root, loader, IMG_EXTENSIONS,
+ ann_file=ann_file, img_prefix=img_prefix,
+ train_props_file=train_props_file,
+ select_strategy=select_strategy, select_k=select_k,
+ transform=transform, target_transform=target_transform,
+ cache_mode=cache_mode, dataset=dataset)
+ self.imgs = self.samples
+ self.props = self.samples_props
+ self.select_strategy = select_strategy
+ self.select_k = select_k
+ self.weight_strategy = weight_strategy
+ self.jitter_prob = jitter_prob
+ self.jitter_ratio = jitter_ratio
+ self.padding_k = padding_k
+ self.view_size = (image_size, image_size)
+ self.view_size_3 = (image3_size, image3_size)
+ self.view_size_4 = (image4_size, image4_size)
+ assert image3_size > 0
+ assert image4_size > 0
+ self.debug = False
+ self.max_tries = max_tries
+ self.least_common = max(self.padding_k // 2, 1)
+ self.aware_range = aware_range
+ assert len(self.aware_range) == 5, 'Must give P2 P3 P4 P5 P6 size range'
+ self.aware_start = aware_start # starting from 0 means use p2
+ self.aware_end = aware_end # end, if use P6 might be 5
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ image = self.loader(path)
+ image_size = image.size
+ image_proposals = self.props[index] # for cur image, numpy array type, [[x1, y1, x2, y2]] x2 = x1 + w - 1
+ if image_proposals.shape[0] == 0: # if no proposals, insert one single proposal, the whole raw image
+ image_proposals = np.array([[0, 0, image_size[0] - 1, image_size[1] - 1]])
+
+ image_proposals_w_id = append_prop_id(image_proposals) # start from 1
+
+ assert len(self.transform) == 8
+ # transform = (transform_whole_img, transform_img, transform_img_small, transform_img_resize, transform_flip_flip, transform_flip, transform_post_1, transform_post_2)
+
+ tries = 0
+ least_common = self.least_common
+
+ while tries < self.max_tries:
+ img, params = self.transform[0](image) # whole image resize
+ img2, params2 = self.transform[1](image) # random crop resize
+ img3, params3 = self.transform[2](image) # small random crop resize
+
+ params_overlap12 = cal_overlap_params(params, params2)
+ overlap_props12 = get_overlap_props(image_proposals_w_id, params_overlap12)
+ selected_image_props12 = select_props(overlap_props12, self.select_strategy, self.select_k) # check paras are
+
+ params_overlap13 = cal_overlap_params(params, params3)
+ overlap_props13 = get_overlap_props(image_proposals_w_id, params_overlap13)
+ selected_image_props13 = select_props(overlap_props13, self.select_strategy, self.select_k) # check paras are
+
+
+ # TODO: ensure clipped bboxs width and height are greater than 32
+ if selected_image_props12.shape[0] >= least_common and selected_image_props13.shape[0] >= least_common: # ok
+ break
+ least_common = max(least_common // 2, 1)
+ tries += 1
+
+ img4 = self.transform[3](img2) # image4 are resized from image 2
+
+ jittered_selected_image_props12 = jitter_props(selected_image_props12, self.jitter_prob, self.jitter_ratio)
+ jittered_selected_image_props13 = jitter_props(selected_image_props13, self.jitter_prob, self.jitter_ratio)
+
+ bboxs1_12 = clip_bboxs(jittered_selected_image_props12, params[0], params[1], params[2], params[3])
+ bboxs1_13 = clip_bboxs(jittered_selected_image_props13, params[0], params[1], params[2], params[3])
+ bboxs2 = clip_bboxs(selected_image_props12, params2[0], params2[1], params2[2], params2[3])
+ bboxs3 = clip_bboxs(selected_image_props13, params3[0], params3[1], params3[2], params3[3])
+ common_bboxs_ids12 = get_common_bboxs_ids(bboxs1_12, bboxs2)
+ common_bboxs_ids13 = get_common_bboxs_ids(bboxs1_13, bboxs3)
+
+
+ pad1_12 = self.padding_k - bboxs1_12.shape[0]
+ if pad1_12 > 0:
+ pad_bboxs1_12 = pad_bboxs_with_common(bboxs1_12, common_bboxs_ids12, self.jitter_ratio, pad1_12, params[2], params[3])
+ bboxs1_12 = np.concatenate([bboxs1_12, pad_bboxs1_12], axis=0)
+
+ pad1_13 = self.padding_k - bboxs1_13.shape[0]
+ if pad1_13 > 0:
+ pad_bboxs1_13 = pad_bboxs_with_common(bboxs1_13, common_bboxs_ids13, self.jitter_ratio, pad1_13, params[2], params[3])
+ bboxs1_13 = np.concatenate([bboxs1_13, pad_bboxs1_13], axis=0)
+
+ pad2 = self.padding_k - bboxs2.shape[0]
+ if pad2 > 0:
+ pad_bboxs2 = pad_bboxs_with_common(bboxs2, common_bboxs_ids12, self.jitter_ratio, pad2, params2[2], params2[3])
+ bboxs2 = np.concatenate([bboxs2, pad_bboxs2], axis=0)
+
+ pad3 = self.padding_k - bboxs3.shape[0]
+ if pad3 > 0:
+ pad_bboxs3 = pad_bboxs_with_common(bboxs3, common_bboxs_ids13, self.jitter_ratio, pad3, params3[2], params3[3])
+ bboxs3 = np.concatenate([bboxs3, pad_bboxs3], axis=0)
+
+ bboxs1_14 = np.copy(bboxs1_12)
+
+ bboxs4 = np.copy(bboxs2)
+ params4 = np.copy(params2)
+
+ resized_bboxs1_12 = resize_bboxs(bboxs1_12, params[2], params[3], self.view_size)
+ resized_bboxs1_13 = resize_bboxs(bboxs1_13, params[2], params[3], self.view_size)
+ resized_bboxs1_14 = resize_bboxs(bboxs1_14, params[2], params[3], self.view_size)
+ resized_bboxs2 = resize_bboxs(bboxs2, params2[2], params2[3], self.view_size)
+ resized_bboxs3 = resize_bboxs(bboxs3, params3[2], params3[3], self.view_size_3)
+ resized_bboxs4 = resize_bboxs(bboxs4, params4[2], params4[3], self.view_size_4)
+ resized_bboxs1_12 = resized_bboxs1_12.astype(int)
+ resized_bboxs1_13 = resized_bboxs1_13.astype(int)
+ resized_bboxs1_14 = resized_bboxs1_14.astype(int)
+ resized_bboxs2 = resized_bboxs2.astype(int)
+ resized_bboxs3 = resized_bboxs3.astype(int)
+ resized_bboxs4 = resized_bboxs4.astype(int)
+
+ bboxs1_12_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_12, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs1_13_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_13, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs1_14_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs1_14, self.aware_range, self.aware_start, self.aware_end, -1)
+ bboxs2_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs2, self.aware_range, self.aware_start, self.aware_end, -2)
+ bboxs3_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs3, self.aware_range, self.aware_start, self.aware_end, -3)
+ bboxs4_with_feature_assign = assign_bboxs_to_feature_map(resized_bboxs4, self.aware_range, self.aware_start, self.aware_end, -4)
+
+
+ aware_corres_12 = get_aware_correspondence_matrix(bboxs1_12_with_feature_assign, bboxs2_with_feature_assign)
+ aware_corres_13 = get_aware_correspondence_matrix(bboxs1_13_with_feature_assign, bboxs3_with_feature_assign)
+ aware_corres_14 = get_aware_correspondence_matrix(bboxs1_14_with_feature_assign, bboxs4_with_feature_assign)
+
+ bboxs1_12 = bboxs_to_tensor(bboxs1_12, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs1_13 = bboxs_to_tensor(bboxs1_13, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs1_14 = bboxs_to_tensor(bboxs1_14, params) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs2 = bboxs_to_tensor(bboxs2, params2) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs3 = bboxs_to_tensor(bboxs3, params3) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+ bboxs4 = bboxs_to_tensor(bboxs4, params4) # x1y1x2y2 -> x1y1x2y2 (0, 1)
+
+ img, bboxs1_12, bboxs1_13, bboxs1_14, params = self.transform[4](img, bboxs1_12, bboxs1_13, bboxs1_14, params) # flip
+ img2, bboxs2, params2 = self.transform[5](img2, bboxs2, params2) # flip
+ img3, bboxs3, params3 = self.transform[5](img3, bboxs3, params3) # flip
+ img4, bboxs4, params4 = self.transform[5](img4, bboxs4, params4) # flip
+
+ img1 = self.transform[6](img) # color
+ img2 = self.transform[7](img2) # color
+ img3 = self.transform[7](img3) # color
+ img4 = self.transform[7](img4) # color
+
+ return img1, img2, img3, img4, bboxs1_12, bboxs1_13, bboxs1_14, bboxs2, bboxs3, bboxs4, aware_corres_12, aware_corres_13, aware_corres_14, index, target
diff --git a/contrast/data/props_utils.py b/contrast/data/props_utils.py
new file mode 100644
index 0000000..6b58044
--- /dev/null
+++ b/contrast/data/props_utils.py
@@ -0,0 +1,43 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+
+
+def select_props(all_props, select_strategy, select_k):
+ # all_props: numpy array
+ if select_k > all_props.shape[0]: # if we do not have k proposals, we just select all
+ select_strategy = 'none'
+ selected_proposals = None
+ if select_strategy == 'none':
+ selected_proposals = all_props
+ elif select_strategy == 'random':
+ selected_idx = np.random.choice(all_props.shape[0], size=select_k, replace=False)
+ selected_proposals = all_props[selected_idx]
+ elif select_strategy == 'top':
+ selected_proposals = all_props[:select_k]
+ elif select_strategy == 'area':
+ areas = np.zeros((all_props.shape[0], ))
+ for i in range(all_props.shape[0]):
+ area = all_props[i][2] * all_props[i][3]
+ areas[i] = area
+ areas_index = areas.argsort()
+ selected_proposals = all_props[areas_index[::-1]][:select_k]
+ else:
+ raise NotImplementedError
+ return selected_proposals
+
+
+def convert_props(all_props):
+ all_props_np = np.array(all_props)
+ if all_props_np.shape[0] == 0:
+ return all_props_np
+ all_props_np = all_props_np[:, :4]
+ all_props_np[:, 2] = all_props_np[:, 0] + all_props_np[:, 2] - 1 # x2 = x1 + w - 1
+ all_props_np[:, 3] = all_props_np[:, 1] + all_props_np[:, 3] - 1 # y2 = y1 + h - 1
+ return all_props_np
diff --git a/contrast/data/rand_augment.py b/contrast/data/rand_augment.py
new file mode 100644
index 0000000..104809e
--- /dev/null
+++ b/contrast/data/rand_augment.py
@@ -0,0 +1,448 @@
+""" AutoAugment and RandAugment
+Implementation adapted from:
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
+Hacked together by Ross Wightman
+"""
+import random
+import math
+import re
+from PIL import Image, ImageOps, ImageEnhance
+import PIL
+import numpy as np
+
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+_HPARAMS_DEFAULT = dict(
+ translate_const=250,
+ img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _interpolation(kwargs):
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
+ if isinstance(interpolation, (list, tuple)):
+ return random.choice(interpolation)
+ else:
+ return interpolation
+
+
+def _check_args_tf(kwargs):
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+ kwargs.pop('fillcolor')
+ kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+ pixels = pct * img.size[0]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+ pixels = pct * img.size[1]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+ _check_args_tf(kwargs)
+ if _PIL_VER >= (5, 2):
+ return img.rotate(degrees, **kwargs)
+ elif _PIL_VER >= (5, 0):
+ w, h = img.size
+ post_trans = (0, 0)
+ rotn_center = (w / 2.0, h / 2.0)
+ angle = -math.radians(degrees)
+ matrix = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def transform(x, y, matrix):
+ (a, b, c, d, e, f) = matrix
+ return a * x + b * y + c, d * x + e * y + f
+
+ matrix[2], matrix[5] = transform(
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+ )
+ matrix[2] += rotn_center[0]
+ matrix[5] += rotn_center[1]
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+ else:
+ return img.rotate(degrees, resample=kwargs['resample'])
+
+
+def auto_contrast(img, **__):
+ return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+ return ImageOps.invert(img)
+
+
+def identity(img, **__):
+ return img
+
+
+def equalize(img, **__):
+ return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+ return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+ lut = []
+ for i in range(256):
+ if i < thresh:
+ lut.append(min(255, i + add))
+ else:
+ lut.append(i)
+ if img.mode in ("L", "RGB"):
+ if img.mode == "RGB" and len(lut) == 256:
+ lut = lut + lut + lut
+ return img.point(lut)
+ else:
+ return img
+
+
+def posterize(img, bits_to_keep, **__):
+ if bits_to_keep >= 8:
+ return img
+ return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+ return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+ return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+ return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+ """With 50% prob, negate the value"""
+ return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+ # range [-30, 30]
+ level = (level / _MAX_LEVEL) * 30.
+ level = _randomly_negate(level)
+ return level,
+
+
+def _enhance_level_to_arg(level, _hparams):
+ # range [0.1, 1.9]
+ return (level / _MAX_LEVEL) * 1.8 + 0.1,
+
+
+def _shear_level_to_arg(level, _hparams):
+ # range [-0.3, 0.3]
+ level = (level / _MAX_LEVEL) * 0.3
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_abs_level_to_arg(level, hparams):
+ translate_const = hparams['translate_const']
+ level = (level / _MAX_LEVEL) * float(translate_const)
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_rel_level_to_arg(level, _hparams):
+ # range [-0.45, 0.45]
+ level = (level / _MAX_LEVEL) * 0.45
+ level = _randomly_negate(level)
+ return level,
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+ # As per original AutoAugment paper description
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
+ return int((level / _MAX_LEVEL) * 4) + 4,
+
+
+def _posterize_research_level_to_arg(level, _hparams):
+ # As per Tensorflow models research and UDA impl
+ # range [4, 0], 'keep 4 down to 0 MSB of original image'
+ return 4 - int((level / _MAX_LEVEL) * 4),
+
+
+def _posterize_tpu_level_to_arg(level, _hparams):
+ # As per Tensorflow TPU EfficientNet impl
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
+ return int((level / _MAX_LEVEL) * 4),
+
+
+def _solarize_level_to_arg(level, _hparams):
+ # range [0, 256]
+ return int((level / _MAX_LEVEL) * 256),
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+ # range [0, 110]
+ return int((level / _MAX_LEVEL) * 110),
+
+
+LEVEL_TO_ARG = {
+ 'AutoContrast': None,
+ 'Equalize': None,
+ 'Invert': None,
+ 'Identity': None,
+ 'Rotate': _rotate_level_to_arg,
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
+ 'PosterizeResearch': _posterize_research_level_to_arg,
+ 'PosterizeTpu': _posterize_tpu_level_to_arg,
+ 'Solarize': _solarize_level_to_arg,
+ 'SolarizeAdd': _solarize_add_level_to_arg,
+ 'Color': _enhance_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'TranslateX': _translate_abs_level_to_arg,
+ 'TranslateY': _translate_abs_level_to_arg,
+ 'TranslateXRel': _translate_rel_level_to_arg,
+ 'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+
+NAME_TO_OP = {
+ 'AutoContrast': auto_contrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Identity': identity,
+ 'Rotate': rotate,
+ 'PosterizeOriginal': posterize,
+ 'PosterizeResearch': posterize,
+ 'PosterizeTpu': posterize,
+ 'Solarize': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'Contrast': contrast,
+ 'Brightness': brightness,
+ 'Sharpness': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x_abs,
+ 'TranslateY': translate_y_abs,
+ 'TranslateXRel': translate_x_rel,
+ 'TranslateYRel': translate_y_rel,
+}
+
+
+class AutoAugmentOp:
+
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ self.aug_fn = NAME_TO_OP[name]
+ self.level_fn = LEVEL_TO_ARG[name]
+ self.prob = prob
+ self.magnitude = magnitude
+ self.hparams = hparams.copy()
+ self.kwargs = dict(
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+ resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+ )
+
+ # If magnitude_std is > 0, we introduce some randomness
+ # in the usually fixed policy and sample magnitude from a normal distribution
+ # with mean `magnitude` and std-dev of `magnitude_std`.
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
+
+ def __call__(self, img):
+ if random.random() > self.prob:
+ return img
+ magnitude = self.magnitude
+ if self.magnitude_std and self.magnitude_std > 0:
+ magnitude = random.gauss(magnitude, self.magnitude_std)
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+ return self.aug_fn(img, *level_args, **self.kwargs)
+
+
+_RAND_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'PosterizeTpu',
+ 'Solarize',
+ 'SolarizeAdd',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # FIXME I implement this as random erasing separately
+]
+
+_RAND_TRANSFORMS_CMC = [
+ 'AutoContrast',
+ 'Identity',
+ 'Rotate',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ # 'Cutout' # FIXME I implement this as random erasing separately
+]
+
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+ 'Rotate': 0.3,
+ 'ShearX': 0.2,
+ 'ShearY': 0.2,
+ 'TranslateXRel': 0.1,
+ 'TranslateYRel': 0.1,
+ 'Color': .025,
+ 'Sharpness': 0.025,
+ 'AutoContrast': 0.025,
+ 'Solarize': .005,
+ 'SolarizeAdd': .005,
+ 'Contrast': .005,
+ 'Brightness': .005,
+ 'Equalize': .005,
+ 'PosterizeTpu': 0,
+ 'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+ transforms = transforms or _RAND_TRANSFORMS
+ assert weight_idx == 0 # only one set of weights currently
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
+ probs = [rand_weights[k] for k in transforms]
+ probs /= np.sum(probs)
+ return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+ """rand augment ops for RGB images"""
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS
+ return [AutoAugmentOp(
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None):
+ """rand augment ops for CMC images (removing color ops)"""
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS_CMC
+ return [AutoAugmentOp(
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+ def __init__(self, ops, num_layers=2, choice_weights=None):
+ self.ops = ops
+ self.num_layers = num_layers
+ self.choice_weights = choice_weights
+
+ def __call__(self, img):
+ # no replacement when using weighted choice
+ ops = np.random.choice(
+ self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+ for op in ops:
+ img = op(img)
+ return img
+
+
+def rand_augment_transform(config_str, hparams, use_cmc=False):
+ """
+ Create a RandAugment transform
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude of rand augment
+ 'n' - integer num layers (number of transform ops selected per image)
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+ 'mstd' - float std deviation of magnitude noise applied
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+ :param use_cmc: Flag indicates removing augmentation for coloring ops.
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
+ num_layers = 2 # default to 2 ops per image
+ weight_idx = None # default to no probability weights for op choice
+ config = config_str.split('-')
+ assert config[0] == 'rand'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'n':
+ num_layers = int(val)
+ elif key == 'w':
+ weight_idx = int(val)
+ else:
+ assert False, 'Unknown RandAugment config section'
+ if use_cmc:
+ ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams)
+ else:
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
diff --git a/contrast/data/sampler.py b/contrast/data/sampler.py
new file mode 100644
index 0000000..5290225
--- /dev/null
+++ b/contrast/data/sampler.py
@@ -0,0 +1,53 @@
+import numpy as np
+from torch.utils.data import Sampler
+
+
+class SubsetSlidingWindowSampler(Sampler):
+ r"""Samples elements randomly from a given list of indices, without replacement.
+
+ Arguments:
+ indices (sequence): a sequence of indices
+ """
+
+ def __init__(self, indices, window_stride, window_size, shuffle_per_epoch=False):
+ self.window_stride = window_stride
+ self.window_size = window_size
+ self.shuffle_per_epoch = shuffle_per_epoch
+ self.indices = indices
+ np.random.shuffle(self.indices)
+ self.start_index = 0
+
+ def __iter__(self):
+ # optionally shuffle all indices per epoch
+ if self.shuffle_per_epoch and self.start_index + self.window_size > len(self):
+ np.random.shuffle(self.indices)
+
+ # get indices of sampler in the current window
+ indices = np.mod(np.arange(self.window_size, dtype=np.int) + self.start_index, len(self))
+ window_indices = self.indices[indices]
+
+ # shuffle the current window
+ np.random.shuffle(window_indices)
+
+ # move start index to next window
+ self.start_index = (self.start_index + self.window_stride) % len(self)
+
+ return iter(window_indices.tolist())
+
+ def __len__(self):
+ return len(self.indices)
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {"start_index": self.start_index}
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+ Arguments:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
diff --git a/contrast/data/selective_search_utils.py b/contrast/data/selective_search_utils.py
new file mode 100644
index 0000000..898f7c1
--- /dev/null
+++ b/contrast/data/selective_search_utils.py
@@ -0,0 +1,17 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+
+
+def append_prop_id(image_proposals):
+ prop_ids = np.arange(1, image_proposals.shape[0]+1) # we start from 1, refer to clip_bboxs, which set
+ prop_ids = np.reshape(prop_ids, (image_proposals.shape[0], 1))
+ image_proposals_with_prop_id = np.concatenate((image_proposals, prop_ids), axis=1)
+
+ return image_proposals_with_prop_id
diff --git a/contrast/data/transform.py b/contrast/data/transform.py
new file mode 100644
index 0000000..4f28477
--- /dev/null
+++ b/contrast/data/transform.py
@@ -0,0 +1,146 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+from PIL import ImageFilter, ImageOps
+from torchvision import transforms
+
+from . import transform_ops
+
+
+class GaussianBlur(object):
+ """Gaussian Blur version 2"""
+
+ def __call__(self, x):
+ sigma = np.random.uniform(0.1, 2.0)
+ x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
+ return x
+
+
+def get_transform(args, aug_type, crop, image_size=224, crop1=0.9, cutout_prob=0.5, cutout_ratio=(0.1, 0.2),
+ image3_size=224, image4_size=224):
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ if aug_type == 'ImageAsymBboxCutout':
+ transform_whole_img = transform_ops.WholeImageResizedParams(image_size)
+ transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.))
+ transform_flip = transform_ops.RandomHorizontalFlipImageBbox()
+
+ transform_post_1 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=1.0),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_post_2 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=0.1),
+ transforms.RandomApply([ImageOps.solarize], p=0.2),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_cutout = transform_ops.RandomCutoutInBbox(image_size, cutout_prob=cutout_prob, cutout_ratio=cutout_ratio)
+ transform = (transform_whole_img, transform_img, transform_flip, transform_post_1, transform_post_2, transform_cutout)
+
+
+ elif aug_type == 'ImageAsymBboxAwareMultiJitter1':
+ transform_whole_img = transform_ops.WholeImageResizedParams(image_size)
+ transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.))
+ transform_img_small = transform_ops.RandomResizedCropParams(image_size//2, scale=(crop, 1.))
+ transform_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBbox()
+ transform_flip = transform_ops.RandomHorizontalFlipImageBbox()
+ transform_post_1 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=1.0),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_post_2 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=0.1),
+ transforms.RandomApply([ImageOps.solarize], p=0.2),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2)
+
+
+ elif aug_type == 'ImageAsymBboxAwareMultiJitter1Cutout':
+ transform_whole_img = transform_ops.WholeImageResizedParams(image_size)
+ transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.))
+ transform_img_small = transform_ops.RandomResizedCropParams(image_size//2, scale=(crop, 1.))
+ transform_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBbox()
+ transform_flip = transform_ops.RandomHorizontalFlipImageBbox()
+ transform_post_1 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=1.0),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_post_2 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=0.1),
+ transforms.RandomApply([ImageOps.solarize], p=0.2),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_cutout = transform_ops.RandomCutoutInBbox(image_size, cutout_prob=cutout_prob, cutout_ratio=cutout_ratio)
+ transform = (transform_whole_img, transform_img, transform_img_small, transform_flip_flip, transform_flip, transform_post_1, transform_post_2, transform_cutout)
+
+
+ elif aug_type == 'ImageAsymBboxAwareMulti3ResizeExtraJitter1':
+ transform_whole_img = transform_ops.WholeImageResizedParams(image_size)
+ transform_img = transform_ops.RandomResizedCropParams(image_size, scale=(crop, 1.))
+ transform_img_small = transform_ops.RandomResizedCropParams(image3_size, scale=(crop, 1.))
+ transform_img_resize = transforms.Resize(image4_size)
+ transform_flip_flip_flip = transform_ops.RandomHorizontalFlipImageBboxBboxBbox()
+ transform_flip = transform_ops.RandomHorizontalFlipImageBbox()
+ transform_post_1 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=1.0),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_post_2 = transform_ops.ComposeImage([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ transforms.RandomApply([GaussianBlur()], p=0.1),
+ transforms.RandomApply([ImageOps.solarize], p=0.2),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform = (transform_whole_img, transform_img, transform_img_small, transform_img_resize, transform_flip_flip_flip, transform_flip, transform_post_1, transform_post_2)
+
+
+ elif aug_type == 'NULL': # used in linear evaluation
+ transform = transform_ops.Compose([
+ transform_ops.RandomResizedCropCoord(image_size, scale=(crop, 1.)),
+ transform_ops.RandomHorizontalFlipCoord(),
+ transforms.ToTensor(),
+ normalize,
+ ])
+
+ elif aug_type == 'val': # used in validate
+ transform = transforms.Compose([
+ transforms.Resize(image_size + 32),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor(),
+ normalize
+ ])
+ else:
+ supported = '[ImageAsymBboxCutout, ImageAsymBboxAwareMultiJitter1, ImageAsymBboxAwareMultiJitter1Cutout, ImageAsymBboxAwareMulti3ResizeExtraJitter1, NULL]'
+ raise NotImplementedError(f'aug_type "{aug_type}" not supported. Should in {supported}')
+
+ return transform
diff --git a/contrast/data/transform_ops.py b/contrast/data/transform_ops.py
new file mode 100644
index 0000000..ed06716
--- /dev/null
+++ b/contrast/data/transform_ops.py
@@ -0,0 +1,566 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import math
+import random
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as TF
+from PIL import Image
+from torchvision.transforms import functional as F
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: 'PIL.Image.NEAREST',
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
+ Image.HAMMING: 'PIL.Image.HAMMING',
+ Image.BOX: 'PIL.Image.BOX',
+}
+
+
+def _get_image_size(img):
+ if F._is_pil_image(img):
+ return img.size
+ elif isinstance(img, torch.Tensor) and img.dim() > 2:
+ return img.shape[-2:][::-1]
+ else:
+ raise TypeError("Unexpected type {}".format(type(img)))
+
+
+def crop_tensor(img_tensor, top, left, height, width):
+ return img_tensor[:, :, left:left+width, top:top+height]
+
+
+def resize_tensor(img_tensor, size, interpolation='bilinear'):
+ return TF.interpolate(img_tensor, size, mode=interpolation, align_corners=False)
+
+
+def resized_crop_tensor(img_tensor, top, left, height, width, size, interpolation='bilinear'):
+ """
+ tensor version of F.resized_crop
+ """
+ assert isinstance(img_tensor, torch.Tensor)
+ img_tensor = crop_tensor(img_tensor, top, left, height, width)
+ img_tensor = resize_tensor(img_tensor, size, interpolation)
+ return img_tensor
+
+
+
+class ComposeImage(object):
+ """Composes several transforms together.
+
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img):
+ for t in self.transforms:
+ img = t(img)
+ return img
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+class Compose(object):
+ """Composes several transforms together.
+
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img):
+ coord = None
+ for t in self.transforms:
+ if 'RandomResizedCropCoord' in t.__class__.__name__:
+ img, coord = t(img)
+ elif 'FlipCoord' in t.__class__.__name__:
+ img, coord = t(img, coord)
+ else:
+ img = t(img)
+ return img, coord
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+class WholeImageResizedParams(object):
+ """Crop the given PIL Image to random size and aspect ratio.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ if isinstance(size, (tuple, list)):
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation = interpolation
+
+ @staticmethod
+ def get_params(img):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ width, height = _get_image_size(img)
+
+ return 0, 0, height, width, height, width
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ params = self.get_params(img)
+ params_np = np.array(params)
+ i, j, h, w, _, _ = params
+
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), params_np
+
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
+
+
+
+class RandomResizedCropParams(object):
+ """Crop the given PIL Image to random size and aspect ratio.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+ if isinstance(size, (tuple, list)):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ warnings.warn("range should be of kind (min, max)")
+
+ self.interpolation = interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ width, height = _get_image_size(img)
+ area = height * width
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < w <= width and 0 < h <= height:
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ return i, j, h, w, height, width
+
+ # Fallback to central crop
+ in_ratio = float(width) / float(height)
+ if (in_ratio < min(ratio)):
+ w = width
+ h = int(round(w / min(ratio)))
+ elif (in_ratio > max(ratio)):
+ h = height
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ return i, j, h, w, height, width
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ params = self.get_params(img, self.scale, self.ratio)
+ params_np = np.array(params)
+ i, j, h, w, _, _ = params
+
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), params_np
+
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
+
+
+
+class RandomHorizontalFlipImageBbox(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, bboxs, params):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1]
+ weight (torch tensor): (1, h, w), in [0, 1]
+ params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image)
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ bboxs_new = bboxs.clone()
+ bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2
+ bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1
+ # change x, keep y, w, h
+ params_new = np.copy(params)
+ params_new[1] = params[5] - params[3] - params[1]
+ return F.hflip(img), bboxs_new, params_new
+ return img, bboxs, params
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+
+
+class RandomCutoutInBbox(object):
+ """RandomCutout in bboxs
+ """
+ def __init__(self, size, cutout_prob, cutout_ratio=(0.1, 0.2)):
+ if isinstance(size, (tuple, list)):
+ self.size = size
+ else:
+ self.size = (size, size)
+ assert isinstance(cutout_ratio, tuple)
+ self.cutout_prob = cutout_prob
+ self.cutout_ratio = cutout_ratio
+ self.width = self.size[0]
+ self.height = self.size[1]
+
+ def __call__(self, img, resized_bboxs):
+ """ img is tensor
+ """
+ new_img = img.clone()
+ for bbox in resized_bboxs:
+ cutout_r = random.random()
+ if cutout_r < self.cutout_prob:
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
+ bbox_w = x2 - x1 + 1
+ bbox_h = y2 - y1 + 1
+ bbox_area = bbox_w * bbox_h
+
+ target_area = random.uniform(*self.cutout_ratio) * bbox_area
+
+ w = int(round(math.sqrt(target_area)))
+ h = w
+ center_cut_x = random.randint(x1, x2)
+ center_cut_y = random.randint(y1, y2)
+ cut_x1 = max(center_cut_x - w // 2, 0)
+ cut_x2 = min(center_cut_x + w // 2, self.width)
+ cut_y1 = max(center_cut_y - h // 2, 0)
+ cut_y2 = min(center_cut_y + h // 2, self.height)
+
+ # img is tensor 3, H, W
+ new_img[:, cut_y1:cut_y2+1, cut_x1:cut_x2+1] = 0.0
+ return new_img
+
+
+class RandomHorizontalFlipImageBboxBbox(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, bboxs, bboxs_p, params):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1]
+ bboxs_p (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img
+ weight (torch tensor): (1, h, w), in [0, 1]
+ params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image)
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ bboxs_new = bboxs.clone()
+ bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2
+ bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1
+
+ bboxs_p_new = bboxs_p.clone()
+ bboxs_p_new[:, 0] = 1.0 - bboxs_p[:, 2] # x1 = x2
+ bboxs_p_new[:, 2] = 1.0 - bboxs_p[:, 0] # x2 = x1
+ # change x, keep y, w, h
+ params_new = np.copy(params)
+ params_new[1] = params[5] - params[3] - params[1]
+ return F.hflip(img), bboxs_new, bboxs_p_new, params_new
+ return img, bboxs, bboxs_p, params
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+class RandomHorizontalFlipImageBboxBboxBbox(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, bboxs, bboxs_p, bboxs_q, params):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ bboxs (torch tensor): [[y1, x1, y2, x2]] in [0, 1]
+ bboxs_p (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img
+ bboxs_q (torch tensor): [[y1, x1, y2, x2]] in [0, 1] another bboxs for current img
+ weight (torch tensor): (1, h, w), in [0, 1]
+ params (numpy array), i, j, h(crop image), w(crop image), height(raw image), width(raw image)
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ bboxs_new = bboxs.clone()
+ bboxs_new[:, 0] = 1.0 - bboxs[:, 2] # x1 = x2
+ bboxs_new[:, 2] = 1.0 - bboxs[:, 0] # x2 = x1
+
+ bboxs_p_new = bboxs_p.clone()
+ bboxs_p_new[:, 0] = 1.0 - bboxs_p[:, 2] # x1 = x2
+ bboxs_p_new[:, 2] = 1.0 - bboxs_p[:, 0] # x2 = x1
+
+ bboxs_q_new = bboxs_q.clone()
+ bboxs_q_new[:, 0] = 1.0 - bboxs_q[:, 2] # x1 = x2
+ bboxs_q_new[:, 2] = 1.0 - bboxs_q[:, 0] # x2 = x1
+ # change x, keep y, w, h
+ params_new = np.copy(params)
+ params_new[1] = params[5] - params[3] - params[1]
+ return F.hflip(img), bboxs_new, bboxs_p_new, bboxs_q_new, params_new
+ return img, bboxs, bboxs_p, bboxs_q, params
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+class RandomResizedCropCoord(object):
+ """Crop the given PIL Image to random size and aspect ratio.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+ if isinstance(size, (tuple, list)):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ warnings.warn("range should be of kind (min, max)")
+
+ self.interpolation = interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ width, height = _get_image_size(img)
+ area = height * width
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < w <= width and 0 < h <= height:
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ return i, j, h, w, height, width
+
+ # Fallback to central crop
+ in_ratio = float(width) / float(height)
+ if (in_ratio < min(ratio)):
+ w = width
+ h = int(round(w / min(ratio)))
+ elif (in_ratio > max(ratio)):
+ h = height
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ return i, j, h, w, height, width
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w, height, width = self.get_params(img, self.scale, self.ratio)
+ coord = torch.Tensor([float(j) / (width - 1), float(i) / (height - 1),
+ float(j + w - 1) / (width - 1), float(i + h - 1) / (height - 1)])
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), coord
+
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
+class RandomHorizontalFlipCoord(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, coord):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ coord_new = coord.clone()
+ coord_new[0] = coord[2]
+ coord_new[2] = coord[0]
+ return F.hflip(img), coord_new
+ return img, coord
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
diff --git a/contrast/data/zipreader.py b/contrast/data/zipreader.py
new file mode 100644
index 0000000..ce0db9f
--- /dev/null
+++ b/contrast/data/zipreader.py
@@ -0,0 +1,85 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import os
+import zipfile
+
+
+def is_zip_path(img_or_path):
+ """judge if this is a zip path"""
+ return '.zip@' in img_or_path
+
+
+class ZipReader(object):
+ """A class to read zipped files"""
+ zip_bank = dict()
+
+ def __init__(self):
+ super(ZipReader, self).__init__()
+
+ @staticmethod
+ def get_zipfile(path):
+ zip_bank = ZipReader.zip_bank
+ if path not in zip_bank:
+ zfile = zipfile.ZipFile(path, 'r')
+ zip_bank[path] = zfile
+ return zip_bank[path]
+
+ @staticmethod
+ def split_zip_style_path(path):
+ pos_at = path.index('@')
+ assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
+
+ zip_path = path[0: pos_at]
+ folder_path = path[pos_at + 1:]
+ folder_path = str.strip(folder_path, '/')
+ return zip_path, folder_path
+
+ @staticmethod
+ def list_folder(path):
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ folder_list = []
+ for file_folder_name in zfile.namelist():
+ file_folder_name = str.strip(file_folder_name, '/')
+ if file_folder_name.startswith(folder_path) and \
+ len(os.path.splitext(file_folder_name)[-1]) == 0 and \
+ file_folder_name != folder_path:
+ if len(folder_path) == 0:
+ folder_list.append(file_folder_name)
+ else:
+ folder_list.append(file_folder_name[len(folder_path)+1:])
+
+ return folder_list
+
+ @staticmethod
+ def list_files(path, extension=None):
+ if extension is None:
+ extension = ['.*']
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ file_lists = []
+ for file_folder_name in zfile.namelist():
+ file_folder_name = str.strip(file_folder_name, '/')
+ if file_folder_name.startswith(folder_path) and \
+ str.lower(os.path.splitext(file_folder_name)[-1]) in extension:
+ if len(folder_path) == 0:
+ file_lists.append(file_folder_name)
+ else:
+ file_lists.append(file_folder_name[len(folder_path)+1:])
+
+ return file_lists
+
+ @staticmethod
+ def read(path):
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
+ zfile = ZipReader.get_zipfile(zip_path)
+ data = zfile.read(path_img)
+ return data
diff --git a/contrast/lars.py b/contrast/lars.py
new file mode 100644
index 0000000..2267146
--- /dev/null
+++ b/contrast/lars.py
@@ -0,0 +1,161 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+__all__ = ['LARS']
+
+
+def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
+ """Splits param group into weight_decay / non-weight decay.
+ Tweaked from https://bit.ly/3dzyqod
+ :param model: the torch.nn model
+ :param weight_decay: weight decay term
+ :param skip_list: extra modules (besides BN/bias) to skip
+ :returns: split param group into weight_decay/not-weight decay
+ :rtype: list(dict)
+ """
+ # if weight_decay == 0:
+ # return model.parameters()
+
+ decay, no_decay = [], []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ if len(param.shape) == 1 or name in skip_list:
+ # print(name)
+ no_decay.append(param)
+ else:
+ decay.append(param)
+ return [
+ {'params': no_decay, 'weight_decay': 0, 'ignore': True},
+ {'params': decay, 'weight_decay': weight_decay, 'ignore': False}]
+
+
+class LARS(Optimizer):
+ """Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a
+ :class:`~torch.optim.Optimizer` wrapper.
+
+ __ : https://arxiv.org/abs/1708.03888
+
+ Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If
+ you want to the same performance obtained with small-batch training when
+ you use large-batch training, LARS will be helpful::
+
+ Args:
+ optimizer (Optimizer):
+ optimizer to wrap
+ eps (float, optional):
+ epsilon to help with numerical stability while calculating the
+ adaptive learning rate
+ trust_coef (float, optional):
+ trust coefficient for calculating the adaptive learning rate
+
+ Example::
+ base_optimizer = optim.SGD(model.parameters(), lr=0.1)
+ optimizer = LARS(optimizer=base_optimizer)
+
+ output = model(input)
+ loss = loss_fn(output, target)
+ loss.backward()
+
+ optimizer.step()
+
+ """
+
+ def __init__(self, optimizer, eps=1e-8, trust_coef=0.001):
+ if eps < 0.0:
+ raise ValueError('invalid epsilon value: , %f' % eps)
+
+ if trust_coef < 0.0:
+ raise ValueError("invalid trust coefficient: %f" % trust_coef)
+
+ self.optim = optimizer
+ self.eps = eps
+ self.trust_coef = trust_coef
+
+ def __getstate__(self):
+ lars_dict = {}
+ lars_dict['eps'] = self.eps
+ lars_dict['trust_coef'] = self.trust_coef
+ return (self.optim, lars_dict)
+
+ def __setstate__(self, state):
+ self.optim, lars_dict = state
+ self.eps = lars_dict['eps']
+ self.trust_coef = lars_dict['trust_coef']
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, self.optim)
+
+ @property
+ def param_groups(self):
+ return self.optim.param_groups
+
+ @property
+ def state(self):
+ return self.optim.state
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def apply_adaptive_lrs(self):
+ with torch.no_grad():
+ for group in self.optim.param_groups:
+ weight_decay = group['weight_decay']
+ ignore = group.get('ignore', None) # NOTE: this is set by add_weight_decay
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ # Add weight decay before computing adaptive LR
+ # Seems to be pretty important in SIMclr style models.
+ if weight_decay > 0:
+ p.grad = p.grad.add(p, alpha=weight_decay)
+
+ # Ignore bias / bn terms for LARS update
+ if ignore is not None and not ignore:
+ # compute the parameter and gradient norms
+ param_norm = p.norm()
+ grad_norm = p.grad.norm()
+
+ # compute our adaptive learning rate
+ adaptive_lr = 1.0
+ if param_norm > 0 and grad_norm > 0:
+ adaptive_lr = self.trust_coef * param_norm / (grad_norm + self.eps)
+
+ # print("applying {} lr scaling to param of shape {}".format(adaptive_lr, p.shape))
+ p.grad = p.grad.mul(adaptive_lr)
+
+ def step(self, *args, **kwargs):
+ self.apply_adaptive_lrs()
+
+ # Zero out weight decay as we do it in LARS
+ weight_decay_orig = [group['weight_decay'] for group in self.optim.param_groups]
+ for group in self.optim.param_groups:
+ group['weight_decay'] = 0
+
+ loss = self.optim.step(*args, **kwargs) # Normal optimizer
+
+ # Restore weight decay
+ for group, wo in zip(self.optim.param_groups, weight_decay_orig):
+ group['weight_decay'] = wo
+
+ return loss
diff --git a/contrast/logger.py b/contrast/logger.py
new file mode 100644
index 0000000..b4c853a
--- /dev/null
+++ b/contrast/logger.py
@@ -0,0 +1,94 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import functools
+import logging
+import os
+import sys
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+# so that calling setup_logger multiple times won't add many handlers
+@functools.lru_cache()
+def setup_logger(
+ output=None, distributed_rank=0, *, color=True, name="contrast", abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
+
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logger
+
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = name
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+ return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
diff --git a/contrast/lr_scheduler.py b/contrast/lr_scheduler.py
new file mode 100644
index 0000000..0b8b67d
--- /dev/null
+++ b/contrast/lr_scheduler.py
@@ -0,0 +1,93 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+from torch.optim.lr_scheduler import (CosineAnnealingLR, MultiStepLR,
+ _LRScheduler)
+
+
+# noinspection PyAttributeOutsideInit
+class GradualWarmupScheduler(_LRScheduler):
+ """ Gradually warm-up(increasing) learning rate in optimizer.
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ multiplier: init learning rate = base lr / multiplier
+ warmup_epoch: target learning rate is reached at warmup_epoch, gradually
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
+ """
+
+ def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1):
+ self.multiplier = multiplier
+ if self.multiplier <= 1.:
+ raise ValueError('multiplier should be greater than 1.')
+ self.warmup_epoch = warmup_epoch
+ self.after_scheduler = after_scheduler
+ self.finished = False
+ super().__init__(optimizer, last_epoch=last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch > self.warmup_epoch:
+ return self.after_scheduler.get_lr()
+ else:
+ return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)
+ for base_lr in self.base_lrs]
+
+ def step(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = epoch
+ if epoch > self.warmup_epoch:
+ self.after_scheduler.step(epoch - self.warmup_epoch)
+ else:
+ super(GradualWarmupScheduler, self).step(epoch)
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+
+ state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'}
+ state['after_scheduler'] = self.after_scheduler.state_dict()
+ return state
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Arguments:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+
+ after_scheduler_state = state_dict.pop('after_scheduler')
+ self.__dict__.update(state_dict)
+ self.after_scheduler.load_state_dict(after_scheduler_state)
+
+
+def get_scheduler(optimizer, n_iter_per_epoch, args):
+ if "cosine" in args.lr_scheduler:
+ scheduler = CosineAnnealingLR(
+ optimizer=optimizer,
+ eta_min=0.000001,
+ T_max=(args.epochs - args.warmup_epoch) * n_iter_per_epoch)
+ elif "step" in args.lr_scheduler:
+ scheduler = MultiStepLR(
+ optimizer=optimizer,
+ gamma=args.lr_decay_rate,
+ milestones=[(m - args.warmup_epoch)*n_iter_per_epoch for m in args.lr_decay_epochs])
+ else:
+ raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported")
+
+ if args.warmup_epoch > 0:
+ scheduler = GradualWarmupScheduler(
+ optimizer,
+ multiplier=args.warmup_multiplier,
+ after_scheduler=scheduler,
+ warmup_epoch=args.warmup_epoch * n_iter_per_epoch)
+ return scheduler
diff --git a/contrast/models/SoCo_C4.py b/contrast/models/SoCo_C4.py
new file mode 100644
index 0000000..d4924c2
--- /dev/null
+++ b/contrast/models/SoCo_C4.py
@@ -0,0 +1,139 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.ops as tvops
+from torch.distributed import get_world_size
+
+from .base import BaseModel
+from .mlps import Pred_Head, Proj_Head
+
+
+class SoCo_C4(BaseModel):
+ """
+ Build a MoCo model with: a query encoder, a key encoder, and a queue
+ https://arxiv.org/abs/1911.05self.output_size22
+ """
+
+ def __init__(self, base_encoder, args):
+ """
+ dim: feature dimension (default: 128)
+ K: queue size; number of negative keys (default: 65536)
+ m: moco momentum of updating key encoder (default: 0.999)
+ T: softmax temperature (default: 0.07)
+ """
+ super(SoCo_C4, self).__init__(base_encoder, args)
+
+ self.contrast_num_negative = args.contrast_num_negative
+ self.contrast_momentum = args.contrast_momentum
+ self.contrast_temperature = args.contrast_temperature
+ self.output_size = args.output_size
+ self.aligned = args.aligned
+
+ # create the encoder
+ self.encoder = base_encoder(low_dim=args.feature_dim, head_type='pass', use_roi_align_on_c4=True)
+ self.projector = Proj_Head()
+ self.predictor = Pred_Head()
+
+ # create the encoder_k
+ self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='pass', use_roi_align_on_c4=True)
+ self.projector_k = Proj_Head()
+
+ self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data.copy_(param_q.data)
+ param_k.requires_grad = False
+
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor)
+
+ # create the queue
+ self.register_buffer("queue", torch.randn(args.feature_dim, self.contrast_num_negative))
+ self.queue = F.normalize(self.queue, dim=0)
+
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
+
+ self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs)
+ self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1))
+ # print('Initial', get_rank(), self.k, self.K)
+
+ @torch.no_grad()
+ def _momentum_update_key_encoder(self):
+ """
+ Momentum update of the key encoder
+ """
+ _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2.
+ self.k = self.k + 1
+ # print('Update', get_rank(), self.k, _contrast_momentum)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ def regression_loss_bboxs(self, vectors_q, vectors_k, correspondence):
+ M, C = vectors_q.shape
+ N, L, P = correspondence.shape
+ assert L == P
+ assert N * L == M
+ vectors_q = vectors_q.view(N, L, C)
+ vectors_k = vectors_k.view(N, L, C)
+ # vectors_q: N, L, C
+ # vectors_k: N, L, C
+ vectors_k = torch.transpose(vectors_k, 1, 2)
+ sim = torch.bmm(vectors_q, vectors_k) # N, L, L
+ loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6)
+ return -2 * loss.mean()
+
+ def forward(self, im_1, im_2, bboxs1, bboxs2, corres):
+ """
+ Input:
+ im_q: a batch of query images
+ im_k: a batch of key images
+ Output:
+ logits, targets
+ """
+ # compute query features
+ feat_1 = self.encoder(im_1, bboxs=bboxs1) # queries: NxC
+ proj_1 = self.projector(feat_1)
+ pred_1 = self.predictor(proj_1)
+ pred_1 = F.normalize(pred_1, dim=1)
+
+ feat_2 = self.encoder(im_2, bboxs=bboxs2)
+ proj_2 = self.projector(feat_2)
+ pred_2 = self.predictor(proj_2)
+ pred_2 = F.normalize(pred_2, dim=1)
+
+ # compute key features
+ with torch.no_grad(): # no gradient to keys
+ self._momentum_update_key_encoder() # update the key encoder
+
+ feat_1_ng = self.encoder_k(im_1, bboxs=bboxs1) # keys: NxC
+ proj_1_ng = self.projector_k(feat_1_ng)
+ proj_1_ng = F.normalize(proj_1_ng, dim=1)
+
+ feat_2_ng = self.encoder_k(im_2, bboxs=bboxs2)
+ proj_2_ng = self.projector_k(feat_2_ng)
+ proj_2_ng = F.normalize(proj_2_ng, dim=1)
+
+ # compute loss
+ corres_2_1 = corres.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1
+ loss = self.regression_loss_bboxs(pred_1, proj_2_ng, corres) + self.regression_loss_bboxs(pred_2, proj_1_ng, corres_2_1)
+ return loss
diff --git a/contrast/models/SoCo_FPN.py b/contrast/models/SoCo_FPN.py
new file mode 100644
index 0000000..fe34a0c
--- /dev/null
+++ b/contrast/models/SoCo_FPN.py
@@ -0,0 +1,286 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.ops as tvops
+from torch.distributed import get_world_size
+
+from .base import BaseModel
+from .box_util import append_batch_index_to_bboxs_and_scale
+from .fast_rcnn_conv_fc_head import FastRCNNConvFCHead
+from .fpn import FPN
+from .mlps import Pred_Head, Proj_Head
+
+
+class SoCo_FPN(BaseModel):
+ """
+ Build a MoCo model with: a query encoder, a key encoder, and a queue
+ https://arxiv.org/abs/1911.05self.output_size22
+ """
+
+ def __init__(self, base_encoder, args):
+ """
+ dim: feature dimension (default: 128)
+ K: queue size; number of negative keys (default: 65536)
+ m: moco momentum of updating key encoder (default: 0.999)
+ T: softmax temperature (default: 0.07)
+ Just cleaned unused forward tensors
+ """
+ super(SoCo_FPN, self).__init__(base_encoder, args)
+
+ self.contrast_num_negative = args.contrast_num_negative
+ self.contrast_momentum = args.contrast_momentum
+ self.contrast_temperature = args.contrast_temperature
+ self.output_size = args.output_size
+ self.aligned = args.aligned
+
+ norm_cfg = dict(type='BN', requires_grad=True)
+
+ # align to detectron2 !
+ # create the encoder
+ self.encoder = base_encoder(low_dim=args.feature_dim, head_type='multi_layer')
+ self.neck = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level,
+ end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs,
+ relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg)
+ self.head = FastRCNNConvFCHead()
+ self.projector = Proj_Head(in_dim=1024) # head channel
+ self.predictor = Pred_Head()
+
+ # create the encoder_k
+ self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='multi_layer')
+ self.neck_k = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level,
+ end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs,
+ relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg)
+ self.head_k = FastRCNNConvFCHead()
+ self.projector_k = Proj_Head(in_dim=1024) # head channel
+
+ self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data.copy_(param_q.data)
+ param_k.requires_grad = False
+
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.neck)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.neck_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.head)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.head_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor)
+
+ self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs)
+ self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1))
+ # print('Initial', get_rank(), self.k, self.K)
+
+ @torch.no_grad()
+ def _momentum_update_key_encoder(self):
+ """
+ Momentum update of the key encoder
+ """
+ _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2.
+ self.k = self.k + 1
+ # print('Update', get_rank(), self.k, _contrast_momentum)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+
+ def regression_loss_bboxs_aware(self, vectors_q, vectors_k, correspondence):
+ N, M, C = vectors_q.shape # N, P * L, C
+ N, M1, M2 = correspondence.shape # N, P * L, P * L
+ assert M == M1 == M2
+ # vectors_q: N, L, C
+ # vectors_k: N, L, C
+ vectors_k = torch.transpose(vectors_k, 1, 2) # N, C, P * L
+ sim = torch.bmm(vectors_q, vectors_k) # N, P * L, P * L
+ loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6)
+ return -2 * loss.mean()
+
+
+ def roi_align_feature_map(self, feature_map, bboxs):
+ feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp
+ N, C, H, W = feature_map.shape
+ N, L, _ = bboxs.shape
+
+ output_size = (self.output_size, self.output_size)
+
+ bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W)
+ aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned)
+ # N*L, C, output_size, output_size
+ return aligned_features
+
+ def forward(self, im_1, im_2, im_3, bboxs1_12, bboxs1_13, bboxs2, bboxs3, corres_12, corres_13):
+ """
+ Input:
+ im_q: a batch of query images
+ im_k: a batch of key images
+ Output:
+ logits, targets
+ """
+ # compute query features
+ N, L, _ = bboxs1_12.shape
+ feats_1 = self.encoder(im_1)
+ fpn_feats_1 = self.neck(feats_1) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_1) == 4
+
+ preds_1_12 = [None] * len(fpn_feats_1)
+ for i, feat_1_12 in enumerate(fpn_feats_1):
+ feat_roi_1_12 = self.roi_align_feature_map(feat_1_12, bboxs1_12)
+ feat_vec_1_12 = self.head(feat_roi_1_12)
+ proj_1_12 = self.projector(feat_vec_1_12)
+ pred_1_12 = self.predictor(proj_1_12)
+ pred_1_12 = F.normalize(pred_1_12, dim=1) # N * L, C
+ pred_1_12 = pred_1_12.reshape((N, L, -1)) # N, L, C
+ preds_1_12[i] = pred_1_12
+
+ preds_1_12 = torch.cat(preds_1_12, dim=1) # N, P * L, C
+
+
+ preds_1_13 = [None] * len(fpn_feats_1)
+ for i, feat_1_13 in enumerate(fpn_feats_1):
+ feat_roi_1_13 = self.roi_align_feature_map(feat_1_13, bboxs1_13)
+ feat_vec_1_13 = self.head(feat_roi_1_13)
+ proj_1_13 = self.projector(feat_vec_1_13)
+ pred_1_13 = self.predictor(proj_1_13)
+ pred_1_13 = F.normalize(pred_1_13, dim=1) # N * L, C
+ pred_1_13 = pred_1_13.reshape((N, L, -1)) # N, L, C
+ preds_1_13[i] = pred_1_13
+
+ preds_1_13 = torch.cat(preds_1_13, dim=1) # N, P * L, C
+
+
+ feats_2 = self.encoder(im_2)
+ fpn_feats_2 = self.neck(feats_2) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_2) == 4
+
+ preds_2 = [None] * len(fpn_feats_2)
+ for i, feat_2 in enumerate(fpn_feats_2):
+ feat_roi_2 = self.roi_align_feature_map(feat_2, bboxs2)
+ feat_vec_2 = self.head(feat_roi_2)
+ proj_2 = self.projector(feat_vec_2)
+ pred_2 = self.predictor(proj_2)
+ pred_2 = F.normalize(pred_2, dim=1)
+ pred_2 = pred_2.reshape((N, L, -1)) # N, L, C
+ preds_2[i] = pred_2
+
+ preds_2 = torch.cat(preds_2, dim=1) # N, P * L, C
+
+
+ feats_3 = self.encoder(im_3)
+ fpn_feats_3 = self.neck(feats_3) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_3) == 4
+
+ preds_3 = [None] * len(fpn_feats_3)
+ for i, feat_3 in enumerate(fpn_feats_3):
+ feat_roi_3 = self.roi_align_feature_map(feat_3, bboxs3)
+ feat_vec_3 = self.head(feat_roi_3)
+ proj_3 = self.projector(feat_vec_3)
+ pred_3 = self.predictor(proj_3)
+ pred_3 = F.normalize(pred_3, dim=1)
+ pred_3 = pred_3.reshape((N, L, -1)) # N, L, C
+ preds_3[i] = pred_3
+
+ preds_3 = torch.cat(preds_3, dim=1) # N, P * L, C
+
+
+ # compute key features
+ with torch.no_grad(): # no gradient to keys
+ self._momentum_update_key_encoder() # update the key encoder
+
+ feats_1_ng = self.encoder_k(im_1)
+ fpn_feats_1_ng = self.neck_k(feats_1_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_1_ng) == 4
+ projs_1_12_ng = [None] * len(fpn_feats_1_ng)
+ for i, feat_1_12_ng in enumerate(fpn_feats_1_ng):
+ feat_roi_1_12_ng = self.roi_align_feature_map(feat_1_12_ng, bboxs1_12)
+ feat_vec_1_12_ng = self.head_k(feat_roi_1_12_ng)
+ proj_1_12_ng = self.projector_k(feat_vec_1_12_ng)
+ proj_1_12_ng = F.normalize(proj_1_12_ng, dim=1)
+ proj_1_12_ng = proj_1_12_ng.reshape((N, L, -1))
+ projs_1_12_ng[i] = proj_1_12_ng
+
+ projs_1_12_ng = torch.cat(projs_1_12_ng, dim=1) # N, P * L, C
+
+
+ projs_1_13_ng = [None] * len(fpn_feats_1_ng)
+ for i, feat_1_13_ng in enumerate(fpn_feats_1_ng):
+ feat_roi_1_13_ng = self.roi_align_feature_map(feat_1_13_ng, bboxs1_13)
+ feat_vec_1_13_ng = self.head_k(feat_roi_1_13_ng)
+ proj_1_13_ng = self.projector_k(feat_vec_1_13_ng)
+ proj_1_13_ng = F.normalize(proj_1_13_ng, dim=1)
+ proj_1_13_ng = proj_1_13_ng.reshape((N, L, -1))
+ projs_1_13_ng[i] = proj_1_13_ng
+
+ projs_1_13_ng = torch.cat(projs_1_13_ng, dim=1) # N, P * L, C
+
+
+ feats_2_ng = self.encoder_k(im_2)
+ fpn_feats_2_ng = self.neck_k(feats_2_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_2_ng) == 4
+ projs_2_ng = [None] * len(fpn_feats_2_ng)
+ for i, feat_2_ng in enumerate(fpn_feats_2_ng):
+ feat_roi_2_ng = self.roi_align_feature_map(feat_2_ng, bboxs2)
+ feat_vec_2_ng = self.head_k(feat_roi_2_ng)
+ proj_2_ng = self.projector_k(feat_vec_2_ng)
+ proj_2_ng = F.normalize(proj_2_ng, dim=1)
+ proj_2_ng = proj_2_ng.reshape((N, L, -1))
+ projs_2_ng[i] = proj_2_ng
+
+ projs_2_ng = torch.cat(projs_2_ng, dim=1) # N, P * L, C
+
+
+ feats_3_ng = self.encoder_k(im_3)
+ fpn_feats_3_ng = self.neck_k(feats_3_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_3_ng) == 4
+ projs_3_ng = [None] * len(fpn_feats_3_ng)
+ for i, feat_3_ng in enumerate(fpn_feats_3_ng):
+ feat_roi_3_ng = self.roi_align_feature_map(feat_3_ng, bboxs3)
+ feat_vec_3_ng = self.head_k(feat_roi_3_ng)
+ proj_3_ng = self.projector_k(feat_vec_3_ng)
+ proj_3_ng = F.normalize(proj_3_ng, dim=1)
+ proj_3_ng = proj_3_ng.reshape((N, L, -1))
+ projs_3_ng[i] = proj_3_ng
+
+ projs_3_ng = torch.cat(projs_3_ng, dim=1) # N, P * L, C
+
+
+ # compute loss
+ corres_12_2to1 = corres_12.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1
+ corres_13_3to1 = corres_13.transpose(1, 2) # transpose dim 1 dim 2, map bboxs3 to bboxs1
+ loss_bbox_aware_12 = self.regression_loss_bboxs_aware(preds_1_12, projs_2_ng, corres_12) + self.regression_loss_bboxs_aware(preds_2, projs_1_12_ng, corres_12_2to1)
+ loss_bbox_aware_13 = self.regression_loss_bboxs_aware(preds_1_13, projs_3_ng, corres_13) + self.regression_loss_bboxs_aware(preds_3, projs_1_13_ng, corres_13_3to1)
+
+ loss = loss_bbox_aware_12 + loss_bbox_aware_13
+
+ return loss
diff --git a/contrast/models/SoCo_FPN_Star.py b/contrast/models/SoCo_FPN_Star.py
new file mode 100644
index 0000000..3fa9081
--- /dev/null
+++ b/contrast/models/SoCo_FPN_Star.py
@@ -0,0 +1,343 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.ops as tvops
+from torch.distributed import get_world_size
+
+from .base import BaseModel
+from .box_util import append_batch_index_to_bboxs_and_scale
+from .fast_rcnn_conv_fc_head import FastRCNNConvFCHead
+from .fpn import FPN
+from .mlps import Pred_Head, Proj_Head
+
+
+class SoCo_FPN_Star(BaseModel):
+ """
+ Build a MoCo model with: a query encoder, a key encoder, and a queue
+ https://arxiv.org/abs/1911.05self.output_size22
+ """
+
+ def __init__(self, base_encoder, args):
+ """
+ dim: feature dimension (default: 128)
+ K: queue size; number of negative keys (default: 65536)
+ m: moco momentum of updating key encoder (default: 0.999)
+ T: softmax temperature (default: 0.07)
+ Just cleaned unused forward tensors
+ """
+ super(SoCo_FPN_Star, self).__init__(base_encoder, args)
+
+ self.contrast_num_negative = args.contrast_num_negative
+ self.contrast_momentum = args.contrast_momentum
+ self.contrast_temperature = args.contrast_temperature
+ self.output_size = args.output_size
+ self.aligned = args.aligned
+
+ norm_cfg = dict(type='BN', requires_grad=True)
+
+ # align to detectron2 !
+ # create the encoder
+ self.encoder = base_encoder(low_dim=args.feature_dim, head_type='multi_layer')
+ self.neck = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level,
+ end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs,
+ relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg)
+ self.head = FastRCNNConvFCHead()
+ self.projector = Proj_Head(in_dim=1024) # head channel
+ self.predictor = Pred_Head()
+
+ # create the encoder_k
+ self.encoder_k = base_encoder(low_dim=args.feature_dim, head_type='multi_layer')
+ self.neck_k = FPN(in_channels=args.in_channels, out_channels=args.out_channels, num_outs=args.num_outs, start_level=args.start_level,
+ end_level=args.end_level, add_extra_convs=args.add_extra_convs, extra_convs_on_inputs=args.extra_convs_on_inputs,
+ relu_before_extra_convs=args.relu_before_extra_convs, norm_cfg=norm_cfg)
+ self.head_k = FastRCNNConvFCHead()
+ self.projector_k = Proj_Head(in_dim=1024) # head channel
+
+ self.roi_avg_pool = nn.AvgPool2d(self.output_size, stride=1)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data.copy_(param_q.data)
+ param_k.requires_grad = False
+
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.neck)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.neck_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.head)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.head_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.projector_k)
+ nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor)
+
+ self.K = int(args.num_instances * 1. / get_world_size() / args.batch_size * args.epochs)
+ self.k = int(args.num_instances * 1. / get_world_size() / args.batch_size * (args.start_epoch - 1))
+ # print('Initial', get_rank(), self.k, self.K)
+
+ @torch.no_grad()
+ def _momentum_update_key_encoder(self):
+ """
+ Momentum update of the key encoder
+ """
+ _contrast_momentum = 1. - (1. - self.contrast_momentum) * (np.cos(np.pi * self.k / self.K) + 1) / 2.
+ self.k = self.k + 1
+ # print('Update', get_rank(), self.k, _contrast_momentum)
+
+ for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.neck.parameters(), self.neck_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.head.parameters(), self.head_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+ for param_q, param_k in zip(self.projector.parameters(), self.projector_k.parameters()):
+ param_k.data = param_k.data * _contrast_momentum + param_q.data * (1. - _contrast_momentum)
+
+
+ def regression_loss_bboxs_aware(self, vectors_q, vectors_k, correspondence):
+ N, M, C = vectors_q.shape # N, P * L, C
+ N, M1, M2 = correspondence.shape # N, P * L, P * L
+ assert M == M1 == M2
+ # vectors_q: N, L, C
+ # vectors_k: N, L, C
+ vectors_k = torch.transpose(vectors_k, 1, 2) # N, C, P * L
+ sim = torch.bmm(vectors_q, vectors_k) # N, P * L, P * L
+ loss = (sim * correspondence).sum(-1).sum(-1) / (correspondence.sum(-1).sum(-1) + 1e-6)
+ return -2 * loss.mean()
+
+ def roi_align_feature_map(self, feature_map, bboxs):
+ feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp
+ N, C, H, W = feature_map.shape
+ N, L, _ = bboxs.shape
+
+ output_size = (self.output_size, self.output_size)
+
+ bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W)
+ aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned)
+ # N*L, C, output_size, output_size
+ return aligned_features
+
+ def forward(self, im_1, im_2, im_3, im_4, bboxs1_12, bboxs1_13, bboxs1_14, bboxs2, bboxs3, bboxs4, corres_12, corres_13, corres_14):
+ """
+ Input:
+ im_q: a batch of query images
+ im_k: a batch of key images
+ Output:
+ logits, targets
+ """
+ # compute query features
+ N, L, _ = bboxs1_12.shape
+ feats_1 = self.encoder(im_1)
+ fpn_feats_1 = self.neck(feats_1) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_1) == 4
+
+ preds_1_12 = [None] * len(fpn_feats_1)
+ for i, feat_1_12 in enumerate(fpn_feats_1):
+ feat_roi_1_12 = self.roi_align_feature_map(feat_1_12, bboxs1_12)
+ feat_vec_1_12 = self.head(feat_roi_1_12)
+ proj_1_12 = self.projector(feat_vec_1_12)
+ pred_1_12 = self.predictor(proj_1_12)
+ pred_1_12 = F.normalize(pred_1_12, dim=1) # N * L, C
+ pred_1_12 = pred_1_12.reshape((N, L, -1)) # N, L, C
+ preds_1_12[i] = pred_1_12
+
+ preds_1_12 = torch.cat(preds_1_12, dim=1) # N, P * L, C
+
+
+ preds_1_13 = [None] * len(fpn_feats_1)
+ for i, feat_1_13 in enumerate(fpn_feats_1):
+ feat_roi_1_13 = self.roi_align_feature_map(feat_1_13, bboxs1_13)
+ feat_vec_1_13 = self.head(feat_roi_1_13)
+ proj_1_13 = self.projector(feat_vec_1_13)
+ pred_1_13 = self.predictor(proj_1_13)
+ pred_1_13 = F.normalize(pred_1_13, dim=1) # N * L, C
+ pred_1_13 = pred_1_13.reshape((N, L, -1)) # N, L, C
+ preds_1_13[i] = pred_1_13
+
+ preds_1_13 = torch.cat(preds_1_13, dim=1) # N, P * L, C
+
+
+ preds_1_14 = [None] * len(fpn_feats_1)
+ for i, feat_1_14 in enumerate(fpn_feats_1):
+ feat_roi_1_14 = self.roi_align_feature_map(feat_1_14, bboxs1_14)
+ feat_vec_1_14 = self.head(feat_roi_1_14)
+ proj_1_14 = self.projector(feat_vec_1_14)
+ pred_1_14 = self.predictor(proj_1_14)
+ pred_1_14 = F.normalize(pred_1_14, dim=1) # N * L, C
+ pred_1_14 = pred_1_14.reshape((N, L, -1)) # N, L, C
+ preds_1_14[i] = pred_1_14
+
+ preds_1_14 = torch.cat(preds_1_14, dim=1) # N, P * L, C
+
+
+ feats_2 = self.encoder(im_2)
+ fpn_feats_2 = self.neck(feats_2) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_2) == 4
+
+ preds_2 = [None] * len(fpn_feats_2)
+ for i, feat_2 in enumerate(fpn_feats_2):
+ feat_roi_2 = self.roi_align_feature_map(feat_2, bboxs2)
+ feat_vec_2 = self.head(feat_roi_2)
+ proj_2 = self.projector(feat_vec_2)
+ pred_2 = self.predictor(proj_2)
+ pred_2 = F.normalize(pred_2, dim=1)
+ pred_2 = pred_2.reshape((N, L, -1)) # N, L, C
+ preds_2[i] = pred_2
+
+ preds_2 = torch.cat(preds_2, dim=1) # N, P * L, C
+
+
+ feats_3 = self.encoder(im_3)
+ fpn_feats_3 = self.neck(feats_3) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_3) == 4
+
+ preds_3 = [None] * len(fpn_feats_3)
+ for i, feat_3 in enumerate(fpn_feats_3):
+ feat_roi_3 = self.roi_align_feature_map(feat_3, bboxs3)
+ feat_vec_3 = self.head(feat_roi_3)
+ proj_3 = self.projector(feat_vec_3)
+ pred_3 = self.predictor(proj_3)
+ pred_3 = F.normalize(pred_3, dim=1)
+ pred_3 = pred_3.reshape((N, L, -1)) # N, L, C
+ preds_3[i] = pred_3
+
+ preds_3 = torch.cat(preds_3, dim=1) # N, P * L, C
+
+
+ feats_4 = self.encoder(im_4)
+ fpn_feats_4 = self.neck(feats_4) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_4) == 4
+
+ preds_4 = [None] * len(fpn_feats_4)
+ for i, feat_4 in enumerate(fpn_feats_4):
+ feat_roi_4 = self.roi_align_feature_map(feat_4, bboxs4)
+ feat_vec_4 = self.head(feat_roi_4)
+ proj_4 = self.projector(feat_vec_4)
+ pred_4 = self.predictor(proj_4)
+ pred_4 = F.normalize(pred_4, dim=1)
+ pred_4 = pred_4.reshape((N, L, -1)) # N, L, C
+ preds_4[i] = pred_4
+
+ preds_4 = torch.cat(preds_4, dim=1) # N, P * L, C
+
+
+ # compute key features
+ with torch.no_grad(): # no gradient to keys
+ self._momentum_update_key_encoder() # update the key encoder
+
+ feats_1_ng = self.encoder_k(im_1)
+ fpn_feats_1_ng = self.neck_k(feats_1_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_1_ng) == 4
+ projs_1_12_ng = [None] * len(fpn_feats_1_ng)
+ for i, feat_1_12_ng in enumerate(fpn_feats_1_ng):
+ feat_roi_1_12_ng = self.roi_align_feature_map(feat_1_12_ng, bboxs1_12)
+ feat_vec_1_12_ng = self.head_k(feat_roi_1_12_ng)
+ proj_1_12_ng = self.projector_k(feat_vec_1_12_ng)
+ proj_1_12_ng = F.normalize(proj_1_12_ng, dim=1)
+ proj_1_12_ng = proj_1_12_ng.reshape((N, L, -1))
+ projs_1_12_ng[i] = proj_1_12_ng
+
+ projs_1_12_ng = torch.cat(projs_1_12_ng, dim=1) # N, P * L, C
+
+
+ projs_1_13_ng = [None] * len(fpn_feats_1_ng)
+ for i, feat_1_13_ng in enumerate(fpn_feats_1_ng):
+ feat_roi_1_13_ng = self.roi_align_feature_map(feat_1_13_ng, bboxs1_13)
+ feat_vec_1_13_ng = self.head_k(feat_roi_1_13_ng)
+ proj_1_13_ng = self.projector_k(feat_vec_1_13_ng)
+ proj_1_13_ng = F.normalize(proj_1_13_ng, dim=1)
+ proj_1_13_ng = proj_1_13_ng.reshape((N, L, -1))
+ projs_1_13_ng[i] = proj_1_13_ng
+
+ projs_1_13_ng = torch.cat(projs_1_13_ng, dim=1) # N, P * L, C
+
+
+ projs_1_14_ng = [None] * len(fpn_feats_1_ng)
+ for i, feat_1_14_ng in enumerate(fpn_feats_1_ng):
+ feat_roi_1_14_ng = self.roi_align_feature_map(feat_1_14_ng, bboxs1_14)
+ feat_vec_1_14_ng = self.head_k(feat_roi_1_14_ng)
+ proj_1_14_ng = self.projector_k(feat_vec_1_14_ng)
+ proj_1_14_ng = F.normalize(proj_1_14_ng, dim=1)
+ proj_1_14_ng = proj_1_14_ng.reshape((N, L, -1))
+ projs_1_14_ng[i] = proj_1_14_ng
+
+ projs_1_14_ng = torch.cat(projs_1_14_ng, dim=1) # N, P * L, C
+
+
+ feats_2_ng = self.encoder_k(im_2)
+ fpn_feats_2_ng = self.neck_k(feats_2_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_2_ng) == 4
+ projs_2_ng = [None] * len(fpn_feats_2_ng)
+ for i, feat_2_ng in enumerate(fpn_feats_2_ng):
+ feat_roi_2_ng = self.roi_align_feature_map(feat_2_ng, bboxs2)
+ feat_vec_2_ng = self.head_k(feat_roi_2_ng)
+ proj_2_ng = self.projector_k(feat_vec_2_ng)
+ proj_2_ng = F.normalize(proj_2_ng, dim=1)
+ proj_2_ng = proj_2_ng.reshape((N, L, -1))
+ projs_2_ng[i] = proj_2_ng
+
+ projs_2_ng = torch.cat(projs_2_ng, dim=1) # N, P * L, C
+
+
+ feats_3_ng = self.encoder_k(im_3)
+ fpn_feats_3_ng = self.neck_k(feats_3_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_3_ng) == 4
+ projs_3_ng = [None] * len(fpn_feats_3_ng)
+ for i, feat_3_ng in enumerate(fpn_feats_3_ng):
+ feat_roi_3_ng = self.roi_align_feature_map(feat_3_ng, bboxs3)
+ feat_vec_3_ng = self.head_k(feat_roi_3_ng)
+ proj_3_ng = self.projector_k(feat_vec_3_ng)
+ proj_3_ng = F.normalize(proj_3_ng, dim=1)
+ proj_3_ng = proj_3_ng.reshape((N, L, -1))
+ projs_3_ng[i] = proj_3_ng
+
+ projs_3_ng = torch.cat(projs_3_ng, dim=1) # N, P * L, C
+
+
+ feats_4_ng = self.encoder_k(im_4)
+ fpn_feats_4_ng = self.neck_k(feats_4_ng) # p2, p3, p4, p5, num_outs = 4
+ assert len(fpn_feats_4_ng) == 4
+ projs_4_ng = [None] * len(fpn_feats_4_ng)
+ for i, feat_4_ng in enumerate(fpn_feats_4_ng):
+ feat_roi_4_ng = self.roi_align_feature_map(feat_4_ng, bboxs4)
+ feat_vec_4_ng = self.head_k(feat_roi_4_ng)
+ proj_4_ng = self.projector_k(feat_vec_4_ng)
+ proj_4_ng = F.normalize(proj_4_ng, dim=1)
+ proj_4_ng = proj_4_ng.reshape((N, L, -1))
+ projs_4_ng[i] = proj_4_ng
+
+ projs_4_ng = torch.cat(projs_4_ng, dim=1) # N, P * L, C
+
+ # compute loss
+ corres_12_2to1 = corres_12.transpose(1, 2) # transpose dim 1 dim 2, map bboxs2 to bboxs1
+ corres_13_3to1 = corres_13.transpose(1, 2) # transpose dim 1 dim 2, map bboxs3 to bboxs1
+ corres_14_4to1 = corres_14.transpose(1, 2) # transpose dim 1 dim 2, map bboxs4 to bboxs1
+ loss_bbox_aware_12 = self.regression_loss_bboxs_aware(preds_1_12, projs_2_ng, corres_12) + self.regression_loss_bboxs_aware(preds_2, projs_1_12_ng, corres_12_2to1)
+ loss_bbox_aware_13 = self.regression_loss_bboxs_aware(preds_1_13, projs_3_ng, corres_13) + self.regression_loss_bboxs_aware(preds_3, projs_1_13_ng, corres_13_3to1)
+ loss_bbox_aware_14 = self.regression_loss_bboxs_aware(preds_1_14, projs_4_ng, corres_14) + self.regression_loss_bboxs_aware(preds_4, projs_1_14_ng, corres_14_4to1)
+
+ loss = loss_bbox_aware_12 + loss_bbox_aware_13 + loss_bbox_aware_14
+
+ return loss
diff --git a/contrast/models/__init__.py b/contrast/models/__init__.py
new file mode 100644
index 0000000..ec53b68
--- /dev/null
+++ b/contrast/models/__init__.py
@@ -0,0 +1,13 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+from .SoCo_C4 import SoCo_C4
+from .SoCo_FPN import SoCo_FPN
+from .SoCo_FPN_Star import SoCo_FPN_Star
+
+__all__ = []
diff --git a/contrast/models/base.py b/contrast/models/base.py
new file mode 100644
index 0000000..ade293d
--- /dev/null
+++ b/contrast/models/base.py
@@ -0,0 +1,29 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BaseModel(nn.Module):
+ """
+ Base model with: a encoder
+ """
+
+ def __init__(self, base_encoder, args):
+ super(BaseModel, self).__init__()
+
+ # create the encoders
+ self.encoder = base_encoder(low_dim=args.feature_dim, head_type=args.head_type)
+
+ def forward(self, x1, x2):
+ """
+ Input: x1, x2 or x, y_idx
+ Output: logits, labels
+ """
+ raise NotImplementedError
diff --git a/contrast/models/box_util.py b/contrast/models/box_util.py
new file mode 100644
index 0000000..0a418e3
--- /dev/null
+++ b/contrast/models/box_util.py
@@ -0,0 +1,25 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+import torch
+
+
+def append_batch_index_to_bboxs_and_scale(bboxs, H, W):
+ N, L, _ = bboxs.size()
+ batch_index = torch.arange(N, device=bboxs.device)
+ batch_index = batch_index.reshape(N, 1)
+ batch_index = batch_index.repeat(1, L)
+ batch_index = batch_index.reshape(N, L, 1)
+
+ bboxs_coord = bboxs[:, :, :4].clone() # must clone
+ bboxs_coord[:, :, 0] = bboxs_coord[:, :, 0] * W # x1
+ bboxs_coord[:, :, 2] = bboxs_coord[:, :, 2] * W # x2
+ bboxs_coord[:, :, 1] = bboxs_coord[:, :, 1] * H # y1
+ bboxs_coord[:, :, 3] = bboxs_coord[:, :, 3] * H # y2
+ bboxs_with_index = torch.cat([batch_index, bboxs_coord], dim=2)
+ bboxs_with_index = bboxs_with_index.reshape(N*L, 5)
+ return bboxs_with_index
diff --git a/contrast/models/fast_rcnn_conv_fc_head.py b/contrast/models/fast_rcnn_conv_fc_head.py
new file mode 100644
index 0000000..b04bb98
--- /dev/null
+++ b/contrast/models/fast_rcnn_conv_fc_head.py
@@ -0,0 +1,56 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import torch.nn as nn
+
+
+class FastRCNNConvFCHead(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ self.bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ self.ac1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ self.bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ self.ac2 = nn.ReLU()
+
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ self.bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ self.ac3 = nn.ReLU()
+
+ self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ self.bn4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ self.ac4 = nn.ReLU()
+
+ self.flatten = nn.Flatten()
+ self.fc1 = nn.Linear(in_features=12544, out_features=1024, bias=True)
+ self.fc_relu1 = nn.ReLU()
+
+ def forward(self, roi_feature_map):
+ conv1_out = self.conv1(roi_feature_map)
+ bn1_out = self.bn1(conv1_out)
+ ac1_out = self.ac1(bn1_out)
+
+ conv2_out = self.conv2(ac1_out)
+ bn2_out = self.bn2(conv2_out)
+ ac2_out = self.ac2(bn2_out)
+
+ conv3_out = self.conv3(ac2_out)
+ bn3_out = self.bn3(conv3_out)
+ ac3_out = self.ac3(bn3_out)
+
+ conv4_out = self.conv4(ac3_out)
+ bn4_out = self.bn4(conv4_out)
+ ac4_out = self.ac4(bn4_out)
+
+ flat = self.flatten(ac4_out)
+ fc1_out = self.fc1(flat)
+ fc_relu1_out = self.fc_relu1(fc1_out)
+
+ return fc_relu1_out
diff --git a/contrast/models/fpn.py b/contrast/models/fpn.py
new file mode 100644
index 0000000..28d31ef
--- /dev/null
+++ b/contrast/models/fpn.py
@@ -0,0 +1,172 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, xavier_init
+
+
+class FPN(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=True,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # TODO: deprecate `extra_convs_on_inputs`
+ warnings.simplefilter('once')
+ warnings.warn(
+ '"extra_convs_on_inputs" will be deprecated in v2.9.0,'
+ 'Please use "add_extra_convs"', DeprecationWarning)
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of FPN module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/contrast/models/mlps.py b/contrast/models/mlps.py
new file mode 100644
index 0000000..d78aede
--- /dev/null
+++ b/contrast/models/mlps.py
@@ -0,0 +1,100 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import torch.nn as nn
+
+
+class MLP(nn.Module):
+ def __init__(self, in_dim, inner_dim=4096, out_dim=256):
+ super(MLP, self).__init__()
+
+ self.linear1 = nn.Linear(in_dim, inner_dim)
+ self.bn1 = nn.BatchNorm1d(inner_dim)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.linear2 = nn.Linear(inner_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = x.unsqueeze(-1)
+ x = self.bn1(x)
+ x = x.squeeze(-1)
+ x = self.relu1(x)
+
+ x = self.linear2(x)
+
+ return x
+
+
+def conv1x1(in_planes, out_planes):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=True)
+
+
+class MLP2d(nn.Module):
+ def __init__(self, in_dim, inner_dim=4096, out_dim=256):
+ super(MLP2d, self).__init__()
+
+ self.linear1 = conv1x1(in_dim, inner_dim)
+ self.bn1 = nn.BatchNorm2d(inner_dim)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.linear2 = conv1x1(inner_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+
+ x = self.linear2(x)
+
+ return x
+
+
+class MLP2d_3Layer(nn.Module):
+ def __init__(self, in_dim, inner_dim=4096, out_dim=256):
+ super(MLP2d_3Layer, self).__init__()
+
+ self.linear1 = conv1x1(in_dim, inner_dim)
+ self.bn1 = nn.BatchNorm2d(inner_dim)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.linear2 = conv1x1(inner_dim, inner_dim)
+ self.bn2 = nn.BatchNorm2d(inner_dim)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.linear3 = conv1x1(inner_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+
+ x = self.linear2(x)
+ x = self.bn2(x)
+ x = self.relu2(x)
+
+ x = self.linear3(x)
+
+ return x
+
+
+def Proj_Head(in_dim=2048, inner_dim=4096, out_dim=256):
+ return MLP(in_dim, inner_dim, out_dim)
+
+
+def Pred_Head(in_dim=256, inner_dim=4096, out_dim=256):
+ return MLP(in_dim, inner_dim, out_dim)
+
+
+def Proj_Head2d(in_dim=2048, inner_dim=4096, out_dim=256):
+ return MLP2d(in_dim, inner_dim, out_dim)
+
+
+def Pred_Head2d(in_dim=256, inner_dim=4096, out_dim=256):
+ return MLP2d(in_dim, inner_dim, out_dim)
diff --git a/contrast/option.py b/contrast/option.py
new file mode 100644
index 0000000..1eaeb69
--- /dev/null
+++ b/contrast/option.py
@@ -0,0 +1,161 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import argparse
+import os
+
+from contrast import resnet
+from contrast.util import MyHelpFormatter
+
+model_names = sorted(name for name in resnet.__all__
+ if name.islower() and callable(resnet.__dict__[name]))
+
+
+def parse_option(stage='pre_train'):
+ parser = argparse.ArgumentParser(f'contrast {stage} stage', formatter_class=MyHelpFormatter)
+ # develop
+ parser.add_argument('--debug', action='store_true', help='enable debug mode')
+
+ # dataset
+ parser.add_argument('--data_dir', type=str, default='./data', help='dataset director')
+ parser.add_argument('--crop', type=float, default=0.2 if stage == 'pre_train' else 0.08, help='minimum crop')
+ parser.add_argument('--crop1', type=float, default=1.0, help='minimum crop for view1 when asym asym crop')
+ parser.add_argument('--aug', type=str, default='NULL', choices=['NULL', 'ImageAsymBboxCutout', 'ImageAsymBboxAwareMultiJitter1',
+ 'ImageAsymBboxAwareMultiJitter1Cutout', 'ImageAsymBboxAwareMulti3ResizeExtraJitter1'],
+ help='which augmentation to use.')
+ parser.add_argument('--zip', action='store_true', help='use zipped dataset')
+ parser.add_argument('--split_map', type=str, default='map')
+ parser.add_argument('--cache_mode', type=str, default='part', choices=['no', 'full', 'part'],
+ help='cache mode: no for no cache, full for cache all data, part for only cache part of data')
+ parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'VOC', 'COCO'], help='dataset type')
+ parser.add_argument('--ann_file', type=str, default='', help='annotation file')
+ parser.add_argument('--image_size', type=int, default=224, help='image crop size')
+ parser.add_argument('--image3_size', type=int, default=112, help='image crop size')
+ parser.add_argument('--image4_size', type=int, default=112, help='image crop size')
+
+ parser.add_argument('--num_workers', type=int, default=4, help='num of workers per GPU to use')
+ # sliding window sampler
+ parser.add_argument('--swin_', type=int, default=131072, help='window size in sliding window sampler')
+ parser.add_argument('--window_stride', type=int, default=16384, help='window stride in sliding window sampler')
+ parser.add_argument('--use_sliding_window_sampler', action='store_true',
+ help='whether to use sliding window sampler')
+ parser.add_argument('--shuffle_per_epoch', action='store_true',
+ help='shuffle indices in sliding window sampler per epoch')
+ if stage == 'linear':
+ parser.add_argument('--total_batch_size', type=int, default=256, help='total train batch size for all GPU')
+ else:
+ parser.add_argument('--batch_size', type=int, default=64, help='batch_size for single gpu')
+
+ # model
+ parser.add_argument('--arch', type=str, default='resnet50', choices=model_names,
+ help="backbone architecture")
+ if stage == 'pre_train':
+ parser.add_argument('--model', type=str, required=True, help='which model to use')
+ parser.add_argument('--contrast_temperature', type=float, default=0.07, help='temperature in instance cls loss')
+ parser.add_argument('--contrast_momentum', type=float, default=0.999,
+ help='momentum parameter used in MoCo and InstDisc')
+ parser.add_argument('--contrast_num_negative', type=int, default=65536,
+ help='number of negative samples used in MoCo and InstDisc')
+ parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension')
+ parser.add_argument('--head_type', type=str, default='mlp_head', help='choose head type')
+ parser.add_argument('--lambda_img', type=float, default=0., help='loss weight of image_to_image loss')
+ parser.add_argument('--lambda_cross', type=float, default=1., help='loss weight of image_to_point loss')
+
+ # FPN default args
+ parser.add_argument('--in_channels', type=list, default=[256, 512, 1024, 2048], help='FPN feature map input channels')
+ parser.add_argument('--out_channels', type=int, default=256, help='FPN feature map output channels')
+ parser.add_argument('--start_level', type=int, default=1, help='FPN start level')
+ parser.add_argument('--end_level', type=int, default=-1, help='FPN end level')
+ parser.add_argument('--add_extra_convs', type=int, default=1, help='FPN add extra convs')
+ parser.add_argument('--extra_convs_on_inputs', type=int, default=0, help='FPN extra_convs_on_inputs')
+ parser.add_argument('--relu_before_extra_convs', type=int, default=1, help='FPN relu_before_extra_convs')
+ parser.add_argument('--no_norm_on_lateral', type=int, default=0, help='FPN no_norm_on_lateral')
+ parser.add_argument('--num_outs', type=int, default=3, help='FPN num_outs, use p3~p5')
+
+ # Head default args
+ parser.add_argument('--head_in_channels', type=int, default=256, help='Head feature map input channels')
+ parser.add_argument('--head_feat_channels', type=int, default=256, help='Head feature map feat channels')
+ parser.add_argument('--head_stacked_convs', type=int, default=4, help='Head stacked convs')
+
+ # optimization
+ if stage == 'pre_train':
+ parser.add_argument('--base_learning_rate', '--base_lr', type=float, default=0.03,
+ help='base learning when batch size = 256. final lr is determined by linear scale')
+ else:
+ parser.add_argument('--learning_rate', type=float, default=30, help='learning rate')
+ parser.add_argument('--optimizer', type=str, choices=['sgd', 'lars'], default='sgd',
+ help='for optimizer choice.')
+ parser.add_argument('--lr_scheduler', type=str, default='cosine',
+ choices=["step", "cosine"], help="learning rate scheduler")
+ parser.add_argument('--warmup_epoch', type=int, default=5, help='warmup epoch')
+ parser.add_argument('--warmup_multiplier', type=int, default=100, help='warmup multiplier')
+ parser.add_argument('--lr_decay_epochs', type=int, default=[120, 160, 200], nargs='+',
+ help='for step scheduler. where to decay lr, can be a list')
+ parser.add_argument('--lr_decay_rate', type=float, default=0.1,
+ help='for step scheduler. decay rate for learning rate')
+ parser.add_argument('--weight_decay', type=float, default=1e-4 if stage == 'pre_train' else 0, help='weight decay')
+ parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD')
+ parser.add_argument('--amp_opt_level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
+ help='mixed precision opt level, if O0, no amp is used')
+ parser.add_argument('--start_epoch', type=int, default=1, help='used for resume')
+ parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
+
+ # misc
+ parser.add_argument('--output_dir', type=str, default='./output', help='output director')
+ parser.add_argument('--auto_resume', action='store_true', help='auto resume from current.pth')
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
+ help='path to latest checkpoint')
+ parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
+ parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
+ parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
+ if stage == 'linear':
+ parser.add_argument('--pretrained_model', type=str, required=True, help="pretrained model path")
+ parser.add_argument('-e', '--eval', action='store_true', help='only evaluate')
+ else:
+ parser.add_argument('--pretrained_model', type=str, default="", help="pretrained model path")
+
+ # selective search
+ parser.add_argument('--ss_props', action='store_true', help='use selective search propos to calculate weight map')
+ parser.add_argument('--filter_strategy', type=str, default='none', help='filter strategy')
+ parser.add_argument('--select_strategy', type=str, default='none', help='select strategy')
+ parser.add_argument('--select_k', type=int, default=0, help='select strategy k, required when select strategy is not none')
+ parser.add_argument('--weight_strategy', type=str, default='bbox', help='weight map strategy')
+ # we use same select_strategy for weight map and bbox to ins
+ parser.add_argument('--bbox_size_range', type=tuple, default=(32, 112, 224))
+ parser.add_argument('--iou_thres', type=float, default=0.5)
+ parser.add_argument('--output_size', type=int, default=7)
+ parser.add_argument('--aligned', action='store_true')
+ parser.add_argument('--jitter_prob', type=float, default=0.5)
+ parser.add_argument('--jitter_ratio', type=float, default=0.2)
+ parser.add_argument('--padding_k', type=int, default=32)
+ parser.add_argument('--max_tries', type=int, default=5)
+ parser.add_argument('--aware_range', type=list, default=[48, 96, 192, 224, 0])
+ parser.add_argument('--aware_start', type=int, default=0, help="starting from using P?")
+ parser.add_argument('--aware_end', type=int, default=4, help="ending from using P?, not included")
+
+ parser.add_argument('--cutout_prob', type=float, default=0.5)
+ parser.add_argument('--cutout_ratio', type=tuple)
+ parser.add_argument('--cutout_ratio_min', type=float, default=0.1)
+ parser.add_argument('--cutout_ratio_max', type=float, default=0.2)
+
+ parser.add_argument('--max_props', type=int, default=32)
+ parser.add_argument('--aspect_ratio', type=float, default=3)
+ parser.add_argument('--min_size_ratio', type=float, default=0.3)
+ parser.add_argument('--max_size_ratio', type=float, default=0.8)
+
+ args = parser.parse_args()
+
+ if stage == 'pre_train':
+ # Due to base command line can not directly pass bool values, we use int, 0 -> False, 1 -> True
+ args.add_extra_convs = bool(args.add_extra_convs)
+ args.extra_convs_on_inputs = bool(args.extra_convs_on_inputs)
+ args.relu_before_extra_convs = bool(args.relu_before_extra_convs)
+ args.no_norm_on_lateral = bool(args.no_norm_on_lateral)
+ args.cutout_ratio = (args.cutout_ratio_min, args.cutout_ratio_max)
+
+ return args
diff --git a/contrast/resnet.py b/contrast/resnet.py
new file mode 100644
index 0000000..ab1140f
--- /dev/null
+++ b/contrast/resnet.py
@@ -0,0 +1,341 @@
+import math
+import os
+
+import torch
+import torch.nn as nn
+import torchvision.ops as tvops
+
+from .models.box_util import append_batch_index_to_bboxs_and_scale
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
+ 'resnet18_d', 'resnet34_d', 'resnet50_d', 'resnet101_d', 'resnet152_d',
+ 'resnet50_16s', 'resnet50_w2x', 'resnext101_32x8d', 'resnext152_32x8d']
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+def conv3x3_bn_relu(in_planes, out_planes, stride=1):
+ return nn.Sequential(
+ conv3x3(in_planes, out_planes, stride),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU()
+ )
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=-1):
+ super(BasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(Bottleneck, self).__init__()
+ width = int(planes * (base_width / 64.)) * groups
+ self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width)
+ self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, dilation=dilation,
+ padding=dilation, groups=groups, bias=False)
+ self.bn2 = nn.BatchNorm2d(width)
+ self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, in_channel=3, width=1,
+ groups=1, width_per_group=64,
+ mid_dim=1024, low_dim=128,
+ avg_down=False, deep_stem=False,
+ head_type='mlp_head', layer4_dilation=1, use_roi_align_on_c4=False,
+ pretrained=None):
+ super(ResNet, self).__init__()
+ self.avg_down = avg_down
+ self.inplanes = 64 * width
+ self.base = int(64 * width)
+ self.groups = groups
+ self.base_width = width_per_group
+ self.use_roi_align_on_c4 = use_roi_align_on_c4
+
+ mid_dim = self.base * 8 * block.expansion
+
+ if deep_stem:
+ self.conv1 = nn.Sequential(
+ conv3x3_bn_relu(in_channel, 32, stride=2),
+ conv3x3_bn_relu(32, 32, stride=1),
+ conv3x3(32, 64, stride=1)
+ )
+ else:
+ self.conv1 = nn.Conv2d(in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, self.base, layers[0])
+ self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2)
+ if layer4_dilation == 1:
+ self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2)
+ elif layer4_dilation == 2:
+ self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=1, dilation=2)
+ else:
+ raise NotImplementedError
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.small_avgpool = nn.AvgPool2d(4, stride=1)
+
+ if self.use_roi_align_on_c4:
+ self.output_size = 14
+ self.aligned = True
+
+ self.head_type = head_type
+ if head_type == 'mlp_head':
+ self.fc1 = nn.Linear(mid_dim, mid_dim)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.fc2 = nn.Linear(mid_dim, low_dim)
+ elif head_type == 'reduce':
+ self.fc = nn.Linear(mid_dim, low_dim)
+ elif head_type == 'conv_head':
+ self.fc1 = nn.Conv2d(mid_dim, mid_dim, kernel_size=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(2048)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.fc2 = nn.Linear(mid_dim, low_dim)
+ elif head_type in ['pass', 'early_return', 'multi_layer']:
+ pass
+ else:
+ raise NotImplementedError
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ # zero gamma for batch norm: reference bag of tricks
+ if block is Bottleneck:
+ gamma_name = "bn3.weight"
+ elif block is BasicBlock:
+ gamma_name = "bn2.weight"
+ else:
+ raise RuntimeError(f"block {block} not supported")
+ for name, value in self.named_parameters():
+ if name.endswith(gamma_name):
+ value.data.zero_()
+
+ if pretrained:
+ self.pretrained = pretrained
+ self._init_weights_pretrained()
+
+ def _init_weights_pretrained(self):
+ print(f" => loading pretrained resnet: {self.pretrained}")
+ assert os.path.isfile(self.pretrained)
+ ckpt = torch.load(self.pretrained, map_location='cpu')
+ missing_keys, unexpected_keys = self.load_state_dict(ckpt['state_dict'], strict=False)
+ print(f"missing_keys: {missing_keys}")
+ print(f"unexpected_keys: {unexpected_keys}")
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ if self.avg_down:
+ downsample = nn.Sequential(
+ nn.AvgPool2d(kernel_size=stride, stride=stride),
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ else:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+ def roi_align_feature_map(self, feature_map, bboxs, small=False):
+ feature_map = feature_map.type(dtype=bboxs.dtype) # feature map will be convert to HalfFloat in favor of amp
+ N, C, H, W = feature_map.shape
+ N, L, _ = bboxs.shape
+
+ if small:
+ output_size = (self.output_size // 2, self.output_size // 2)
+ else:
+ output_size = (self.output_size, self.output_size)
+
+ bboxs_q_with_batch_index = append_batch_index_to_bboxs_and_scale(bboxs, H, W)
+ aligned_features = tvops.roi_align(input=feature_map, boxes=bboxs_q_with_batch_index, output_size=output_size, aligned=self.aligned)
+ # N*L, C, output_size, output_size
+ return aligned_features
+
+ def forward(self, x, bboxs=None, small=False):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ c2 = self.layer1(x)
+ c3 = self.layer2(c2)
+ c4 = self.layer3(c3)
+ if self.use_roi_align_on_c4 and bboxs is not None:
+ c4align = self.roi_align_feature_map(c4, bboxs, small=small)
+ c5 = self.layer4(c4align)
+ else:
+ c5 = self.layer4(c4)
+
+ if self.head_type == 'multi_layer':
+ return c2, c3, c4, c5
+
+ if self.head_type == 'early_return':
+ return c5
+
+ if self.head_type != 'conv_head':
+ if small:
+ c5 = self.small_avgpool(c5)
+ else:
+ c5 = self.avgpool(c5)
+ c5 = c5.view(c5.size(0), -1)
+
+ if self.head_type == 'mlp_head':
+ out = self.fc1(c5)
+ out = self.relu2(out)
+ out = self.fc2(out)
+ elif self.head_type == 'reduce':
+ out = self.fc(c5)
+ elif self.head_type == 'conv_head':
+ out = self.fc1(c5)
+ out = self.bn2(out)
+ out = self.relu2(out)
+ out = self.avgpool(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc2(out)
+ elif self.head_type == 'pass':
+ return c5
+ else:
+ raise NotImplementedError
+
+ return out
+
+
+def resnet18(**kwargs):
+ return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+
+
+def resnet18_d(**kwargs):
+ return ResNet(BasicBlock, [2, 2, 2, 2], deep_stem=True, avg_down=True, **kwargs)
+
+
+def resnet34(**kwargs):
+ return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+
+
+def resnet34_d(**kwargs):
+ return ResNet(BasicBlock, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs)
+
+
+def resnet50(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+
+
+def resnet50_w2x(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 6, 3], width=2, **kwargs)
+
+
+def resnet50_16s(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 6, 3], layer4_dilation=2, **kwargs)
+
+
+def resnet50_d(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs)
+
+
+def resnet101(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+
+
+def resnet101_d(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 23, 3], deep_stem=True, avg_down=True, **kwargs)
+
+
+def resnext101_32x8d(**kwargs):
+ return ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
+
+
+def resnet152(**kwargs):
+ return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+
+
+def resnet152_d(**kwargs):
+ return ResNet(Bottleneck, [3, 8, 36, 3], deep_stem=True, avg_down=True, **kwargs)
+
+
+def resnext152_32x8d(**kwargs):
+ return ResNet(Bottleneck, [3, 8, 36, 3], groups=32, width_per_group=8, **kwargs)
+
diff --git a/contrast/util.py b/contrast/util.py
new file mode 100644
index 0000000..4b39198
--- /dev/null
+++ b/contrast/util.py
@@ -0,0 +1,146 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import argparse
+
+import torch
+import torch.distributed as dist
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def dist_collect(x):
+ """ collect all tensor from all GPUs
+ args:
+ x: shape (mini_batch, ...)
+ returns:
+ shape (mini_batch * num_gpu, ...)
+ """
+ x = x.contiguous()
+ out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype)
+ for _ in range(dist.get_world_size())]
+ dist.all_gather(out_list, x)
+ return torch.cat(out_list, dim=0)
+
+
+def reduce_tensor(tensor):
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ rt /= dist.get_world_size()
+ return rt
+
+
+class MyHelpFormatter(argparse.MetavarTypeHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
+ pass
+
+class DistributedShuffle:
+
+ @staticmethod
+ def forward_shuffle(x):
+ """
+ Batch shuffle, for making use of BatchNorm.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ """
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = dist_collect(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # random shuffle index
+ idx_shuffle = torch.randperm(batch_size_all).cuda()
+
+ # broadcast to all gpus
+ dist.broadcast(idx_shuffle, src=0)
+
+ # index for restoring
+ idx_unshuffle = torch.argsort(idx_shuffle)
+
+ # shuffled index for this gpu
+ gpu_idx = dist.get_rank()
+ idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this], idx_unshuffle
+
+ @staticmethod
+ def backward_shuffle(x, idx_unshuffle, return_local=True):
+ """
+ Undo batch shuffle.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ """
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = dist_collect(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ if return_local:
+ # restored index for this gpu
+ gpu_idx = dist.get_rank()
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
+ return x_gather[idx_unshuffle], x_gather[idx_this]
+ else:
+ return x_gather[idx_unshuffle]
+
+ @staticmethod
+ def get_local_id(ids):
+ return ids.chunk(dist.get_world_size())[dist.get_rank()]
+
+ @staticmethod
+ def get_shuffle_ids(bsz, epoch):
+ """generate shuffle ids for ShuffleBN"""
+ torch.manual_seed(epoch)
+ # global forward shuffle id for all process
+ forward_inds = torch.randperm(bsz).long().cuda()
+
+ # global backward shuffle id
+ backward_inds = torch.zeros(forward_inds.shape[0]).long().cuda()
+ value = torch.arange(bsz).long().cuda()
+ backward_inds.index_copy_(0, forward_inds, value)
+
+ return forward_inds, backward_inds
diff --git a/converter_detectron2/convert_detectron2_C4.py b/converter_detectron2/convert_detectron2_C4.py
new file mode 100644
index 0000000..619ec4d
--- /dev/null
+++ b/converter_detectron2/convert_detectron2_C4.py
@@ -0,0 +1,63 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import pickle as pkl
+import torch
+import argparse
+
+
+def convert_detectron2_C4(input_file_name, output_file_name, ema=False):
+ ckpt = torch.load(input_file_name, map_location="cpu")
+ if ema:
+ state_dict = ckpt["model_ema"]
+ prefix = "encoder."
+ else:
+ state_dict = ckpt["model"]
+ prefix = "module.encoder."
+
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if not k.startswith(prefix):
+ continue
+ old_k = k
+ k = k.replace(prefix, "")
+ if "layer" not in k:
+ k = "stem." + k
+ # k = "backbone." + k
+ k = k.replace("layer1", "res2")
+ k = k.replace("layer2", "res3")
+ k = k.replace("layer3", "res4")
+ k = k.replace("layer4", "res5")
+ k = k.replace("bn1", "conv1.norm")
+ k = k.replace("bn2", "conv2.norm")
+ k = k.replace("bn3", "conv3.norm")
+ k = k.replace("downsample.0", "shortcut")
+ k = k.replace("downsample.1", "shortcut.norm")
+ print(old_k, "->", k)
+ new_state_dict[k] = v.numpy()
+
+ res = {"model": new_state_dict,
+ "__author__": "PixPro",
+ "matching_heuristics": True}
+
+ with open(output_file_name, "wb") as f:
+ pkl.dump(res, f)
+ print(f"Saved converted detectron2 C4 checkpoint to {output_file_name}")
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description='Convert Models')
+ parser.add_argument('input', metavar='I',
+ help='input model path')
+ parser.add_argument('output', metavar='O',
+ help='output path')
+ parser.add_argument('--ema', action='store_true',
+ help='using ema model')
+ args = parser.parse_args()
+ convert_detectron2_C4(args.input, args.output, args.ema)
diff --git a/converter_detectron2/convert_detectron2_Head.py b/converter_detectron2/convert_detectron2_Head.py
new file mode 100644
index 0000000..6501d6f
--- /dev/null
+++ b/converter_detectron2/convert_detectron2_Head.py
@@ -0,0 +1,97 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import pickle as pkl
+import torch
+import argparse
+
+
+def convert_detectron2_Head(input_file_name, output_file_name, start, num_outs, ema=False):
+ ckpt = torch.load(input_file_name, map_location="cpu")
+ if ema:
+ state_dict = ckpt["model_ema"]
+ backbone_prefix = "encoder."
+ fpn_prefix = "neck."
+ head_prefix = "head."
+ else:
+ state_dict = ckpt["model"]
+ backbone_prefix = "module.encoder."
+ fpn_prefix = "module.neck."
+ head_prefix = "module.head."
+
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith(backbone_prefix):
+ old_k = k
+ k = k.replace(backbone_prefix, "")
+ if "layer" not in k:
+ k = "stem." + k
+ k = k.replace("layer1", "res2")
+ k = k.replace("layer2", "res3")
+ k = k.replace("layer3", "res4")
+ k = k.replace("layer4", "res5")
+ k = k.replace("bn1", "conv1.norm")
+ k = k.replace("bn2", "conv2.norm")
+ k = k.replace("bn3", "conv3.norm")
+ k = k.replace("downsample.0", "shortcut")
+ k = k.replace("downsample.1", "shortcut.norm")
+ print(old_k, "->", k)
+ new_state_dict[k] = v.numpy()
+
+ elif k.startswith(fpn_prefix):
+ old_k = k
+ k = k.replace(fpn_prefix, "")
+
+ for i in range(num_outs):
+ k = k.replace(f"lateral_convs.{i}.conv", f"fpn_lateral{start+i}")
+ k = k.replace(f"lateral_convs.{i}.bn", f"fpn_lateral{start+i}.norm")
+
+ k = k.replace(f"fpn_convs.{i}.conv", f"fpn_output{start+i}")
+ k = k.replace(f"fpn_convs.{i}.bn", f"fpn_output{start+i}.norm")
+
+ print(old_k, "->", k)
+ new_state_dict[k] = v.numpy()
+
+ elif k.startswith(head_prefix):
+ old_k = k
+ k = k.replace(head_prefix, "box_head.")
+ k = k.replace("bn1", "conv1.norm")
+ k = k.replace("ac1", "conv1.activation")
+ k = k.replace("bn2", "conv2.norm")
+ k = k.replace("ac2", "conv2.activation")
+ k = k.replace("bn3", "conv3.norm")
+ k = k.replace("ac3", "conv3.activation")
+ k = k.replace("bn4", "conv4.norm")
+ k = k.replace("ac4", "conv4.activation")
+ print(old_k, "->", k)
+ new_state_dict[k] = v.numpy()
+
+ res = {"model": new_state_dict,
+ "__author__": "Yue",
+ "matching_heuristics": True}
+
+ with open(output_file_name, "wb") as f:
+ pkl.dump(res, f)
+ print(f"Saved converted detectron2 Head checkpoint to {output_file_name}")
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description='Convert Models')
+ parser.add_argument('input', metavar='I',
+ help='input model path')
+ parser.add_argument('output', metavar='O',
+ help='output path')
+ parser.add_argument('start', metavar='S', type=int,
+ help='FPN start')
+ parser.add_argument('num_outs', metavar='N', type=int,
+ help='FPN number of outputs')
+ parser.add_argument('--ema', action='store_true',
+ help='using ema model')
+ args = parser.parse_args()
+ convert_detectron2_Head(args.input, args.output, args.start, args.num_outs, args.ema)
diff --git a/converter_mmdetection/convert_mmdetection_Backbone.py b/converter_mmdetection/convert_mmdetection_Backbone.py
new file mode 100644
index 0000000..f331562
--- /dev/null
+++ b/converter_mmdetection/convert_mmdetection_Backbone.py
@@ -0,0 +1,40 @@
+# disclaimer: inspired by MoCo official repo.
+import torch
+import argparse
+
+
+def convert_mmdetection_Backbone(input_file_name, output_file_name, ema=False):
+ ckpt = torch.load(input_file_name, map_location="cpu")
+ if ema:
+ state_dict = ckpt["model_ema"]
+ backbone_prefix = "encoder."
+ else:
+ state_dict = ckpt["model"]
+ backbone_prefix = "module.encoder."
+
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith(backbone_prefix):
+ old_k = k
+ k = k.replace(backbone_prefix, "backbone.")
+ print(old_k, "->", k)
+ new_state_dict[k] = v
+
+ res = {"state_dict": new_state_dict,
+ "meta": {}}
+
+ torch.save(res, output_file_name)
+ print(f"Saved converted mmdetection load checkpoint to {output_file_name}")
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description='Convert Models')
+ parser.add_argument('input', metavar='I',
+ help='input model path')
+ parser.add_argument('output', metavar='O',
+ help='output path')
+ parser.add_argument('--ema', action='store_true',
+ help='using ema model')
+ args = parser.parse_args()
+ convert_mmdetection_Backbone(args.input, args.output, args.ema)
diff --git a/converter_mmdetection/convert_mmdetection_Head.py b/converter_mmdetection/convert_mmdetection_Head.py
new file mode 100644
index 0000000..a83c595
--- /dev/null
+++ b/converter_mmdetection/convert_mmdetection_Head.py
@@ -0,0 +1,73 @@
+# disclaimer: inspired by MoCo official repo.
+import torch
+import argparse
+
+
+def convert_mmdetection_Head(input_file_name, output_file_name, ema=False):
+ ckpt = torch.load(input_file_name, map_location="cpu")
+ if ema:
+ state_dict = ckpt["model_ema"]
+ backbone_prefix = "encoder."
+ fpn_prefix = "neck."
+ head_prefix = "head."
+ else:
+ state_dict = ckpt["model"]
+ backbone_prefix = "module.encoder."
+ fpn_prefix = "module.neck."
+ head_prefix = "module.head."
+
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith(backbone_prefix):
+ old_k = k
+ k = k.replace(backbone_prefix, "backbone.")
+ print(old_k, "->", k)
+ new_state_dict[k] = v
+ elif k.startswith(fpn_prefix):
+ old_k = k
+ k = k.replace(fpn_prefix, "neck.")
+
+ print(old_k, "->", k)
+ new_state_dict[k] = v
+ elif k.startswith(head_prefix):
+ old_k = k
+ k = k.replace(head_prefix, "roi_head.bbox_head.")
+ k = k.replace("conv1", "shared_convs.0.conv")
+ k = k.replace("bn1", "shared_convs.0.bn")
+ k = k.replace("ac1", "shared_convs.0.activate")
+
+ k = k.replace("conv2", "shared_convs.1.conv")
+ k = k.replace("bn2", "shared_convs.1.bn")
+ k = k.replace("ac2", "shared_convs.1.activate")
+
+ k = k.replace("conv3", "shared_convs.2.conv")
+ k = k.replace("bn3", "shared_convs.2.bn")
+ k = k.replace("ac3", "shared_convs.2.activate")
+
+ k = k.replace("conv4", "shared_convs.3.conv")
+ k = k.replace("bn4", "shared_convs.3.bn")
+ k = k.replace("ac4", "shared_convs.3.activate")
+
+ k = k.replace("fc1", "shared_fcs.0")
+
+ print(old_k, "->", k)
+ new_state_dict[k] = v
+
+ res = {"state_dict": new_state_dict,
+ "meta": {}}
+
+ torch.save(res, output_file_name)
+ print(f"Saved converted mmdetection load checkpoint to {output_file_name}")
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description='Convert Models')
+ parser.add_argument('input', metavar='I',
+ help='input model path')
+ parser.add_argument('output', metavar='O',
+ help='output path')
+ parser.add_argument('--ema', action='store_true',
+ help='using ema model')
+ args = parser.parse_args()
+ convert_mmdetection_Head(args.input, args.output, args.ema)
diff --git a/detectron2_configs/R_50_C4_1x.yaml b/detectron2_configs/R_50_C4_1x.yaml
new file mode 100644
index 0000000..aaf5930
--- /dev/null
+++ b/detectron2_configs/R_50_C4_1x.yaml
@@ -0,0 +1,13 @@
+_BASE_: "Base-RCNN-C4-BN.yaml"
+MODEL:
+ MASK_ON: True
+ WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
+ MIN_SIZE_TEST: 800
+DATASETS:
+ TRAIN: ("coco_2017_train",)
+ TEST: ("coco_2017_val",)
+SOLVER:
+ STEPS: (60000, 80000)
+ MAX_ITER: 90000
diff --git a/detectron2_configs/R_50_FPN_1x.yaml b/detectron2_configs/R_50_FPN_1x.yaml
new file mode 100644
index 0000000..1c87ae8
--- /dev/null
+++ b/detectron2_configs/R_50_FPN_1x.yaml
@@ -0,0 +1,24 @@
+_BASE_: "Base-RCNN-FPN.yaml"
+MODEL:
+ MASK_ON: True
+ WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
+ BACKBONE:
+ FREEZE_AT: 0
+ RESNETS:
+ DEPTH: 50
+ NORM: "SyncBN"
+ FPN:
+ NORM: "SyncBN"
+ ROI_BOX_HEAD:
+ NAME: "FastRCNNConvFCHead"
+ NUM_CONV: 4
+ NUM_FC: 1
+ NORM: "SyncBN"
+ ROI_MASK_HEAD:
+ NORM: "SyncBN"
+TEST:
+ PRECISE_BN:
+ ENABLED: True
+SOLVER:
+ STEPS: (60000, 80000)
+ MAX_ITER: 90000
diff --git a/detectron2_configs/SoCo_C4_100ep.yaml b/detectron2_configs/SoCo_C4_100ep.yaml
new file mode 100644
index 0000000..550f46a
--- /dev/null
+++ b/detectron2_configs/SoCo_C4_100ep.yaml
@@ -0,0 +1,11 @@
+_BASE_: "R_50_C4_1x.yaml"
+MODEL:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ WEIGHTS: "./SoCo_output/SoCo_C4_100ep/current_detectron2_C4.pkl"
+ RESNETS:
+ STRIDE_IN_1X1: False
+INPUT:
+ FORMAT: "RGB"
+SOLVER:
+ WEIGHT_DECAY: 0.000025
diff --git a/detectron2_configs/SoCo_C4_400ep.yaml b/detectron2_configs/SoCo_C4_400ep.yaml
new file mode 100644
index 0000000..2c4bd1b
--- /dev/null
+++ b/detectron2_configs/SoCo_C4_400ep.yaml
@@ -0,0 +1,11 @@
+_BASE_: "R_50_C4_1x.yaml"
+MODEL:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ WEIGHTS: "./SoCo_output/SoCo_C4_400ep/current_detectron2_C4.pkl"
+ RESNETS:
+ STRIDE_IN_1X1: False
+INPUT:
+ FORMAT: "RGB"
+SOLVER:
+ WEIGHT_DECAY: 0.000025
diff --git a/detectron2_configs/SoCo_FPN_100ep.yaml b/detectron2_configs/SoCo_FPN_100ep.yaml
new file mode 100644
index 0000000..bfcdb0e
--- /dev/null
+++ b/detectron2_configs/SoCo_FPN_100ep.yaml
@@ -0,0 +1,11 @@
+_BASE_: "R_50_FPN_1x.yaml"
+MODEL:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ WEIGHTS: "./SoCo_output/SoCo_FPN_100ep/current_detectron2_Head.pkl"
+ RESNETS:
+ STRIDE_IN_1X1: False
+INPUT:
+ FORMAT: "RGB"
+SOLVER:
+ WEIGHT_DECAY: 0.000025
diff --git a/detectron2_configs/SoCo_FPN_400ep.yaml b/detectron2_configs/SoCo_FPN_400ep.yaml
new file mode 100644
index 0000000..7b7bf32
--- /dev/null
+++ b/detectron2_configs/SoCo_FPN_400ep.yaml
@@ -0,0 +1,11 @@
+_BASE_: "R_50_FPN_1x.yaml"
+MODEL:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ WEIGHTS: "./SoCo_output/SoCo_FPN_400ep/current_detectron2_Head.pkl"
+ RESNETS:
+ STRIDE_IN_1X1: False
+INPUT:
+ FORMAT: "RGB"
+SOLVER:
+ WEIGHT_DECAY: 0.000025
diff --git a/detectron2_configs/SoCo_FPN_Star_400ep.yaml b/detectron2_configs/SoCo_FPN_Star_400ep.yaml
new file mode 100644
index 0000000..e376271
--- /dev/null
+++ b/detectron2_configs/SoCo_FPN_Star_400ep.yaml
@@ -0,0 +1,11 @@
+_BASE_: "R_50_FPN_1x.yaml"
+MODEL:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ WEIGHTS: "./SoCo_output/SoCo_FPN_Star_400ep/current_detectron2_Head.pkl"
+ RESNETS:
+ STRIDE_IN_1X1: False
+INPUT:
+ FORMAT: "RGB"
+SOLVER:
+ WEIGHT_DECAY: 0.000025
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..de022d4
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,52 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04
+
+# ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
+ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
+ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
+
+RUN apt-get update && apt-get install -y vim curl wget ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev \
+ && apt-get install -y libblacs-mpi-dev \
+ && apt-get install software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa \
+ && apt-get install python3.7 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install python 3.7 to python and install pip
+RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.7 10
+RUN wget https://bootstrap.pypa.io/get-pip.py
+RUN python get-pip.py
+
+# Torch
+RUN pip install torch==1.6.0 torchvision==0.7.0
+
+# accimage
+RUN pip install --prefix=/opt/intel/ipp ipp-devel
+RUN pip install git+https://github.com/pytorch/accimage
+
+# apex
+RUN git clone https://github.com/NVIDIA/apex
+RUN cd apex/
+# P100, P40, V100 and 2080Ti
+ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5"
+RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
+
+# detectron2
+RUN python -m pip install 'git+https://github.com/hologerry/detectron2.git' --upgrade --force-reinstall
+
+# Install MMCV and MMdetection
+RUN pip uninstall pycocotools
+RUN pip install mmcv-full==1.2.4 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html --upgrade --force-reinstall
+
+RUN git clone https://github.com/hologerry/mmdetection.git
+RUN cd mmdetection
+RUN pip install -r requirements/build.txt --upgrade --force-reinstall
+RUN pip install -v -e . --upgrade --force-reinstall
diff --git a/figures/overview.png b/figures/overview.png
new file mode 100644
index 0000000..15e4d21
Binary files /dev/null and b/figures/overview.png differ
diff --git a/main_linear.py b/main_linear.py
new file mode 100644
index 0000000..1ce8018
--- /dev/null
+++ b/main_linear.py
@@ -0,0 +1,303 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import json
+import os
+import time
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+
+from contrast import resnet
+from contrast.data import get_loader
+from contrast.logger import setup_logger
+from contrast.lr_scheduler import get_scheduler
+from contrast.option import parse_option
+from contrast.util import AverageMeter, accuracy, reduce_tensor
+
+try:
+ from apex import amp # type: ignore
+except ImportError:
+ amp = None
+
+
+def build_model(args, num_class):
+ # create model
+ model = resnet.__dict__[args.arch](low_dim=num_class, head_type='reduce').cuda()
+
+ # set requires_grad of parameters except last fc layer to False
+ for name, p in model.named_parameters():
+ if 'fc' not in name:
+ p.requires_grad = False
+
+ optimizer = torch.optim.SGD(model.fc.parameters(),
+ lr=args.learning_rate,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ if args.amp_opt_level != "O0":
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)
+
+ model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)
+
+ return model, optimizer
+
+
+def load_pretrained(model, pretrained_model):
+ ckpt = torch.load(pretrained_model, map_location='cpu')
+ model_dict = model.state_dict()
+
+ base_fix = False
+ for key in ckpt['model'].keys():
+ if key.startswith('module.base.'):
+ base_fix = True
+ break
+
+ if base_fix:
+ state_dict = {k.replace("module.base.", "module."): v
+ for k, v in ckpt['model'].items()
+ if k.startswith('module.base.')}
+ logger.info(f"==> load checkpoint from Module.Base")
+ else:
+ state_dict = {k.replace("module.encoder.", "module."): v
+ for k, v in ckpt['model'].items()
+ if k.startswith('module.encoder.')}
+ logger.info(f"==> load checkpoint from Module.Encoder")
+
+ state_dict = {k: v for k, v in state_dict.items()
+ if k in model_dict and v.size() == model_dict[k].size()}
+
+ model_dict.update(state_dict)
+ model.load_state_dict(model_dict)
+ logger.info(f"==> loaded checkpoint '{pretrained_model}' (epoch {ckpt['epoch']})")
+
+
+def load_checkpoint(args, model, optimizer, scheduler):
+ logger.info("=> loading checkpoint '{args.resume'")
+
+ checkpoint = torch.load(args.resume, map_location='cpu')
+
+ global best_acc1
+ best_acc1 = checkpoint['best_acc1']
+ args.start_epoch = checkpoint['epoch'] + 1
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if args.amp_opt_level != "O0" and checkpoint['args'].amp_opt_level != "O0":
+ amp.load_state_dict(checkpoint['amp'])
+
+ logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
+
+
+def save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler):
+ state = {
+ 'args': args,
+ 'epoch': epoch,
+ 'model': model.state_dict(),
+ 'best_acc1': test_acc,
+ 'optimizer': optimizer.state_dict(),
+ 'scheduler': scheduler.state_dict(),
+ }
+ if args.amp_opt_level != "O0":
+ state['amp'] = amp.state_dict()
+ torch.save(state, os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth'))
+ torch.save(state, os.path.join(args.output_dir, f'current.pth'))
+
+
+def main(args):
+ global best_acc1
+
+ args.batch_size = args.total_batch_size // dist.get_world_size()
+ train_loader = get_loader(args.aug, args, prefix='train')
+ val_loader = get_loader('val', args, prefix='val')
+ logger.info(f"length of training dataset: {len(train_loader.dataset)}")
+
+ model, optimizer = build_model(args, num_class=len(train_loader.dataset.classes))
+ scheduler = get_scheduler(optimizer, len(train_loader), args)
+
+ # load pre-trained model
+ load_pretrained(model, args.pretrained_model)
+
+ # optionally resume from a checkpoint
+ if args.auto_resume:
+ resume_file = os.path.join(args.output_dir, "current.pth")
+ if os.path.exists(resume_file):
+ logger.info(f'auto resume from {resume_file}')
+ args.resume = resume_file
+ else:
+ logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume')
+ if args.resume:
+ assert os.path.isfile(args.resume), f"no checkpoint found at '{args.resume}'"
+ load_checkpoint(args, model, optimizer, scheduler)
+
+ if args.eval:
+ logger.info("==> testing...")
+ validate(val_loader, model, args)
+ return
+
+ # tensorboard
+ if dist.get_rank() == 0:
+ summary_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ summary_writer = None
+
+ # routine
+ for epoch in range(args.start_epoch, args.epochs + 1):
+ if isinstance(train_loader.sampler, DistributedSampler):
+ train_loader.sampler.set_epoch(epoch)
+
+ tic = time.time()
+ train(epoch, train_loader, model, optimizer, scheduler, args)
+ logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}')
+
+ logger.info("==> testing...")
+ test_acc, test_acc5, test_loss = validate(val_loader, model, args)
+ if summary_writer is not None:
+ summary_writer.add_scalar('test_acc', test_acc, epoch)
+ summary_writer.add_scalar('test_acc5', test_acc5, epoch)
+ summary_writer.add_scalar('test_loss', test_loss, epoch)
+
+ # save model
+ if dist.get_rank() == 0 and epoch % args.save_freq == 0:
+ logger.info('==> Saving...')
+ save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler)
+
+
+def train(epoch, train_loader, model, optimizer, scheduler, args):
+ """
+ one epoch training
+ """
+
+ model.train()
+
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+ loss_meter = AverageMeter()
+ acc1_meter = AverageMeter()
+ acc5_meter = AverageMeter()
+
+ end = time.time()
+ for idx, (x, _, y) in enumerate(train_loader):
+ x = x.cuda(non_blocking=True)
+ y = y.cuda(non_blocking=True)
+
+ # measure data loading time
+ data_time.update(time.time() - end)
+
+ # forward
+ output = model(x)
+ loss = F.cross_entropy(output, y)
+
+ # backward
+ optimizer.zero_grad()
+ if args.amp_opt_level != "O0":
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ # update meters
+ acc1, acc5 = accuracy(output, y, topk=(1, 5))
+ loss_meter.update(loss.item(), x.size(0))
+ acc1_meter.update(acc1[0], x.size(0))
+ acc5_meter.update(acc5[0], x.size(0))
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ # print info
+ if idx % args.print_freq == 0:
+ logger.info(
+ f'Epoch: [{epoch}][{idx}/{len(train_loader)}]\t'
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
+ f'Lr {optimizer.param_groups[0]["lr"]:.3f} \t'
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+ f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+ f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})')
+
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+
+def validate(val_loader, model, args):
+ batch_time = AverageMeter()
+ loss_meter = AverageMeter()
+ acc1_meter = AverageMeter()
+ acc5_meter = AverageMeter()
+
+ # switch to evaluate mode
+ model.eval()
+
+ with torch.no_grad():
+ end = time.time()
+ for idx, (x, _, y) in enumerate(val_loader):
+ x = x.cuda(non_blocking=True)
+ y = y.cuda(non_blocking=True)
+
+ # compute output
+ output = model(x)
+ loss = F.cross_entropy(output, y)
+
+ # measure accuracy and record loss
+ acc1, acc5 = accuracy(output, y, topk=(1, 5))
+
+ acc1 = reduce_tensor(acc1)
+ acc5 = reduce_tensor(acc5)
+ loss = reduce_tensor(loss)
+
+ loss_meter.update(loss.item(), x.size(0))
+ acc1_meter.update(acc1[0], x.size(0))
+ acc5_meter.update(acc5[0], x.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if idx % args.print_freq == 0:
+ logger.info(
+ f'Test: [{idx}/{len(val_loader)}]\t'
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+ f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+ f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})')
+
+ logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+
+if __name__ == '__main__':
+ opt = parse_option(stage='linear')
+
+ if opt.amp_opt_level != "O0":
+ assert amp is not None, "amp not installed!"
+
+ torch.cuda.set_device(opt.local_rank)
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
+ cudnn.benchmark = True
+ best_acc1 = 0
+
+ os.makedirs(opt.output_dir, exist_ok=True)
+ logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="contrast")
+ if dist.get_rank() == 0:
+ path = os.path.join(opt.output_dir, "config.json")
+ with open(path, "w") as f:
+ json.dump(vars(opt), f, indent=2)
+ logger.info("Full config saved to {}".format(path))
+
+ # print args
+ # TODO: check format
+ logger.info(vars(opt))
+
+ main(opt)
diff --git a/main_pretrain.py b/main_pretrain.py
new file mode 100644
index 0000000..3743856
--- /dev/null
+++ b/main_pretrain.py
@@ -0,0 +1,279 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import json
+import math
+import os
+import time
+from shutil import copyfile
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.backends import cudnn
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+
+from contrast import models, resnet
+from contrast.data import get_loader
+from contrast.lars import LARS, add_weight_decay
+from contrast.logger import setup_logger
+from contrast.lr_scheduler import get_scheduler
+from contrast.option import parse_option
+from contrast.util import AverageMeter
+
+from converter_detectron2.convert_detectron2_C4 import convert_detectron2_C4
+from converter_detectron2.convert_detectron2_Head import convert_detectron2_Head
+from converter_mmdetection.convert_mmdetection_Head import convert_mmdetection_Head
+
+try:
+ from apex import amp # type: ignore
+except ImportError:
+ amp = None
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ import random
+
+ import numpy as np
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def build_model(args):
+ encoder = resnet.__dict__[args.arch]
+ model = models.__dict__[args.model](encoder, args).cuda()
+
+ if args.optimizer == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(),
+ lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+ elif args.optimizer == 'lars':
+ params = add_weight_decay(model, args.weight_decay)
+ optimizer = torch.optim.SGD(params,
+ lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate,
+ momentum=args.momentum)
+ optimizer = LARS(optimizer)
+ else:
+ raise NotImplementedError
+
+ if args.amp_opt_level != "O0":
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)
+
+ model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False, find_unused_parameters=True)
+ return model, optimizer
+
+
+def load_pretrained(model, pretrained_model):
+ ckpt = torch.load(pretrained_model, map_location='cpu')
+ state_dict = ckpt['model']
+ model_dict = model.state_dict()
+
+ model_dict.update(state_dict)
+ model.load_state_dict(model_dict)
+ logger.info(f"==> loaded checkpoint '{pretrained_model}' (epoch {ckpt['epoch']})")
+
+
+def load_checkpoint(args, model, optimizer, scheduler, sampler=None):
+ logger.info(f"=> loading checkpoint '{args.resume}'")
+
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ args.start_epoch = checkpoint['epoch'] + 1
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0":
+ amp.load_state_dict(checkpoint['amp'])
+ if args.use_sliding_window_sampler:
+ sampler.load_state_dict(checkpoint['sampler'])
+
+ logger.info(f"=> loaded successfully '{args.resume}' (epoch {checkpoint['epoch']})")
+
+ del checkpoint
+ torch.cuda.empty_cache()
+
+
+def save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=None):
+ logger.info('==> Saving...')
+ state = {
+ 'opt': args,
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'scheduler': scheduler.state_dict(),
+ 'epoch': epoch,
+ }
+ if args.amp_opt_level != "O0":
+ state['amp'] = amp.state_dict()
+ if args.use_sliding_window_sampler:
+ state['sampler'] = sampler.state_dict()
+ file_name = os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')
+ torch.save(state, file_name)
+ copyfile(file_name, os.path.join(args.output_dir, 'current.pth'))
+
+
+def convert_checkpoint(args):
+ file_name = os.path.join(args.output_dir, 'current.pth')
+ output_file_name_C4 = os.path.join(args.output_dir, 'current_detectron2_C4.pkl')
+ output_file_name_Head = os.path.join(args.output_dir, 'current_detectron2_Head.pkl')
+ output_file_name_mmdet_Head = os.path.join(args.output_dir, 'current_mmdetection_Head.pth')
+
+ convert_detectron2_C4(file_name, output_file_name_C4)
+ convert_detectron2_Head(file_name, output_file_name_Head, start=2, num_outs=4)
+ convert_mmdetection_Head(file_name, output_file_name_mmdet_Head)
+
+
+def main(args):
+ train_prefix = 'train2017' if args.dataset == 'COCO' else 'train'
+ train_loader = get_loader(args.aug, args, prefix=train_prefix, return_coord=True)
+ args.num_instances = len(train_loader.dataset)
+ logger.info(f"length of training dataset: {args.num_instances}")
+
+ model, optimizer = build_model(args)
+ if dist.get_rank() == 0:
+ print(model)
+
+ scheduler = get_scheduler(optimizer, len(train_loader), args)
+
+ # optionally resume from a checkpoint
+ if args.pretrained_model:
+ assert os.path.isfile(args.pretrained_model)
+ load_pretrained(model, args.pretrained_model)
+ if args.auto_resume:
+ resume_file = os.path.join(args.output_dir, "current.pth")
+ if os.path.exists(resume_file):
+ logger.info(f'auto resume from {resume_file}')
+ args.resume = resume_file
+ else:
+ logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume')
+ if args.resume:
+ assert os.path.isfile(args.resume)
+ load_checkpoint(args, model, optimizer, scheduler, sampler=train_loader.sampler)
+
+ # tensorboard
+ if dist.get_rank() == 0:
+ summary_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ summary_writer = None
+
+ if args.use_sliding_window_sampler:
+ args.epochs = math.ceil(args.epochs * len(train_loader.dataset) / args.window_size)
+ for epoch in range(args.start_epoch, args.epochs + 1):
+ if isinstance(train_loader.sampler, DistributedSampler):
+ train_loader.sampler.set_epoch(epoch)
+
+ train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer)
+
+ if dist.get_rank() == 0 and (epoch % args.save_freq == 0 or epoch == args.epochs):
+ save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=train_loader.sampler)
+ if dist.get_rank() == 0 and epoch == args.epochs:
+ convert_checkpoint(args)
+
+
+def train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer):
+ """
+ one epoch training
+ """
+ model.train()
+
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+ loss_meter = AverageMeter()
+
+ end = time.time()
+ for idx, data in enumerate(train_loader):
+ data = [item.cuda(non_blocking=True) for item in data]
+ data_time.update(time.time() - end)
+
+ if args.model in ['SoCo_C4']:
+ loss = model(data[0], data[1], data[2], data[3], data[4])
+ elif args.model in ['SoCo_FPN',]:
+ loss = model(data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8])
+ elif args.model in ['SoCo_FPN_Star']:
+ loss = model(data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9], data[10], data[11], data[12])
+ else:
+ logit, label = model(data[0], data[1])
+ loss = F.cross_entropy(logit, label)
+
+ # backward
+ optimizer.zero_grad()
+ if args.amp_opt_level != "O0":
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
+
+ optimizer.step()
+ scheduler.step()
+
+ # update meters and print info
+ loss_meter.update(loss.item(), data[0].size(0))
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ train_len = len(train_loader)
+ if args.use_sliding_window_sampler:
+ train_len = int(args.window_size / args.batch_size / dist.get_world_size())
+ if idx % args.print_freq == 0:
+ lr = optimizer.param_groups[0]['lr']
+ logger.info(
+ f'Train: [{epoch}/{args.epochs}][{idx}/{train_len}] '
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
+ f'Data Time {data_time.val:.3f} ({data_time.avg:.3f}) '
+ f'lr {lr:.3f} '
+ f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})')
+ # tensorboard logger
+ if summary_writer is not None:
+ step = (epoch - 1) * len(train_loader) + idx
+ summary_writer.add_scalar('lr', lr, step)
+ summary_writer.add_scalar('loss', loss_meter.val, step)
+
+
+if __name__ == '__main__':
+ opt = parse_option(stage='pre_train')
+
+ if opt.amp_opt_level != "O0":
+ assert amp is not None, "amp not installed!"
+
+ torch.cuda.set_device(opt.local_rank)
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
+ cudnn.benchmark = True
+
+ # setup logger
+ os.makedirs(opt.output_dir, exist_ok=True)
+ logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="SoCo")
+ if dist.get_rank() == 0:
+ path = os.path.join(opt.output_dir, "config.json")
+ with open(path, 'w') as f:
+ json.dump(vars(opt), f, indent=2)
+ logger.info("Full config saved to {}".format(path))
+
+ # print args
+ logger.info(
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(opt)).items()))
+ )
+
+ if opt.debug:
+ logger.info('enable debug mode, set seed to 0')
+ set_random_seed(0)
+
+ main(opt)
diff --git a/selective_search/filter_ss_proposals_json.py b/selective_search/filter_ss_proposals_json.py
new file mode 100644
index 0000000..e0c241d
--- /dev/null
+++ b/selective_search/filter_ss_proposals_json.py
@@ -0,0 +1,73 @@
+import os
+import pickle
+import multiprocessing as mp
+import numpy as np
+import json
+
+from filters import filter_none, filter_ratio, filter_size, filter_ratio_size
+
+
+imagenet_root = './imagenet_root'
+imagenet_root_proposals = './imagenet_root_proposals_mp'
+filter_strategy = 'ratio3size0308' # 'ratio2', 'ratio3', 'ratio4', 'size01'
+print("filter_strategy", filter_strategy)
+filtered_proposals = './imagenet_filtered_proposals'
+
+split = 'train'
+
+
+filtered_proposals_dict = {}
+
+
+os.makedirs(filtered_proposals, exist_ok=True)
+json_path = os.path.join(filtered_proposals, f'{split}_{filter_strategy}.json')
+source_path = os.path.join(imagenet_root_proposals, split)
+
+class_names = sorted(os.listdir(os.path.join(imagenet_root_proposals, split)))
+
+no_props_images = []
+
+
+for ci, class_name in enumerate(class_names):
+ filenames = sorted(os.listdir(os.path.join(source_path, class_name)))
+ for fi, filename in enumerate(filenames):
+ base_filename = os.path.splitext(filename)[0]
+ cur_img_pro_path = os.path.join(source_path, class_name, filename)
+
+ with open(cur_img_pro_path, 'rb') as f:
+ cur_img_proposal = pickle.load(f)
+ if filter_strategy == 'none':
+ filtered_img_rects = filter_none(cur_img_proposal['regions'])
+ elif 'ratio' in filter_strategy and 'size' in filter_strategy:
+ ratio = float(filter_strategy[5])
+ min_size_ratio = float(filter_strategy[-4:-2]) / 10
+ max_size_ratio = float(filter_strategy[-2:]) / 10
+ filtered_img_rects = filter_ratio_size(cur_img_proposal['regions'], cur_img_proposal['label'].shape, ratio, min_size_ratio, max_size_ratio)
+ elif 'ratio' in filter_strategy:
+ ratio = float(filter_strategy[-1])
+ filtered_img_rects = filter_ratio(cur_img_proposal['regions'], r=ratio)
+ elif 'size' in filter_strategy:
+ min_size_ratio = float(filter_strategy[-2:]) / 10
+ filtered_img_rects = filter_size(cur_img_proposal['regions'], cur_img_proposal['label'].shape, min_size_ratio)
+ else:
+ raise NotImplementedError
+
+ filtered_proposals_dict[base_filename] = filtered_img_rects
+ if len(filtered_img_rects) == 0:
+ no_props_images.append(base_filename)
+ print(f"with strategy {filter_strategy}, image {base_filename} has no proposals")
+ if (fi + 1) % 100 == 0:
+ print(f"Processed [{ci}/{len(class_names)}] classes, [{fi+1}/{len(filenames)}] images")
+
+
+print(f"Finished filtering with strategy {filter_strategy}, there are {len(no_props_images)} images have no proposals.")
+
+
+with open(json_path, 'w') as f:
+ json.dump(filtered_proposals_dict, f)
+
+
+with open(json_path.replace('.json', 'no_props_images.txt'), 'w') as f:
+ for image_id in no_props_images:
+ f.write(image_id+'\n')
+
diff --git a/selective_search/filter_ss_proposals_json_post_no_prop.py b/selective_search/filter_ss_proposals_json_post_no_prop.py
new file mode 100644
index 0000000..1ed1ef3
--- /dev/null
+++ b/selective_search/filter_ss_proposals_json_post_no_prop.py
@@ -0,0 +1,36 @@
+import json
+import os
+import pickle
+
+from filters import filter_none, filter_ratio
+
+
+json_path = './imagenet_filtered_proposals/train_ratio3size0308.json'
+json_path_post = './imagenet_filtered_proposals/train_ratio3size0308post.json'
+no_props_images = open('./imagenet_filtered_proposals/train_ratio3size0308no_props_images.txt').readlines()
+imagenet_root_proposals = './imagenet_root_proposals_mp'
+
+
+with open(json_path, 'r') as f:
+ json_dict = json.load(f)
+
+
+for no_props_image in no_props_images:
+ filename = no_props_image.strip()
+ class_name = filename.split('_')[0]
+ cur_img_pro_path = os.path.join(imagenet_root_proposals, 'train', class_name, filename+'.pkl')
+
+ with open(cur_img_pro_path, 'rb') as f:
+ cur_img_proposal = pickle.load(f)
+ props_size_ratio = filter_ratio(cur_img_proposal['regions'], r=3)
+ props_none = filter_none(cur_img_proposal['regions'])
+ print("props_size_ratio", len(props_size_ratio))
+ print("props_none", len(props_none))
+ if len(props_size_ratio) > 0:
+ json_dict[filename] = props_size_ratio
+ elif len(props_none) > 0:
+ json_dict[filename] = filter_none(cur_img_proposal['regions'])
+
+
+with open(json_path_post, 'w') as f:
+ json.dump(json_dict, f)
diff --git a/selective_search/filters.py b/selective_search/filters.py
new file mode 100644
index 0000000..59fe26c
--- /dev/null
+++ b/selective_search/filters.py
@@ -0,0 +1,66 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+from math import sqrt
+
+'''
+regions : array of dict
+ [
+ {
+ 'rect': (left, top, width, height),
+ 'labels': [...],
+ 'size': component_size
+ },
+ ...
+ ]
+'''
+
+def filter_none(regions):
+ rects = []
+ for region in regions:
+ cur_rect = region['rect']
+ w, h = cur_rect[2], cur_rect[3]
+ if w >= 32 and h >= 32:
+ rects.append(cur_rect)
+ return rects
+
+
+def filter_ratio(regions, r=2):
+ rects = []
+ for region in regions:
+ cur_rect = region['rect']
+ w, h = cur_rect[2], cur_rect[3]
+ if w >= 32 and h >= 32 and (1.0 / r) <= (w / h) <= r:
+ rects.append(cur_rect)
+ return rects
+
+
+def filter_size(regions, image_size, min_ratio=0.1, max_ratio=1.0):
+ rects = []
+ ih, iw = image_size # get from ss_props label, which is h, w
+ img_sqrt_size = sqrt(iw * ih)
+ for region in regions:
+ cur_rect = region['rect']
+ w, h = cur_rect[2], cur_rect[3]
+ prop_sqrt_size = sqrt(w * h)
+ if w >= 32 and h >= 32 and min_ratio <= (prop_sqrt_size / img_sqrt_size) <= max_ratio:
+ rects.append(cur_rect)
+ return rects
+
+
+def filter_ratio_size(regions, image_size, r=2, min_ratio=0.1, max_ratio=1.0):
+ rects = []
+ ih, iw = image_size # get from ss_props label, which is h, w
+ img_sqrt_size = sqrt(iw * ih)
+ for region in regions:
+ cur_rect = region['rect']
+ w, h = cur_rect[2], cur_rect[3]
+ prop_sqrt_size = sqrt(w * h)
+ if w >= 32 and h >= 32 and (1.0 / r) <= (w / h) <= r and min_ratio <= (prop_sqrt_size / img_sqrt_size) <= max_ratio:
+ rects.append(cur_rect)
+ return rects
diff --git a/selective_search/generate_imagenet_ss_proposals.py b/selective_search/generate_imagenet_ss_proposals.py
new file mode 100644
index 0000000..ef7814e
--- /dev/null
+++ b/selective_search/generate_imagenet_ss_proposals.py
@@ -0,0 +1,71 @@
+# --------------------------------------------------------
+# SoCo
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Yue Gao
+# --------------------------------------------------------
+
+
+import multiprocessing as mp
+import os
+import pickle
+
+import numpy as np
+import PIL.Image as Image
+
+from .selective_search import selective_search
+
+imagenet_root = './imagenet_root'
+imagenet_root_proposals = './imagenet_root_proposals_mp'
+
+split = 'train'
+scale = 300
+min_size = 100
+
+processes_num = 48
+class_names = sorted(os.listdir(os.path.join(imagenet_root, split)))
+classes_num = len(class_names)
+classes_per_process = classes_num // processes_num + 1
+
+source_path = os.path.join(imagenet_root, split)
+target_path = os.path.join(imagenet_root_proposals, split)
+
+
+def process_one_class(process_id, classes_per_process, class_names, source_path, target_path):
+ print(f"Process id: {process_id} started")
+ for i in range(process_id*classes_per_process, process_id*classes_per_process + classes_per_process):
+ if i >= len(class_names):
+ break
+ class_name = class_names[i]
+ filenames = sorted(os.listdir(os.path.join(source_path, class_name)))
+ os.makedirs(os.path.join(target_path, class_name))
+ for filename in filenames:
+ base_filename = os.path.splitext(filename)[0]
+ img_path = os.path.join(source_path, class_name, filename)
+ img = np.array(Image.open(img_path).convert('RGB'))
+
+ img_with_lbl, regions, bboxs = selective_search(img, scale=scale, sigma=0.9, min_size=min_size)
+
+ region_label = img_with_lbl[:, :, 3]
+ cur_img_proposal = {}
+ cur_img_proposal['label'] = region_label
+ cur_img_proposal['regions'] = regions
+
+ cur_img_pro_path = os.path.join(target_path, class_name, base_filename+'.pkl')
+
+ with open(cur_img_pro_path, 'wb') as f:
+ pickle.dump(cur_img_proposal, f)
+ print("Process ", process_id, "processed class:", class_name)
+
+
+processes = [mp.Process(target=process_one_class,
+ args=(process_id, classes_per_process, class_names, source_path, target_path))
+ for process_id in range(processes_num)]
+
+# Run processes
+for p in processes:
+ p.start()
+
+# Exit the completed processes
+for p in processes:
+ p.join()
diff --git a/selective_search/selective_search.py b/selective_search/selective_search.py
new file mode 100644
index 0000000..4c8f62e
--- /dev/null
+++ b/selective_search/selective_search.py
@@ -0,0 +1,319 @@
+# -*- coding: utf-8 -*-
+from __future__ import division
+
+import numpy
+import skimage.color
+import skimage.feature
+import skimage.io
+import skimage.segmentation
+import skimage.transform
+import skimage.util
+
+# "Selective Search for Object Recognition" by J.R.R. Uijlings et al.
+#
+# - Modified version with LBP extractor for texture vectorization
+
+
+def _generate_segments(im_orig, scale, sigma, min_size):
+ """
+ segment smallest regions by the algorithm of Felzenswalb and
+ Huttenlocher
+ """
+
+ # open the Image
+ im_mask = skimage.segmentation.felzenszwalb(
+ skimage.util.img_as_float(im_orig), scale=scale, sigma=sigma,
+ min_size=min_size)
+
+ # merge mask channel to the image as a 4th channel
+ im_orig = numpy.append(
+ im_orig, numpy.zeros(im_orig.shape[:2])[:, :, numpy.newaxis], axis=2)
+ im_orig[:, :, 3] = im_mask
+
+ return im_orig
+
+
+def _sim_colour(r1, r2):
+ """
+ calculate the sum of histogram intersection of colour
+ """
+ return sum([min(a, b) for a, b in zip(r1["hist_c"], r2["hist_c"])])
+
+
+def _sim_texture(r1, r2):
+ """
+ calculate the sum of histogram intersection of texture
+ """
+ return sum([min(a, b) for a, b in zip(r1["hist_t"], r2["hist_t"])])
+
+
+def _sim_size(r1, r2, imsize):
+ """
+ calculate the size similarity over the image
+ """
+ return 1.0 - (r1["size"] + r2["size"]) / imsize
+
+
+def _sim_fill(r1, r2, imsize):
+ """
+ calculate the fill similarity over the image
+ """
+ bbsize = (
+ (max(r1["max_x"], r2["max_x"]) - min(r1["min_x"], r2["min_x"]))
+ * (max(r1["max_y"], r2["max_y"]) - min(r1["min_y"], r2["min_y"]))
+ )
+ return 1.0 - (bbsize - r1["size"] - r2["size"]) / imsize
+
+
+def _calc_sim(r1, r2, imsize):
+ return (_sim_colour(r1, r2) + _sim_texture(r1, r2)
+ + _sim_size(r1, r2, imsize) + _sim_fill(r1, r2, imsize))
+
+
+def _calc_colour_hist(img):
+ """
+ calculate colour histogram for each region
+
+ the size of output histogram will be BINS * COLOUR_CHANNELS(3)
+
+ number of bins is 25 as same as [uijlings_ijcv2013_draft.pdf]
+
+ extract HSV
+ """
+
+ BINS = 25
+ hist = numpy.array([])
+
+ for colour_channel in (0, 1, 2):
+
+ # extracting one colour channel
+ c = img[:, colour_channel]
+
+ # calculate histogram for each colour and join to the result
+ hist = numpy.concatenate(
+ [hist] + [numpy.histogram(c, BINS, (0.0, 255.0))[0]])
+
+ # L1 normalize
+ hist = hist / len(img)
+
+ return hist
+
+
+def _calc_texture_gradient(img):
+ """
+ calculate texture gradient for entire image
+
+ The original SelectiveSearch algorithm proposed Gaussian derivative
+ for 8 orientations, but we use LBP instead.
+
+ output will be [height(*)][width(*)]
+ """
+ ret = numpy.zeros((img.shape[0], img.shape[1], img.shape[2]))
+
+ for colour_channel in (0, 1, 2):
+ ret[:, :, colour_channel] = skimage.feature.local_binary_pattern(
+ img[:, :, colour_channel], 8, 1.0)
+
+ return ret
+
+
+def _calc_texture_hist(img):
+ """
+ calculate texture histogram for each region
+
+ calculate the histogram of gradient for each colours
+ the size of output histogram will be
+ BINS * ORIENTATIONS * COLOUR_CHANNELS(3)
+ """
+ BINS = 10
+
+ hist = numpy.array([])
+
+ for colour_channel in (0, 1, 2):
+
+ # mask by the colour channel
+ fd = img[:, colour_channel]
+
+ # calculate histogram for each orientation and concatenate them all
+ # and join to the result
+ hist = numpy.concatenate(
+ [hist] + [numpy.histogram(fd, BINS, (0.0, 255.0))[0]]) # there is a bug https://github.com/AlpacaDB/selectivesearch/issues/30
+
+ # L1 Normalize
+ hist = hist / len(img)
+
+ return hist
+
+
+def _extract_regions(img):
+
+ R = {}
+
+ # get hsv image
+ hsv = skimage.color.rgb2hsv(img[:, :, :3])
+
+ # pass 1: count pixel positions
+ for y, i in enumerate(img):
+
+ for x, (r, g, b, l) in enumerate(i):
+
+ # initialize a new region
+ if l not in R:
+ R[l] = {
+ "min_x": 0xffff, "min_y": 0xffff,
+ "max_x": 0, "max_y": 0, "labels": [l]}
+
+ # bounding box
+ if R[l]["min_x"] > x:
+ R[l]["min_x"] = x
+ if R[l]["min_y"] > y:
+ R[l]["min_y"] = y
+ if R[l]["max_x"] < x:
+ R[l]["max_x"] = x
+ if R[l]["max_y"] < y:
+ R[l]["max_y"] = y
+
+ # pass 2: calculate texture gradient
+ tex_grad = _calc_texture_gradient(img)
+
+ # pass 3: calculate colour, texture histogram of each region
+ for k, v in list(R.items()):
+ # k is the region label
+ # colour histogram
+ masked_pixels = hsv[:, :, :][img[:, :, 3] == k]
+ R[k]["size"] = len(masked_pixels / 4) # number of pixels in the region, why /4 ?, in the repo, the question is not answered
+ R[k]["hist_c"] = _calc_colour_hist(masked_pixels) # calculated on HSV space
+
+ # texture histogram
+ R[k]["hist_t"] = _calc_texture_hist(tex_grad[:, :][img[:, :, 3] == k]) # calculated on RGB space
+
+ return R
+
+
+def _extract_neighbors(regions):
+
+ def intersect(a, b):
+ if (a["min_x"] < b["min_x"] < a["max_x"]
+ and a["min_y"] < b["min_y"] < a["max_y"]) or (
+ a["min_x"] < b["max_x"] < a["max_x"]
+ and a["min_y"] < b["max_y"] < a["max_y"]) or (
+ a["min_x"] < b["min_x"] < a["max_x"]
+ and a["min_y"] < b["max_y"] < a["max_y"]) or (
+ a["min_x"] < b["max_x"] < a["max_x"]
+ and a["min_y"] < b["min_y"] < a["max_y"]):
+ return True
+ return False
+
+ R = list(regions.items())
+ neighbors = []
+ for cur, a in enumerate(R[:-1]):
+ for b in R[cur + 1:]:
+ if intersect(a[1], b[1]):
+ neighbors.append((a, b))
+
+ return neighbors
+
+
+def _merge_regions(r1, r2):
+ new_size = r1["size"] + r2["size"]
+ rt = {
+ "min_x": min(r1["min_x"], r2["min_x"]),
+ "min_y": min(r1["min_y"], r2["min_y"]),
+ "max_x": max(r1["max_x"], r2["max_x"]),
+ "max_y": max(r1["max_y"], r2["max_y"]),
+ "size": new_size,
+ "hist_c": (
+ r1["hist_c"] * r1["size"] + r2["hist_c"] * r2["size"]) / new_size,
+ "hist_t": (
+ r1["hist_t"] * r1["size"] + r2["hist_t"] * r2["size"]) / new_size,
+ "labels": r1["labels"] + r2["labels"]
+ }
+ return rt
+
+
+def selective_search(im_orig, scale=1.0, sigma=0.8, min_size=50):
+ '''Selective Search
+
+ Parameters
+ ----------
+ im_orig : ndarray
+ Input image
+ scale : int
+ Free parameter. Higher means larger clusters in felzenszwalb segmentation.
+ sigma : float
+ Width of Gaussian kernel for felzenszwalb segmentation.
+ min_size : int
+ Minimum component size for felzenszwalb segmentation.
+ Returns
+ -------
+ img : ndarray
+ image with region label
+ region label is stored in the 4th value of each pixel [r,g,b,(region)]
+ regions : array of dict
+ [
+ {
+ 'rect': (left, top, width, height),
+ 'labels': [...],
+ 'size': component_size
+ },
+ ...
+ ]
+ '''
+ assert im_orig.shape[2] == 3, "3ch image is expected"
+
+ # load image and get smallest regions
+ # region label is stored in the 4th value of each pixel [r,g,b,(region)]
+ img = _generate_segments(im_orig, scale, sigma, min_size) # img first 3 channel values are in [0, 255]
+
+ if img is None:
+ return None, {}
+
+ imsize = img.shape[0] * img.shape[1]
+ R = _extract_regions(img)
+
+ # extract neighboring information
+ neighbors = _extract_neighbors(R)
+
+ # calculate initial similarities
+ S = {}
+ for (ai, ar), (bi, br) in neighbors:
+ S[(ai, bi)] = _calc_sim(ar, br, imsize)
+
+ # hierarchal search
+ while S != {}:
+
+ # get highest similarity
+ i, j = sorted(S.items(), key=lambda i: i[1])[-1][0]
+
+ # merge corresponding regions
+ t = max(R.keys()) + 1.0
+ R[t] = _merge_regions(R[i], R[j])
+
+ # mark similarities for regions to be removed
+ key_to_delete = []
+ for k, v in list(S.items()):
+ if (i in k) or (j in k):
+ key_to_delete.append(k)
+
+ # remove old similarities of related regions
+ for k in key_to_delete:
+ del S[k]
+
+ # calculate similarity set with the new region
+ for k in [a for a in key_to_delete if a != (i, j)]:
+ n = k[1] if k[0] in (i, j) else k[0]
+ S[(t, n)] = _calc_sim(R[t], R[n], imsize)
+
+ regions = []
+ boxes = []
+ for k, r in list(R.items()):
+ regions.append({
+ 'rect': (
+ r['min_x'], r['min_y'],
+ r['max_x'] - r['min_x'], r['max_y'] - r['min_y']),
+ 'size': r['size'],
+ 'labels': r['labels']
+ })
+ boxes.append([int(r['min_x']), int(r['min_y']), int(r['max_x'] - r['min_x']), int(r['max_y'] - r['min_y'])])
+
+ return img, regions, boxes
diff --git a/tools/SoCo_C4_100ep.sh b/tools/SoCo_C4_100ep.sh
new file mode 100755
index 0000000..8f72c87
--- /dev/null
+++ b/tools/SoCo_C4_100ep.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-100}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_C4_100ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \
+ SoCo/main_pretrain.py \
+ --data_dir ${data_dir} \
+ --crop 0.6 \
+ --base_lr 1.0 \
+ --optimizer lars \
+ --weight_decay 1e-5 \
+ --amp_opt_level O1 \
+ --ss_props \
+ --auto_resume \
+ --aug ImageAsymBboxCutout \
+ --zip --cache_mode no \
+ --arch resnet50 \
+ --model SoCo_C4 \
+ --warmup_epoch ${warmup} \
+ --epochs ${epochs} \
+ --output_dir ${output_dir} \
+ --save_freq 10 \
+ --batch_size 128 \
+ --contrast_momentum 0.99 \
+ --filter_strategy ratio3size0308post \
+ --select_strategy random \
+ --select_k 4 \
+ --output_size 7 \
+ --aligned \
+ --jitter_ratio 0.1 \
+ --padding_k 4 \
+ --cutout_prob 0.5 \
+ --cutout_ratio_min 0.1 \
+ --cutout_ratio_max 0.3
diff --git a/tools/SoCo_C4_100ep_linear.sh b/tools/SoCo_C4_100ep_linear.sh
new file mode 100755
index 0000000..ede931c
--- /dev/null
+++ b/tools/SoCo_C4_100ep_linear.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-100}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_C4_100ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \
+ SoCo/main_linear.py \
+ --data_dir ${data_dir} \
+ --zip --cache_mode part \
+ --arch resnet50 \
+ --output_dir ${output_dir}_linear_eval \
+ --pretrained_model ${output_dir}/current.pth \
+ --save_freq 10 \
+ --auto_resume
diff --git a/tools/SoCo_C4_400ep.sh b/tools/SoCo_C4_400ep.sh
new file mode 100755
index 0000000..17859a9
--- /dev/null
+++ b/tools/SoCo_C4_400ep.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_C4_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \
+ SoCo/main_pretrain.py \
+ --data_dir ${data_dir} \
+ --crop 0.6 \
+ --base_lr 1.0 \
+ --optimizer lars \
+ --weight_decay 1e-5 \
+ --amp_opt_level O1 \
+ --ss_props \
+ --auto_resume \
+ --aug ImageAsymBboxCutout \
+ --zip --cache_mode no \
+ --arch resnet50 \
+ --model SoCo_C4 \
+ --warmup_epoch ${warmup} \
+ --epochs ${epochs} \
+ --output_dir ${output_dir} \
+ --save_freq 10 \
+ --batch_size 128 \
+ --contrast_momentum 0.99 \
+ --filter_strategy ratio3size0308post \
+ --select_strategy random \
+ --select_k 4 \
+ --output_size 7 \
+ --aligned \
+ --jitter_ratio 0.1 \
+ --padding_k 4 \
+ --cutout_prob 0.5 \
+ --cutout_ratio_min 0.1 \
+ --cutout_ratio_max 0.3
+
diff --git a/tools/SoCo_C4_400ep_linear.sh b/tools/SoCo_C4_400ep_linear.sh
new file mode 100755
index 0000000..9457962
--- /dev/null
+++ b/tools/SoCo_C4_400ep_linear.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./self_det_output/SoCo_C4_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \
+ SoCo/main_linear.py \
+ --data_dir ${data_dir} \
+ --zip --cache_mode part \
+ --arch resnet50 \
+ --output_dir ${output_dir}_linear_eval \
+ --pretrained_model ${output_dir}/current.pth \
+ --save_freq 10 \
+ --auto_resume
diff --git a/tools/SoCo_FPN_100ep.sh b/tools/SoCo_FPN_100ep.sh
new file mode 100755
index 0000000..8c2f22e
--- /dev/null
+++ b/tools/SoCo_FPN_100ep.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-100}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_100ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \
+ SoCo/main_pretrain.py \
+ --data_dir ${data_dir} \
+ --crop 0.5 \
+ --base_lr 1.0 \
+ --optimizer lars \
+ --weight_decay 1e-5 \
+ --amp_opt_level O1 \
+ --ss_props \
+ --auto_resume \
+ --aug ImageAsymBboxAwareMultiJitter1Cutout \
+ --zip --cache_mode no \
+ --arch resnet50 \
+ --model SoCo_FPN \
+ --warmup_epoch ${warmup} \
+ --epochs ${epochs} \
+ --output_dir ${output_dir} \
+ --save_freq 1 \
+ --batch_size 128 \
+ --contrast_momentum 0.99 \
+ --filter_strategy ratio3size0308post \
+ --select_strategy random \
+ --select_k 4 \
+ --output_size 7 \
+ --aligned \
+ --jitter_prob 0.5 \
+ --jitter_ratio 0.1 \
+ --padding_k 4 \
+ --start_level 0 \
+ --num_outs 4 \
+ --add_extra_convs 0 \
+ --extra_convs_on_inputs 0 \
+ --relu_before_extra_convs 0 \
+ --aware_start 0 \
+ --aware_end 4 \
+ --cutout_prob 0.5 \
+ --cutout_ratio_min 0.1 \
+ --cutout_ratio_max 0.3
diff --git a/tools/SoCo_FPN_100ep_linear.sh b/tools/SoCo_FPN_100ep_linear.sh
new file mode 100755
index 0000000..de06bf2
--- /dev/null
+++ b/tools/SoCo_FPN_100ep_linear.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-100}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_100ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \
+ SoCo/main_linear.py \
+ --data_dir ${data_dir} \
+ --zip --cache_mode part \
+ --arch resnet50 \
+ --output_dir ${output_dir}_linear_eval \
+ --pretrained_model ${output_dir}/current.pth \
+ --save_freq 10 \
+ --auto_resume
diff --git a/tools/SoCo_FPN_400ep.sh b/tools/SoCo_FPN_400ep.sh
new file mode 100755
index 0000000..10121a6
--- /dev/null
+++ b/tools/SoCo_FPN_400ep.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \
+ SoCo/main_pretrain.py \
+ --data_dir ${data_dir} \
+ --crop 0.5 \
+ --base_lr 1.0 \
+ --optimizer lars \
+ --weight_decay 1e-5 \
+ --amp_opt_level O1 \
+ --ss_props \
+ --auto_resume \
+ --aug ImageAsymBboxAwareMultiJitter1 \
+ --zip --cache_mode no \
+ --arch resnet50 \
+ --model SoCo_FPN \
+ --warmup_epoch ${warmup} \
+ --epochs ${epochs} \
+ --output_dir ${output_dir} \
+ --save_freq 1 \
+ --batch_size 128 \
+ --contrast_momentum 0.99 \
+ --filter_strategy ratio3size0308post \
+ --select_strategy random \
+ --select_k 4 \
+ --output_size 7 \
+ --aligned \
+ --jitter_prob 0.5 \
+ --jitter_ratio 0.1 \
+ --padding_k 4 \
+ --start_level 0 \
+ --num_outs 4 \
+ --add_extra_convs 0 \
+ --extra_convs_on_inputs 0 \
+ --relu_before_extra_convs 0 \
+ --aware_start 0 \
+ --aware_end 4 \
+ --cutout_prob 0.5 \
+ --cutout_ratio_min 0.1 \
+ --cutout_ratio_max 0.3
+
+# python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \
+# main_linear.py \
+# --data_dir ${data_dir} \
+# --zip --cache_mode part \
+# --arch resnet50 \
+# --output_dir ${output_dir}/eval \
+# --pretrained_model ${output_dir}/current.pth \
diff --git a/tools/SoCo_FPN_400ep_linear.sh b/tools/SoCo_FPN_400ep_linear.sh
new file mode 100755
index 0000000..12111dd
--- /dev/null
+++ b/tools/SoCo_FPN_400ep_linear.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \
+ SoCo/main_linear.py \
+ --data_dir ${data_dir} \
+ --zip --cache_mode part \
+ --arch resnet50 \
+ --output_dir ${output_dir}_linear_eval \
+ --pretrained_model ${output_dir}/current.pth \
+ --save_freq 10 \
+ --auto_resume
diff --git a/tools/SoCo_FPN_Star_400ep.sh b/tools/SoCo_FPN_Star_400ep.sh
new file mode 100755
index 0000000..fea5df5
--- /dev/null
+++ b/tools/SoCo_FPN_Star_400ep.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_Star_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes ${OMPI_COMM_WORLD_SIZE} --node_rank ${OMPI_COMM_WORLD_RANK} --master_addr ${master_addr} --master_port ${master_port} \
+ SoCo/main_pretrain.py \
+ --data_dir ${data_dir} \
+ --crop 0.5 \
+ --base_lr 1.0 \
+ --optimizer lars \
+ --weight_decay 1e-5 \
+ --amp_opt_level O1 \
+ --ss_props \
+ --auto_resume \
+ --aug ImageAsymBboxAwareMulti3ResizeExtraJitter1 \
+ --zip --cache_mode no \
+ --arch resnet50 \
+ --model SoCo_FPN_Star \
+ --warmup_epoch ${warmup} \
+ --epochs ${epochs} \
+ --output_dir ${output_dir} \
+ --save_freq 1 \
+ --batch_size 128 \
+ --contrast_momentum 0.99 \
+ --filter_strategy ratio3size0308post \
+ --select_strategy random \
+ --select_k 4 \
+ --output_size 7 \
+ --aligned \
+ --jitter_prob 0.5 \
+ --jitter_ratio 0.1 \
+ --padding_k 4 \
+ --start_level 0 \
+ --num_outs 4 \
+ --add_extra_convs 0 \
+ --extra_convs_on_inputs 0 \
+ --relu_before_extra_convs 0 \
+ --aware_start 0 \
+ --aware_end 4 \
+ --cutout_prob 0.5 \
+ --cutout_ratio_min 0.1 \
+ --cutout_ratio_max 0.3 \
+ --image3_size 192 \
+ --image4_size 112 \
diff --git a/tools/SoCo_FPN_Star_400ep_linear.sh b/tools/SoCo_FPN_Star_400ep_linear.sh
new file mode 100755
index 0000000..e03b80e
--- /dev/null
+++ b/tools/SoCo_FPN_Star_400ep_linear.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+warmup=${1:-10}
+epochs=${2:-400}
+
+data_dir="./data/ImageNet-Zip"
+output_dir="./SoCo_output/SoCo_FPN_Star_400ep"
+
+master_addr=${MASTER_IP}
+master_port=28652
+
+
+python -m torch.distributed.launch --master_port ${master_port} --nproc_per_node=8 \
+ SoCo/main_linear.py \
+ --data_dir ${data_dir} \
+ --zip --cache_mode part \
+ --arch resnet50 \
+ --output_dir ${output_dir}_linear_eval \
+ --pretrained_model ${output_dir}/current.pth \
+ --save_freq 10 \
+ --auto_resume