Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #559 from nickyfantasy/updateBeamSearch
Browse files Browse the repository at this point in the history
update beam search API in machine translation book example
  • Loading branch information
nickyfantasy authored Jun 29, 2018
2 parents 3f0b8ec + 801e400 commit ef10dd9
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 24 deletions.
21 changes: 17 additions & 4 deletions 08.machine_translation/README.cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,18 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -264,10 +273,14 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
21 changes: 17 additions & 4 deletions 08.machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,18 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -301,10 +310,14 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
23 changes: 17 additions & 6 deletions 08.machine_translation/index.cn.html
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@

```python
def encoder(is_sparse):
# encoder
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
Expand All @@ -221,7 +220,6 @@

```python
def train_decoder(context, is_sparse):
# decoder
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
Expand Down Expand Up @@ -297,9 +295,18 @@
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -308,10 +315,14 @@
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
23 changes: 17 additions & 6 deletions 08.machine_translation/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@

```python
def encoder(is_sparse):
# encoder
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
Expand All @@ -258,7 +257,6 @@

```python
def train_decoder(context, is_sparse):
# decoder
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
Expand Down Expand Up @@ -334,9 +332,18 @@
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -345,10 +352,14 @@
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
21 changes: 17 additions & 4 deletions 08.machine_translation/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,18 @@ def decode(context):
# use score to do beam search
current_score = pd.fc(
input=current_state_with_lod, size=target_dict_dim, act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

with pd.Switch() as switch:
with switch.case(pd.is_empty(selected_ids)):
Expand All @@ -113,10 +122,14 @@ def decode(context):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores

Expand Down

0 comments on commit ef10dd9

Please sign in to comment.