diff --git a/mpa/modules/models/classifiers/sam_classifier.py b/mpa/modules/models/classifiers/sam_classifier.py index 8a316564..af3d9356 100644 --- a/mpa/modules/models/classifiers/sam_classifier.py +++ b/mpa/modules/models/classifiers/sam_classifier.py @@ -106,9 +106,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs): if backbone_type == 'OTEMobileNetV3': for k, v in state_dict.items(): if k.startswith('backbone'): - k = k.replace('backbone.', '') + k = k.replace('backbone.', '', 1) elif k.startswith('head'): - k = k.replace('head.', '') + k = k.replace('head.', '', 1) if '3' in k: # MPA uses "classifier.3", OTE uses "classifier.4". Convert for OTE compatibility. k = k.replace('3', '4') if module.multilabel and not module.is_export: @@ -118,9 +118,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs): elif backbone_type == 'OTEEfficientNet': for k, v in state_dict.items(): if k.startswith('backbone'): - k = k.replace('backbone.', '') + k = k.replace('backbone.', '', 1) elif k.startswith('head'): - k = k.replace('head', 'output') + k = k.replace('head', 'output', 1) if not module.hierarchical and not module.is_export: k = k.replace('fc', 'asl') v = v.t() @@ -129,7 +129,7 @@ def state_dict_hook(module, state_dict, *args, **kwargs): elif backbone_type == 'OTEEfficientNetV2': for k, v in state_dict.items(): if k.startswith('backbone'): - k = k.replace('backbone.', '') + k = k.replace('backbone.', '', 1) elif k == 'head.fc.weight': k = k.replace('head.fc', 'model.classifier') if not module.hierarchical and not module.is_export: