Skip to content

Commit

Permalink
add llama_pp in blip2
Browse files Browse the repository at this point in the history
  • Loading branch information
wjm202 committed Aug 23, 2023
1 parent 52bb53c commit cb06c1d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 35 deletions.
2 changes: 0 additions & 2 deletions paddlemix/examples/blip2/run_eval_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.transformers import AutoTokenizer

from paddlemix.datasets import load_dataset
from paddlemix.examples.blip2.utils import BlipCollator, create_tokenizer
Expand Down Expand Up @@ -217,7 +216,6 @@ def setdistenv(args):
args.data_parallel_degree = args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
# breakpoint()
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
Expand Down
1 change: 0 additions & 1 deletion paddlemix/examples/blip2/run_eval_vqav2_zeroshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def setdistenv(args):
args.data_parallel_degree = args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
# breakpoint()
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
Expand Down
2 changes: 0 additions & 2 deletions paddlemix/examples/blip2/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ def setdistenv(args):
args.data_parallel_degree = args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
if args.sharding_parallel_degree > 1:
args.sharding = "stage1"
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
Expand Down
2 changes: 0 additions & 2 deletions paddlemix/examples/blip2/run_pretrain_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ def setdistenv(args):
args.data_parallel_degree = args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
if args.sharding_parallel_degree > 1:
args.sharding = "stage1"
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
Expand Down
2 changes: 0 additions & 2 deletions paddlemix/examples/blip2/run_pretrain_stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ def setdistenv(args):
args.data_parallel_degree = args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
if args.sharding_parallel_degree > 1:
args.sharding = "stage1"
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
Expand Down
28 changes: 7 additions & 21 deletions paddlemix/examples/blip2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time

import paddle
from paddlenlp.transformers import AutoTokenizer, T5Tokenizer
from paddlenlp.transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO

Expand Down Expand Up @@ -50,6 +50,9 @@ def create_tokenizer(text_model_name_or_path):
tokenizer_class = AutoTokenizer.from_pretrained(text_model_name_or_path, use_fast=False)
elif "t5" in text_model_name_or_path:
tokenizer_class = T5Tokenizer.from_pretrained(text_model_name_or_path, use_fast=False)
elif "llama" in text_model_name_or_path:
tokenizer_class = LlamaTokenizer.from_pretrained(text_model_name_or_path)
tokenizer_class.pad_token = tokenizer_class.eos_token
else:
raise NotImplementedError
return tokenizer_class
Expand Down Expand Up @@ -152,27 +155,10 @@ def load_model(args, model, optimizer=None, ckpt_dir="", load_language_model=Tru

ckpt_dir = path
if ckpt_dir and os.path.isfile(ckpt_dir):
# breakpoint()
print("Try to load a whole checkpoint from %s " % ckpt_dir)
embedding_list = ["word_embeddings"]
collinear_list = [
"fc1",
"fc2",
"qkv",
"proj",
"query",
"key",
"value",
"qkv_proj",
"q_proj",
"k_proj",
"v_proj",
"linear1",
"linear2",
"project_in",
"project_out",
]
rowlinear_list = ["out_proj"]
embedding_list = []
collinear_list = ["fc1", "qkv"]
rowlinear_list = []
all_list = collinear_list + rowlinear_list + embedding_list
skip_list = ["visual_encoder.patch_embed.proj.weight", "visual_encoder.patch_embed.proj.bias"]

Expand Down
37 changes: 32 additions & 5 deletions paddlemix/models/blip2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.distributed as dist
import paddle.nn as nn
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.transformers.llama.modeling import LlamaForCausalLM
from paddlenlp.transformers.model_outputs import ModelOutput
from paddlenlp.transformers.model_utils import PretrainedModel
from paddlenlp.transformers.t5.modeling import T5ForConditionalGeneration
Expand Down Expand Up @@ -372,7 +373,9 @@ def __init__(
from paddlemix.models.blip2.eva_vit import VisionTransformer

self.visual_encoder = VisionTransformer.from_pretrained(
pretrained_model_name_or_path=config.vision_config, mp_degree=config.mp_degree
pretrained_model_name_or_path=config.vision_config,
mp_degree=config.mp_degree,
ignore_mismatched_sizes=True,
)
self.freeze_vit = config.freeze_vit
self.train_stage1 = False
Expand All @@ -392,6 +395,7 @@ def __init__(
train_in_satge1=True,
tokenizer_length=len(self.tokenizer),
mp_degree=config.mp_degree,
ignore_mismatched_sizes=True,
)

state_dict = self.Qformer.state_dict()
Expand All @@ -408,8 +412,31 @@ def __init__(
if config.use_decoder_only_language_model:
if "opt" in config.text_config:
language_model = OPTForCausalLM.from_pretrained(
config.text_config, load_state_as_np=True, mp_degree=config.mp_degree
config.text_config,
load_state_as_np=True,
mp_degree=config.mp_degree,
ignore_mismatched_sizes=True,
)
elif "llama" in config.text_config:
from paddlenlp.transformers.llama.configuration import LlamaConfig

if config.mp_degree > 1:
import paddle.distributed.fleet as fleet

hcg = fleet.get_hybrid_communicate_group()
language_model = LlamaForCausalLM.from_pretrained(
config.text_config,
tensor_parallel_degree=config.mp_degree,
tensor_parallel_rank=hcg.get_model_parallel_rank(),
tensor_parallel_output=False,
)
else:
language_model = LlamaForCausalLM.from_pretrained(
config.text_config,
tensor_parallel_output=False,
)
language_model.hidden_size = LlamaConfig.from_pretrained(config.text_config).hidden_size
language_model.pad_token_id = LlamaConfig.from_pretrained(config.text_config).pad_token_id
else:
raise NotImplementedError
else:
Expand All @@ -418,8 +445,8 @@ def __init__(
t5_config = T5Config(config.text_config)
for key, value in config.text_config.items():
t5_config[key] = config.text_config[key]
language_model = T5ForConditionalGeneration(t5_config)
language_model.hidden_size = config.text_config["d_model"]
language_model = T5ForConditionalGeneration.from_pretrained(config.text_config, load_state_as_np=True)
language_model.hidden_size = t5_config["d_model"]

self.language_model = language_model
for name, param in self.language_model.named_parameters():
Expand All @@ -428,11 +455,11 @@ def __init__(

self.Qformer = BertLMHeadModel.from_pretrained(
pretrained_model_name_or_path=config.qformer_config,
mp_degree=config.mp_degree,
encoder_width=self.visual_encoder.num_features,
train_in_satge1=False,
text_hidden_size=self.language_model.hidden_size,
ignore_mismatched_sizes=True,
mp_degree=config.mp_degree,
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
Expand Down

0 comments on commit cb06c1d

Please sign in to comment.