You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This works great and passes the assertion test laid out in the remainder of the code. The issue arises when I attempt to make the model scriptable (Gist):
model_script=torch.jit.script(model)
This fails with the error (same for both YOLOv5 v4.0 and v6.0):
Starting TorchScript export with torch 1.9.0a0+gitd69c22d...Traceback (most recent call last): File "convert_weights_issue.py", line 145, in <module> model_script = torch.jit.script(model) # THIS FAILS File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1096, in script return torch.jit._recursive.create_script_module( File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 412, in create_script_module return create_script_module_impl(nn_module, concrete_type, stubs_fn) File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 474, in create_script_module_impl script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 497, in _construct init_fn(script_module) File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 437, in init_fn cpp_module.setattr(name, orig_value)RuntimeError: Could not cast attribute 'strides' to type List[int]: Unable to cast Python instance to C++ type (compile in debug mode for details)
Additionally, I found your comment here instructing to load the model differently (Gist):
fromyolort.modelsimportyolov5smodel=yolov5s(upstream_version=rversion, score_thresh=score_thresh)
model.load_from_yolov5(checkpoint_path=path_ultralytics_weights, version=rversion)
model.eval()
results=model.predict(img_name)
print("Results from loading model via model.load_from_yolov5:")
print(results)
While this model is successfully scriptable via torch.jit.script(model), neither v4.0 or v6.0 returns any detections:
Results from loading model via model.load_from_yolov5:[{'scores': tensor([]), 'labels': tensor([], dtype=torch.int64), 'boxes': tensor([], size=(0, 4))}]
Not sure what I'm doing wrong, but hopefully it's something simple like last time (#142)! Thank you for the help in advance!
Below is a script I wrote that you can use to verify/test everything I mentioned above using models from Ultralytics. Also available via this Gist.
Click to display convert_ultralytics_to_rt-stack.py
# Author: Matt Popovich (mattpopovich.com)# Date: January 6, 2022# yolort Release: 0.5.2# Except for a bit at the end, this is all copied from:# https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/how-to-align-with-ultralytics-yolov5.ipynbimportosimportcv2importtorchimportsys# sys.path.insert(0, "/home/mpopovich/git/yolov5-rt-stack")os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"os.environ["CUDA_VISIBLE_DEVICES"]="0"# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')device=torch.device('cpu')
fromyolort.models.yoloimportYOLOfromyolort.utilsimport (
cv2_imshow,
get_image_from_url,
read_image_to_tensor,
)
fromyolort.utils.image_utilsimportplot_one_box, color_listfromyolort.v5importload_yolov5_model, letterbox, non_max_suppression, scale_coords, attempt_download# Set LABELS and COLORSimportrequestslabel_path="https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/notebooks/assets/coco.names"response=requests.get(label_path)
names=response.textLABELS= []
forlabelinnames.strip().split('\n'):
LABELS.append(label)
COLORS=color_list()
# Get image img_name='bus.jpg'img_url='https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/main/test/assets/'+img_nameifos.path.isfile(img_name):
print(img_name+" already downloaded!")
else:
attempt_download(img_url)
print("Downloaded "+img_name+" successfully!")
img_raw=cv2.imread(img_name)
# Preprocessimg=letterbox(img_raw, new_shape=(640, 640))[0]
img=read_image_to_tensor(img)
img=img.to(device)
version="6.0"# "4.0" or "6.0"rversion="r"+versionmodel_url="https://github.com/ultralytics/yolov5/releases/download/v"+version+"/yolov5s.pt"full_model_name="yolov5s-v"+version+".pt"# Download model from Ultralytics GitHubifos.path.isfile(full_model_name):
print(full_model_name+" already downloaded!")
else:
checkpoint_path=attempt_download(model_url)
os.rename("yolov5s.pt", full_model_name)
print("Downloaded "+full_model_name+" successfully!")
# Load Ultralytics model score_thresh=0.30iou=0.45model=load_yolov5_model(full_model_name, autoshape=False, verbose=True)
model=model.to(device)
model.conf=score_thresh# confidence threshold (0-1)model.iou=iou# NMS IoU threshold (0-1)model.classes=None# (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogsmodel=model.eval()
# Perform inferencewithtorch.no_grad():
ultralytics_dets=model(img[None])[0]
ultralytics_dets=non_max_suppression(ultralytics_dets, score_thresh, iou, agnostic=True)[0]
scaled_ultralytics_dets=ultralytics_dets.clone()
print("Ultralytics detections:")
print(ultralytics_dets)
# Save Ultralytics inference imageboxes=scale_coords(img.shape[1:], scaled_ultralytics_dets[:,:4], img_raw.shape[:-1])
labels=scaled_ultralytics_dets[:,5:]
forbox, labelinzip(boxes.tolist(), labels.tolist()):
img_raw=plot_one_box(box, img_raw, color=COLORS[int(label[0]) %len(COLORS)], label=LABELS[int(label[0])])
cv2.imwrite(os.path.splitext(img_name)[0] +'-ultralytics-inference.jpg', img_raw)
# # Loading the trained checkpoint as instructed in:# # https://github.com/zhiqwang/yolov5-rt-stack/issues/141#issuecomment-924221401# # This model is able to be scriptable: torch.jit.script(model) # # But the inference results are empty# from yolort.models import yolov5s# model = yolov5s(upstream_version=rversion, score_thresh=score_thresh)# model.load_from_yolov5(checkpoint_path=full_model_name, version=rversion)# model.eval()# results = model.predict(img_name)# print("Results from loading model via model.load_from_yolov5:")# print(results)# Update model weights from Ultralytics to yolort # According to [8] in how-to-align-with-ultralytics-yolov5.ipynbmodel=YOLO.load_from_yolov5(
full_model_name,
score_thresh=score_thresh,
nms_thresh=iou,
version=rversion,
)
model.eval()
withtorch.no_grad():
yolort_dets=model(img[None])
print(f"Detection boxes with yolort:\n{yolort_dets[0]['boxes']}")
print(f"Detection scores with yolort:\n{yolort_dets[0]['scores']}")
print(f"Detection labels with yolort:\n{yolort_dets[0]['labels']}")
# Verify the detection results between yolort and Ultralytics # Testing boxestorch.testing.assert_allclose(
yolort_dets[0]['boxes'], ultralytics_dets[:, :4], rtol=1e-05, atol=1e-07)
# Testing scorestorch.testing.assert_allclose(
yolort_dets[0]['scores'], ultralytics_dets[:, 4], rtol=1e-05, atol=1e-07)
# Testing labelstorch.testing.assert_allclose(
yolort_dets[0]['labels'], ultralytics_dets[:, 5].to(dtype=torch.int64), rtol=1e-05, atol=1e-07)
print("Exported model has been tested, and the result looks good!")
# Save yolort inference image boxes=scale_coords(img.shape[1:], yolort_dets[0]['boxes'], img_raw.shape[:-1])
labels=yolort_dets[0]['labels']
forbox, labelinzip(boxes.tolist(), labels.tolist()):
img_raw=plot_one_box(box, img_raw, color=COLORS[label%len(COLORS)], label=LABELS[label])
cv2.imwrite(os.path.splitext(img_name)[0] +'-yolort-inference.jpg', img_raw)
# Scripting YOLOv5, basically a copy of:# https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/inference-pytorch-export-libtorch.ipynb# TorchScript exportprint(f'Starting TorchScript export with torch {torch.__version__}...')
export_script_name=os.path.splitext(full_model_name)[0] +'-RT-v0.5.2.torchscript.pt'model_script=torch.jit.script(model) # THIS FAILSmodel_script.eval()
# Save the scripted model file for subsequent use (Optional)model_script.save(export_script_name)
Versions
Click to display Versions
# python3 -m torch.utils.collect_envCollecting environment information...PyTorch version: 1.9.0a0+gitd69c22dIs debug build: FalseCUDA used to build PyTorch: 11.2ROCM used to build PyTorch: N/AOS: Ubuntu 20.04.2 LTS (x86_64)GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0Clang version: Could not collectCMake version: version 3.21.1Libc version: glibc-2.31Python version: 3.8 (64-bit runtime)Python platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.29Is CUDA available: TrueCUDA runtime version: 11.2.152GPU models and configuration: GPU 0: GeForce GTX 1080GPU 1: GeForce GTX 1080GPU 2: GeForce GTX 1080Nvidia driver version: 460.91.03cuDNN version: Probably one of the following:/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0HIP runtime version: N/AMIOpen runtime version: N/AVersions of relevant libraries:[pip3] numpy==1.21.4[pip3] pytorch-lightning==1.5.8[pip3] torch==1.9.0a0+gitd69c22d[pip3] torchmetrics==0.6.2[pip3] torchvision==0.10.0a0+300a8a4[conda] Could not collect
The text was updated successfully, but these errors were encountered:
Hi @mattpopovich , Thanks for your detailed information about this problem, I can reproduce this issue here, it should be a bug, we aim to fix this at the end of this week.
We now recommend using the YOLOv5.load_from_yolov5() or YOLO.load_from_yolov5() to load the checkpoint trained from ultralytics/yolov5 to reduce the transformation between yolov5 and yolort, and I fix the bug of exporting the torchscript for custom checkpoint trained with ultralytics/yolov5 in #267, you can check the example #267 (comment) for more details.
BTW, the YOLO.load_from_yolov5() will behave the same with ultralytics/yolov5. Otherwise, in YOLOv5.load_from_yolov5() we implement a different pre-processing method, which will have minor differences with yolov5 as the last time we discussed about these #142 (comment) . (We have plan to make the pre-processing behave the same with yolov5.)
And we provide a CLI tool to translate the checkpoint from yolov5 to yolort, I think this tool can also solve this problem if you load the translated checkpoints in yolort.models.yolov5s().
I believe this issue will be resolved, as such I'm closing this, feel free to reopen this or create another ticket if you have more question.
🐛 Describe the bug
I believe the recommended way to do this is to first convert your Ultralytics weights to yolort via how-to-align-with-ultralytics-yolov5.ipynb and to then make the model scriptable via inference-pytorch-export-libtorch.ipynb. I have tried this with both YOLOv5 releases 4.0 and 6.0, but with no success.
Following the instructions on how-to-align-with-ultralytics-yolov5.ipynb, I convert the model weights from Ultralytics to yolort via (Gist):
This works great and passes the assertion test laid out in the remainder of the code. The issue arises when I attempt to make the model scriptable (Gist):
This fails with the error (same for both YOLOv5 v4.0 and v6.0):
Additionally, I found your comment here instructing to load the model differently (Gist):
While this
model
is successfully scriptable viatorch.jit.script(model)
, neither v4.0 or v6.0 returns any detections:Not sure what I'm doing wrong, but hopefully it's something simple like last time (#142)! Thank you for the help in advance!
Below is a script I wrote that you can use to verify/test everything I mentioned above using models from Ultralytics. Also available via this Gist.
Click to display convert_ultralytics_to_rt-stack.py
Versions
Click to display Versions
The text was updated successfully, but these errors were encountered: