Skip to content

Commit

Permalink
visualglm update image_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
LokeZhou committed Oct 16, 2023
1 parent fad8c3b commit 72da4f5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 59 deletions.
2 changes: 1 addition & 1 deletion paddlemix/examples/visualglm/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def predict(args):
# Epoch 1
query = "写诗描述一下这个场景"
history = []
inputs = processor(image, query, max_length=1024)
inputs = processor(image, query)

generate_ids, _ = model.generate(**inputs, **generate_kwargs)
responses = processor.get_responses(generate_ids)
Expand Down
117 changes: 59 additions & 58 deletions paddlemix/models/visualglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,20 @@ def __init__(self, config: ChatGLMConfig):
super(ChatGLMForConditionalGenerationWithImage, self).__init__(config)
self.config = config

def generate_inputs_position_ids(self, input_ids):

MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
use_gmasks = []
mask_positions = []
for seq in input_ids:
mask_token = gMASK if gMASK in seq else MASK
use_gmask = mask_token == gMASK
use_gmasks.append(use_gmask)
mask_positions.append(paddle.where(seq == mask_token)[0][0])

position_ids = self.get_position_ids(input_ids, mask_positions=mask_positions, use_gmasks=use_gmasks)
return position_ids

def get_masks(self, input_ids):

batch_size, seq_length = input_ids.shape
Expand All @@ -1459,70 +1473,58 @@ def get_masks(self, input_ids):
return attention_mask

def prepare_inputs_for_generation(
self, input_ids, position_ids=None, attention_mask=None, past_key_values=None, cache=None, **kwargs
self,
input_ids,
position_ids=None,
attention_mask=None,
past_key_values=None,
cache=None,
inputs_embeds=None,
**kwargs
):
batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
use_gmasks = []
mask_positions = []
for seq in input_ids:
mask_token = gMASK if gMASK in seq else MASK
use_gmask = mask_token == gMASK
use_gmasks.append(use_gmask)
mask_positions.append(paddle.where(seq == mask_token)[0][0])

if cache is None and inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

if cache is not None or past_key_values is not None:
last_token = input_ids[:, -1].unsqueeze(-1)

attention_mask = attention_mask[:, :, -1:]

if position_ids is not None:
position_ids = position_ids[..., -1:]
else:
if self.position_encoding_2d:
context_lengths = []
for seq in input_ids:
context_lengths.append(paddle.where(seq == self.config.bos_token_id)[0][0])

context_lengths = paddle.to_tensor(context_lengths, dtype="int64")
block_position_ids = seq_length - context_lengths
position_ids = paddle.concat(
[paddle.to_tensor(mask_positions, dtype="int64"), block_position_ids], axis=1
).unsqueeze(-1)
else:
position_ids = paddle.to_tensor(mask_positions, dtype="int64").unsqueeze(-1)
position_ids = position_ids[..., -1:]

if cache is None:
cache = past_key_values

return {
"input_ids": last_token,
"cache": cache[-1],
"position_ids": position_ids,
"use_cache": True,
"attention_mask": attention_mask,
**kwargs,
}
model_inputs.update(
{
"input_ids": last_token,
"cache": cache[-1],
"position_ids": position_ids,
"use_cache": True,
"attention_mask": attention_mask,
}
)

return model_inputs
else:
if position_ids is None:
position_ids = self.get_position_ids(input_ids, mask_positions=mask_positions, use_gmasks=use_gmasks)

return {
"input_ids": input_ids,
"cache": cache,
"position_ids": position_ids,
"use_cache": True,
"attention_mask": attention_mask,
**kwargs,
}

model_inputs.update(
{
"cache": cache,
"position_ids": position_ids,
"use_cache": True,
"attention_mask": attention_mask,
}
)

return model_inputs

def forward(
self,
image_features: paddle.Tensor,
input_ids: paddle.Tensor,
input_ids: Optional[paddle.Tensor] = None,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
pre_image_length: Optional[int] = None,
cache: Optional[Tuple[paddle.Tensor]] = None,
inputs_embeds: Optional[paddle.Tensor] = None,
labels: Optional[paddle.Tensor] = None,
Expand All @@ -1531,12 +1533,6 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is None and cache is None and image_features is not None:
pre_ids, pad_ids, post_ids = paddle.split(input_ids, num_or_sections=[pre_image_length, 32, -1], axis=1)
pre_txt_emb = self.chatglm.transformer.word_embeddings(pre_ids)
post_txt_emb = self.chatglm.transformer.word_embeddings(post_ids)
inputs_embeds = paddle.concat([pre_txt_emb, image_features, post_txt_emb], axis=1)

outputs = super().forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down Expand Up @@ -1638,11 +1634,16 @@ def generate(

image_features = self.encode_images(pixel_values)
attention_mask = self.language_model.get_masks(input_ids)
position_ids = self.language_model.generate_inputs_position_ids(input_ids)
if image_features is not None:
pre_ids, pad_ids, post_ids = paddle.split(input_ids, num_or_sections=[pre_image_length, 32, -1], axis=1)
pre_txt_emb = self.language_model.chatglm.transformer.word_embeddings(pre_ids)
post_txt_emb = self.language_model.chatglm.transformer.word_embeddings(post_ids)
inputs_embeds = paddle.concat([pre_txt_emb, image_features, post_txt_emb], axis=1)

outputs = self.language_model.generate(
input_ids=input_ids,
image_features=image_features,
pre_image_length=pre_image_length,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
Expand Down

0 comments on commit 72da4f5

Please sign in to comment.