Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

Commit

Permalink
Replace only the first occurrence in the state_dict keys (#91)
Browse files Browse the repository at this point in the history
Doing otherwise might corrupt the model.

Fixes #90.
  • Loading branch information
arrufat authored Dec 12, 2022
1 parent 540ea47 commit 6de41a6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mpa/modules/models/classifiers/sam_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
if backbone_type == 'OTEMobileNetV3':
for k, v in state_dict.items():
if k.startswith('backbone'):
k = k.replace('backbone.', '')
k = k.replace('backbone.', '', 1)
elif k.startswith('head'):
k = k.replace('head.', '')
k = k.replace('head.', '', 1)
if '3' in k: # MPA uses "classifier.3", OTE uses "classifier.4". Convert for OTE compatibility.
k = k.replace('3', '4')
if module.multilabel and not module.is_export:
Expand All @@ -118,9 +118,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
elif backbone_type == 'OTEEfficientNet':
for k, v in state_dict.items():
if k.startswith('backbone'):
k = k.replace('backbone.', '')
k = k.replace('backbone.', '', 1)
elif k.startswith('head'):
k = k.replace('head', 'output')
k = k.replace('head', 'output', 1)
if not module.hierarchical and not module.is_export:
k = k.replace('fc', 'asl')
v = v.t()
Expand All @@ -129,7 +129,7 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
elif backbone_type == 'OTEEfficientNetV2':
for k, v in state_dict.items():
if k.startswith('backbone'):
k = k.replace('backbone.', '')
k = k.replace('backbone.', '', 1)
elif k == 'head.fc.weight':
k = k.replace('head.fc', 'model.classifier')
if not module.hierarchical and not module.is_export:
Expand Down

0 comments on commit 6de41a6

Please sign in to comment.