本项目参考:
中文闲聊对话:
https://drive.google.com/file/d/1nEuew_KNpTMbyy7BO4c8bXMXN351RCPp/view
情感分类:
下载数据集后存放在 data 文件夹下,路径分别为data/train.txt
和data/ChnSentiCorp_htl_all.csv
模型参数量90M
python train.py config/train90M.yaml
模型参数量300M
python train.py config/train300M.yaml
默认配置训练需要约16GB显存,你可以根据实际的硬件条件修改batch size
你可以在hugging face🤗上下载预训练模型
mkdir -p checkpoints/pretrained
cd checkpoints/pretrained
wget https://huggingface.co/cjl196/small-gpt/resolve/main/cpt90M.pth?download=true -O cpt90M.pth
wget https://huggingface.co/cjl196/small-gpt/resolve/main/cpt300M.pth?download=true -O cpt300M.pth
如果你希望使用自己训练的模型,在对话前,请修改配置文件中resume_from
的值为模型的路径
使用下面的指令,和预训练的300M模型对话
python chat.py config/chat300M.yaml
基于预训练的300M模型,训练情感分类器
情感分类提供多个配置文件config/sentimental*.yaml
,主要区别是是否mask、是否冻结参数,可用于消融实验
python sentimentalTrain.py config/sentimental.yaml
消融实验效果:
准确度 | 训练时间 | |
---|---|---|
无mask&无冻结参数 | 91.3% | 1hr |
有mask&无冻结参数 | 91.2% | 1hr |
有mask&有冻结参数 | 87.8% | 26.63min |