diff --git a/yolort/utils/update_module_state.py b/yolort/utils/update_module_state.py index 95c02b78..aefbbef6 100644 --- a/yolort/utils/update_module_state.py +++ b/yolort/utils/update_module_state.py @@ -1,6 +1,6 @@ # Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. from functools import reduce -from typing import Dict, Optional +from typing import List, Dict, Optional from torch import nn @@ -29,6 +29,9 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"): num_classes = checkpoint_yolov5.yaml["nc"] strides = checkpoint_yolov5.stride anchor_grids = checkpoint_yolov5.yaml["anchors"] + if isinstance(anchor_grids, int): + anchor_grids = [list(range(anchor_grids * 2))] * len(strides) + depth_multiple = checkpoint_yolov5.yaml["depth_multiple"] width_multiple = checkpoint_yolov5.yaml["width_multiple"] @@ -70,6 +73,8 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"): inner_block_maps=inner_block_maps, layer_block_maps=layer_block_maps, p6_block_maps=p6_block_maps, + strides=strides, + anchor_grids=anchor_grids, head_ind=head_ind, head_name=head_name, num_classes=num_classes, @@ -105,6 +110,8 @@ def __init__( inner_block_maps: Optional[Dict[str, str]] = None, layer_block_maps: Optional[Dict[str, str]] = None, p6_block_maps: Optional[Dict[str, str]] = None, + strides: Optional[List[int]] = None, + anchor_grids: Optional[List[List[float]]] = None, head_ind: int = 24, head_name: str = "m", num_classes: int = 80, @@ -144,6 +151,8 @@ def __init__( version, num_classes=num_classes, use_p6=use_p6, + strides=strides, + anchor_grids=anchor_grids, ) def updating(self, state_dict):