本项目基于tanluren/yolov3-channel-and-layer-pruning实现,将项目扩展到yolov5上。
项目的基本流程是,使用ultralytics/yolov5训练自己的数据集,在模型性能达到要求但速度未达到要求时,对模型进行剪枝。首先是稀疏化训练,稀疏化训练很重要,如果模型稀疏度不够,剪枝比例过大会导致剪枝后的模型map接近0。剪枝完成后对模型进行微调回复精度。
本项目使用的yolov5为第四版本。 yolov5第三版本参考yolov5-v3-prune yolov5第二版本参考yolov5-v2-prune
TODO: 增加m,l,x的模型剪枝,如果有时间的话。>-<
PS:在开源数据集和不能开源的数据集上模型均剪枝成功。
数据集下载dataset
附件:训练记录
附件:稀疏训练记录
附件:剪枝后模型
附件:微调训练记录
附件:微调蒸馏训练记录
yolov5
示例代码
python train.py --img 640 --batch 8 --epochs 50 --weights weights/yolov5s_v4.pt --data data/coco_hand.yaml --cfg models/yolov5s.yaml --name s_hand
--prune 0 适用于通道剪枝策略一,--prune 1 适用于其他剪枝策略。
yolov5
示例代码
python train_sparsity.py --img 640 --batch 8 --epochs 50 --data data/coco_hand.yaml --cfg models/yolov5s.yaml --weights runs/train/s_hand/weights/last.pt --name s_hand_sparsity -sr --s 0.001 --prune 1
不对shortcut直连的层进行剪枝,避免维度处理。
python prune_yolov5s.py --cfg cfg/yolov5s.cfg --data data/fangweisui.data --weights weights/yolov5s_prune0.pt --percent 0.8
对shortcut层也进行了剪枝,剪枝采用每组shortcut中第一个卷积层的mask。
python shortcut_prune_yolov5s.py --cfg cfg/yolov5s.cfg --data data/fangweisui.data --weights weights/yolov5s_prune1.pt --percent 0.3
先以全局阈值找出各卷积层的mask,然后对于每组shortcut,它将相连的各卷积层的剪枝mask取并集,用merge后的mask进行剪枝。
python slim_prune_yolov5s.py --cfg cfg/yolov5s.cfg --data data/fangweisui.data --weights weights/yolov5s_prune1.pt --global_percent 0.8 --layer_keep 0.01
在硬件部署上发现,模型剪枝率相同时,通道数为8的倍数速度最快。(采坑:需要将硬件性能开启到最大)
示例代码
python slim_prune_yolov5s_8x.py --cfg cfg/yolov5s_v4_hand.cfg --data data/oxfordhand.data --weights weights/last_v4s.pt --global_percent 0.5 --layer_keep 0.01 --img_size 640
yolov5
示例代码
python prune_finetune.py --img 640 --batch 8 --epochs 50 --data data/coco_hand.yaml --cfg ./cfg/prune_0.5_keep_0.01_8x_yolov5s_v4_hand.cfg --weights ./weights/prune_0.5_keep_0.01_8x_last_v4s.pt --name s_hand_finetune
yolov5
示例代码
python prune_finetune.py --img 640 --batch 8 --epochs 50 --data data/coco_hand.yaml --cfg ./cfg/prune_0.5_keep_0.01_8x_yolov5s_v4_hand.cfg --weights ./weights/prune_0.5_keep_0.01_8x_last_v4s.pt --name s_hand_finetune_distill --distill
yolov5
示例代码
python prune_detect.py --weights weights/last_s_hand_finetune.pt --img 640 --conf 0.7 --save-txt --source inference/images