Skip to content

Commit

Permalink
keep_extend_vocab_only->keep_extended_vocab_only
Browse files Browse the repository at this point in the history
  • Loading branch information
joey12300 committed Feb 24, 2021
1 parent 0c8824d commit 261c14c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions paddlenlp/embeddings/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class TokenEmbedding(nn.Embedding):
The file path of extended vocabulary.
trainable (object: `bool`, optional, default to True):
Whether the weight of embedding can be trained.
keep_extend_vocab_only (object: `bool`, optional, default to True):
Whether keep the extend vocabulary only, only effective if provides extended_vocab_path
keep_extended_vocab_only (object: `bool`, optional, default to True):
Whether keep the extended vocabulary only, will be effective only if provides extended_vocab_path
"""

def __init__(self,
Expand All @@ -68,7 +68,7 @@ def __init__(self,
unknown_token_vector=None,
extended_vocab_path=None,
trainable=True,
keep_extend_vocab_only=False):
keep_extended_vocab_only=False):
vector_path = osp.join(EMBEDDING_HOME, embedding_name + ".npz")
if not osp.exists(vector_path):
# download
Expand All @@ -91,7 +91,7 @@ def __init__(self,
if extended_vocab_path is not None:
embedding_table = self._extend_vocab(extended_vocab_path, vector_np,
pad_vector, unk_vector,
keep_extend_vocab_only)
keep_extended_vocab_only)
trainable = True
else:
embedding_table = self._init_without_extend_vocab(
Expand Down Expand Up @@ -142,7 +142,7 @@ def _read_vocab_list_from_file(self, extended_vocab_path):
return vocab_list

def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
unk_vector, keep_extend_vocab_only):
unk_vector, keep_extended_vocab_only):
"""
Construct index to word list, word to index dict and embedding weight using
extended vocab.
Expand Down Expand Up @@ -186,7 +186,7 @@ def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
embedding_table[
extend_vocab_intersect_index] = pretrained_embedding_table[
pretrained_vocab_intersect_index]
if not keep_extend_vocab_only:
if not keep_extended_vocab_only:
for idx in pretrained_vocab_subtract_index:
word = pretrained_idx_to_word[idx]
self._idx_to_word.append(word)
Expand Down

0 comments on commit 261c14c

Please sign in to comment.