Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Fuyu tensor parallel inference #8986

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
],
)
@fork_new_process_for_each_test
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,9 @@ def __init__(self,
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(config,
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)

Expand Down
20 changes: 10 additions & 10 deletions vllm/model_executor/models/persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import torch
from torch import nn
from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self,
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = ReLUSquaredActivation()
self.act = get_act_fn(config.hidden_act, quant_config)

def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.num_heads * self.head_dim,
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
Expand Down Expand Up @@ -213,10 +213,10 @@ def __init__(self,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
config.text_config.vocab_size, config.hidden_size)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
Expand Down Expand Up @@ -252,19 +252,19 @@ def forward(
class PersimmonForCausalLM(nn.Module):

def __init__(self,
config,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

def forward(
Expand Down
Loading