Skip to content

Commit

Permalink
Mask based BGMV implementation (#223)
Browse files Browse the repository at this point in the history
Refactors BGMV implementation from gather based to mask-based to
optimize performance and reduce device memory usage.
  • Loading branch information
vivekgoe authored Sep 5, 2024
2 parents d0eb7d7 + 538c8f1 commit 05acb89
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 55 deletions.
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def verify_with_model_config(self, model_config: ModelConfig):
model_config.quantization)

def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
if not is_hpu() and scheduler_config.max_num_batched_tokens > 65528:
raise ValueError(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
Expand Down
47 changes: 27 additions & 20 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ def prompt_attention(
return attn_weights


class LoraMask:
lora_mask = None

@staticmethod
def setLoraMask(mask):
LoraMask.lora_mask = mask

@staticmethod
def getLoraMask():
return LoraMask.lora_mask


def dispatch_bgmv_linear(
y: torch.Tensor,
x: torch.Tensor,
Expand All @@ -205,29 +217,24 @@ def dispatch_bgmv_linear(
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization. For
custom BGMV, the corresponding `wa` and `wb` for each batch is created
based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
appended at the end of `wa_t_all` and `wb_t_all` during initialization.
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
mask = LoraMask.getLoraMask()
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
out = out.squeeze(1)
y += out * scale


Expand Down Expand Up @@ -264,4 +271,4 @@ def dispatch_bgmv_embedding(
x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale
y += out * scale
11 changes: 11 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,17 @@ def set_mapping(
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embedding_len = self.indices_len[3]
# NOTE(vgoel): These asserts can be skipped when upstreaming.
# Can be removed from vllm-fork also once lora functionality
# on Gaudi stabilizes.
if is_hpu():
emb_len = embedding_len
x_shape = x.shape
ind_shape = self.embeddings_indices[1].shape
assert embedding_len == x.shape[0] * x.shape[1], \
f"Extra Info: {emb_len}, {x_shape}, {ind_shape}"
assert embedding_len <= self.embeddings_indices[1].shape[0], \
f"Extra Info: {emb_len}, {x.shape}, {ind_shape}"
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
Expand Down
Loading

0 comments on commit 05acb89

Please sign in to comment.