pytorch版NEZHA,适配transformers
论文下载地址: NEZHA: Neural Contextualized Representation for Chinese Language Understanding
如果需要运行该案例脚本,需要安装以下模块:
官方提供的Tensorflow版本权重下载地址:huawei-noah
已经转化为PyTorch版本权重下载地址如下:
-
nezha-cn-base 百度网盘链接 提取码: hckq
-
nezha-large-zh 百度网盘链接 提取码: qks2
-
nezha-base-wwm 百度网盘链接 提取码: ysg3
-
nezha-large-wwm 百度网盘链接 提取码: 8dig
说明:若加载的模型权重是从下列百度网盘下载的PyTorch模型权重,则需要保证torch版本>=1.6.0
执行命令:
sh scripts/run_task_text_classification_chnsenti.sh
长文本可以通过设置config.max_position_embeddings
参数实现,默认值为512,如:
config.max_position_embeddings=args.train_max_seq_length
NEZHA(base-wwm) | chnsenti |
---|---|
tensorflow | 94.75 |
pytorch | 94.92 |