Skip to content

Commit

Permalink
Make Inception V3 torch-scriptable (pytorch#2976)
Browse files Browse the repository at this point in the history
* Making quantized inception torchscriptable.

* Adding a test.

* Fix mypy warning.
  • Loading branch information
datumbox authored Nov 9, 2020
1 parent d6bc625 commit a21ed3a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
12 changes: 12 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,18 @@ def get_gn(num_channels):
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))

def test_inceptionv3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
kwargs = {}
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = models.Inception3(**kwargs)
model.aux_logits = False
model.AuxLogits = None
m = torch.jit.script(model.eval())
self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299))

def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
Expand Down
12 changes: 6 additions & 6 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
return model

return Inception3(**kwargs)
Expand Down Expand Up @@ -108,6 +108,7 @@ def __init__(
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.AuxLogits: Optional[nn.Module] = None
if aux_logits:
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
Expand Down Expand Up @@ -170,11 +171,10 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
aux = torch.jit.annotate(Optional[Tensor], None)
if self.AuxLogits is not None:
if self.training:
aux = self.AuxLogits(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
if quantize:
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
model_url = quant_model_urls['inception_v3_google' + '_' + backend]
else:
model_url = inception_module.model_urls['inception_v3_google']
Expand All @@ -80,7 +80,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
if not quantize:
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
return model


Expand Down

0 comments on commit a21ed3a

Please sign in to comment.