Skip to content

Commit

Permalink
[Bugfix] Correct adapter usage for cohere and jamba (vllm-project#8292)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavkruglikov authored Sep 9, 2024
1 parent 58fcc85 commit f9b4a2d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA


@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
Expand Down Expand Up @@ -292,8 +294,7 @@ def forward(
return hidden_states


class CohereForCausalLM(nn.Module):

class CohereForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

from .interfaces import SupportsLoRA

KVCache = Tuple[torch.Tensor, torch.Tensor]


Expand Down Expand Up @@ -539,7 +541,7 @@ def forward(
return hidden_states


class JambaForCausalLM(nn.Module, HasInnerState):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down

0 comments on commit f9b4a2d

Please sign in to comment.