Skip to content

Commit

Permalink
fix batch error for llava-hd (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
caoshiyi authored Jan 25, 2024
1 parent 2395005 commit 0147f94
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,28 @@ def forward(
need_vision = need_vision & has_pixel

if need_vision.any():
pixel_values = torch.tensor(
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
device=self.vision_tower.device,
)
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]

########## Encode Image ########

if pixel_values.ndim == 5:
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
concat_images = torch.cat(
[image for image in pixel_values], dim=0
) # ndim=4
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096

Expand Down

0 comments on commit 0147f94

Please sign in to comment.