Skip to content

Commit

Permalink
remove vgg model with bn.
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyaofo committed Apr 9, 2021
1 parent 3423956 commit 2f82b01
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 51 deletions.
8 changes: 0 additions & 8 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,12 @@
from pytorch_cifar_models import cifar100_resnet44
from pytorch_cifar_models import cifar100_resnet56

from pytorch_cifar_models import cifar10_vgg11
from pytorch_cifar_models import cifar10_vgg11_bn
from pytorch_cifar_models import cifar10_vgg13
from pytorch_cifar_models import cifar10_vgg13_bn
from pytorch_cifar_models import cifar10_vgg16
from pytorch_cifar_models import cifar10_vgg16_bn
from pytorch_cifar_models import cifar10_vgg19
from pytorch_cifar_models import cifar10_vgg19_bn

from pytorch_cifar_models import cifar100_vgg11
from pytorch_cifar_models import cifar100_vgg11_bn
from pytorch_cifar_models import cifar100_vgg13
from pytorch_cifar_models import cifar100_vgg13_bn
from pytorch_cifar_models import cifar100_vgg16
from pytorch_cifar_models import cifar100_vgg16_bn
from pytorch_cifar_models import cifar100_vgg19
from pytorch_cifar_models import cifar100_vgg19_bn
8 changes: 0 additions & 8 deletions pytorch_cifar_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,14 @@
from .resnet import cifar100_resnet44
from .resnet import cifar100_resnet56

from .vgg import cifar10_vgg11
from .vgg import cifar10_vgg11_bn
from .vgg import cifar10_vgg13
from .vgg import cifar10_vgg13_bn
from .vgg import cifar10_vgg16
from .vgg import cifar10_vgg16_bn
from .vgg import cifar10_vgg19
from .vgg import cifar10_vgg19_bn

from .vgg import cifar100_vgg11
from .vgg import cifar100_vgg11_bn
from .vgg import cifar100_vgg13
from .vgg import cifar100_vgg13_bn
from .vgg import cifar100_vgg16
from .vgg import cifar100_vgg16_bn
from .vgg import cifar100_vgg19
from .vgg import cifar100_vgg19_bn

__version__ = "0.0.1-alpha"
45 changes: 14 additions & 31 deletions pytorch_cifar_models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,16 @@
from typing import Union, List, Dict, Any, cast

cifar10_pretrained_weight_urls = {
'vgg11': '',
'vgg11_bn': '',
'vgg13': '',
'vgg13_bn': '',
'vgg16': '',
'vgg16_bn': '',
'vgg19': '',
'vgg19_bn': '',
}

cifar100_pretrained_weight_urls = {
'vgg11': '',
'vgg11_bn': '',
'vgg13': '',
'vgg13_bn': '',
'vgg16': '',
'vgg16_bn': '',
'vgg19': '',
'vgg19_bn': '',
}

Expand Down Expand Up @@ -145,40 +137,31 @@ def _vgg(arch: str, cfg: str, batch_norm: bool,
return model


def cifar10_vgg11(*args, **kwargs) -> VGG: pass
def cifar10_vgg11_bn(*args, **kwargs) -> VGG: pass
def cifar10_vgg13(*args, **kwargs) -> VGG: pass
def cifar10_vgg13_bn(*args, **kwargs) -> VGG: pass
def cifar10_vgg16(*args, **kwargs) -> VGG: pass
def cifar10_vgg16_bn(*args, **kwargs) -> VGG: pass
def cifar10_vgg19(*args, **kwargs) -> VGG: pass
def cifar10_vgg19_bn(*args, **kwargs) -> VGG: pass


def cifar100_vgg11(*args, **kwargs) -> VGG: pass
def cifar100_vgg11_bn(*args, **kwargs) -> VGG: pass
def cifar100_vgg13(*args, **kwargs) -> VGG: pass
def cifar100_vgg13_bn(*args, **kwargs) -> VGG: pass
def cifar100_vgg16(*args, **kwargs) -> VGG: pass
def cifar100_vgg16_bn(*args, **kwargs) -> VGG: pass
def cifar100_vgg19(*args, **kwargs) -> VGG: pass
def cifar100_vgg19_bn(*args, **kwargs) -> VGG: pass


thismodule = sys.modules[__name__]
for dataset in ["cifar10", "cifar100"]:
for cfg, model_name in zip(["A", "B", "D", "E"], ["vgg11", "vgg13", "vgg16", "vgg19"]):
for batch_norm in [False, True]:
method_name = f"{dataset}_{model_name}{'_bn' if batch_norm else ''}"
model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
num_classes = 10 if dataset == "cifar10" else 100
setattr(
thismodule,
method_name,
partial(_vgg,
arch=model_name,
cfg=cfg,
batch_norm=batch_norm,
model_urls=model_urls,
num_classes=num_classes)
)
for cfg, model_name in zip(["A", "B", "D", "E"], ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]):
method_name = f"{dataset}_{model_name}"
model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
num_classes = 10 if dataset == "cifar10" else 100
setattr(
thismodule,
method_name,
partial(_vgg,
arch=model_name,
cfg=cfg,
batch_norm=True,
model_urls=model_urls,
num_classes=num_classes)
)
7 changes: 3 additions & 4 deletions tests/pytorch_cifar_models/test_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@


@pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
@pytest.mark.parametrize("model_name", ["vgg11", "vgg13", "vgg16", "vgg19"])
@pytest.mark.parametrize("bn", ["", "_bn"])
def test_resnet(dataset, model_name, bn):
@pytest.mark.parametrize("model_name", ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"])
def test_resnet(dataset, model_name):
num_classes = 10 if dataset == "cifar10" else 100
model = getattr(vgg, f"{dataset}_{model_name}{bn}")()
model = getattr(vgg, f"{dataset}_{model_name}")()
x = torch.empty((1, 3, 32, 32))
y = model(x)
assert y.shape == (1, num_classes)

0 comments on commit 2f82b01

Please sign in to comment.