Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT ut #3133

Merged
merged 11 commits into from
Aug 29, 2022
Merged

Add GPT ut #3133

merged 11 commits into from
Aug 29, 2022

Conversation

FrostML
Copy link
Contributor

@FrostML FrostML commented Aug 24, 2022

PR types

Others

PR changes

Others

Description

Add GPT ut.

self.bias = paddle.tril(
paddle.ones(
[1, 1, max_position_embeddings, max_position_embeddings],
dtype="int64"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是否对静态图运行有影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 program guard 里面应该是 ok 的

def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
return config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个看看是否先删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

model = GPTLMHeadModel.from_pretrained(
"gpt2-en",
reorder_and_upcast_attn=False,
scale_attn_by_inverse_layer_idx=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个暂时不支持的先注掉吧,另外当前这样不会报错是吗,如果不报错的话也有点跟之前预期不太一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除。确实没有报错。

self.tokenizer_class.pretrained_resource_files_map.values())
[0]), 1)

def test_offsets_mapping(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种加上skip装饰来标识吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config["eos_token_id"] = None
config["forced_eos_token_id"] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为什么去掉呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为 hf 和我们的 max_length 含义不一样,所以 hf 强依赖这个。我们不需要这样设置。

GPTForTokenClassification)
all_generative_model_classes = {GPTLMHeadModel: (GPTModel, "gpt")}
all_parallelizable_model_classes = (GPTLMHeadModel)
fx_compatible = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个torch独有的也去掉吧,对于其他我们暂时没有的可以先注掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
attention_mask = causal_mask
attention_mask = (1.0 - casual_mask) * -1e9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保持和原来一致还用 -1e4 吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API文档中attention_mask的支持情况也一并调整了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

guoshengCS
guoshengCS previously approved these changes Aug 26, 2022
Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@FrostML FrostML merged commit e8f8eca into PaddlePaddle:develop Aug 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants