forked from supercoderhawk/DNN_CWS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathseg_base.py
99 lines (84 loc) · 3.14 KB
/
seg_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: UTF-8 -*-
import numpy as np
import math
class SegBase:
def __init__(self):
self.TAGS = np.arange(4)
self.TAG_MAPS = np.array([[0, 1], [2, 3], [2, 3], [0, 1]], dtype=np.int32)
self.tags_count = len(self.TAG_MAPS)
self.dictionary = {}
self.skip_window_left = 0
self.skip_window_right = 1
def viterbi(self, emission, A, init_A, return_score=False):
"""
维特比算法的实现,所有输入和返回参数均为numpy数组对象
:param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length
:param A: 转移概率矩阵,4*4
:param init_A: 初始转移概率矩阵,4
:param return_score: 是否返回最优路径的分值,默认为False
:return: 最优路径,若return_score为True,返回最优路径及其对应分值
"""
length = emission.shape[1]
path = np.ones([self.tags_count, length], dtype=np.int32) * -1
corr_path = np.zeros([length], dtype=np.int32)
path_score = np.ones([self.tags_count, length], dtype=np.float64) * (np.finfo('f').min / 2)
path_score[:, 0] = init_A + emission[:, 0]
for pos in range(1, length):
for t in range(self.tags_count):
for prev in range(self.tags_count):
temp = path_score[prev][pos - 1] + A[prev][t] + emission[t][pos]
if temp >= path_score[t][pos]:
path[t][pos] = prev
path_score[t][pos] = temp
max_index = np.argmax(path_score[:, -1])
corr_path[length - 1] = max_index
for i in range(length - 1, 0, -1):
max_index = path[max_index][i]
corr_path[i - 1] = max_index
if return_score:
return corr_path, path_score[max_index, -1]
else:
return corr_path
def sentence2index(self, sentence):
index = []
for word in sentence:
if word not in self.dictionary:
index.append(0)
else:
index.append(self.dictionary[word])
return index
def index2seq(self, indices):
ext_indices = [1] * self.skip_window_left
ext_indices.extend(indices + [2] * self.skip_window_right)
seq = []
for index in range(self.skip_window_left, len(ext_indices) - self.skip_window_right):
seq.append(ext_indices[index - self.skip_window_left: index + self.skip_window_right + 1])
return seq
def tags2words(self, sentence, tags):
words = []
word = ''
for tag_index, tag in enumerate(tags):
if tag == 0:
words.append(sentence[tag_index])
elif tag == 1:
word = sentence[tag_index]
elif tag == 2:
word += sentence[tag_index]
else:
words.append(word + sentence[tag_index])
word = ''
# 处理最后一个标记为I的情况
if word != '':
words.append(word)
return words
def cal_sentence_loss(self, tags, sentence_scores, A, init_A):
_, score = self.viterbi(sentence_scores, A, init_A, True)
loss = 0.0
before = 0
for index, (corr_tag, scores) in enumerate(zip(tags, sentence_scores.T)):
if index == 0:
loss += scores[corr_tag] + init_A[corr_tag]
else:
loss += scores[corr_tag] + A[before, corr_tag]
before = corr_tag
return math.fabs(loss - score)