diff --git a/ivy_models/__init__.py b/ivy_models/__init__.py index 4f1c1d2..acce73e 100644 --- a/ivy_models/__init__.py +++ b/ivy_models/__init__.py @@ -24,3 +24,6 @@ from . import bert from .bert import * from .vit import * + +from . import googlenet +from .googlenet import * diff --git a/ivy_models/googlenet/__init__.py b/ivy_models/googlenet/__init__.py new file mode 100644 index 0000000..1b42224 --- /dev/null +++ b/ivy_models/googlenet/__init__.py @@ -0,0 +1,2 @@ +from . import googlenet +from .googlenet import * diff --git a/ivy_models/googlenet/googlenet.py b/ivy_models/googlenet/googlenet.py new file mode 100644 index 0000000..b18d7c5 --- /dev/null +++ b/ivy_models/googlenet/googlenet.py @@ -0,0 +1,162 @@ +# global +import ivy +import ivy_models +from ivy_models.base import BaseSpec, BaseModel +from ivy_models.googlenet.layers import ( + InceptionConvBlock, + InceptionBlock, + InceptionAuxiliaryBlock, +) + + +class GoogLeNetSpec(BaseSpec): + def __init__( + self, + training=False, + num_classes=1000, + dropout=0.4, + aux_dropout=0.7, + data_format="NCHW", + ): + if not training: + dropout = 0 + aux_dropout = 0 + super(GoogLeNetSpec, self).__init__( + training=training, + num_classes=num_classes, + dropout=dropout, + aux_dropout=aux_dropout, + data_format=data_format, + ) + + +class GoogLeNet(BaseModel): + def __init__( + self, + training=False, + num_classes=1000, + dropout=0.4, + aux_dropout=0.7, + data_format="NCHW", + spec=None, + v=None, + ): + self.spec = ( + spec + if spec and isinstance(spec, GoogLeNetSpec) + else GoogLeNetSpec( + training=training, + num_classes=num_classes, + dropout=dropout, + aux_dropout=aux_dropout, + data_format=data_format, + ) + ) + super(GoogLeNet, self).__init__(v=v) + + def _build(self, *args, **kwargs): + self.conv1 = InceptionConvBlock(3, 64, [7, 7], 2, padding=3) + + self.conv2 = InceptionConvBlock(64, 64, [1, 1], 1, padding=0) + self.conv3 = InceptionConvBlock(64, 192, [3, 3], 1, padding=1) + + self.inception3A = InceptionBlock(192, 64, 96, 128, 16, 32, 32) + self.inception3B = InceptionBlock(256, 128, 128, 192, 32, 96, 64) + + self.inception4A = InceptionBlock(480, 192, 96, 208, 16, 48, 64) + + self.aux4A = InceptionAuxiliaryBlock( + 512, self.spec.num_classes, self.spec.aux_dropout + ) + + self.inception4B = InceptionBlock(512, 160, 112, 224, 24, 64, 64) + self.inception4C = InceptionBlock(512, 128, 128, 256, 24, 64, 64) + self.inception4D = InceptionBlock(512, 112, 144, 288, 32, 64, 64) + + self.aux4D = InceptionAuxiliaryBlock( + 528, self.spec.num_classes, self.spec.aux_dropout + ) + + self.inception4E = InceptionBlock(528, 256, 160, 320, 32, 128, 128) + + self.inception5A = InceptionBlock(832, 256, 160, 320, 32, 128, 128) + self.inception5B = InceptionBlock(832, 384, 192, 384, 48, 128, 128) + self.pool6 = ivy.AdaptiveAvgPool2d([1, 1]) + + self.dropout = ivy.Dropout(self.spec.dropout) + self.fc = ivy.Linear(1024, self.spec.num_classes, with_bias=True) + + @classmethod + def get_spec_class(self): + return GoogLeNetSpec + + def _forward(self, x, data_format=None): + data_format = data_format if data_format else self.spec.data_format + if data_format == "NHWC": + x = ivy.permute_dims(x, (0, 3, 1, 2)) + + out = self.conv1(x) + out = ivy.max_pool2d(out, [3, 3], 2, 0, ceil_mode=True, data_format="NCHW") + out = self.conv2(out) + out = self.conv3(out) + out = ivy.max_pool2d(out, [3, 3], 2, 0, ceil_mode=True, data_format="NCHW") + out = self.inception3A(out) + out = self.inception3B(out) + out = ivy.max_pool2d(out, [3, 3], 2, 0, ceil_mode=True, data_format="NCHW") + out = self.inception4A(out) + + aux1 = None + if self.spec.training: + aux1 = self.aux4A(out) + + out = self.inception4B(out) + out = self.inception4C(out) + out = self.inception4D(out) + + aux2 = None + if self.spec.training: + aux2 = self.aux4D(out) + + out = self.inception4E(out) + out = ivy.max_pool2d(out, [2, 2], 2, 0, ceil_mode=True, data_format="NCHW") + out = self.inception5A(out) + out = self.inception5B(out) + out = self.pool6(out) + out = ivy.flatten(out, start_dim=1) + out = self.dropout(out) + out = self.fc(out) + return out, aux1, aux2 + + +def _inceptionNet_torch_weights_mapping(old_key, new_key): + if "conv/weight" in old_key: + return {"key_chain": new_key, "pattern": "b c h w -> h w c b"} + return new_key + + +def inceptionNet_v1( + pretrained=True, + training=False, + num_classes=1000, + dropout=0.4, + aux_dropout=0.7, + data_format="NCHW", +): + """InceptionNet-V1 model""" + model = GoogLeNet( + training=training, + num_classes=num_classes, + dropout=dropout, + aux_dropout=aux_dropout, + data_format=data_format, + ) + if pretrained: + url = "https://download.pytorch.org/models/googlenet-1378be20.pth" + w_clean = ivy_models.helpers.load_torch_weights( + url, + model, + raw_keys_to_prune=["num_batches_tracked"], + custom_mapping=_inceptionNet_torch_weights_mapping, + ) + model.v = w_clean + return model diff --git a/ivy_models/googlenet/layers.py b/ivy_models/googlenet/layers.py new file mode 100644 index 0000000..5dd2d25 --- /dev/null +++ b/ivy_models/googlenet/layers.py @@ -0,0 +1,127 @@ +import ivy + + +class InceptionConvBlock(ivy.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + super(InceptionConvBlock, self).__init__() + + def _build(self, *args, **kwargs): + self.conv = ivy.Conv2D( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + with_bias=False, + data_format="NCHW", + ) + self.bn = ivy.BatchNorm2D( + self.out_channels, eps=0.001, data_format="NCS", training=False + ) + + def _forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = ivy.relu(x) + return x + + +class InceptionBlock(ivy.Module): + def __init__( + self, + in_channels, + num1x1, + num3x3_reduce, + num3x3, + num5x5_reduce, + num5x5, + pool_proj, + ): + self.in_channels = in_channels + self.num1x1 = num1x1 + self.num3x3_reduce = num3x3_reduce + self.num3x3 = num3x3 + self.num5x5_reduce = num5x5_reduce + self.num5x5 = num5x5 + self.pool_proj = pool_proj + super(InceptionBlock, self).__init__() + + def _build(self, *args, **kwargs): + self.conv_1x1 = InceptionConvBlock( + self.in_channels, self.num1x1, kernel_size=[1, 1], stride=1, padding=0 + ) + + self.conv_3x3 = InceptionConvBlock( + self.in_channels, + self.num3x3_reduce, + kernel_size=[1, 1], + stride=1, + padding=0, + ) + self.conv_3x3_red = InceptionConvBlock( + self.num3x3_reduce, self.num3x3, kernel_size=[3, 3], stride=1, padding=1 + ) + + self.conv_5x5 = InceptionConvBlock( + self.in_channels, + self.num5x5_reduce, + kernel_size=[1, 1], + stride=1, + padding=0, + ) + self.conv_5x5_red = InceptionConvBlock( + self.num5x5_reduce, self.num5x5, kernel_size=[3, 3], stride=1, padding=1 + ) + + self.pool_proj_conv = InceptionConvBlock( + self.in_channels, self.pool_proj, kernel_size=[1, 1], stride=1, padding=0 + ) + + def _forward(self, x): + # 1x1 + conv_1x1 = self.conv_1x1(x) + + # 3x3 + conv_3x3 = self.conv_3x3(x) + conv_3x3_red = self.conv_3x3_red(conv_3x3) + + # 5x5 + conv_5x5 = self.conv_5x5(x) + conv_5x5_red = self.conv_5x5_red(conv_5x5) + + # pool_proj + pool_proj = ivy.max_pool2d(x, [3, 3], 1, 1, ceil_mode=True, data_format="NCHW") + pool_proj = self.pool_proj_conv(pool_proj) + + ret = ivy.concat([conv_1x1, conv_3x3_red, conv_5x5_red, pool_proj], axis=1) + return ret + + +class InceptionAuxiliaryBlock(ivy.Module): + def __init__(self, in_channels, num_classes, aux_dropout=0.7): + self.in_channels = in_channels + self.num_classes = num_classes + self.aux_dropout = aux_dropout + super(InceptionAuxiliaryBlock, self).__init__() + + def _build(self, *args, **kwargs): + self.conv = InceptionConvBlock(self.in_channels, 128, [1, 1], 1, 0) + self.fc1 = ivy.Linear(2048, 1024, with_bias=True) + self.dropout = ivy.Dropout(self.aux_dropout) + self.fc2 = ivy.Linear(1024, self.num_classes, with_bias=True) + self.softmax = ivy.Softmax() + + def _forward(self, x): + out = ivy.adaptive_avg_pool2d(x, [4, 4]) + out = self.conv(out) + out = ivy.flatten(out, start_dim=1) + out = self.fc1(out) + out = ivy.relu(out) + out = self.dropout(out) + out = self.fc2(out) + return out diff --git a/ivy_models_tests/googlenet/test_googlenet.py b/ivy_models_tests/googlenet/test_googlenet.py new file mode 100644 index 0000000..b7bb3c2 --- /dev/null +++ b/ivy_models_tests/googlenet/test_googlenet.py @@ -0,0 +1,47 @@ +import os +import random +import ivy +import pytest +import numpy as np + +from ivy_models.googlenet import inceptionNet_v1 +from ivy_models_tests import helpers + + +load_weights = random.choice([False, True]) +model = inceptionNet_v1(pretrained=load_weights) +v = ivy.to_numpy(model.v) + + +@pytest.mark.parametrize("data_format", ["NHWC", "NCHW"]) +def test_GoogleNet_tiny_img_classification(device, fw, data_format): + """Test GoogleNet image classification.""" + num_classes = 1000 + batch_shape = [1] + this_dir = os.path.dirname(os.path.realpath(__file__)) + + # Load image + img = helpers.load_and_preprocess_img( + os.path.join(this_dir, "..", "..", "images", "cat.jpg"), + 256, + 224, + data_format=data_format, + to_ivy=True, + ) + + model.v = ivy.asarray(v) + logits, _, _ = model(img, data_format=data_format) + + # Cardinality test + assert logits.shape == tuple([ivy.to_scalar(batch_shape), num_classes]) + + # Value test + if load_weights: + np_out = ivy.to_numpy(logits[0]) + true_indices = np.array([282, 281, 285]) + calc_indices = np.argsort(np_out)[-3:][::-1] + assert np.array_equal(true_indices, calc_indices) + + # true_logits = np.array([0.2539, 0.2391, 0.1189]) + # calc_logits = np.take(np_out, calc_indices) + # assert np.allclose(true_logits, calc_logits, rtol=1e-1)