Skip to content

Commit

Permalink
Add keep_extended_vocab_only option (PaddlePaddle#34)
Browse files Browse the repository at this point in the history
add keep_extend_vocab_only option
  • Loading branch information
ZeyuChen authored Feb 24, 2021
2 parents d466171 + 261c14c commit 1b203ca
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
14 changes: 1 addition & 13 deletions examples/word_embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,8 @@ wget https://paddlenlp.bj.bcebos.com/data/dict.txt

我们以中文情感分类公开数据集ChnSentiCorp为示例数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在验证集(dev.tsv)验证。

CPU 启动:

```
# 使用paddlenlp.embeddings.TokenEmbedding
python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir'
# 使用paddle.nn.Embedding
python train.py --vocab_path='./dict.txt' --use_gpu=False --lr=1e-4 --batch_size=64 --epochs=20 --use_token_embedding=False --vdl_dir='./vdl_dir'
```

GPU 启动:
启动训练:
```
export CUDA_VISIBLE_DEVICES=0
# 使用paddlenlp.embeddings.TokenEmbedding
python train.py --vocab_path='./dict.txt' --use_gpu=True --lr=5e-4 --batch_size=64 --epochs=20 --use_token_embedding=True --vdl_dir='./vdl_dir'
Expand Down
30 changes: 17 additions & 13 deletions paddlenlp/embeddings/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ class TokenEmbedding(nn.Embedding):
The file path of extended vocabulary.
trainable (object: `bool`, optional, default to True):
Whether the weight of embedding can be trained.
keep_extended_vocab_only (object: `bool`, optional, default to True):
Whether keep the extended vocabulary only, will be effective only if provides extended_vocab_path
"""

def __init__(self,
embedding_name=EMBEDDING_NAME_LIST[0],
unknown_token=UNK_TOKEN,
unknown_token_vector=None,
extended_vocab_path=None,
trainable=True):
trainable=True,
keep_extended_vocab_only=False):
vector_path = osp.join(EMBEDDING_HOME, embedding_name + ".npz")
if not osp.exists(vector_path):
# download
Expand All @@ -87,7 +90,8 @@ def __init__(self,
[0] * self.embedding_dim).astype(paddle.get_default_dtype())
if extended_vocab_path is not None:
embedding_table = self._extend_vocab(extended_vocab_path, vector_np,
pad_vector, unk_vector)
pad_vector, unk_vector,
keep_extended_vocab_only)
trainable = True
else:
embedding_table = self._init_without_extend_vocab(
Expand Down Expand Up @@ -138,7 +142,7 @@ def _read_vocab_list_from_file(self, extended_vocab_path):
return vocab_list

def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
unk_vector):
unk_vector, keep_extended_vocab_only):
"""
Construct index to word list, word to index dict and embedding weight using
extended vocab.
Expand Down Expand Up @@ -182,16 +186,16 @@ def _extend_vocab(self, extended_vocab_path, vector_np, pad_vector,
embedding_table[
extend_vocab_intersect_index] = pretrained_embedding_table[
pretrained_vocab_intersect_index]

for idx in pretrained_vocab_subtract_index:
word = pretrained_idx_to_word[idx]
self._idx_to_word.append(word)
self._word_to_idx[word] = len(self._idx_to_word) - 1

embedding_table = np.append(
embedding_table,
pretrained_embedding_table[pretrained_vocab_subtract_index],
axis=0)
if not keep_extended_vocab_only:
for idx in pretrained_vocab_subtract_index:
word = pretrained_idx_to_word[idx]
self._idx_to_word.append(word)
self._word_to_idx[word] = len(self._idx_to_word) - 1

embedding_table = np.append(
embedding_table,
pretrained_embedding_table[pretrained_vocab_subtract_index],
axis=0)

if self.unknown_token not in extend_vocab_set:
self._idx_to_word.append(self.unknown_token)
Expand Down

0 comments on commit 1b203ca

Please sign in to comment.