diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index c2c31eedf4782c..927f3c816a13e1 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -781,14 +781,14 @@ def model(self, def forward(self, input_ids, end_id): output, cached_kvs = self.model(input_ids, use_cache=True, cache=None) src_ids = input_ids - nid = paddle.argmax(output[0, -1]).reshape([1, -1]) + nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1]) src_ids = paddle.concat([src_ids, nid], axis=1) cur_len = 0 while (cur_len < self.max_predict_len): output, cached_kvs = self.model( nid, use_cache=True, cache=cached_kvs) - nid = paddle.argmax(output[0, -1]).reshape([1, -1]) + nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1]) src_ids = paddle.concat([src_ids, nid], axis=1) cur_len += 1 if paddle.max(nid) == end_id: