diff --git a/docs/source/models.rst b/docs/source/models.rst index f1331d5baa9..4c65eac8135 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -41,6 +41,7 @@ architectures for image classification: - `EfficientNet`_ - `RegNet`_ - `VisionTransformer`_ +- `ConvNeXt`_ You can construct a model with random weights by calling its constructor: @@ -88,7 +89,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() - vit_h_14 = models.vit_h_14() + vit_h_14 = models.vit_h_14() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -248,6 +249,7 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 +convnext_tiny (prototype) 82.520 96.146 ================================ ============= ============= @@ -266,6 +268,7 @@ vit_l_32 76.972 93.070 .. _EfficientNet: https://arxiv.org/abs/1905.11946 .. _RegNet: https://arxiv.org/abs/2003.13678 .. _VisionTransformer: https://arxiv.org/abs/2010.11929 +.. _ConvNeXt: https://arxiv.org/abs/2201.03545 .. currentmodule:: torchvision.models diff --git a/references/classification/README.md b/references/classification/README.md index 48b20a30242..0fb27eac7cc 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -197,6 +197,20 @@ Note that the above command corresponds to training on a single node with 8 GPUs For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), and `--batch_size 64`. + +### ConvNeXt +``` +torchrun --nproc_per_node=8 train.py\ +--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ +--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ +--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ +--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4 +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), +and `--batch_size 64`. + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). diff --git a/test/expect/ModelTester.test_convnext_tiny_expect.pkl b/test/expect/ModelTester.test_convnext_tiny_expect.pkl new file mode 100644 index 00000000000..c6fb873f12f Binary files /dev/null and b/test/expect/ModelTester.test_convnext_tiny_expect.pkl differ diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 392517cb772..6fe16b0e757 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -131,7 +131,7 @@ def __init__( norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, - inplace: bool = True, + inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: if padding is None: @@ -153,7 +153,8 @@ def __init__( if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: - layers.append(activation_layer(inplace=inplace)) + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) super().__init__(*layers) _log_api_usage_once(self) self.out_channels = out_channels diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index bfa44ffa720..83e49908348 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,4 +1,5 @@ from .alexnet import * +from .convnext import * from .densenet import * from .efficientnet import * from .googlenet import * diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py new file mode 100644 index 00000000000..788dcbc2cd1 --- /dev/null +++ b/torchvision/prototype/models/convnext.py @@ -0,0 +1,227 @@ +from functools import partial +from typing import Any, Callable, List, Optional, Sequence + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from torchvision.prototype.transforms import ImageNetEval +from torchvision.transforms.functional import InterpolationMode + +from ...ops.misc import ConvNormActivation +from ...ops.stochastic_depth import StochasticDepth +from ...utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + + +__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.channels_last = kwargs.pop("channels_last", False) + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 + if not self.channels_last: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if not self.channels_last: + x = x.permute(0, 3, 1, 2) + return x + + +class CNBlock(nn.Module): + def __init__( + self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] + ) -> None: + super().__init__() + self.block = nn.Sequential( + ConvNormActivation( + dim, + dim, + kernel_size=7, + groups=dim, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ), + ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), + ConvNormActivation( + 4 * dim, + dim, + kernel_size=1, + norm_layer=None, + activation_layer=None, + ), + ) + self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + def forward(self, input: Tensor) -> Tensor: + result = self.layer_scale * self.block(input) + result = self.stochastic_depth(result) + result += input + return result + + +class CNBlockConfig: + # Stores information listed at Section 3 of the ConvNeXt paper + def __init__( + self, + input_channels: int, + out_channels: Optional[int], + num_layers: int, + ) -> None: + self.input_channels = input_channels + self.out_channels = out_channels + self.num_layers = num_layers + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "input_channels={input_channels}" + s += ", out_channels={out_channels}" + s += ", num_layers={num_layers}" + s += ")" + return s.format(**self.__dict__) + + +class ConvNeXt(nn.Module): + def __init__( + self, + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float = 0.0, + layer_scale: float = 1e-6, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if not block_setting: + raise ValueError("The block_setting should not be empty") + elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): + raise TypeError("The block_setting should be List[CNBlockConfig]") + + if block is None: + block = CNBlock + + if norm_layer is None: + norm_layer = partial(LayerNorm2d, eps=1e-6) + + layers: List[nn.Module] = [] + + # Stem + firstconv_output_channels = block_setting[0].input_channels + layers.append( + ConvNormActivation( + 3, + firstconv_output_channels, + kernel_size=4, + stride=4, + padding=0, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ) + ) + + total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) + stage_block_id = 0 + for cnf in block_setting: + # Bottlenecks + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + if cnf.out_channels is not None: + # Downsampling + layers.append( + nn.Sequential( + norm_layer(cnf.input_channels), + nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), + ) + ) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + lastblock = block_setting[-1] + lastconv_output_channels = ( + lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels + ) + self.classifier = nn.Sequential( + norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) + ) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +class ConvNeXt_Tiny_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=236), + meta={ + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "num_params": 28589128, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", + "acc@1": 82.520, + "acc@5": 96.146, + }, + ) + default = ImageNet1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1)) +def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt model architecture from the + `"A ConvNet for the 2020s" `_ paper. + + Args: + weights (ConvNeXt_Tiny_Weights, optional): The pre-trained weights of the model + progress (bool): If True, displays a progress bar of the download to stderr + """ + weights = ConvNeXt_Tiny_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 9), + CNBlockConfig(768, None, 3), + ] + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model