Skip to content

Commit

Permalink
update t5encoder & test modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
wj-Mcat committed Sep 15, 2022
1 parent 6ebc289 commit ea3f876
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
33 changes: 32 additions & 1 deletion paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
from typing import Optional, Tuple
Expand Down Expand Up @@ -1733,8 +1734,8 @@ def __getattr__(self, name):
raise e


@register_base_model
class T5EncoderModel(T5PretrainedModel):
base_model_class = None

def __init__(self,
vocab_size=32128,
Expand All @@ -1750,6 +1751,20 @@ def __init__(self,
is_decoder: bool = False,
**kwargs):
super().__init__()
self.config = {
"vocab_size": vocab_size,
"d_model": d_model,
"d_kv": d_kv,
"d_ff": d_ff,
"num_layers": num_layers,
"num_heads": num_heads,
"relative_attention_num_buckets": relative_attention_num_buckets,
"dropout_rate": dropout_rate,
"layer_norm_epsilon": layer_norm_epsilon,
"feed_forward_proj": feed_forward_proj,
"is_decoder": is_decoder,
}
self.config.update(kwargs)
self.shared = nn.Embedding(vocab_size, d_model)

self.use_cache = False
Expand All @@ -1769,6 +1784,19 @@ def __init__(self,
# Initialize weights and apply final processing
self.init_weights()

def _post_init(self, *args, **kwargs):
"""
**prevent the `config` property to be assigned**
It would be hooked after `__init__` to add a dict including arguments of
`__init__` as a attribute named `config` of the pretrained model instance.
"""
pass

@property
def t5(self):
return self

def get_input_embeddings(self):
return self.shared

Expand Down Expand Up @@ -1802,3 +1830,6 @@ def forward(
)

return encoder_outputs


T5EncoderModel.base_model_class = T5EncoderModel
14 changes: 9 additions & 5 deletions tests/transformers/t5/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

from tests.testing_utils import slow

from ..test_generation_utils import GenerationTesterMixin
from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
# from ..test_generation_utils import GenerationTesterMixin
# from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor

from tests.transformers.test_generation_utils import GenerationTesterMixin
from tests.transformers.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor

import paddle
from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer
from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer, T5EncoderModel
from paddlenlp.transformers.t5.modeling import T5_PRETRAINED_MODEL_ARCHIVE_LIST


Expand Down Expand Up @@ -500,9 +503,10 @@ def prepare_config_and_inputs_for_common(self):
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
base_model_class = T5Model

all_model_classes = (T5Model, T5ForConditionalGeneration)
all_model_classes = (T5Model, T5ForConditionalGeneration, T5EncoderModel)
all_generative_model_classes = {T5ForConditionalGeneration: (T5Model, "t5")}
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration)
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration,
T5EncoderModel)
fx_compatible = True
test_pruning = False
test_resize_embeddings = True
Expand Down
1 change: 1 addition & 0 deletions tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def test_sample_generate(self):
output_generate[0].tolist())

def test_beam_search_generate(self):
paddle.seed(100)
for model_class in self.all_generative_model_classes.keys():
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(
)
Expand Down

0 comments on commit ea3f876

Please sign in to comment.