From d4ca6d1f23f609bae7fbb888f45f06f01801adc1 Mon Sep 17 00:00:00 2001 From: dongrixinyu Date: Tue, 3 May 2022 17:18:34 +0800 Subject: [PATCH] update pos bugs --- jiojio/__init__.py | 2 +- jiojio/cws/predict_text.py | 12 +++++------- jiojio/pos/feature_extractor.py | 6 ++++-- jiojio/pos/predict_text.py | 27 +++++++++------------------ 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/jiojio/__init__.py b/jiojio/__init__.py index 28c3b30..cf8a8d5 100644 --- a/jiojio/__init__.py +++ b/jiojio/__init__.py @@ -122,8 +122,8 @@ def cut(text): """ 对文本进行分词和词性标注 """ if jiojio_pos_flag: words, norm_words, word_pos_map = jiojio_cws_obj.cut_with_pos(text) - tags = jiojio_pos_obj.cut(norm_words, word_pos_map=word_pos_map) + return [words, tags] else: diff --git a/jiojio/cws/predict_text.py b/jiojio/cws/predict_text.py index 82f1409..5ad2567 100644 --- a/jiojio/cws/predict_text.py +++ b/jiojio/cws/predict_text.py @@ -85,6 +85,9 @@ def __init__(self, model_dir=None, user_dict=None, with_viterbi=False, def _cut(self, text): length = len(text) all_features = list() + + # 每个节点的得分 + # Y = np.empty((length, 2), dtype=np.float16) for idx in range(length): if self.get_node_features_c is None: @@ -95,7 +98,6 @@ def _cut(self, text): node_features = self.get_node_features_c( idx, text, length, self.feature_extractor.unigram, self.feature_extractor.bigram) - # pdb.set_trace() # 测试: # if node_features != self.feature_extractor.get_node_features(idx, text): @@ -114,14 +116,11 @@ def _cut(self, text): node_feature_idx.append(0) else: - # print(len(node_features), node_features) node_feature_idx = self.cws_feature2idx_c( node_features, self.feature_extractor.feature_to_idx) - # print(len(node_feature_idx), node_feature_idx) - # print(len(all_features), idx) - # pdb.set_trace() - # all_features.append(node_feature_idx) + all_features.append(node_feature_idx) + # Y[idx] = np.sum(node_weight[node_feature_idx], axis=0) Y = get_log_Y_YY(all_features, self.model.node_weight, dtype=np.float16) @@ -136,7 +135,6 @@ def _cut(self, text): else: tags_idx = Y.argmax(axis=1).astype(np.int8) - # print(text) # print(tags_idx) # pdb.set_trace() return tags_idx diff --git a/jiojio/pos/feature_extractor.py b/jiojio/pos/feature_extractor.py index 535b17f..ac5f89a 100644 --- a/jiojio/pos/feature_extractor.py +++ b/jiojio/pos/feature_extractor.py @@ -759,7 +759,7 @@ def save(self, model_dir=None): data['char'] = sorted(list(self.char)) data['part'] = sorted(list(self.part)) data['single_pos_word'] = sorted(list(self.single_pos_word)) - data['feature_to_idx'] = self.feature_to_idx + data['feature_to_idx'] = list(self.feature_to_idx.keys()) data['tag_to_idx'] = self.tag_to_idx feature_path = os.path.join(model_dir, 'features.json') @@ -790,7 +790,9 @@ def load(cls, config, model_dir=None): extractor.char = set(data['char']) extractor.part = set(data['part']) extractor.single_pos_word = set(data['single_pos_word']) - extractor.feature_to_idx = data['feature_to_idx'] + feature_list = data['feature_to_idx'] + extractor.feature_to_idx = dict( + [(feature, idx) for idx, feature in enumerate(feature_list)]) extractor.tag_to_idx = data['tag_to_idx'] # extractor.idx_to_tag = extractor._reverse_dict(extractor.tag_to_idx) diff --git a/jiojio/pos/predict_text.py b/jiojio/pos/predict_text.py index d41f459..41f0c47 100644 --- a/jiojio/pos/predict_text.py +++ b/jiojio/pos/predict_text.py @@ -78,16 +78,6 @@ def __init__(self, model_dir=None, user_dict=None, with_viterbi=True, normalize_num_letter=pos_config.normalize_num_letter, convert_exception=pos_config.convert_exception) - ''' - tmp = dict([(val, k) for k, val in self.feature_extractor.feature_to_idx.items()]) - for i in range(len(tmp)): - if i not in tmp: - print(i) - pdb.set_trace() - print(len(tmp)) - print(self.model.node_weight[0]) - pdb.set_trace() - ''' # C 方式调用 self.get_node_features_c = pos_get_node_features_c @@ -98,14 +88,14 @@ def _cut(self, words): all_node_features = list() for idx in range(length): - if self.get_node_features_c is None: - # 以 python 方式计算,效率较低 - node_features = self.feature_extractor.get_node_features(idx, words) - else: - # 以 C 方式计算,效率高 - node_features = self.get_node_features_c( - idx, words, len(words), self.feature_extractor.unigram, - self.feature_extractor.bigram) + # if self.get_node_features_c is None: + # # 以 python 方式计算,效率较低 + node_features = self.feature_extractor.get_node_features(idx, words) + # else: + # # 以 C 方式计算,效率高 + # node_features = self.get_node_features_c( + # idx, words, len(words), self.feature_extractor.unigram, + # self.feature_extractor.bigram) # if node_features != self.feature_extractor.get_node_features(idx, words): # print(node_features) @@ -123,6 +113,7 @@ def _cut(self, words): all_features.append(node_feature_idx) Y = get_log_Y_YY(all_features, self.model.node_weight, dtype=self.dtype) + # pdb.set_trace() ''' for idx in range(length): print(words[idx])