diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..4c23622 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,28 @@ +version: 2.1 + +jobs: + python_lint: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off flake8 typing + flake8 . + test: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off pytest + pip install --user --progress-bar off torch torchvision + pip install --user --progress-bar off timm==0.3.2 + pytest . + +workflows: + build: + jobs: + - python_lint diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..071fa33 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.swp +**/__pycache__/** +imnet_resnet50_scratch/timm_temp/ +.dumbo.json +checkpoints/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..83ce9a6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Shoufa Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..12b90a1 --- /dev/null +++ b/README.md @@ -0,0 +1,91 @@ +# A MLP-like Architecture for Dense Prediction + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +![Python 3.8](https://img.shields.io/badge/python-3.8-green.svg) + + + +
+ + + +
+ +# Updates + +- (22/07/2021) Initial release. + + + +# Model Zoo + +We provide CycleMLP models pretrained on ImageNet 2012. + +| Model | Parameters | FLOPs | Top 1 Acc. | Download | +| :------------------- | :--------- | :------- | :--------- | :------- | +| CycleMLP-B1 | 15M | 2.1G | 78.9% | | +| CycleMLP-B2 | 27M | 3.9G | 81.6% | | +| CycleMLP-B3 | 38M | 6.9G | 82.4% | | +| CycleMLP-B4 | 52M | 10.1G | 83.0% | | +| CycleMLP-B5 | 76M | 12.3G | 83.2% | | + + +# Usage + + +## Install + +- PyTorch 1.7.0+ and torchvision 0.8.1+ +- [timm](https://github.com/rwightman/pytorch-image-models/tree/c2ba229d995c33aaaf20e00a5686b4dc857044be): +``` +pip install 'git+https://github.com/rwightman/pytorch-image-models@c2ba229d995c33aaaf20e00a5686b4dc857044be' + +or + +git clone https://github.com/rwightman/pytorch-image-models +cd pytorch-image-models +git checkout c2ba229d995c33aaaf20e00a5686b4dc857044be +pip install -e . +``` +- fvcore (optional, for FLOPs calculation) +- mmcv, mmdetection, mmsegmentation (optional) + +## Data preparation + +Download and extract ImageNet train and val images from http://image-net.org/. +The directory structure is: + +``` +│path/to/imagenet/ +├──train/ +│ ├── n01440764 +│ │ ├── n01440764_10026.JPEG +│ │ ├── n01440764_10027.JPEG +│ │ ├── ...... +│ ├── ...... +├──val/ +│ ├── n01440764 +│ │ ├── ILSVRC2012_val_00000293.JPEG +│ │ ├── ILSVRC2012_val_00002138.JPEG +│ │ ├── ...... +│ ├── ...... +``` + +## Evaluation +To evaluate a pre-trained CycleMLP-B5 on ImageNet val with a single GPU run: +``` +python main.py --eval --model CycleMLP_B5 --resume path/to/CycleMLP_B5.pth --data-path /path/to/imagenet +``` + + +## Training + +To train CycleMLP-B5 on ImageNet on a single node with 8 gpus for 300 epochs run: +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model CycleMLP_B5 --batch-size 128 --data-path /path/to/imagenet --output_dir /path/to/save +``` + + +# License + +CycleMLP is released under MIT License. diff --git a/cycle_mlp.py b/cycle_mlp.py new file mode 100644 index 0000000..3dd16c3 --- /dev/null +++ b/cycle_mlp.py @@ -0,0 +1,484 @@ +import os +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +from timm.models.layers.helpers import to_2tuple + +import math +from torch import Tensor +from torch.nn import init +from torch.nn.modules.utils import _pair +from torchvision.ops.deform_conv import deform_conv2d as deform_conv2d_tv + +try: + from mmseg.models.builder import BACKBONES as seg_BACKBONES + from mmseg.utils import get_root_logger + from semantic.custom_fun import load_checkpoint + has_mmseg = True +except ImportError: + print('Please Install mmsegmentation first for semantic segmentation.') + has_mmseg = False + +try: + from mmdet.models.builder import BACKBONES as det_BACKBONES + from mmdet.utils import get_root_logger + has_mmdet = True +except ImportError: + print('Please Install mmdetection first for object detection and instance segmentation.') + has_mmdet = False + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + **kwargs + } + +default_cfgs = { + 'cycle_S': _cfg(crop_pct=0.9), + 'cycle_M': _cfg(crop_pct=0.9), + 'cycle_L': _cfg(crop_pct=0.875), +} + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CycleFC(nn.Module): + """ + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, # re-defined kernel_size, represent the spatial area of staircase FC + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super(CycleFC, self).__init__() + + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + if stride != 1: + raise ValueError('stride must be 1') + if padding != 0: + raise ValueError('padding must be 0') + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + + self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, 1, 1)) # kernel size == 1 + + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter('bias', None) + self.register_buffer('offset', self.gen_offset()) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def gen_offset(self): + """ + offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, + out_height, out_width]): offsets to be applied for each position in the + convolution kernel. + """ + offset = torch.empty(1, self.in_channels*2, 1, 1) + start_idx = (self.kernel_size[0] * self.kernel_size[1]) // 2 + assert self.kernel_size[0] == 1 or self.kernel_size[1] == 1, self.kernel_size + for i in range(self.in_channels): + if self.kernel_size[0] == 1: + offset[0, 2 * i + 0, 0, 0] = 0 + offset[0, 2 * i + 1, 0, 0] = (i + start_idx) % self.kernel_size[1] - (self.kernel_size[1] // 2) + else: + offset[0, 2 * i + 0, 0, 0] = (i + start_idx) % self.kernel_size[0] - (self.kernel_size[0] // 2) + offset[0, 2 * i + 1, 0, 0] = 0 + return offset + + def forward(self, input: Tensor) -> Tensor: + """ + Args: + input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor + """ + B, C, H, W = input.size() + return deform_conv2d_tv(input, self.offset.expand(B, -1, H, W), self.weight, self.bias, stride=self.stride, + padding=self.padding, dilation=self.dilation) + + def extra_repr(self) -> str: + s = self.__class__.__name__ + '(' + s += '{in_channels}' + s += ', {out_channels}' + s += ', kernel_size={kernel_size}' + s += ', stride={stride}' + s += ', padding={padding}' if self.padding != (0, 0) else '' + s += ', dilation={dilation}' if self.dilation != (1, 1) else '' + s += ', groups={groups}' if self.groups != 1 else '' + s += ', bias=False' if self.bias is None else '' + s += ')' + return s.format(**self.__dict__) + + +class CycleMLP(nn.Module): + def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias) + + self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0) + self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0) + + self.reweight = Mlp(dim, dim // 4, dim * 3) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, H, W, C = x.shape + h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + c = self.mlp_c(x) + + a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) + a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) + + x = h * a[0] + w * a[1] + c * a[2] + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CycleBlock(nn.Module): + + def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn=CycleMLP): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = mlp_fn(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.skip_lam = skip_lam + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam + x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam + return x + + +class PatchEmbedOverlapping(nn.Module): + """ 2D Image to Patch Embedding with overlapping + """ + def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=None, groups=1): + super().__init__() + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + self.patch_size = patch_size + # remove image_size in model init to support dynamic image size + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding, groups=groups) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + return x + + +class Downsample(nn.Module): + """ Downsample transition stage + """ + def __init__(self, in_embed_dim, out_embed_dim, patch_size): + super().__init__() + assert patch_size == 2, patch_size + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=1) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.proj(x) # B, C, H, W + x = x.permute(0, 2, 3, 1) + return x + + +def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0., + drop_path_rate=0., skip_lam=1.0, mlp_fn=CycleMLP, **kwargs): + blocks = [] + + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(CycleBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn=mlp_fn)) + blocks = nn.Sequential(*blocks) + + return blocks + + +class CycleNet(nn.Module): + """ CycleMLP Network """ + def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0, + qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=nn.LayerNorm, mlp_fn=CycleMLP, fork_feat=False): + + super().__init__() + if not fork_feat: + self.num_classes = num_classes + self.fork_feat = fork_feat + + self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0]) + + network = [] + for i in range(len(layers)): + stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, + norm_layer=norm_layer, skip_lam=skip_lam, mlp_fn=mlp_fn) + network.append(stage) + if i >= len(layers) - 1: + break + if transitions[i] or embed_dims[i] != embed_dims[i+1]: + patch_size = 2 if transitions[i] else 1 + network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size)) + + self.network = nn.ModuleList(network) + + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get('FORK_LAST3', None): + # TODO: more elegant way + """For RetinaNet, `start_level=1`. The first norm layer will not used. + cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` + """ + layer = nn.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + else: + # Classifier head + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.apply(self.cls_init_weights) + + def cls_init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, CycleFC): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def init_weights(self, pretrained=None): + """ mmseg or mmdet `init_weight` """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_embeddings(self, x): + x = self.patch_embed(x) + # B,C,H,W-> B,H,W,C + x = x.permute(0, 2, 3, 1) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if self.fork_feat and idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out.permute(0, 3, 1, 2).contiguous()) + if self.fork_feat: + return outs + + B, H, W, C = x.shape + x = x.reshape(B, -1, C) + return x + + def forward(self, x): + x = self.forward_embeddings(x) + # B, H, W, C -> B, N, C + x = self.forward_tokens(x) + if self.fork_feat: + return x + + x = self.norm(x) + cls_out = self.head(x.mean(1)) + return cls_out + + +@register_model +def CycleMLP_B1(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 2, 4, 2] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, **kwargs) + model.default_cfg = default_cfgs['cycle_S'] + return model + + +@register_model +def CycleMLP_B2(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 3, 10, 3] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, **kwargs) + model.default_cfg = default_cfgs['cycle_S'] + return model + + +@register_model +def CycleMLP_B3(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [3, 4, 18, 3] + mlp_ratios = [8, 8, 4, 4] + embed_dims = [64, 128, 320, 512] + model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, **kwargs) + model.default_cfg = default_cfgs['cycle_M'] + return model + + +@register_model +def CycleMLP_B4(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [3, 8, 27, 3] + mlp_ratios = [8, 8, 4, 4] + embed_dims = [64, 128, 320, 512] + model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, **kwargs) + model.default_cfg = default_cfgs['cycle_L'] + return model + + +@register_model +def CycleMLP_B5(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [3, 4, 24, 3] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [96, 192, 384, 768] + model = CycleNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, **kwargs) + model.default_cfg = default_cfgs['cycle_L'] + return model + + +if has_mmseg and has_mmdet: + # For dense prediction tasks only + @seg_BACKBONES.register_module() + @det_BACKBONES.register_module() + class CycleMLP_B1_feat(CycleNet): + def __init__(self, **kwargs): + transitions = [True, True, True, True] + layers = [2, 2, 4, 2] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + super(CycleMLP_B1_feat, self).__init__(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, fork_feat=True) + + @seg_BACKBONES.register_module() + @det_BACKBONES.register_module() + class CycleMLP_B2_feat(CycleNet): + def __init__(self, **kwargs): + transitions = [True, True, True, True] + layers = [2, 3, 10, 3] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + super(CycleMLP_B2_feat, self).__init__(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, fork_feat=True) + + + @seg_BACKBONES.register_module() + @det_BACKBONES.register_module() + class CycleMLP_B3_feat(CycleNet): + def __init__(self, **kwargs): + transitions = [True, True, True, True] + layers = [3, 4, 18, 3] + mlp_ratios = [8, 8, 4, 4] + embed_dims = [64, 128, 320, 512] + super(CycleMLP_B3_feat, self).__init__(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, fork_feat=True) + + @seg_BACKBONES.register_module() + @det_BACKBONES.register_module() + class CycleMLP_B4_feat(CycleNet): + def __init__(self, **kwargs): + transitions = [True, True, True, True] + layers = [3, 8, 27, 3] + mlp_ratios = [8, 8, 4, 4] + embed_dims = [64, 128, 320, 512] + super(CycleMLP_B4_feat, self).__init__(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, fork_feat=True) + + + @seg_BACKBONES.register_module() + @det_BACKBONES.register_module() + class CycleMLP_B5_feat(CycleNet): + def __init__(self, **kwargs): + transitions = [True, True, True, True] + layers = [3, 4, 24, 3] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [96, 192, 384, 768] + super(CycleMLP_B5_feat, self).__init__(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mlp_fn=CycleMLP, fork_feat=True) diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..1fd8f3c --- /dev/null +++ b/datasets.py @@ -0,0 +1,114 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import os +import json + +from torchvision import datasets, transforms +from torchvision.datasets.folder import ImageFolder, default_loader + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import create_transform + +from mcloader import ClassificationDataset + + +class INatDataset(ImageFolder): + def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, + category='name', loader=default_loader): + self.transform = transform + self.loader = loader + self.target_transform = target_transform + self.year = year + # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] + path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') + with open(path_json) as json_file: + data = json.load(json_file) + + with open(os.path.join(root, 'categories.json')) as json_file: + data_catg = json.load(json_file) + + path_json_for_targeter = os.path.join(root, f"train{year}.json") + + with open(path_json_for_targeter) as json_file: + data_for_targeter = json.load(json_file) + + targeter = {} + indexer = 0 + for elem in data_for_targeter['annotations']: + king = [] + king.append(data_catg[int(elem['category_id'])][category]) + if king[0] not in targeter.keys(): + targeter[king[0]] = indexer + indexer += 1 + self.nb_classes = len(targeter) + + self.samples = [] + for elem in data['images']: + cut = elem['file_name'].split('/') + target_current = int(cut[2]) + path_current = os.path.join(root, cut[0], cut[2], cut[3]) + + categors = data_catg[target_current] + target_current_true = targeter[categors[category]] + self.samples.append((path_current, target_current_true)) + + # __getitem__ and __len__ inherited from ImageFolder + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + if args.data_set == 'CIFAR': + dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) + nb_classes = 100 + elif args.data_set == 'IMNET': + if not args.mcloader: + root = os.path.join(args.data_path, 'train' if is_train else 'val') + dataset = datasets.ImageFolder(root, transform=transform) + else: + dataset = ClassificationDataset(args.data_path, 'train' if is_train else 'val', pipeline=transform) + nb_classes = 1000 + elif args.data_set == 'INAT': + dataset = INatDataset(args.data_path, train=is_train, year=2018, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + elif args.data_set == 'INAT19': + dataset = INatDataset(args.data_path, train=is_train, year=2019, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + + return dataset, nb_classes + + +def build_transform(is_train, args): + resize_im = args.input_size > 32 + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation=args.train_interpolation, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + ) + if not resize_im: + # replace RandomResizedCropAndInterpolation with + # RandomCrop + transform.transforms[0] = transforms.RandomCrop( + args.input_size, padding=4) + return transform + + t = [] + if resize_im: + size = int((256 / 224) * args.input_size) + t.append( + transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + return transforms.Compose(t) diff --git a/detection/README.md b/detection/README.md new file mode 100644 index 0000000..10aab2d --- /dev/null +++ b/detection/README.md @@ -0,0 +1,26 @@ +# CycleMLP for COCO Object Detection and Instance Segmentation + +1. Install mmcv +``` +# cuda10.1 pytorch1.7.1 +pip install mmcv-full==1.3.5 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.1/index.html + +or + +git clone https://github.com/open-mmlab/mmcv.git +git checkout v1.3.5 +MMCV_WITH_OPS=1 pip install -e . + +# for pytorch>=1.9: +# https://github.com/open-mmlab/mmcv/pull/1138 +``` + +2. Install mmdetection v2.11.0 + +``` +pip install 'git+https://github.com/open-mmlab/mmdetection@2894516bacf9ff82c3bc6d6970019d0890a993aa' +``` + + +Code and configs are coming soon. + diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..aa4a987 --- /dev/null +++ b/engine.py @@ -0,0 +1,102 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Train and eval functions used in main.py +""" +import math +import sys +from typing import Iterable, Optional + +import torch + +from timm.data import Mixup +from timm.utils import accuracy, ModelEma + +from losses import DistillationLoss +import utils + + +def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, + model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, + set_training_mode=True, amp_autocast=None): + model.train(set_training_mode) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 10 + + for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with amp_autocast(): + outputs = model(samples) + loss = criterion(samples, outputs, targets) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + optimizer.zero_grad() + + # this attribute is added by timm on one optimizer (adahessian) + is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + if loss_scaler is not None: + loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=is_second_order) + else: + loss.backward(create_graph=is_second_order) + if max_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + optimizer.step() + + torch.cuda.synchronize() + if model_ema is not None: + model_ema.update(model) + + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device, amp_autocast=None): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # compute output + with amp_autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/figures/flops.png b/figures/flops.png new file mode 100644 index 0000000..8da3ac7 Binary files /dev/null and b/figures/flops.png differ diff --git a/figures/teaser.png b/figures/teaser.png new file mode 100644 index 0000000..6c9707c Binary files /dev/null and b/figures/teaser.png differ diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..eed6283 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,5 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +from models import * + +dependencies = ["torch", "torchvision", "timm"] diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..ff7748f --- /dev/null +++ b/losses.py @@ -0,0 +1,64 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Implements the knowledge distillation loss +""" +import torch +from torch.nn import functional as F + + +class DistillationLoss(torch.nn.Module): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, + distillation_type: str, alpha: float, tau: float): + super().__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ['none', 'soft', 'hard'] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, torch.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == 'none': + return base_loss + + if outputs_kd is None: + raise ValueError("When knowledge distillation is enabled, the model is " + "expected to return a Tuple[Tensor, Tensor] with the output of the " + "class_token and the dist_token") + # don't backprop throught the teacher + with torch.no_grad(): + teacher_outputs = self.teacher_model(inputs) + + if self.distillation_type == 'soft': + T = self.tau + # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # with slight modifications + distillation_loss = F.kl_div( + F.log_softmax(outputs_kd / T, dim=1), + F.log_softmax(teacher_outputs / T, dim=1), + reduction='sum', + log_target=True + ) * (T * T) / outputs_kd.numel() + elif self.distillation_type == 'hard': + distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss diff --git a/main.py b/main.py new file mode 100644 index 0000000..e944daa --- /dev/null +++ b/main.py @@ -0,0 +1,507 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +from contextlib import suppress + +from pathlib import Path + +from timm.data import Mixup +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma + +from datasets import build_dataset +from engine import train_one_epoch, evaluate +from losses import DistillationLoss +from samplers import RASampler +import cycle_mlp +import utils + +try: + from apex import amp + from apex.parallel import DistributedDataParallel as ApexDDP + from apex.parallel import convert_syncbn_model + from timm.utils import ApexScaler + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +try: + from fvcore.nn import flop_count, parameter_count, FlopCountAnalysis, flop_count_table + from utils import sfc_flop_jit + has_fvcore = True +except ImportError: + has_fvcore = False + + +def get_args_parser(): + parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) + parser.add_argument('--batch-size', default=64, type=int) + parser.add_argument('--epochs', default=300, type=int) + + # Model parameters + parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--input-size', default=224, type=int, help='images input size') + + parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') + parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + parser.add_argument('--model-ema', action='store_true') + parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=True) + parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') + parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') + parser.add_argument('--train-interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + parser.add_argument('--repeated-aug', action='store_true') + parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') + parser.set_defaults(repeated_aug=True) + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Distillation parameters + parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', + help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher-path', type=str, default='') + parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") + parser.add_argument('--distillation-tau', default=1.0, type=float, help="") + + # * Finetuning params + parser.add_argument('--finetune', default='', help='finetune from checkpoint') + + # Dataset parameters + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, + help='dataset path') + parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], + type=str, help='Image Net dataset path') + parser.add_argument('--mcloader', action='store_true', default=False, help='whether use mcloader') + parser.add_argument('--inat-category', default='name', + choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], + type=str, help='semantic granularity') + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + # custom parameters + parser.add_argument('--flops', action='store_true', help='whether calculate FLOPs of the model') + parser.add_argument('--no_amp', action='store_true', help='not using amp') + return parser + + +def main(args): + utils.init_distributed_mode(args) + + print(args) + + if args.distillation_type != 'none' and args.finetune and not args.eval: + raise NotImplementedError("Finetuning with distillation not yet supported") + + device = torch.device(args.device) + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if not args.no_amp: # args.amp: Default use AMP + # `--amp` chooses native amp before apex (APEX ver not actively maintained) + if has_native_amp: + args.native_amp = True + args.apex_amp = False + elif has_apex: + args.native_amp = False + args.apex_amp = True + else: + raise ValueError("Warning: Neither APEX or native Torch AMP is available, using float32." + "Install NVIDA apex or upgrade to PyTorch 1.6") + else: + args.apex_amp = False + args.native_amp = False + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + print ("Warning: Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + # random.seed(seed) + + cudnn.benchmark = True + + dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) + dataset_val, _ = build_dataset(is_train=False, args=args) + + if True: # args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + if args.repeated_aug: + sampler_train = RASampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + print(f"Creating model: {args.model}") + model = create_model( + args.model, + pretrained=False, + num_classes=args.nb_classes, + drop_rate=args.drop, + drop_path_rate=args.drop_path, + drop_block_rate=None, + ) + + if args.flops: + if not has_fvcore: + print("Please install fvcore first for FLOPs calculation.") + else: + # Set model to evaluation mode for analysis. + model_mode = model.training + model.eval() + fake_input = torch.rand(1, 3, 224, 224) + flops_dict, *_ = flop_count(model, fake_input, + supported_ops={"torchvision::deform_conv2d": sfc_flop_jit}) + count = sum(flops_dict.values()) + model.train(model_mode) + print('=' * 30) + print("fvcore MAdds: {:.3f} G".format(count)) + + if args.finetune: + if args.finetune.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.finetune, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.finetune, map_location='cpu') + + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + model.load_state_dict(checkpoint_model, strict=False) + + model.to(device) + + linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 + args.lr = linear_scaled_lr + optimizer = create_optimizer(args, model) + + # setup automatic mixed-precision (AMP) loss scaling and op casting + amp_autocast = suppress # do nothing + loss_scaler = None + if use_amp == 'apex': + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + loss_scaler = ApexScaler() + print('Using NVIDIA APEX AMP. Training in mixed precision.') + elif use_amp == 'native': + amp_autocast = torch.cuda.amp.autocast + loss_scaler = NativeScaler() + print('Using native Torch AMP. Training in mixed precision.') + else: + print('AMP not enabled. Training in float32.') + + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume='') + + model_without_ddp = model + if args.distributed: + if has_apex and use_amp != 'native': + # Apex DDP preferred unless native amp is activated + model = ApexDDP(model, delay_allreduce=True) + else: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + print('=' * 30) + + lr_scheduler, _ = create_scheduler(args, optimizer) + + criterion = LabelSmoothingCrossEntropy() + + if args.mixup > 0.: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + teacher_model = None + if args.distillation_type != 'none': + assert args.teacher_path, 'need to specify teacher-path when using distillation' + print(f"Creating teacher model: {args.teacher_model}") + teacher_model = create_model( + args.teacher_model, + pretrained=False, + num_classes=args.nb_classes, + global_pool='avg', + ) + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.teacher_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) + teacher_model.to(device) + teacher_model.eval() + + # wrap the criterion in our custom DistillationLoss, which + # just dispatches to the original criterion if args.distillation_type is 'none' + criterion = DistillationLoss( + criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau + ) + + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + if args.model_ema: + utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + + if args.eval: + test_stats = evaluate(data_loader_val, model, device, amp_autocast=amp_autocast) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + return + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, model_ema, mixup_fn, + set_training_mode=args.finetune == '', # keep in eval mode during finetuning + amp_autocast=amp_autocast, + ) + + lr_scheduler.step(epoch) + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'model_ema': get_state_dict(model_ema), + 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, + 'args': args, + }, checkpoint_path) + + test_stats = evaluate(data_loader_val, model, device, amp_autocast=amp_autocast) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + max_accuracy = max(max_accuracy, test_stats["acc1"]) + print(f'Max accuracy: {max_accuracy:.2f}%') + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/mcloader/__init__.py b/mcloader/__init__.py new file mode 100644 index 0000000..a7f558a --- /dev/null +++ b/mcloader/__init__.py @@ -0,0 +1,2 @@ +from .classification import ClassificationDataset +from .data_prefetcher import DataPrefetcher diff --git a/mcloader/classification.py b/mcloader/classification.py new file mode 100644 index 0000000..41d7da3 --- /dev/null +++ b/mcloader/classification.py @@ -0,0 +1,25 @@ +import os +from torch.utils.data import Dataset +from .imagenet import ImageNet + + +class ClassificationDataset(Dataset): + """Dataset for classification. + """ + + def __init__(self, data_root, split, pipeline=None): + self.data_source = ImageNet(root=os.path.join(data_root, split), + list_file=os.path.join(data_root, 'meta', '{}.txt'.format(split)), + memcached=True, + mclient_path='/mnt/lustre/share/memcached_client') + self.pipeline = pipeline + + def __len__(self): + return self.data_source.get_length() + + def __getitem__(self, idx): + img, target = self.data_source.get_sample(idx) + if self.pipeline is not None: + img = self.pipeline(img) + + return img, target diff --git a/mcloader/data_prefetcher.py b/mcloader/data_prefetcher.py new file mode 100644 index 0000000..b57c306 --- /dev/null +++ b/mcloader/data_prefetcher.py @@ -0,0 +1,28 @@ +import torch + + +class DataPrefetcher: + def __init__(self, loader): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.preload() + + def preload(self): + try: + self.next_input, self.next_target = next(self.loader) + except StopIteration: + self.next_input = None + self.next_target = None + return + + with torch.cuda.stream(self.stream): + self.next_input = self.next_input.cuda(non_blocking=True) + self.next_target = self.next_target.cuda(non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + target = self.next_target + if input is not None: + self.preload() + return input, target diff --git a/mcloader/image_list.py b/mcloader/image_list.py new file mode 100644 index 0000000..8156e0d --- /dev/null +++ b/mcloader/image_list.py @@ -0,0 +1,44 @@ +import os +from PIL import Image + +from .mcloader import McLoader + + +class ImageList(object): + + def __init__(self, root, list_file, memcached=False, mclient_path=None): + with open(list_file, 'r') as f: + lines = f.readlines() + self.has_labels = len(lines[0].split()) == 2 + if self.has_labels: + self.fns, self.labels = zip(*[l.strip().split() for l in lines]) + self.labels = [int(l) for l in self.labels] + else: + self.fns = [l.strip() for l in lines] + self.fns = [os.path.join(root, fn) for fn in self.fns] + self.memcached = memcached + self.mclient_path = mclient_path + self.initialized = False + + def _init_memcached(self): + if not self.initialized: + assert self.mclient_path is not None + self.mc_loader = McLoader(self.mclient_path) + self.initialized = True + + def get_length(self): + return len(self.fns) + + def get_sample(self, idx): + if self.memcached: + self._init_memcached() + if self.memcached: + img = self.mc_loader(self.fns[idx]) + else: + img = Image.open(self.fns[idx]) + img = img.convert('RGB') + if self.has_labels: + target = self.labels[idx] + return img, target + else: + return img diff --git a/mcloader/imagenet.py b/mcloader/imagenet.py new file mode 100644 index 0000000..62b3bb6 --- /dev/null +++ b/mcloader/imagenet.py @@ -0,0 +1,8 @@ +from .image_list import ImageList + + +class ImageNet(ImageList): + + def __init__(self, root, list_file, memcached, mclient_path): + super(ImageNet, self).__init__( + root, list_file, memcached, mclient_path) diff --git a/mcloader/mcloader.py b/mcloader/mcloader.py new file mode 100644 index 0000000..c448a6f --- /dev/null +++ b/mcloader/mcloader.py @@ -0,0 +1,37 @@ +import io +from PIL import Image +try: + import mc +except ImportError as E: + print("Please install mc first\n", + "cp /mnt/lustre/share/pymc/py3/mc.so ~/anaconda3/envs/deit/lib/python3.8/site-packages/") + + +def pil_loader(img_str): + buff = io.BytesIO(img_str) + return Image.open(buff) + + +class McLoader(object): + + def __init__(self, mclient_path): + assert mclient_path is not None, \ + "Please specify 'data_mclient_path' in the config." + self.mclient_path = mclient_path + server_list_config_file = "{}/server_list.conf".format( + self.mclient_path) + client_config_file = "{}/client.conf".format(self.mclient_path) + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, + client_config_file) + + def __call__(self, fn): + try: + img_value = mc.pyvector() + self.mclient.Get(fn, img_value) + img_value_str = mc.ConvertBuffer(img_value) + img = pil_loader(img_value_str) + except: + print('Read image failed ({})'.format(fn)) + return None + else: + return img diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b7bd9bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch==1.7.0 +torchvision==0.8.1 +timm==0.3.2 diff --git a/run_with_submitit.py b/run_with_submitit.py new file mode 100644 index 0000000..a0da744 --- /dev/null +++ b/run_with_submitit.py @@ -0,0 +1,126 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main as classification +import submitit + + +def parse_args(): + classification_parser = classification.get_args_parser() + parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main as classification + + self._setup_gpu_args() + classification.main(self.args) + + def checkpoint(self): + import os + import submitit + + self.args.dist_url = get_init_file().as_uri() + checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") + if os.path.exists(checkpoint_file): + self.args.resume = checkpoint_file + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="deit") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main() diff --git a/samplers.py b/samplers.py new file mode 100644 index 0000000..e67dc6d --- /dev/null +++ b/samplers.py @@ -0,0 +1,59 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import torch +import torch.distributed as dist +import math + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU) + Heavily based on torch.utils.data.DistributedSampler + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) + self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(3)] + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[:self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/semantic/README.md b/semantic/README.md new file mode 100644 index 0000000..5d71198 --- /dev/null +++ b/semantic/README.md @@ -0,0 +1,28 @@ +# CycleMLP for Semantic Segmentation + + +## Installation + +1. Install mmcv v1.3.5 +``` +# cuda10.1 pytorch1.7.1 +pip install mmcv-full==1.3.5 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.1/index.html + +or + +git clone https://github.com/open-mmlab/mmcv.git +git checkout v1.3.5 +MMCV_WITH_OPS=1 pip install -e . + +# for pytorch>=1.9: +# https://github.com/open-mmlab/mmcv/pull/1138 + +``` + +2. Install mmsegmentation v0.13.0 + +``` +pip install 'git+https://github.com/open-mmlab/mmsegmentation@f884489120c3b1af6506651e82458dd8a2bd10ec' +``` + +Code and configs are coming soon. \ No newline at end of file diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..5554a88 --- /dev/null +++ b/tox.ini @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +ignore = F401,E402,F403,W503,W504 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5ddafab --- /dev/null +++ b/utils.py @@ -0,0 +1,267 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import io +import os +import time +from collections import defaultdict, deque +import datetime + +import torch +import torch.distributed as dist + +try: + from fvcore.nn.jit_handles import get_shape, conv_flop_count + from collections import Counter + import typing + from typing import Any, List +except ImportError: + has_fvcore = False + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def _load_checkpoint_for_ema(model_ema, checkpoint): + """ + Workaround for ModelEma._load_checkpoint to accept an already-loaded object + """ + mem_file = io.BytesIO() + torch.save(checkpoint, mem_file) + mem_file.seek(0) + model_ema._load_checkpoint(mem_file) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + for i in range(10): + try: + torch.save(*args, **kwargs) + break + except: + print("Saving model failed for {} times, will retry".format(i+1)) + time.sleep(30.0) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +# copy-paste from +# https://github.com/facebookresearch/fvcore/blob/166a030e093013a934642ca3744592a2e3de5ea2/fvcore/nn/jit_handles.py#L143-L157 +# change input length assert +def sfc_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: + """ + Count flops for cycle FC. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) + assert w_shape[-1] == 1 and w_shape[-2] == 1, w_shape + + # use a custom name instead of "_convolution" + return Counter({"conv": conv_flop_count(x_shape, w_shape, out_shape)})