Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize errors for when pre-trained weights are not available #5572

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ def _get_model_weights(model_fn):
def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
except ValueError as e:
msg = str(e)
if "No checkpoint is available" in msg:
pytest.skip(msg)
raise e
except NotImplementedError as e:
pytest.skip(str(e))
return model.eval()


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _convnext(
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if pretrained:
if arch not in _MODELS_URLS:
raise ValueError(f"No checkpoint is available for model type {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
datumbox marked this conversation as resolved.
Show resolved Hide resolved
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)
return model
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def fasterrcnn_resnet50_fpn(


def _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
arch,
jdsgomes marked this conversation as resolved.
Show resolved Hide resolved
pretrained=False,
progress=True,
num_classes=91,
Expand Down Expand Up @@ -435,9 +435,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
if model_urls.get(arch, None) is None:
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model

Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ def ssd300_vgg16(
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if pretrained:
weights_name = "ssd300_vgg16_coco"
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
arch = "ssd300_vgg16_coco"
if model_urls.get(arch, None) is None:
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
8 changes: 4 additions & 4 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ def ssdlite320_mobilenet_v3_large(
)

if pretrained:
weights_name = "ssdlite320_mobilenet_v3_large_coco"
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
arch = "ssdlite320_mobilenet_v3_large_coco"
if model_urls.get(arch, None) is None:
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
2 changes: 1 addition & 1 deletion torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError(f"No checkpoint is available for model type {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def _load_from_state_dict(
)


def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(f"No checkpoint is available for model type {model_name}")
checkpoint_url = _MODEL_URLS[model_name]
def _load_pretrained(arch: str, model: nn.Module, progress: bool) -> None:
if arch not in _MODEL_URLS or _MODEL_URLS[arch] is None:
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
checkpoint_url = _MODEL_URLS[arch]
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError(f"No checkpoint is available for model type {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None:

def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError(f"No checkpoint is available for {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bo
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if pretrained:
if arch not in model_urls:
raise ValueError(f"No checkpoint is available for model type {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/segmentation/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:

def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError(f"No checkpoint is available for {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
2 changes: 1 addition & 1 deletion torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise ValueError(f"No checkpoint is available for model type {arch}")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _vision_transformer(

if pretrained:
if arch not in model_urls:
raise ValueError(f"No checkpoint is available for model type '{arch}'!")
raise NotImplementedError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)

Expand Down