A BiGRU-Attention DSSM implementation with tensorflow estimator.
对应博客:https://blog.csdn.net/cdj0311/article/details/107634795
之前使用Keras和paddlepaddle实现过DSSM文本表示模型,(https://github.com/cdj0311/keras_bert_classification/blob/master/bert_dssm.py, https://github.com/cdj0311/paddledssm) 由于Keras做分布式计算比较麻烦,而paddlepaddle早已弃用,现在用tensorflow的高级API tf.estimator重写一遍,其中表示层使用双向GRU+Attention,最终输出为64维的向量。
python == 3.6
tensorflow == 1.13.1
训练步骤如下:
-
将文本数据转换为tfrecord格式:
python convert_data.py
data目录的data.txt中包含了10000条训练数据,数据为某新闻网站上的标题和对应的内容,格式为:title\tcontent,train.tfrecord是转换完成的tfrecord数据。
-
模型训练:
sh train_local.sh
模型训练完后会分别导出query和doc的pb格式模型,可根据需要进行选择。
-
模型预测:
python predict.py
给定一个句子得到向量,并获取最相似的N个句子,例如:
输入: 赵丽颖冯绍峰在拍女儿国的时候真的超级甜了
输出:
0.801103 女神赵丽颖李沁都爱穿黄毛衣,但差距真的蛮大的 0.744942 街拍:喜欢第二位俏皮可爱的小姐姐,和她在一起不会觉得无聊! 0.722599 杜江霍思燕夫妇甜蜜现身 牵手依偎恩爱甜到发腻 0.719018 还在情侣穿搭烦恼,看街拍情侣都是怎么搭配的 0.707306 赵丽颖,应是绿肥红瘦,剧照 0.701783 她的闺蜜则穿了一件白色的蕾丝连衣裙,尽显女人味 0.70024 国民妖精十元女神可爱撩人瞬间合集!出色的不只是时尚穿搭 0.691073 图集:#杨幂#赵丽颖暗斗时尚穿同款婚纱谁更美 0.687201 赵丽颖 路人抓拍下的颖宝,这颜值可以说是完美的纯天然美女了~
输入: 祝考研的女士们先生们都顺利考进自己理想的学校
输出:
0.890815 祝考研的女士们先生们都顺利考进自己理想的学校!实在考不上就滚tm的,当代... 0.758741 硕士研究生招生考试22日开考 0.701588 加油高考!祝你们顺利考上心仪的大学! 0.660756 中考,你准备好了吗? 0.654576 这些考研复试面试小技巧收好,导师的心就抓住了! 0.63505 高考生作弊被抓飞踹监考老师:你知道我爸是谁? 0.626651 高考倒计时30天,祝所有今年参加高考的小伙伴们心想事成,高考必胜 0.590912 各位同学请注意,第一季期末考试现在开始~请认真阅读仔细答题 0.585147 航班延误艺考生妈妈痛哭 浙传:可提供证明安排考试 0.575564 当女儿带男同学回家写作业的时候,爸爸都在想什么
-
分布式训练
设置run_on_cluster=True, 提交到job中即可训练,由于每个公司的分布式训练提交命令不一样,这里就不贴出来了。
该项目是基于字符做Embedding,实际使用中我们一般会将字和词同时作为输入进行训练。