-
Notifications
You must be signed in to change notification settings - Fork 0
/
stf.py
145 lines (124 loc) · 4.83 KB
/
stf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from stanfordcorenlp import StanfordCoreNLP
from tqdm import tqdm
import json
def restore_dep(dep):
'''
将StanfordCoreNLP按照tail进行排序
Args:
dep:
Returns:
'''
def judge(dep_list,j):
'''判断dep_liat从j往前是否出现过root'''
for index in reversed(dep_list[:(j+1)]):
if index[0]=='ROOT':
return False
if index[-1]==1:
return True
return True
#将root恢复到正常的位置
root_list,dep_list=[],[]
for dp in dep:
if dp[0]=="ROOT":
root_list.append(dp)
else:
dep_list.append(dp)
for i,root in enumerate(root_list):
root_tail=root[-1]
for j in range(len(dep_list)):
dp_tail=dep_list[j][-1]
if j==0: #头
if dp_tail==2 and root_tail==1:
dep_list.insert(0,root)
break
if dp_tail==1 and root_tail==2 and dep_list[j+1][-1]==3:
dep_list.insert(1,root)
break
elif j==len(dep_list)-1: #尾
if ( dp_tail+1==root_tail or root_tail==1 ):
dep_list.insert(j+1,root)
break
else: #中间
if dp_tail+1==root_tail and dep_list[j+1][-1]-1==root_tail and judge(dep_list,j):
dep_list.insert(j+1,root)
break
if dp_tail+1==root_tail and dep_list[j+1][-1]==1 and judge(dep_list,j):
dep_list.insert(j+1,root)
break
if root_tail==1 and dep_list[j+1][-1]==2 and dep_list[j][-1]!=1:
dep_list.insert(j+1,root)
break
assert j!=len(dep_list) #没有找到合适的位置
return dep_list
path = './fewrel_dataset/val.json'
save_path='./fewrel_dataset/val_fewrel_stf.json'
with open(path,"r") as f:
ori_data=json.load(f)
nlp = StanfordCoreNLP('D:\PycharmProjects\deft_corpus\stanford-corenlp-full-2018-10-05\stanford-corenlp-full-2018-10-05', lang='en')
new_data={}
error=0
for k,r in enumerate(ori_data.keys()):
new_data[r]=[]
with tqdm(desc=r,ncols=100) as tq:
for i in range(len(ori_data[r])):
if r=='P156' and i==348: #过滤坏数据
continue
cur_ex=ori_data[r][i]
text=" ".join(cur_ex["tokens"])
text=text.lower()
tokens = nlp.word_tokenize(text)
text=" ".join(tokens)
if " . . " in text:
text=text.replace(" . . "," . ")
tokens = nlp.word_tokenize(text)
dep=nlp.dependency_parse(text)
if len(dep)!=len(tokens):
print("分词与依存分析不匹配")
continue
sub=" ".join(cur_ex["tokens"][cur_ex["h"][2][0][0]:cur_ex["h"][2][0][-1]+1]).lower()
obj=" ".join(cur_ex["tokens"][cur_ex["t"][2][0][0]:cur_ex["t"][2][0][-1]+1]).lower()
sub_tokens = nlp.word_tokenize(sub)
obj_tokens = nlp.word_tokenize(obj)
def get_loc(token, sub_tokens):
for i in range(len(token)):
if sub_tokens == token[i:len(sub_tokens) + i]:
return [" ".join(sub_tokens), "", [list(range(i, len(sub_tokens) + i))]]
return -1
h=get_loc(tokens,sub_tokens)
t=get_loc(tokens,obj_tokens)
if h==-1:
tq.update(1)
error+=1
print(sub_tokens)
continue
if t==-1:
tq.update(1)
error+=1
print(obj_tokens)
continue
dep = restore_dep(dep)
stf_pos=nlp.pos_tag(text)
stanford_deprel, stanford_head=[-1]*len(tokens),[-1]*len(tokens)
stanford_pos=[i[1] for i in nlp.pos_tag(text)]
flag=0
for j in range(len(dep)):
deprel, head, tail = dep[j]
if stanford_head[tail - 1 + flag] != -1:
flag = j
assert stanford_head[tail - 1 + flag] == -1
stanford_head[tail - 1 + flag] = head
stanford_deprel[tail - 1 + flag] = deprel
assert -1 not in stanford_deprel
assert -1 not in stanford_head
new_data[r].append({
"tokens":tokens,
"h":h,
"t":t,
"stanford_head":stanford_head,
"stanford_deprel":stanford_deprel,
"stanford_pos":stanford_pos
})
tq.update(1)
with open(save_path,"w") as f:
json.dump(new_data,f)
print(error)