Skip to content

Commit

Permalink
Ensure that checkpoint_num_classes is propagated from YAML to model (
Browse files Browse the repository at this point in the history
…#1533)

* Ensure that `checkpoint_num_classes` is propagated from YAML to models.get()

* Red checkpoint_num_classes via get_params
  • Loading branch information
BloodAxe authored Oct 17, 2023
1 parent 502313e commit 8923bbc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ strict_load: # key matching strictness for loading checkpoint's weights
_target_: super_gradients.training.sg_trainer.StrictLoad
value: no_key_matching
pretrained_weights: # a string describing the dataset of the pretrained weights (for example "imagenent").

# num_classes of checkpoint_path/ pretrained_weights, when checkpoint_path is not None.
# Used when num_classes != checkpoint_num_class.
# In this case, the module will be initialized with checkpoint_num_class, then weights will be loaded.
# Finally model.replace_head(new_num_classes=num_classes) is called to replace the head with new_num_classes.
checkpoint_num_classes: # number of classes in the checkpoint
2 changes: 2 additions & 0 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
pretrained_weights=cfg.student_checkpoint_params.pretrained_weights,
checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
load_backbone=cfg.student_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.student_checkpoint_params, "checkpoint_num_classes"),
)

teacher = models.get(
Expand All @@ -85,6 +86,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights,
checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.teacher_checkpoint_params, "checkpoint_num_classes"),
)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
Expand Down
2 changes: 2 additions & 0 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def evaluate_from_recipe(cls, cfg: DictConfig) -> Tuple[nn.Module, Tuple]:
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=cfg.checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.checkpoint_params, "checkpoint_num_classes"),
)

# TEST
Expand Down Expand Up @@ -2340,6 +2341,7 @@ def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module,
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=False,
checkpoint_num_classes=get_param(cfg.checkpoint_params, "checkpoint_num_classes"),
)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
Expand Down

0 comments on commit 8923bbc

Please sign in to comment.