Skip to content

Commit

Permalink
partially enable type checking for .models (#2668)
Browse files Browse the repository at this point in the history
* partially enable mypy for .models

* fix existing errors

* ignore error instead of using Union
  • Loading branch information
pmeier authored Sep 22, 2020
1 parent be0b6d9 commit 8dfcff7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
16 changes: 12 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ files = torchvision
show_error_codes = True
pretty = True

;[mypy-torchvision.datasets.*]
[mypy-torchvision.io._video_opt.*]

;ignore_errors = True
ignore_errors = True

[mypy-torchvision.io._video_opt.*]
[mypy-torchvision.io.*]

ignore_errors = True

[mypy-torchvision.models.detection.*]

ignore_errors = True

[mypy-torchvision.models.densenet.*]

ignore_errors = True

[mypy-torchvision.models.*]
[mypy-torchvision.models.quantization.*]

ignore_errors = True

Expand Down
5 changes: 2 additions & 3 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ def _forward(self, x):
return x, aux2, aux1

@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return x
return x # type: ignore[return-value]

def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
Expand Down
5 changes: 2 additions & 3 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,11 @@ def _forward(self, x):
return x, aux

@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (Tensor, Optional[Tensor]) -> InceptionOutputs
def eager_outputs(self, x: torch.Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
if self.training and self.aux_logits:
return InceptionOutputs(x, aux)
else:
return x
return x # type: ignore[return-value]

def forward(self, x):
x = self._transform_input(x)
Expand Down

0 comments on commit 8dfcff7

Please sign in to comment.