Skip to content

Commit

Permalink
Support autoenvolved anchors from ultralytics (#120)
Browse files Browse the repository at this point in the history
* Pass anchor grids as parameter in model creation. Example in ultralytics notebook

* Updated cell outputs
  • Loading branch information
Tomakko authored Jun 16, 2021
1 parent 68d955c commit 7b20f38
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 89 deletions.
116 changes: 48 additions & 68 deletions notebooks/how-to-align-with-ultralytics-yolov5.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "2db8c2b2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -18,6 +19,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "e58aed40",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,6 +36,7 @@
},
{
"cell_type": "markdown",
"id": "ad808a2a",
"metadata": {},
"source": [
"## Prepare image and model weights to test"
Expand All @@ -42,6 +45,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "d25d1d17",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,12 +55,13 @@
"img_path = \"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\" # For user in China\n",
"img_raw = get_image_from_url(img_path)\n",
"\n",
"yolort_weight_path = 'yolov5s_r40_updated.pt'\n",
"ultralytics_weights_path = \"yolov5s.pt\""
"yolort_weight_path = './yolov5s_r40_updated.pt'\n",
"ultralytics_weights_path = \"./yolov5s.pt\""
]
},
{
"cell_type": "markdown",
"id": "2bdbc220",
"metadata": {},
"source": [
"You can download the weight with following methods\n",
Expand All @@ -73,6 +78,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "56cc6a9f",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -84,40 +90,33 @@
},
{
"cell_type": "markdown",
"id": "1008f18d",
"metadata": {},
"source": [
"## Load model as ultralytics and inference"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "3b3bbe08",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /root/.cache/torch/hub/ultralytics_yolov5_master\n",
"Fusing layers... \n",
"Model Summary: 224 layers, 7266973 parameters, 0 gradients\n",
"Adding AutoShape... \n",
"YOLOv5 🚀 v5.0-189-gdaab682 torch 1.8.1+cu102 CUDA:0 (Tesla P100-SXM2-16GB, 16276.25MB)\n",
"\n"
]
}
],
"outputs": [],
"source": [
"conf = 0.25\n",
"iou = 0.45\n",
"\n",
"model = torch.hub.load('ultralytics/yolov5', 'custom', path=ultralytics_weights_path)\n",
"model = torch.hub.load('ultralytics/yolov5', 'custom', path=ultralytics_weights_path,autoshape=False, force_reload=True)\n",
"model = model.to(device)\n",
"model.conf = conf # confidence threshold (0-1)\n",
"model.iou = iou # NMS IoU threshold (0-1)\n",
"model.classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs\n",
"model.eval()\n",
"\n",
"# Get actual anchors from ultralytics model\n",
"m = model.model[-1] # get Detect() layer\n",
"anchor_grids = m.anchor_grid.squeeze().view((3,6)).tolist() # get anchors\n",
"\n",
"with torch.no_grad():\n",
" ultralytics_dets = model(img[None])[0]\n",
" ultralytics_dets = non_max_suppression(ultralytics_dets, conf, iou, agnostic=True)[0]"
Expand All @@ -126,18 +125,19 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "8882ef33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detection results with ultralytics:\n",
"tensor([[3.95028e+02, 2.28687e+02, 4.80437e+02, 5.25319e+02, 8.81427e-01, 0.00000e+00],\n",
" [1.31801e+02, 2.40823e+02, 2.05202e+02, 5.10122e+02, 8.74422e-01, 0.00000e+00],\n",
" [3.28450e+01, 2.39402e+02, 1.42193e+02, 5.31945e+02, 8.50409e-01, 0.00000e+00],\n",
"tensor([[3.95028e+02, 2.28687e+02, 4.80437e+02, 5.25319e+02, 8.81428e-01, 0.00000e+00],\n",
" [1.31801e+02, 2.40823e+02, 2.05202e+02, 5.10122e+02, 8.74423e-01, 0.00000e+00],\n",
" [3.28450e+01, 2.39402e+02, 1.42193e+02, 5.31945e+02, 8.50408e-01, 0.00000e+00],\n",
" [1.81174e+01, 1.36144e+02, 4.74266e+02, 4.48792e+02, 7.12929e-01, 5.00000e+00],\n",
" [1.97870e-01, 2.94924e+02, 4.41640e+01, 5.27107e+02, 4.00531e-01, 0.00000e+00]])\n"
" [1.97870e-01, 2.94923e+02, 4.41640e+01, 5.27107e+02, 4.00531e-01, 0.00000e+00]])\n"
]
}
],
Expand All @@ -147,79 +147,46 @@
},
{
"cell_type": "markdown",
"id": "3445954a",
"metadata": {},
"source": [
"## Updating model weights from ultralytics to yolort"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "f0901ec9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /root/.cache/torch/hub/ultralytics_yolov5_v4.0\n",
"\n",
" from n params module arguments \n",
" 0 -1 1 3520 models.common.Focus [3, 32, 3] \n",
" 1 -1 1 18560 models.common.Conv [32, 64, 3, 2] \n",
" 2 -1 1 18816 models.common.C3 [64, 64, 1] \n",
" 3 -1 1 73984 models.common.Conv [64, 128, 3, 2] \n",
" 4 -1 1 156928 models.common.C3 [128, 128, 3] \n",
" 5 -1 1 295424 models.common.Conv [128, 256, 3, 2] \n",
" 6 -1 1 625152 models.common.C3 [256, 256, 3] \n",
" 7 -1 1 1180672 models.common.Conv [256, 512, 3, 2] \n",
" 8 -1 1 656896 models.common.SPP [512, 512, [5, 9, 13]] \n",
" 9 -1 1 1182720 models.common.C3 [512, 512, 1, False] \n",
" 10 -1 1 131584 models.common.Conv [512, 256, 1, 1] \n",
" 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
" 12 [-1, 6] 1 0 models.common.Concat [1] \n",
" 13 -1 1 361984 models.common.C3 [512, 256, 1, False] \n",
" 14 -1 1 33024 models.common.Conv [256, 128, 1, 1] \n",
" 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
" 16 [-1, 4] 1 0 models.common.Concat [1] \n",
" 17 -1 1 90880 models.common.C3 [256, 128, 1, False] \n",
" 18 -1 1 147712 models.common.Conv [128, 128, 3, 2] \n",
" 19 [-1, 14] 1 0 models.common.Concat [1] \n",
" 20 -1 1 296448 models.common.C3 [256, 256, 1, False] \n",
" 21 -1 1 590336 models.common.Conv [256, 256, 3, 2] \n",
" 22 [-1, 10] 1 0 models.common.Concat [1] \n",
" 23 -1 1 1182720 models.common.C3 [512, 512, 1, False] \n",
" 24 [17, 20, 23] 1 229245 models.yolo.Detect [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]\n",
"Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPs\n",
"\n",
"Adding AutoShape... \n"
]
}
],
"outputs": [],
"source": [
"model = update_module_state_from_ultralytics(arch='yolov5s',\n",
" version='v4.0',\n",
" custom_path_or_model=ultralytics_weights_path,\n",
" set_fp16=is_half)\n",
" set_fp16=is_half,\n",
" num_classes=80)\n",
"\n",
"torch.save(model.state_dict(), yolort_weight_path)"
]
},
{
"cell_type": "markdown",
"id": "1b117bd3",
"metadata": {},
"source": [
"## Load model as yolort and inference"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "05bd5867",
"metadata": {},
"outputs": [],
"source": [
"from yolort.models.yolo import yolov5_darknet_pan_s_r40 as yolov5s\n",
"\n",
"model = yolov5s(score_thresh=conf, nms_thresh=iou)\n",
"model = yolov5s(score_thresh=conf, nms_thresh=iou, num_classes=80,anchor_grids=anchor_grids)\n",
"model.load_state_dict(torch.load(yolort_weight_path))\n",
"\n",
"# Load model\n",
Expand All @@ -234,6 +201,7 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "7a5be968",
"metadata": {},
"outputs": [
{
Expand All @@ -245,7 +213,7 @@
" [1.31801e+02, 2.40823e+02, 2.05202e+02, 5.10122e+02],\n",
" [3.28450e+01, 2.39402e+02, 1.42193e+02, 5.31945e+02],\n",
" [1.81174e+01, 1.36144e+02, 4.74266e+02, 4.48792e+02],\n",
" [1.97876e-01, 2.94923e+02, 4.41640e+01, 5.27107e+02]])\n"
" [1.97870e-01, 2.94923e+02, 4.41640e+01, 5.27107e+02]])\n"
]
}
],
Expand All @@ -256,6 +224,7 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "ae9b4284",
"metadata": {},
"outputs": [
{
Expand All @@ -274,6 +243,7 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "8f5eb19f",
"metadata": {},
"outputs": [
{
Expand All @@ -291,6 +261,7 @@
},
{
"cell_type": "markdown",
"id": "7e2afd7f",
"metadata": {},
"source": [
"## Varify the detection results between yolort and ultralytics"
Expand All @@ -299,6 +270,7 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "b4313594",
"metadata": {},
"outputs": [
{
Expand All @@ -314,13 +286,21 @@
"\n",
"print(\"Exported model has been tested, and the result looks good!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bcbab38",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "venv",
"language": "python",
"name": "python3"
"name": "venv"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit 7b20f38

Please sign in to comment.