Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Model]add t5-encoder-model #3168

Merged
merged 19 commits into from
Sep 21, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# limitations under the License.

import math
from typing import Optional, Tuple

import numpy as np
import paddle
from paddle.tensor.tensor import Tensor

import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -26,9 +28,8 @@
from ..nezha.modeling import ACT2FN

__all__ = [
'T5Model',
"T5PretrainedModel",
'T5ForConditionalGeneration',
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration',
'T5EncoderModel'
]


Expand Down Expand Up @@ -1659,3 +1660,74 @@ def __getattr__(self, name):
return getattr(self, self.base_model_prefix).config[name]
except KeyError:
raise e


@register_base_model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于T5EncoderModel没有base_model,所以在from_pretrained的会检查cls == cls.base_model_class会出错,因为此时cls.base_model_classT5Model,故此处添加一个装饰器可临时解决此问题:重新设置 base_model_class.

Copy link
Member

@JunnYu JunnYu Sep 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个得照着那天说的那样改,然后应该也给这个加个单测

class T5EncoderPretrainedModel(T5PretrainedModel):
    pass
class T5EncoderModel(T5EncoderPretrainedModel):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确实得添加单测,这样才能够让模型跑通。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,这个确实得添加单测。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个得等 #3115 合入之后我这边才能基于他的来做。

class T5EncoderModel(T5PretrainedModel):

def __init__(self,
vocab_size=32128,
d_model=768,
d_kv=64,
d_ff=3072,
num_layers=12,
num_heads=12,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-06,
feed_forward_proj="relu",
is_decoder: bool = False,
**kwargs):
super().__init__()
self.shared = nn.Embedding(vocab_size, d_model)

self.use_cache = False
self.is_encoder_decoder = False
self.encoder = T5Stack(d_model,
num_layers,
layer_norm_epsilon,
dropout_rate,
relative_attention_num_buckets,
d_kv,
num_heads,
feed_forward_proj,
d_ff,
embed_tokens=self.shared,
is_decoder=is_decoder)

# Initialize weights and apply final processing
self.init_weights()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def get_encoder(self):
return self.encoder

def forward(
self,
input_ids: Tensor = None,
attention_mask: Optional[Tensor] = None,
encoder_hidden_states: Optional[Tuple[Tensor]] = None,
encoder_attention_mask: Optional[Tensor] = None,
cache=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache 是否也可以指定下类型

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因套件中 cache 的类型较多,不固定,可以是 MultiHeadAttention.Cache,也可以是 List,所以这里不作限制。

use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
cache=cache,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

return encoder_outputs