Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

update beam search API in machine translation book example #559

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions 08.machine_translation/README.cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,18 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -264,10 +273,14 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
21 changes: 17 additions & 4 deletions 08.machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,18 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -301,10 +310,14 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
23 changes: 17 additions & 6 deletions 08.machine_translation/index.cn.html
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@

```python
def encoder(is_sparse):
# encoder
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
Expand All @@ -221,7 +220,6 @@

```python
def train_decoder(context, is_sparse):
# decoder
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
Expand Down Expand Up @@ -297,9 +295,18 @@
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -308,10 +315,14 @@
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
23 changes: 17 additions & 6 deletions 08.machine_translation/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@

```python
def encoder(is_sparse):
# encoder
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
Expand All @@ -258,7 +257,6 @@

```python
def train_decoder(context, is_sparse):
# decoder
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
Expand Down Expand Up @@ -334,9 +332,18 @@
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

pd.increment(x=counter, value=1, in_place=True)

Expand All @@ -345,10 +352,14 @@
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores
```
Expand Down
21 changes: 17 additions & 4 deletions 08.machine_translation/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,18 @@ def decode(context):
# use score to do beam search
current_score = pd.fc(
input=current_state_with_lod, size=target_dict_dim, act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)

with pd.Switch() as switch:
with switch.case(pd.is_empty(selected_ids)):
Expand All @@ -113,10 +122,14 @@ def decode(context):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)

pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)

translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)

return translation_ids, translation_scores

Expand Down