Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

Commit

Permalink
fixed multilabel configs (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi authored Oct 10, 2022
1 parent 7efc90a commit bc04c98
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
4 changes: 1 addition & 3 deletions models/classification/ote_efficientnet_b0_multilabel_al.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ _base_: ./ote_efficientnet_b0_multilabel.yaml

model:
head:
type: CustomMultiLabelNonLinearClsHead
type: CustomMultiLabelLinearClsHead
normalized: True
scale: 7.0
act_cfg:
type: PReLU
loss:
type: AsymmetricAngularLossWithIgnore
gamma_pos: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ _base_: ./ote_efficientnet_v2_s_multilabel.yaml

model:
head:
type: CustomMultiLabelNonLinearClsHead
type: CustomMultiLabelLinearClsHead
normalized: True
scale: 7.0
act_cfg:
type: PReLU
loss:
type: AsymmetricAngularLossWithIgnore
gamma_pos: 0.0
Expand Down
4 changes: 2 additions & 2 deletions mpa/modules/models/classifiers/sam_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def load_state_dict_pre_hook(module, state_dict, *args, **kwargs):
k = k.replace('asl', 'fc')
v = v.t()
state_dict[k] = v

elif backbone_type == 'OTEEfficientNetV2':
for k in list(state_dict.keys()):
v = state_dict.pop(k)
Expand Down Expand Up @@ -222,7 +222,7 @@ def load_state_dict_mixing_hook(model, model_classes, chkpt_classes, chkpt_dict,
for model_name in param_names:
model_param = model_dict[model_name].clone()
if backbone_type == 'OTEMobileNetV3':
chkpt_name = 'head.'+model_name.replace('4', '3')
chkpt_name = 'head.' + model_name.replace('4', '3')
if model.multilabel:
model_param = model_param.t()
elif backbone_type in 'OTEEfficientNet':
Expand Down

0 comments on commit bc04c98

Please sign in to comment.