Skip to content

GuangchenJ/digit_recognizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

digit_recognizer

Practice deep learning with MNIST dataset

以 Kaggle上面 MNIST 数据集的比赛Digit Recognizer作为练习写的两个小程序,分别使用Python PyTorchC++ TensorRT实现。作为熟悉这两个框架的小练习,难度较低,非常值得一看。

其中Python的Pyorch实现的使用方法就直接运行main.py文件就好,需要注意需要自己更改项目中的参数,例如train还是test模式,epoch还是其他的一些参数。

TensorRT的程序需要您首先训练出来一个模型参数,然后使用gen_wts.py文件生成对应的网络参数cnn.wts(注意可能需要更改程序中的路径问题),若不更改,则默认将cnn.wts文件存放于./cpp_tensorRT/文件夹下面。

cp {file_dir}/cnn.wts {digit_recognizer}/cpp_tensorRT/cnn.wts

然后您需要cd到TensorRT程序的文件夹下,然后构建模型,之后运行:

cd cpp_tensorRT/
mkdir build
cd build/
cmake ..
make
./main -s # 构建模型
./main -d # 进行推理

最后会生成一个submission.csv文件,可以将其上Kaggle上面提交哦。

Inspired by wang-xinyu

About

Practice deep learning with MNIST dataset

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published