Skip to content

Commit

Permalink
Fixed bugs of generation api and unified_transformer model. (PaddlePa…
Browse files Browse the repository at this point in the history
…ddle#371)

Co-authored-by: Guo Sheng <[email protected]>
  • Loading branch information
xiemoyuan and guoshengCS authored May 12, 2021
1 parent ae4b039 commit 54de8b2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ def prepare_input_ids_for_generation(bos_token_id):
@staticmethod
def prepare_attention_mask_for_generation(input_ids, pad_token_id,
eos_token_id):
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (
pad_token_id in input_ids)
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
input_ids == pad_token_id).numpy().item()
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id))
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
Expand Down
10 changes: 8 additions & 2 deletions paddlenlp/transformers/unified_transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,5 +512,11 @@ def prepare_inputs_for_generation(self,
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(getattr(self, self.base_model_prefix), name)
except AttributeError as e:
try:
return getattr(getattr(self, self.base_model_prefix), name)
except AttributeError:
try:
return getattr(self, self.base_model_prefix).config[name]
except KeyError:
raise e

0 comments on commit 54de8b2

Please sign in to comment.