-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Model]add t5-encoder-model #3168
Merged
Merged
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7f2b9a0
add t5-encoder-model
wj-Mcat d41c816
update t5model
wj-Mcat f5c2632
Merge branch 'develop' into add-t5-encoder
wj-Mcat 6ebc289
Merge branch 'develop' into add-t5-encoder
wj-Mcat ea3f876
update t5encoder & test modeling
wj-Mcat e8e42a5
Merge branch 'develop' into add-t5-encoder
wj-Mcat 77198ee
Merge branch 'develop' into add-t5-encoder
wj-Mcat 87eef45
Merge branch 'develop' into add-t5-encoder
wj-Mcat a891db2
Merge branch 'develop' into add-t5-encoder
wj-Mcat 459f055
Merge branch 'develop' into add-t5-encoder
wj-Mcat a097681
update t5
wj-Mcat f6af703
Merge branch 'add-t5-encoder' of github.com:wj-Mcat/PaddleNLP into ad…
wj-Mcat 998fbc3
update type hinting
wj-Mcat 7bdaa2a
Merge branch 'develop' into add-t5-encoder
wj-Mcat 2760b16
update cache type annotation
wj-Mcat 800bdf0
Merge branch 'add-t5-encoder' of github.com:wj-Mcat/PaddleNLP into ad…
wj-Mcat 06459ca
Merge branch 'develop' into add-t5-encoder
wj-Mcat 5b7efe3
Merge branch 'develop' into add-t5-encoder
wj-Mcat 1df0007
Merge branch 'develop' into add-t5-encoder
wj-Mcat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,11 @@ | |
# limitations under the License. | ||
|
||
import math | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
import paddle | ||
from paddle.tensor.tensor import Tensor | ||
|
||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
@@ -26,9 +28,8 @@ | |
from ..nezha.modeling import ACT2FN | ||
|
||
__all__ = [ | ||
'T5Model', | ||
"T5PretrainedModel", | ||
'T5ForConditionalGeneration', | ||
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration', | ||
'T5EncoderModel' | ||
] | ||
|
||
|
||
|
@@ -1659,3 +1660,74 @@ def __getattr__(self, name): | |
return getattr(self, self.base_model_prefix).config[name] | ||
except KeyError: | ||
raise e | ||
|
||
|
||
@register_base_model | ||
class T5EncoderModel(T5PretrainedModel): | ||
|
||
def __init__(self, | ||
vocab_size=32128, | ||
d_model=768, | ||
d_kv=64, | ||
d_ff=3072, | ||
num_layers=12, | ||
num_heads=12, | ||
relative_attention_num_buckets=32, | ||
dropout_rate=0.1, | ||
layer_norm_epsilon=1e-06, | ||
feed_forward_proj="relu", | ||
is_decoder: bool = False, | ||
**kwargs): | ||
super().__init__() | ||
self.shared = nn.Embedding(vocab_size, d_model) | ||
|
||
self.use_cache = False | ||
self.is_encoder_decoder = False | ||
self.encoder = T5Stack(d_model, | ||
num_layers, | ||
layer_norm_epsilon, | ||
dropout_rate, | ||
relative_attention_num_buckets, | ||
d_kv, | ||
num_heads, | ||
feed_forward_proj, | ||
d_ff, | ||
embed_tokens=self.shared, | ||
is_decoder=is_decoder) | ||
|
||
# Initialize weights and apply final processing | ||
self.init_weights() | ||
|
||
def get_input_embeddings(self): | ||
return self.shared | ||
|
||
def set_input_embeddings(self, new_embeddings): | ||
self.shared = new_embeddings | ||
self.encoder.set_input_embeddings(new_embeddings) | ||
|
||
def get_encoder(self): | ||
return self.encoder | ||
|
||
def forward( | ||
self, | ||
input_ids: Tensor = None, | ||
attention_mask: Optional[Tensor] = None, | ||
encoder_hidden_states: Optional[Tuple[Tensor]] = None, | ||
encoder_attention_mask: Optional[Tensor] = None, | ||
cache=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因套件中 cache 的类型较多,不固定,可以是 |
||
use_cache: Optional[bool] = False, | ||
output_attentions: Optional[bool] = False, | ||
output_hidden_states: Optional[bool] = False, | ||
): | ||
encoder_outputs = self.encoder( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
cache=cache, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
) | ||
|
||
return encoder_outputs |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
由于
T5EncoderModel
没有base_model,所以在from_pretrained的会检查cls == cls.base_model_class
会出错,因为此时cls.base_model_class
为T5Model
,故此处添加一个装饰器可临时解决此问题:重新设置base_model_class
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个得照着那天说的那样改,然后应该也给这个加个单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个确实得添加单测,这样才能够让模型跑通。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,这个确实得添加单测。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个得等 #3115 合入之后我这边才能基于他的来做。