仓库路径应该组织成如下结构:
sense_classification/
|->examples
|->models
|->prepare_data
|->data
| |->rssrai_sense_cls
| | |->train
| | |->val
| | |->test
| |->tf_records
| |->train_list
|->ckpt
|->tools
- tensorflow-gpu==1.12.0 (I only test on tensorflow 1.12.0)
- python==3.4.3
- numpy
- easydict
- opencv==3.4.1
- 有些包可能没列出来,根据错误提示安装
- 下载代码
git clone https://github.com/vicwer/sense_classification.git
data目录结构:
data/
|->rssrai_sense_cls
| |->train
| |->val
| |->test
| |->ClsName2id.txt
|->train_list/train.txt
|->tf_records
-
下载数据集并解压: train.zip, val.zip, test.zip, ClsName2id.txt
-
生成 tf_records:
cd tools
python3 img_encode.py
${sense_classification_ROOT}目录提供了config.py, 可设置超参数
例如
cd ${sense_classification_ROOT}
vim config.py
cfg.train.num_gpus = {your gpu nums}
etc.
cd ${sense_classification_ROOT}/examples/
python3 multi_gpus_train.py
cd ${sense_classification_ROOT}/examples/
python3 accuracy.py
cd ${sense_classification_ROOT}/examples/
python3 submit.py
验证集: 0.908+ 测试集:0.90509