Skip to content

Commit

Permalink
Use is to compare type of objects (#4605)
Browse files Browse the repository at this point in the history
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
3 people authored Oct 13, 2021
1 parent 3855901 commit 321f39e
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,5 @@ def fuse_model(self) -> None:
"""

for m in self.modules():
if type(m) == QuantizableBasicConv2d:
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
2 changes: 1 addition & 1 deletion torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,5 @@ def fuse_model(self) -> None:
"""

for m in self.modules():
if type(m) == QuantizableBasicConv2d:
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
6 changes: 3 additions & 3 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d:
if type(self.conv[idx]) is nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)


Expand All @@ -54,9 +54,9 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
if type(m) is ConvNormActivation:
fuse_modules(m, ["0", "1", "2"], inplace=True)
if type(m) == QuantizableInvertedResidual:
if type(m) is QuantizableInvertedResidual:
m.fuse_model()


Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
if type(m) is ConvNormActivation:
modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) == nn.ReLU:
if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation:
elif type(m) is QuantizableSqueezeExcitation:
m.fuse_model()


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def fuse_model(self) -> None:

fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
m.fuse_model()


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def fuse_model(self) -> None:
if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules():
if type(m) == QuantizableInvertedResidual:
if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.quantization.fuse_modules(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _replace_relu(module: nn.Module) -> None:
# Checking for explicit type instead of instance
# as we only want to replace modules of the exact type
# not inherited classes
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False)

for key, value in reassign.items():
Expand Down

0 comments on commit 321f39e

Please sign in to comment.