Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModelingOutput]update roformer unittest #3159

Merged
merged 30 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
81f35ba
add roformer unittest
wj-Mcat Aug 6, 2022
543a349
add roformer unittest
wj-Mcat Aug 6, 2022
8f05abc
Merge branch 'add-roformer-unittest' of github.com:wj-Mcat/PaddleNLP …
wj-Mcat Aug 8, 2022
aa52e3c
update test_modeling
wj-Mcat Aug 9, 2022
bb8e874
use relative import
wj-Mcat Aug 11, 2022
282cf42
reduce model config to accelerate testing
wj-Mcat Aug 11, 2022
1be830a
remove input_embedding from pretrained model
wj-Mcat Aug 11, 2022
2cc4243
revert slow tag
wj-Mcat Aug 12, 2022
68f3203
Merge branch 'develop' of github.com:wj-Mcat/PaddleNLP into add-rofor…
wj-Mcat Aug 16, 2022
c458b9a
update local branch
wj-Mcat Aug 16, 2022
edb99ea
update get_vocab method
wj-Mcat Aug 17, 2022
27d4e2f
update get_vocab method
wj-Mcat Aug 17, 2022
9c85b1d
update test_chinese method
wj-Mcat Aug 18, 2022
487e436
Merge branch 'add-roformer-unittest' of github.com:wj-Mcat/PaddleNLP …
wj-Mcat Aug 18, 2022
d579c70
change absolute import
wj-Mcat Aug 18, 2022
68c7447
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Aug 19, 2022
6c982b3
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Aug 21, 2022
e7c455c
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Aug 22, 2022
54d5a3a
update unittest
wj-Mcat Aug 22, 2022
3e42ca4
update chinese test case
wj-Mcat Aug 22, 2022
0cc54fd
Merge branch 'develop' into add-roformer-unittest
guoshengCS Aug 22, 2022
e0d2446
Merge branch 'develop' into add-roformer-unittest
guoshengCS Aug 22, 2022
9add27b
Merge branch 'develop' of github.com:wj-Mcat/PaddleNLP into add-rofor…
wj-Mcat Aug 30, 2022
1b4ca3a
add roformer more output testing
wj-Mcat Aug 30, 2022
a1e490d
Merge branch 'add-roformer-unittest' of github.com:wj-Mcat/PaddleNLP …
wj-Mcat Aug 30, 2022
535d1a5
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Aug 30, 2022
ce0b22d
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Sep 6, 2022
8386bce
Merge branch 'develop' into add-roformer-unittest
wj-Mcat Sep 6, 2022
cf9e22b
Merge branch 'develop' into add-roformer-unittest
FrostML Sep 7, 2022
8e441f2
Merge branch 'develop' into add-roformer-unittest
guoshengCS Sep 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddlenlp/transformers/roformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,12 @@ def get_input_embeddings(self) -> nn.Embedding:
def set_input_embeddings(self, embedding: nn.Embedding):
self.embeddings.word_embeddings = embedding

def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings

def set_input_embeddings(self, embedding: nn.Embedding):
self.embeddings.word_embeddings = embedding


class RoFormerForQuestionAnswering(RoFormerPretrainedModel):
r"""
Expand Down
205 changes: 127 additions & 78 deletions tests/transformers/roformer/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
import unittest
from typing import Optional, Tuple
from dataclasses import dataclass, fields, Field
from parameterized import parameterized_class

import paddle
from paddle import Tensor

from paddlenlp.transformers import (
RoFormerModel, RoFormerPretrainedModel, RoFormerForPretraining,
RoFormerForSequenceClassification, RoFormerForTokenClassification,
RoFormerForQuestionAnswering, RoFormerForMultipleChoice,
RoFormerForMaskedLM)
from paddlenlp.transformers import (RoFormerModel, RoFormerPretrainedModel,
RoFormerForSequenceClassification,
RoFormerForTokenClassification,
RoFormerForQuestionAnswering,
RoFormerForMultipleChoice,
RoFormerForMaskedLM)

from ..test_modeling_common import ids_tensor, floats_tensor, random_attention_mask, ModelTesterMixin
from ...testing_utils import slow
Expand Down Expand Up @@ -67,6 +70,7 @@ class RoFormerModelTestConfig(RoFormerModelTestModelConfig):
is_training: bool = False
use_input_mask: bool = False
use_token_type_ids: bool = True
type_sequence_label_size = 3

# used for sequence classification
num_classes: int = 3
Expand Down Expand Up @@ -102,41 +106,56 @@ def prepare_config_and_inputs(self):
if self.config.use_token_type_ids:
token_type_ids = ids_tensor([config.batch_size, config.seq_length],
config.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None

if self.parent.use_labels:
sequence_labels = ids_tensor([self.batch_size],
self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length],
self.num_classes)
choice_labels = ids_tensor([self.batch_size], self.num_choices)

config = self.get_config()
return config, input_ids, token_type_ids, input_mask
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def get_config(self) -> dict:
return self.config.model_kwargs

def create_and_check_model(
self,
config,
input_ids,
token_type_ids,
input_mask,
):
def __getattr__(self, key: str):
if not hasattr(self.config, key):
raise AttributeError(f'attribute <{key}> not exist')
return getattr(self.config, key)

def create_and_check_model(self, config, input_ids: Tensor,
token_type_ids: Tensor, input_mask: Tensor,
sequence_labels: Tensor, token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerModel(**config)
model.eval()
result = model(input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
token_type_ids=token_type_ids,
return_dict=self.parent.return_dict)
result = model(input_ids,
token_type_ids=token_type_ids,
return_dict=self.parent.return_dict)
result = model(input_ids, return_dict=self.parent.return_dict)

self.parent.assertEqual(result[0].shape, [
self.config.batch_size, self.config.seq_length,
self.config.hidden_size
])
self.parent.assertEqual(
result[1].shape, [self.config.batch_size, self.config.hidden_size])

def create_and_check_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_mask,
):
def create_and_check_for_multiple_choice(self, config, input_ids: Tensor,
token_type_ids: Tensor,
input_mask: Tensor,
sequence_labels: Tensor,
token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerForMultipleChoice(RoFormerModel(**config),
num_choices=self.config.num_choices)
model.eval()
Expand All @@ -151,89 +170,113 @@ def create_and_check_for_multiple_choice(
input_mask = input_mask.unsqueeze(1).expand(
[-1, self.config.num_choices, -1])

result = model(
multiple_choice_inputs_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
)
self.parent.assertEqual(
result.shape, [self.config.batch_size, self.config.num_choices])
result = model(multiple_choice_inputs_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=choice_labels,
return_dict=self.parent.return_dict)

if paddle.is_tensor(result):
result = [result]
elif choice_labels is not None:
result = result[1:]

def create_and_check_for_question_answering(self, config, input_ids,
token_type_ids, input_mask):
self.parent.assertEqual(
result[0].shape, [self.config.batch_size, self.config.num_choices])

def create_and_check_for_question_answering(self, config, input_ids: Tensor,
token_type_ids: Tensor,
input_mask: Tensor,
sequence_labels: Tensor,
token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerForQuestionAnswering(RoFormerModel(**config))
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
)
result = model(input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
return_dict=self.parent.return_dict)

if paddle.is_tensor(result):
result = [result]
elif choice_labels is not None:
result = result[1:]

self.parent.assertEqual(
result[0].shape, [self.config.batch_size, self.config.seq_length])
self.parent.assertEqual(
result[1].shape, [self.config.batch_size, self.config.seq_length])

def create_and_check_for_token_classification(
self,
config,
input_ids,
token_type_ids,
input_mask,
):
self, config, input_ids: Tensor, token_type_ids: Tensor,
input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerForTokenClassification(RoFormerModel(**config),
num_classes=self.num_classes)
model.eval()
result = model(input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids)
self.parent.assertEqual(result.shape, [
token_type_ids=token_type_ids,
labels=token_labels,
return_dict=self.parent.return_dict)
if paddle.is_tensor(result):
result = [result]
elif choice_labels is not None:
result = result[1:]

self.parent.assertEqual(result[0].shape, [
self.config.batch_size, self.config.seq_length,
self.config.num_classes
])

def create_and_check_for_masked_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
):
def create_and_check_for_masked_lm(self, config, input_ids: Tensor,
token_type_ids: Tensor,
input_mask: Tensor,
sequence_labels: Tensor,
token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerForMaskedLM(RoFormerModel(**config))
model.eval()
result = model(input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids)
self.parent.assertEqual(result.shape, [
token_type_ids=token_type_ids,
labels=token_labels,
return_dict=self.parent.return_dict)
if paddle.is_tensor(result):
result = [result]
elif choice_labels is not None:
result = result[1:]

self.parent.assertEqual(result[0].shape, [
self.config.batch_size, self.config.seq_length,
self.config.vocab_size
])

def create_and_check_for_sequence_classification(
self,
config,
input_ids,
token_type_ids,
input_mask,
):
self, config, input_ids: Tensor, token_type_ids: Tensor,
input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor,
choice_labels: Tensor):
model = RoFormerForSequenceClassification(
RoFormerModel(**config), num_classes=self.config.num_classes)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
)
result = model(input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=sequence_labels,
return_dict=self.parent.return_dict)
if paddle.is_tensor(result):
result = [result]
elif choice_labels is not None:
result = result[1:]
self.parent.assertEqual(
result.shape, [self.config.batch_size, self.config.num_classes])
result[0].shape, [self.config.batch_size, self.config.num_classes])

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
) = config_and_inputs
(config, input_ids, token_type_ids, input_mask, _, _,
_) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
Expand All @@ -242,15 +285,21 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict


@parameterized_class(("return_dict", "use_labels"), [
[False, False],
[False, True],
[True, False],
[True, True],
])
class RoFormerModelTest(ModelTesterMixin, unittest.TestCase):
base_model_class = RoFormerModel
use_labels = False
return_dict = False

all_model_classes = (
RoFormerModel,
RoFormerForMultipleChoice,
RoFormerForPretraining,
RoFormerForSequenceClassification,
)
all_model_classes = (RoFormerModel, RoFormerForSequenceClassification,
RoFormerForTokenClassification,
RoFormerForQuestionAnswering,
RoFormerForMultipleChoice, RoFormerForMaskedLM)

def setUp(self):
self.model_tester = RoFormerModelTester(self)
Expand Down