Skip to content

Commit

Permalink
Imporve according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 27, 2022
1 parent 2a4565c commit b78ff95
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
31 changes: 25 additions & 6 deletions mmcls/models/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@


class LayerNorm2d(nn.LayerNorm):
"""LayerNorm that supports two data formats: channels_last (default) or
channels_first.
"""LayerNorm on channels for 2d images.
The ordering of the dimensions in the inputs. channels_last corresponds to
inputs with shape (batch_size, height, width, channels) while
channels_first corresponds to inputs with shape (batch_size, channels,
height, width).
Args:
num_channels (int): The number of channels of the input tensor.
dim (int): The dimension of channel in the input tensor. Defaults to 1.
eps (float): a value added to the denominator for numerical stability.
Defaults to 1e-6.
elementwise_affine (bool): a boolean value that when set to ``True``,
this module has learnable per-element affine parameters initialized
to ones (for weights) and zeros (for biases). Defaults to True.
Note:
Comparing with the original implementation, this implementation uses
``dim`` instead of ``data_format`` to specify the dimension of channel.
1. For inputs with shape (batch_size, height, width, channels), use
``dim=-1`` or ``dim=3``.
2. For inputs with shape (batch_size, channels, height, width), use
``dim=1``.
"""

def __init__(self,
Expand Down Expand Up @@ -194,8 +206,15 @@ def __init__(self,
assert 'depths' in arch and 'channels' in arch, \
f'The arch dict must have "depths" and "channels", ' \
f'but got {list(arch.keys())}.'

self.depths = arch['depths']
self.channels = arch['channels']
assert (isinstance(self.depths, Sequence)
and isinstance(self.channels, Sequence)
and len(self.depths) == len(self.channels)), \
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
'should be both sequence with the same length.'

self.num_stages = len(self.depths)

if isinstance(out_indices, int):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_backbones/test_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def test_assertion():
# ConvNeXt arch dict should include 'embed_dims',
ConvNeXt(arch=dict(channels=[2, 3, 4, 5]))

with pytest.raises(AssertionError):
# ConvNeXt arch dict should include 'embed_dims',
ConvNeXt(arch=dict(depths=[2, 3, 4], channels=[2, 3, 4, 5]))


def test_convnext():

Expand Down

0 comments on commit b78ff95

Please sign in to comment.