Skip to content

Commit

Permalink
Fix model initialization in ModuleStateUpdate (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang authored Oct 17, 2021
1 parent 42021e2 commit fab48d7
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
from functools import reduce
from typing import Dict, Optional
from typing import List, Dict, Optional

from torch import nn

Expand Down Expand Up @@ -29,6 +29,9 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
anchor_grids = checkpoint_yolov5.yaml["anchors"]
if isinstance(anchor_grids, int):
anchor_grids = [list(range(anchor_grids * 2))] * len(strides)

depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

Expand Down Expand Up @@ -70,6 +73,8 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
p6_block_maps=p6_block_maps,
strides=strides,
anchor_grids=anchor_grids,
head_ind=head_ind,
head_name=head_name,
num_classes=num_classes,
Expand Down Expand Up @@ -105,6 +110,8 @@ def __init__(
inner_block_maps: Optional[Dict[str, str]] = None,
layer_block_maps: Optional[Dict[str, str]] = None,
p6_block_maps: Optional[Dict[str, str]] = None,
strides: Optional[List[int]] = None,
anchor_grids: Optional[List[List[float]]] = None,
head_ind: int = 24,
head_name: str = "m",
num_classes: int = 80,
Expand Down Expand Up @@ -144,6 +151,8 @@ def __init__(
version,
num_classes=num_classes,
use_p6=use_p6,
strides=strides,
anchor_grids=anchor_grids,
)

def updating(self, state_dict):
Expand Down

0 comments on commit fab48d7

Please sign in to comment.