diff --git a/paddlenlp/transformers/t5/modeling.py b/paddlenlp/transformers/t5/modeling.py index e054426a0001..db228d4cedd8 100644 --- a/paddlenlp/transformers/t5/modeling.py +++ b/paddlenlp/transformers/t5/modeling.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import math +from typing import Optional, Tuple, Union, List import numpy as np import paddle +from paddle import Tensor import paddle.nn as nn import paddle.nn.functional as F @@ -25,9 +28,8 @@ from ..nezha.modeling import ACT2FN __all__ = [ - 'T5Model', - "T5PretrainedModel", - 'T5ForConditionalGeneration', + 'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration', + 'T5EncoderModel' ] T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -1730,3 +1732,104 @@ def __getattr__(self, name): return getattr(self, self.base_model_prefix).config[name] except KeyError: raise e + + +class T5EncoderModel(T5PretrainedModel): + base_model_class = None + + def __init__(self, + vocab_size=32128, + d_model=768, + d_kv=64, + d_ff=3072, + num_layers=12, + num_heads=12, + relative_attention_num_buckets=32, + dropout_rate=0.1, + layer_norm_epsilon=1e-06, + feed_forward_proj="relu", + is_decoder: bool = False, + **kwargs): + super().__init__() + self.config = { + "vocab_size": vocab_size, + "d_model": d_model, + "d_kv": d_kv, + "d_ff": d_ff, + "num_layers": num_layers, + "num_heads": num_heads, + "relative_attention_num_buckets": relative_attention_num_buckets, + "dropout_rate": dropout_rate, + "layer_norm_epsilon": layer_norm_epsilon, + "feed_forward_proj": feed_forward_proj, + "is_decoder": is_decoder, + } + self.config.update(kwargs) + self.shared = nn.Embedding(vocab_size, d_model) + + self.use_cache = False + self.is_encoder_decoder = False + self.encoder = T5Stack(d_model, + num_layers, + layer_norm_epsilon, + dropout_rate, + relative_attention_num_buckets, + d_kv, + num_heads, + feed_forward_proj, + d_ff, + embed_tokens=self.shared, + is_decoder=is_decoder) + + # Initialize weights and apply final processing + self.init_weights() + + def _post_init(self, *args, **kwargs): + """ + **prevent the `config` property to be assigned** + + It would be hooked after `__init__` to add a dict including arguments of + `__init__` as a attribute named `config` of the pretrained model instance. + """ + pass + + @property + def t5(self): + return self + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def forward( + self, + input_ids: Tensor = None, + attention_mask: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tuple[Tensor]] = None, + encoder_attention_mask: Optional[Tensor] = None, + cache=None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + cache=cache, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return encoder_outputs + + +T5EncoderModel.base_model_class = T5EncoderModel diff --git a/tests/transformers/t5/test_modeling.py b/tests/transformers/t5/test_modeling.py index 8ca7c882e29e..d76e1705dbb0 100644 --- a/tests/transformers/t5/test_modeling.py +++ b/tests/transformers/t5/test_modeling.py @@ -25,7 +25,7 @@ from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor import paddle -from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer +from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer, T5EncoderModel from paddlenlp.transformers.t5.modeling import T5_PRETRAINED_MODEL_ARCHIVE_LIST @@ -500,9 +500,10 @@ def prepare_config_and_inputs_for_common(self): class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): base_model_class = T5Model - all_model_classes = (T5Model, T5ForConditionalGeneration) + all_model_classes = (T5Model, T5ForConditionalGeneration, T5EncoderModel) all_generative_model_classes = {T5ForConditionalGeneration: (T5Model, "t5")} - all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) + all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration, + T5EncoderModel) fx_compatible = True test_pruning = False test_resize_embeddings = True diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index c6031f641971..06cb23b3646e 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -498,6 +498,7 @@ def test_sample_generate(self): output_generate[0].tolist()) def test_beam_search_generate(self): + paddle.seed(100) for model_class in self.all_generative_model_classes.keys(): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config( )