Skip to content

Commit

Permalink
Add load_from_yolov5 in YOLOModule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 21, 2021
1 parent 251226f commit 0d0b5f1
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from torchvision.io import read_image

from pytorch_lightning import LightningModule
from typing import Any, Callable, List, Dict, Tuple, Optional, Union
from typing import Any, List, Dict, Tuple, Optional, Union, Callable

from yolort.data import COCOEvaluator, contains_any_tensor
from yolort.utils.update_module_state import ModuleStateUpdate
from yolort.v5 import load_yolov5_model

from . import yolo
from .transform import YOLOTransform
Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(
super().__init__()

self.lr = lr
self.arch = arch
self.num_classes = num_classes

self.model = yolo.__dict__[arch](
Expand Down Expand Up @@ -268,3 +271,16 @@ def add_model_specific_args(parent_parser):
parser.add_argument('--weight-decay', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
return parser

def load_from_yolov5(self, checkpoint_path: str):
"""
Load model state from the checkpoint trained by YOLOv5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
"""
checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
module_state_updater = ModuleStateUpdate(arch=self.arch, num_classes=self.num_classes)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.state_dict()
self.model.load_state_dict(state_dict)

0 comments on commit 0d0b5f1

Please sign in to comment.