Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert custom Ultralytics YOLOv5 weights to yolo-rt weights and export as torchscript? #265

Closed
mattpopovich opened this issue Jan 7, 2022 · 2 comments · Fixed by #267
Labels
bug / fix Something isn't working

Comments

@mattpopovich
Copy link
Contributor

mattpopovich commented Jan 7, 2022

🐛 Describe the bug

I believe the recommended way to do this is to first convert your Ultralytics weights to yolort via how-to-align-with-ultralytics-yolov5.ipynb and to then make the model scriptable via inference-pytorch-export-libtorch.ipynb. I have tried this with both YOLOv5 releases 4.0 and 6.0, but with no success.

Following the instructions on how-to-align-with-ultralytics-yolov5.ipynb, I convert the model weights from Ultralytics to yolort via (Gist):

from yolort.models.yolo import YOLO
model = YOLO.load_from_yolov5(
    path_ultralytics_weights,
    score_thresh=score_thresh,
    nms_thresh=iou,
    version=rversion,
)

This works great and passes the assertion test laid out in the remainder of the code. The issue arises when I attempt to make the model scriptable (Gist):

model_script = torch.jit.script(model)

This fails with the error (same for both YOLOv5 v4.0 and v6.0):

Starting TorchScript export with torch 1.9.0a0+gitd69c22d...
Traceback (most recent call last):
  File "convert_weights_issue.py", line 145, in <module>
    model_script = torch.jit.script(model)  # THIS FAILS
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1096, in script
    return torch.jit._recursive.create_script_module(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 412, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 474, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 497, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 437, in init_fn
    cpp_module.setattr(name, orig_value)
RuntimeError: Could not cast attribute 'strides' to type List[int]: Unable to cast Python instance to C++ type (compile in debug mode for details)

Additionally, I found your comment here instructing to load the model differently (Gist):

from yolort.models import yolov5s
model = yolov5s(upstream_version=rversion, score_thresh=score_thresh)
model.load_from_yolov5(checkpoint_path=path_ultralytics_weights, version=rversion)
model.eval()
results = model.predict(img_name)
print("Results from loading model via model.load_from_yolov5:")
print(results)

While this model is successfully scriptable via torch.jit.script(model), neither v4.0 or v6.0 returns any detections:

Results from loading model via model.load_from_yolov5:
[{'scores': tensor([]), 'labels': tensor([], dtype=torch.int64), 'boxes': tensor([], size=(0, 4))}]

Not sure what I'm doing wrong, but hopefully it's something simple like last time (#142)! Thank you for the help in advance!

Below is a script I wrote that you can use to verify/test everything I mentioned above using models from Ultralytics. Also available via this Gist.

Click to display convert_ultralytics_to_rt-stack.py

# Author: Matt Popovich (mattpopovich.com)
# Date: January 6, 2022
# yolort Release: 0.5.2

# Except for a bit at the end, this is all copied from:
#   https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/how-to-align-with-ultralytics-yolov5.ipynb

import os
import cv2
import torch

import sys
# sys.path.insert(0, "/home/mpopovich/git/yolov5-rt-stack")

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

from yolort.models.yolo import YOLO
from yolort.utils import (
    cv2_imshow,
    get_image_from_url,
    read_image_to_tensor,
)
from yolort.utils.image_utils import plot_one_box, color_list
from yolort.v5 import load_yolov5_model, letterbox, non_max_suppression, scale_coords, attempt_download

# Set LABELS and COLORS
import requests
label_path = "https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/notebooks/assets/coco.names"
response = requests.get(label_path)
names = response.text
LABELS = []
for label in names.strip().split('\n'):
    LABELS.append(label)
COLORS = color_list()

# Get image 
img_name = 'bus.jpg'
img_url = 'https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/main/test/assets/' + img_name
if os.path.isfile(img_name):
    print(img_name + " already downloaded!")
else:
    attempt_download(img_url)
    print("Downloaded " + img_name + " successfully!")

img_raw = cv2.imread(img_name)

# Preprocess
img = letterbox(img_raw, new_shape=(640, 640))[0]
img = read_image_to_tensor(img)
img = img.to(device)

version = "6.0"     # "4.0" or "6.0"
rversion = "r" + version
model_url = "https://github.com/ultralytics/yolov5/releases/download/v" + version + "/yolov5s.pt"
full_model_name = "yolov5s-v" + version + ".pt"

# Download model from Ultralytics GitHub
if os.path.isfile(full_model_name):
    print(full_model_name + " already downloaded!")
else:
    checkpoint_path = attempt_download(model_url)
    os.rename("yolov5s.pt", full_model_name)
    print("Downloaded " + full_model_name + " successfully!")

# Load Ultralytics model 
score_thresh = 0.30
iou = 0.45
model = load_yolov5_model(full_model_name, autoshape=False, verbose=True)
model = model.to(device)
model.conf = score_thresh  # confidence threshold (0-1)
model.iou = iou  # NMS IoU threshold (0-1)
model.classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs
model = model.eval()

# Perform inference
with torch.no_grad():
    ultralytics_dets = model(img[None])[0]
    ultralytics_dets = non_max_suppression(ultralytics_dets, score_thresh, iou, agnostic=True)[0]
    scaled_ultralytics_dets = ultralytics_dets.clone()
print("Ultralytics detections:")
print(ultralytics_dets)

# Save Ultralytics inference image
boxes = scale_coords(img.shape[1:], scaled_ultralytics_dets[:,:4], img_raw.shape[:-1])
labels = scaled_ultralytics_dets[:,5:]
for box, label in zip(boxes.tolist(), labels.tolist()):
    img_raw = plot_one_box(box, img_raw, color=COLORS[int(label[0]) % len(COLORS)], label=LABELS[int(label[0])])
cv2.imwrite(os.path.splitext(img_name)[0] + '-ultralytics-inference.jpg', img_raw)

# # Loading the trained checkpoint as instructed in:
# #     https://github.com/zhiqwang/yolov5-rt-stack/issues/141#issuecomment-924221401
# # This model is able to be scriptable: torch.jit.script(model) 
# #     But the inference results are empty
# from yolort.models import yolov5s
# model = yolov5s(upstream_version=rversion, score_thresh=score_thresh)
# model.load_from_yolov5(checkpoint_path=full_model_name, version=rversion)
# model.eval()
# results = model.predict(img_name)
# print("Results from loading model via model.load_from_yolov5:")
# print(results)

# Update model weights from Ultralytics to yolort 
# According to [8] in how-to-align-with-ultralytics-yolov5.ipynb
model = YOLO.load_from_yolov5(
    full_model_name,
    score_thresh=score_thresh,
    nms_thresh=iou,
    version=rversion,
)
model.eval()
with torch.no_grad():
    yolort_dets = model(img[None])
print(f"Detection boxes with yolort:\n{yolort_dets[0]['boxes']}")
print(f"Detection scores with yolort:\n{yolort_dets[0]['scores']}")
print(f"Detection labels with yolort:\n{yolort_dets[0]['labels']}")

# Verify the detection results between yolort and Ultralytics 
# Testing boxes
torch.testing.assert_allclose(
    yolort_dets[0]['boxes'], ultralytics_dets[:, :4], rtol=1e-05, atol=1e-07)
# Testing scores
torch.testing.assert_allclose(
    yolort_dets[0]['scores'], ultralytics_dets[:, 4], rtol=1e-05, atol=1e-07)
# Testing labels
torch.testing.assert_allclose(
    yolort_dets[0]['labels'], ultralytics_dets[:, 5].to(dtype=torch.int64), rtol=1e-05, atol=1e-07)
print("Exported model has been tested, and the result looks good!")

# Save yolort inference image 
boxes = scale_coords(img.shape[1:], yolort_dets[0]['boxes'], img_raw.shape[:-1])
labels = yolort_dets[0]['labels']
for box, label in zip(boxes.tolist(), labels.tolist()):
    img_raw = plot_one_box(box, img_raw, color=COLORS[label % len(COLORS)], label=LABELS[label])
cv2.imwrite(os.path.splitext(img_name)[0] + '-yolort-inference.jpg', img_raw)

# Scripting YOLOv5, basically a copy of:
# https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/inference-pytorch-export-libtorch.ipynb
# TorchScript export
print(f'Starting TorchScript export with torch {torch.__version__}...')
export_script_name = os.path.splitext(full_model_name)[0] + '-RT-v0.5.2.torchscript.pt'
model_script = torch.jit.script(model)  # THIS FAILS
model_script.eval()
# Save the scripted model file for subsequent use (Optional)
model_script.save(export_script_name)

Versions

Click to display Versions

# python3 -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.9.0a0+gitd69c22d
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.21.1
Libc version: glibc-2.31

Python version: 3.8 (64-bit runtime)
Python platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration: 
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080

Nvidia driver version: 460.91.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.4
[pip3] pytorch-lightning==1.5.8
[pip3] torch==1.9.0a0+gitd69c22d
[pip3] torchmetrics==0.6.2
[pip3] torchvision==0.10.0a0+300a8a4
[conda] Could not collect

@zhiqwang zhiqwang added the bug / fix Something isn't working label Jan 7, 2022
@zhiqwang
Copy link
Owner

zhiqwang commented Jan 7, 2022

Hi @mattpopovich , Thanks for your detailed information about this problem, I can reproduce this issue here, it should be a bug, we aim to fix this at the end of this week.

@zhiqwang
Copy link
Owner

zhiqwang commented Jan 9, 2022

Hi @mattpopovich ,

We now recommend using the YOLOv5.load_from_yolov5() or YOLO.load_from_yolov5() to load the checkpoint trained from ultralytics/yolov5 to reduce the transformation between yolov5 and yolort, and I fix the bug of exporting the torchscript for custom checkpoint trained with ultralytics/yolov5 in #267, you can check the example #267 (comment) for more details.

BTW, the YOLO.load_from_yolov5() will behave the same with ultralytics/yolov5. Otherwise, in YOLOv5.load_from_yolov5() we implement a different pre-processing method, which will have minor differences with yolov5 as the last time we discussed about these #142 (comment) . (We have plan to make the pre-processing behave the same with yolov5.)

And we provide a CLI tool to translate the checkpoint from yolov5 to yolort, I think this tool can also solve this problem if you load the translated checkpoints in yolort.models.yolov5s().

I believe this issue will be resolved, as such I'm closing this, feel free to reopen this or create another ticket if you have more question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants