Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon + GradientCache】 #1799

Merged
merged 66 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
784f345
gradient_cache
Zhiyuan-Fan Mar 18, 2022
5c9a957
gradient_cache
Zhiyuan-Fan Mar 18, 2022
01f1bc0
gradient_cache
Zhiyuan-Fan Mar 18, 2022
fa610bc
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 18, 2022
c430b8a
gradient_cache
Zhiyuan-Fan Mar 19, 2022
8f420d1
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 19, 2022
8f3a97c
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 21, 2022
a26786d
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 25, 2022
2227a26
data
Zhiyuan-Fan Mar 25, 2022
4510e1d
Merge branch 'develop' into develop
Zhiyuan-Fan Mar 25, 2022
f86eeb9
train_for_gradient_cache
Zhiyuan-Fan Mar 26, 2022
d129186
Merge branch 'develop' of github.com:Elvisambition/PaddleNLP into dev…
Zhiyuan-Fan Mar 26, 2022
483cb62
Merge branch 'develop' into develop
Zhiyuan-Fan Mar 26, 2022
5e91937
add
Zhiyuan-Fan Mar 26, 2022
bba0521
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 31, 2022
ae0125e
add
Zhiyuan-Fan Mar 31, 2022
ff3789c
Merge branch 'develop' of github.com:Elvisambition/PaddleNLP into dev…
Zhiyuan-Fan Mar 31, 2022
ff34e2a
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Mar 31, 2022
ccbd5b1
add
Zhiyuan-Fan Mar 31, 2022
202664a
Merge branch 'develop' of github.com:Elvisambition/PaddleNLP into dev…
Zhiyuan-Fan Mar 31, 2022
b7a6db3
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Apr 13, 2022
e675ea9
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan May 3, 2022
17be523
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan May 16, 2022
43acadb
修改
Zhiyuan-Fan May 16, 2022
4563c2d
修改
Zhiyuan-Fan May 16, 2022
c892976
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jun 10, 2022
c600939
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jun 13, 2022
87f029a
update
Zhiyuan-Fan Jun 13, 2022
25aa42c
update
Zhiyuan-Fan Jun 13, 2022
d5984a1
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jun 18, 2022
c7fdafd
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jun 18, 2022
8012929
update
Zhiyuan-Fan Jun 18, 2022
7650da6
update
Zhiyuan-Fan Jun 18, 2022
675efb6
Update README_gradient_cache.md
Zhiyuan-Fan Jun 18, 2022
d890d8e
Update README_gradient_cache.md
Zhiyuan-Fan Jun 18, 2022
88ba024
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jun 22, 2022
abff61d
Update README_gradient_cache.md
Zhiyuan-Fan Jun 22, 2022
1cd93be
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jul 19, 2022
162165c
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Jul 31, 2022
9cbcd71
feat: modified the code
Zhiyuan-Fan Jul 31, 2022
e533e10
fix: delete useless code
Zhiyuan-Fan Jul 31, 2022
67bad62
feat: added requirements.txt
Zhiyuan-Fan Jul 31, 2022
d57380c
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Sep 4, 2022
748b63f
feat: modify readme
Zhiyuan-Fan Sep 4, 2022
be889df
feat: modify some code
Zhiyuan-Fan Sep 5, 2022
2f0901d
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Sep 5, 2022
f2a4397
feat: code style
Zhiyuan-Fan Sep 5, 2022
de9ba83
feat: add function
Zhiyuan-Fan Sep 5, 2022
db2ccf0
feat: add licence
Zhiyuan-Fan Sep 5, 2022
476aaa5
feat: add comments
Zhiyuan-Fan Sep 5, 2022
f6716fb
Update README_gradient_cache.md
Sep 5, 2022
6343cf7
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Sep 5, 2022
25c0b2a
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Sep 26, 2022
644438d
feat: modify readme
Zhiyuan-Fan Sep 26, 2022
7ccabad
Merge branch 'PaddlePaddle:develop' into develop
Zhiyuan-Fan Sep 28, 2022
865d50c
feat: modify readme
Zhiyuan-Fan Sep 28, 2022
f5a9606
fix: copyright
Zhiyuan-Fan Sep 28, 2022
0aa9739
fix: yapf
Zhiyuan-Fan Sep 28, 2022
3891997
feat: modify readme
Zhiyuan-Fan Sep 28, 2022
ed675ec
feat: modify readme
Zhiyuan-Fan Sep 28, 2022
152437f
feat: delete useless code
Zhiyuan-Fan Sep 28, 2022
2c57eb6
feat: add new explain
Zhiyuan-Fan Sep 28, 2022
2fbfde8
Merge branch 'develop' into develop
w5688414 Sep 28, 2022
fb38a58
Merge branch 'develop' into develop
w5688414 Sep 29, 2022
ab0f9d1
Merge branch 'develop' into develop
w5688414 Sep 29, 2022
8f335c1
Merge branch 'develop' into develop
w5688414 Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions examples/semantic_indexing/gradient_cache/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from base_model import SemanticIndexBase


class SemanticIndexCacheNeg(SemanticIndexBase):
def __init__(self,
pretrained_model,
dropout=None,
margin=0.3,
scale=30,
output_emb_size=None):
super().__init__(pretrained_model, dropout, output_emb_size)

self.margin = margin
# Used scaling cosine similarity to ease converge
self.sacle = scale

def get_pooled_embedding_with_no_grad(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
if self.use_fp16:
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.ptm.pad_token_id
).astype(self.ptm.pooler.dense.weight.dtype) * -1e4,
axis=[1, 2])


with paddle.no_grad():
embedding_output = self.ptm.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)

embedding_output = paddle.cast(embedding_output, 'float16')
attention_mask = paddle.cast(attention_mask, 'float16')

with paddle.no_grad():
encoder_outputs = self.ptm.encoder(embedding_output, attention_mask)



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空行太多,请删除

if self.use_fp16:
encoder_outputs = paddle.cast(encoder_outputs, 'float32')
cls_embedding = self.ptm.pooler(encoder_outputs)
else:
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
attention_mask)

if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)

return cls_embedding





Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空行太多

def forward(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):


query_cls_embedding = self.get_pooled_embedding(
query_input_ids, query_token_type_ids, query_position_ids,
query_attention_mask)

title_cls_embedding = self.get_pooled_embedding(
title_input_ids, title_token_type_ids, title_position_ids,
title_attention_mask)

cosine_sim = paddle.matmul(
query_cls_embedding, title_cls_embedding, transpose_y=True)

# substract margin from all positive samples cosine_sim()
margin_diag = paddle.full(
shape=[query_cls_embedding.shape[0]],
fill_value=self.margin,
dtype=paddle.get_default_dtype())

cosine_sim = cosine_sim - paddle.diag(margin_diag)

# scale cosine to ease training converge
cosine_sim *= self.sacle

labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')
labels = paddle.reshape(labels, shape=[-1, 1])

#loss = F.cross_entropy(input=cosine_sim, label=labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用的comments


return cosine_sim,labels,
270 changes: 270 additions & 0 deletions examples/semantic_indexing/train_gradient_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import argparse
import os
import sys
import random
import time

import numpy as np
import paddle
import paddle.nn.functional as F

import paddlenlp as ppnlp
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import LinearDecayWithWarmup

from gradient_cache.model import SemanticIndexCacheNeg
from data import read_text_pair, convert_example, create_dataloader

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", default='./checkpoint', type=str,
help="The output directory where the model checkpoints will be written.")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size.")
parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--epochs", default=10, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.0, type=float,
help="Linear warmup proption over the training process.")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization.")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument('--save_steps', type=int, default=10000, help="Inteval steps to save checkpoint.")
parser.add_argument("--train_set_file", type=str, required=True, help="The full path of train_set_file.")
parser.add_argument("--margin", default=0.3, type=float, help="Margin beteween pos_sample and neg_samples.")
parser.add_argument("--scale", default=30, type=int, help="Scale for pair-wise margin_rank_loss")
parser.add_argument("--use_amp", action="store_true", help="Whether to use AMP.")
parser.add_argument("--amp_loss_scale", default=32768, type=float,
help="The value of scale_loss for fp16. This is only used for AMP training.")
parser.add_argument("--chunk_numbers",type=int,default=50,help="The number of the chunks for model")

args = parser.parse_args()


# yapf: enable



def set_seed(seed):
"""sets random seed"""
random.seed(seed)
np.random.seed(seed)
global_generator = paddle.seed(seed)


def do_train():
paddle.set_device(args.device)
rank = paddle.distributed.get_rank()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()

random.seed(args.seed)
np.random.seed(args.seed)
paddle.seed(args.seed)


train_ds = load_dataset(
read_text_pair, data_path=args.train_set_file, lazy=False)

# If you wanna use bert/roberta pretrained model,
# pretrained_model = ppnlp.transformers.BertModel.from_pretrained('bert-base-chinese')
# pretrained_model = ppnlp.transformers.RobertaModel.from_pretrained('roberta-wwm-ext')
pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained(
'ernie-1.0')

# If you wanna use bert/roberta pretrained model,
# tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained('bert-base-chinese')
# tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained('roberta-wwm-ext')
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')

trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入int64限定,eg: change Pad(axis=0, pad_val=tokenizer.pad_token_id) to Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64)

Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
): [data for data in fn(samples)]

train_data_loader = create_dataloader(
train_ds,
mode='train',
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)

model = SemanticIndexCacheNeg(
pretrained_model,
margin=args.margin,
scale=args.scale,
output_emb_size=args.output_emb_size)

if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
state_dict = paddle.load(args.init_from_ckpt)
model.set_dict(state_dict)
print("warmup from:{}".format(args.init_from_ckpt))

model = paddle.DataParallel(model)

num_training_steps = len(train_data_loader) * args.epochs

lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
args.warmup_proportion)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)

if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.amp_loss_scale)

if args.batch_size % args.chunk_numbers == 0:
chunk_numbers = args.chunk_numbers

def split(inputs,chunk_numbers,axis=0):
if inputs.shape[0] % chunk_numbers == 0:
return paddle.split(inputs,chunk_numbers,axis=0)
else:
#return paddle.split(inputs,2,axis=0)
return paddle.split(inputs,inputs.shape[0],axis=0)

global_step = 0
tic_train = time.time()
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):

chunked_x = [split(t,chunk_numbers, axis=0) for t in batch]

sub_batchs = [list(s) for s in zip(*chunked_x)]

all_reps = []
all_rnd_states = []
all_loss = []
all_grads = []
all_labels = []
all_CUDA_rnd_state = []
all_global_rnd_state = []

for sub_batch in sub_batchs:

all_reps = []
all_labels = []

sub_query_input_ids, sub_query_token_type_ids, sub_title_input_ids, sub_title_token_type_ids = sub_batch

with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax", "gelu"]):

with paddle.no_grad():

sub_CUDA_rnd_state = paddle.framework.random.get_cuda_rng_state()
#sub_global_rnd_state = paddle.framework.random.get_random_seed_generator(global_random_generator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused comments


all_CUDA_rnd_state.append(sub_CUDA_rnd_state)
#all_global_rnd_state.append(sub_global_rnd_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused comments


sub_cosine_sim, sub_label = model(
query_input_ids=sub_query_input_ids,
title_input_ids=sub_title_input_ids,
query_token_type_ids=sub_query_token_type_ids,
title_token_type_ids=sub_title_token_type_ids)

all_reps.append(sub_cosine_sim)
all_labels.append(sub_label)

model_reps = paddle.concat(all_reps, axis=0)

model_reps.stop_gradient = False

#Model_Repos = [r.detach() for r in model_reps]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused comments


#Model_Repos.stop_gradient = False

model_label = paddle.concat(all_labels,axis=0)

loss = F.cross_entropy(input=model_reps, label=model_label)

loss.backward()

all_grads.append(model_reps.grad)

for sub_batch,CUDA_state,grad in zip(sub_batchs,all_CUDA_rnd_state,all_grads):

sub_query_input_ids, sub_query_token_type_ids, sub_title_input_ids, sub_title_token_type_ids = sub_batch

paddle.framework.random.set_cuda_rng_state(CUDA_state)
#paddle.framework.random.set_random_seed_generator(global_random_generator,global_rnd_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused comments


cosine_sim, _ = model(
query_input_ids=sub_query_input_ids,
title_input_ids=sub_title_input_ids,
query_token_type_ids=sub_query_token_type_ids,
title_token_type_ids=sub_title_token_type_ids)

surrogate = paddle.dot(cosine_sim,grad)

if args.use_amp:
scaled = scaler.scale(surrogate)
scaled.backward()
else:
surrogate.backward()

if args.use_amp:
scaler.minimize(optimizer, scaled)
else:
optimizer.step()

global_step += 1
if global_step % 10 == 0 and rank == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
% (global_step, epoch, step, loss,
10 / (time.time() - tic_train)))
tic_train = time.time()

lr_scheduler.step()
optimizer.clear_grad()

if global_step % args.save_steps == 0 and rank == 0:
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_param_path = os.path.join(save_dir, 'model_state.pdparams')
paddle.save(model.state_dict(), save_param_path)
tokenizer.save_pretrained(save_dir)


if __name__ == "__main__":
do_train()