-
Notifications
You must be signed in to change notification settings - Fork 573
/
tagger.sh
70 lines (55 loc) · 2.17 KB
/
tagger.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/bin/bash
# tagger script by @bdsqlsz
# Train data path
train_data_dir="./input" # input images path | 图片输入路径
repo_id="SmilingWolf/wd-v1-4-swinv2-tagger-v2" # model repo id from huggingface |huggingface模型repoID
model_dir="" # model dir path | 本地模型文件夹路径
batch_size=12 # batch size in inference 批处理大小,越大越快
max_data_loader_n_workers=0 # enable image reading by DataLoader with this number of workers (faster) | 0最快
thresh=0.35 # concept thresh | 最小识别阈值
general_threshold=0.35 # general threshold | 总体识别阈值
character_threshold=0.1 # character threshold | 人物姓名识别阈值
remove_underscore=0 # remove_underscore | 下划线转空格,1为开,0为关
undesired_tags="" # no need tags | 排除标签
recursive=0 # search for images in subfolders recursively | 递归搜索下层文件夹,1为开,0为关
frequency_tags=0 # order by frequency tags | 从大到小按识别率排序标签,1为开,0为关
# ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 =====================
export HF_HOME="huggingface"
export TF_CPP_MIN_LOG_LEVEL=3
extArgs=()
if [ -n "$repo_id" ]; then
extArgs+=( "--repo_id=$repo_id" )
fi
if [ -n "$model_dir" ]; then
extArgs+=( "--model_dir=$model_dir" )
fi
if [[ $batch_size -ne 0 ]]; then
extArgs+=( "--batch_size=$batch_size" )
fi
if [ -n "$max_data_loader_n_workers" ]; then
extArgs+=( "--max_data_loader_n_workers=$max_data_loader_n_workers" )
fi
if [ -n "$general_threshold" ]; then
extArgs+=( "--general_threshold=$general_threshold" )
fi
if [ -n "$character_threshold" ]; then
extArgs+=( "--character_threshold=$character_threshold" )
fi
if [ "$remove_underscore" -eq 1 ]; then
extArgs+=( "--remove_underscore" )
fi
if [ -n "$undesired_tags" ]; then
extArgs+=( "--undesired_tags=$undesired_tags" )
fi
if [ "$recursive" -eq 1 ]; then
extArgs+=( "--recursive" )
fi
if [ "$frequency_tags" -eq 1 ]; then
extArgs+=( "--frequency_tags" )
fi
# run tagger
accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/finetune/tag_images_by_wd14_tagger.py" \
$train_data_dir \
--thresh=$thresh \
--caption_extension .txt \
${extArgs[@]}