Skip to content

Commit

Permalink
update pos bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dongrixinyu committed May 3, 2022
1 parent f3ed160 commit d4ca6d1
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 28 deletions.
2 changes: 1 addition & 1 deletion jiojio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def cut(text):
""" 对文本进行分词和词性标注 """
if jiojio_pos_flag:
words, norm_words, word_pos_map = jiojio_cws_obj.cut_with_pos(text)

tags = jiojio_pos_obj.cut(norm_words, word_pos_map=word_pos_map)

return [words, tags]

else:
Expand Down
12 changes: 5 additions & 7 deletions jiojio/cws/predict_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, model_dir=None, user_dict=None, with_viterbi=False,
def _cut(self, text):
length = len(text)
all_features = list()

# 每个节点的得分
# Y = np.empty((length, 2), dtype=np.float16)
for idx in range(length):

if self.get_node_features_c is None:
Expand All @@ -95,7 +98,6 @@ def _cut(self, text):
node_features = self.get_node_features_c(
idx, text, length, self.feature_extractor.unigram,
self.feature_extractor.bigram)
# pdb.set_trace()

# 测试:
# if node_features != self.feature_extractor.get_node_features(idx, text):
Expand All @@ -114,14 +116,11 @@ def _cut(self, text):
node_feature_idx.append(0)

else:
# print(len(node_features), node_features)
node_feature_idx = self.cws_feature2idx_c(
node_features, self.feature_extractor.feature_to_idx)
# print(len(node_feature_idx), node_feature_idx)
# print(len(all_features), idx)
# pdb.set_trace()

# all_features.append(node_feature_idx)
all_features.append(node_feature_idx)
# Y[idx] = np.sum(node_weight[node_feature_idx], axis=0)

Y = get_log_Y_YY(all_features, self.model.node_weight, dtype=np.float16)

Expand All @@ -136,7 +135,6 @@ def _cut(self, text):
else:
tags_idx = Y.argmax(axis=1).astype(np.int8)

# print(text)
# print(tags_idx)
# pdb.set_trace()
return tags_idx
Expand Down
6 changes: 4 additions & 2 deletions jiojio/pos/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def save(self, model_dir=None):
data['char'] = sorted(list(self.char))
data['part'] = sorted(list(self.part))
data['single_pos_word'] = sorted(list(self.single_pos_word))
data['feature_to_idx'] = self.feature_to_idx
data['feature_to_idx'] = list(self.feature_to_idx.keys())
data['tag_to_idx'] = self.tag_to_idx

feature_path = os.path.join(model_dir, 'features.json')
Expand Down Expand Up @@ -790,7 +790,9 @@ def load(cls, config, model_dir=None):
extractor.char = set(data['char'])
extractor.part = set(data['part'])
extractor.single_pos_word = set(data['single_pos_word'])
extractor.feature_to_idx = data['feature_to_idx']
feature_list = data['feature_to_idx']
extractor.feature_to_idx = dict(
[(feature, idx) for idx, feature in enumerate(feature_list)])
extractor.tag_to_idx = data['tag_to_idx']

# extractor.idx_to_tag = extractor._reverse_dict(extractor.tag_to_idx)
Expand Down
27 changes: 9 additions & 18 deletions jiojio/pos/predict_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def __init__(self, model_dir=None, user_dict=None, with_viterbi=True,
normalize_num_letter=pos_config.normalize_num_letter,
convert_exception=pos_config.convert_exception)

'''
tmp = dict([(val, k) for k, val in self.feature_extractor.feature_to_idx.items()])
for i in range(len(tmp)):
if i not in tmp:
print(i)
pdb.set_trace()
print(len(tmp))
print(self.model.node_weight[0])
pdb.set_trace()
'''
# C 方式调用
self.get_node_features_c = pos_get_node_features_c

Expand All @@ -98,14 +88,14 @@ def _cut(self, words):
all_node_features = list()
for idx in range(length):

if self.get_node_features_c is None:
# 以 python 方式计算,效率较低
node_features = self.feature_extractor.get_node_features(idx, words)
else:
# 以 C 方式计算,效率高
node_features = self.get_node_features_c(
idx, words, len(words), self.feature_extractor.unigram,
self.feature_extractor.bigram)
# if self.get_node_features_c is None:
# # 以 python 方式计算,效率较低
node_features = self.feature_extractor.get_node_features(idx, words)
# else:
# # 以 C 方式计算,效率高
# node_features = self.get_node_features_c(
# idx, words, len(words), self.feature_extractor.unigram,
# self.feature_extractor.bigram)

# if node_features != self.feature_extractor.get_node_features(idx, words):
# print(node_features)
Expand All @@ -123,6 +113,7 @@ def _cut(self, words):
all_features.append(node_feature_idx)

Y = get_log_Y_YY(all_features, self.model.node_weight, dtype=self.dtype)
# pdb.set_trace()
'''
for idx in range(length):
print(words[idx])
Expand Down

0 comments on commit d4ca6d1

Please sign in to comment.