This repository has been archived by the owner on Feb 22, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 231
/
qaData.py
131 lines (115 loc) · 4.24 KB
/
qaData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
from collections import defaultdict
import jieba
import numpy as np
def loadEmbedding(filename):
"""
加载词向量文件
:param filename: 文件名
:return: embeddings列表和它对应的索引
"""
embeddings = []
word2idx = defaultdict(list)
with open(filename, mode="r", encoding="utf-8") as rf:
for line in rf:
arr = line.split(" ")
embedding = [float(val) for val in arr[1: -1]]
word2idx[arr[0]] = len(word2idx)
embeddings.append(embedding)
return embeddings, word2idx
def sentenceToIndex(sentence, word2idx, maxLen):
"""
将句子分词,并转换成embeddings列表的索引值
:param sentence: 句子
:param word2idx: 词语的索引
:param maxLen: 句子的最大长度
:return: 句子的词向量索引表示
"""
unknown = word2idx.get("UNKNOWN", 0)
num = word2idx.get("NUM", len(word2idx))
index = [unknown] * maxLen
i = 0
for word in jieba.cut(sentence):
if word in word2idx:
index[i] = word2idx[word]
else:
if re.match("\d+", word):
index[i] = num
else:
index[i] = unknown
if i >= maxLen - 1:
break
i += 1
return index
def loadData(filename, word2idx, maxLen, training=False):
"""
加载训练文件或者测试文件
:param filename: 文件名
:param word2idx: 词向量索引
:param maxLen: 句子的最大长度
:param training: 是否作为训练文件读取
:return: 问题,答案,标签和问题ID
"""
question = ""
questionId = 0
questions, answers, labels, questionIds = [], [], [], []
with open(filename, mode="r", encoding="utf-8") as rf:
for line in rf.readlines():
arr = line.split("\t")
if question != arr[0]:
question = arr[0]
questionId += 1
questionIdx = sentenceToIndex(arr[0].strip(), word2idx, maxLen)
answerIdx = sentenceToIndex(arr[1].strip(), word2idx, maxLen)
if training:
label = int(arr[2])
labels.append(label)
questions.append(questionIdx)
answers.append(answerIdx)
questionIds.append(questionId)
return questions, answers, labels, questionIds
def trainingBatchIter(questions, answers, labels, questionIds, batchSize):
"""
逐个获取每一批训练数据的迭代器,会区分每个问题的正确和错误答案,拼接为(q,a+,a-)形式
:param questions: 问题列表
:param answers: 答案列表
:param labels: 标签列表
:param questionIds: 问题ID列表
:param batchSize: 每个batch的大小
"""
trueAnswer = ""
dataLen = questionIds[-1]
batchNum = int(dataLen / batchSize) + 1
line = 0
for batch in range(batchNum):
# 对于每一批问题
resultQuestions, trueAnswers, falseAnswers = [], [], []
for questionId in range(batch * batchSize, min((batch + 1) * batchSize, dataLen)):
# 对于每一个问题
trueCount = 0
while questionIds[line] == questionId:
# 对于某个问题中的某一行
if labels[line] == 0:
resultQuestions.append(questions[line])
falseAnswers.append(answers[line])
else:
trueAnswer = answers[line]
trueCount += 1
line += 1
trueAnswers.extend([trueAnswer] * (questionIds.count(questionId) - trueCount))
yield np.array(resultQuestions), np.array(trueAnswers), np.array(falseAnswers)
def testingBatchIter(questions, answers, batchSize):
"""
逐个获取每一批测试数据的迭代器
:param questions: 问题列表
:param answers: 答案列表
:param batchSize: 每个batch的大小
"""
lines = len(questions)
dataLen = batchSize * 20
batchNum = int(lines / dataLen) + 1
questions, answers = np.array(questions), np.array(answers)
for batch in range(batchNum):
startIndex = batch * dataLen
endIndex = min(batch * dataLen + dataLen, lines)
yield questions[startIndex:endIndex], answers[startIndex:endIndex]