Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1126 from joey12300/fix_crf_decode_le…
Browse files Browse the repository at this point in the history
…n_equal_1_bug

fix viterbi len=1 bug
  • Loading branch information
wawltor authored Oct 10, 2021
2 parents 04de795 + 74a7f10 commit bb413b0
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion paddlenlp/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ def forward(self, inputs, lengths):

# last_ids: batch_size
scores, last_ids = alpha.max(1), alpha.argmax(1)
if max_seq_len == 1:
return scores, last_ids.unsqueeze(1)
# Trace back the best path
# historys: seq_len, batch_size, n_labels
historys = paddle.stack(historys)
Expand All @@ -438,10 +440,14 @@ def forward(self, inputs, lengths):
# hist: batch_size, n_labels
left_length = left_length + 1
gather_idx = batch_offset + last_ids
tag_mask = paddle.cast((left_length >= 0), 'int64')
tag_mask = paddle.cast((left_length > 0), 'int64')
last_ids_update = paddle.gather(hist.flatten(),
gather_idx) * tag_mask
zero_len_mask = paddle.cast((left_length == 0), 'int64')
last_ids_update = last_ids_update * (1 - zero_len_mask
) + last_ids * zero_len_mask
batch_path.append(last_ids_update)
tag_mask = paddle.cast((left_length >= 0), 'int64')
last_ids = last_ids_update + last_ids * (1 - tag_mask)
batch_path = paddle.reverse(paddle.stack(batch_path, 1), [1])
return scores, batch_path
Expand Down

0 comments on commit bb413b0

Please sign in to comment.