-
Notifications
You must be signed in to change notification settings - Fork 3
/
compress.py
49 lines (44 loc) · 1.53 KB
/
compress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import warnings
warnings.filterwarnings('ignore')
import argparse, yaml, copy
from ultralytics.models.yolo.detect.compress import DetectionCompressor, DetectionFinetune
def compress(param_dict):
with open(param_dict['sl_hyp'], errors='ignore') as f:
sl_hyp = yaml.safe_load(f)
param_dict.update(sl_hyp)
param_dict['name'] = f'{param_dict["name"]}-prune'
param_dict['patience'] = 0
compressor = DetectionCompressor(overrides=param_dict)
prune_model_path = compressor.compress()
return prune_model_path
def finetune(param_dict, prune_model_path):
param_dict['model'] = prune_model_path
param_dict['name'] = f'{param_dict["name"]}-finetune'
trainer = DetectionFinetune(overrides=param_dict)
trainer.train()
if __name__ == '__main__':
param_dict = {
# origin
'model': 'runs/train/yolov8s/weights/best.pt',
'data':'F:/YOLOv8/dataset_car/data_car/data_set.yaml',
'imgsz': 640,
'epochs': 200,
'batch': 12,
'workers': 8,
'cache': True,
'optimizer': 'SGD',
'device': '1',
'close_mosaic': 20,
'project':'runs/prune',
'name':'exp',
# prune
'prune_method':'group_taylor',
'global_pruning': False,
'speed_up': 2.0,
'reg': 0.0005,
'sl_epochs': 500,
'sl_hyp': 'ultralytics/cfg/hyp.scratch.sl.yaml',
'sl_model': None
}
prune_model_path = compress(copy.deepcopy(param_dict))
finetune(copy.deepcopy(param_dict), prune_model_path)