Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression API supports ELECTRA #3324

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/compression.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ compression_args = parser.parse_args_into_dataclasses()

#### Trainer 实例化参数介绍

- **--model** 待压缩的模型,目前支持 ERNIE、BERT、RoBERTa、ERNIE-M、ERNIE-Gram、PP-MiniLM、TinyBERT 等结构相似的模型,是在下游任务中微调后的模型,当预训练模型选择 ERNIE 时,需要继承 `ErniePretrainedModel`。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
- **--model** 待压缩的模型,目前支持 ERNIE、BERT、RoBERTa、ERNIE-M、ELECTRA、ERNIE-Gram、PP-MiniLM、TinyBERT 等结构相似的模型,是在下游任务中微调后的模型,当预训练模型选择 ERNIE 时,需要继承 `ErniePretrainedModel`。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
- **--data_collator** 三类任务均可使用 PaddleNLP 预定义好的 [DataCollator 类](../../paddlenlp/data/data_collator.py),`data_collator` 可对数据进行 `Pad` 等操作。使用方法参考 [示例代码](../model_zoo/ernie-3.0/compress_seq_cls.py) 即可;
- **--train_dataset** 裁剪训练需要使用的训练集,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。不启动裁剪时,可以为 None;
- **--eval_dataset** 裁剪训练使用的评估集,也是量化使用的校准数据,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。是 Trainer 的必选参数;
Expand Down
29 changes: 27 additions & 2 deletions paddlenlp/trainer/trainer_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,24 @@ def _dynabert_init(self, model, eval_dataloader):
return ofa_model, teacher_model


def check_dynabert_config(net_config, width_mult):
'''
Corrects net_config for OFA model if necessary.
'''
if 'electra.embeddings_project' in net_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加一行注释吧,解释一下原因

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释,感谢指出。
OFA只对Linear的最后一维进行压缩,这里的expand_ratio指的是对Linear最后一维压缩的比例。weight实际在计算时,shape[0]是根据输入的shape算的,shape[1]则是原本shape[1]和exapand_ratio乘出来的结果。从一开始就是一维被压缩,另一维保持不变,所以为了维持所有的linear都保持住这种压缩比例,到了out_proj和linear2的位置时,第二维不可以再改变了,否则宽度则会一直小下去,不可以维持模型整体的裁剪比例。

net_config["electra.embeddings_project"]['expand_ratio'] = 1.0
for key in net_config:
# Makes sure to expands the size of the last dim to `width_mult` for
# these Linear weights.
if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key or 'linear1' in key:
net_config[key]['expand_ratio'] = width_mult
# Keeps the size of the last dim of these Linear weights same as
# before.
elif 'out_proj' in key or 'linear2' in key:
net_config[key]['expand_ratio'] = 1.0
return net_config


def _dynabert_training(self, ofa_model, model, teacher_model, train_dataloader,
eval_dataloader, num_train_epochs):

Expand Down Expand Up @@ -388,6 +406,7 @@ def evaluate_token_cls(model, data_loader):
# Step8: Broadcast supernet config from width_mult,
# and use this config in supernet training.
net_config = utils.dynabert_config(ofa_model, width_mult)
net_config = check_dynabert_config(net_config, width_mult)
ofa_model.set_net_config(net_config)
if "token_type_ids" in batch:
logits, teacher_logits = ofa_model(
Expand Down Expand Up @@ -424,6 +443,7 @@ def evaluate_token_cls(model, data_loader):
if global_step % self.args.save_steps == 0:
for idx, width_mult in enumerate(self.args.width_mult_list):
net_config = utils.dynabert_config(ofa_model, width_mult)
net_config = check_dynabert_config(net_config, width_mult)
ofa_model.set_net_config(net_config)
tic_eval = time.time()
logger.info("width_mult %s:" % round(width_mult, 2))
Expand Down Expand Up @@ -453,7 +473,7 @@ def evaluate_token_cls(model, data_loader):
model_to_save = model._layers if isinstance(
model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir_width)
logger.info("Best acc of width_mult %.2f: %.4f" %
logger.info("Best result of width_mult %.2f: %.4f" %
(width_mult, best_acc[idx]))
return ofa_model

Expand All @@ -479,6 +499,7 @@ def _dynabert_export(self, ofa_model):
origin_model = self.model.__class__.from_pretrained(model_dir)
ofa_model.model.set_state_dict(state_dict)
best_config = utils.dynabert_config(ofa_model, width_mult)
best_config = check_dynabert_config(best_config, width_mult)
origin_model_new = ofa_model.export(best_config,
input_shapes=[[1, 1], [1, 1]],
input_dtypes=['int64', 'int64'],
Expand Down Expand Up @@ -561,7 +582,9 @@ def _batch_generator_func():
optimize_model=False)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(
save_model_path=os.path.join(model_dir, algo + str(batch_size)),
save_model_path=os.path.join(
model_dir, algo +
"_".join([str(batch_size), str(batch_nums)])),
model_filename=args.output_filename_prefix + ".pdmodel",
params_filename=args.output_filename_prefix + ".pdiparams")

Expand Down Expand Up @@ -632,6 +655,8 @@ def auto_model_forward(self,
embedding_kwargs["input_ids"] = input_ids

embedding_output = self.embeddings(**embedding_kwargs)
if hasattr(self, "embeddings_project"):
embedding_output = self.embeddings_project(embedding_output)

self.encoder._use_cache = use_cache # To be consistent with HF

Expand Down