Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mypy] Typing lora folder #4337

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ jobs:
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
Expand All @@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml

# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml

2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml


CODESPELL_EXCLUDES=(
Expand Down
35 changes: 22 additions & 13 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]

def create_lora_weights(
self,
Expand Down Expand Up @@ -233,9 +235,10 @@ def create_lora_weights(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down Expand Up @@ -267,6 +270,7 @@ def set_lora(
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

def set_mapping(
Expand Down Expand Up @@ -343,11 +347,12 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)

self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2]

# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
Expand Down Expand Up @@ -475,8 +480,9 @@ def create_lora_weights(
device=self.device,
) for _ in range(n_slices))

self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -690,7 +696,8 @@ def create_lora_weights(
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# lazily initialized.
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -814,8 +821,9 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down Expand Up @@ -991,9 +999,10 @@ def create_lora_weights(
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down
28 changes: 17 additions & 11 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
Expand All @@ -108,17 +108,20 @@ def __init__(
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
scaling=scaling, # type: ignore
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file typing is pretty screwed up, and I don't know how to fix it...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like scaling is Optional[List[Optional[float]]]?

Copy link
Collaborator Author

@rkooo567 rkooo567 Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing is the parent class scaling is just Optional[float].

Talked with Antoni and he suggested to just remove super class init here. let me try that out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm tried removing super().init here, but mypy still cannot handle it because self.scaling has different types.

I will probably just ignore for now. The logic itself is technically correct

embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]

@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.

If LoRA is None, it signifies that the submodule does not have a LoRA.
Expand All @@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj

def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self

@property
Expand Down
64 changes: 34 additions & 30 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type

import safetensors.torch
import torch
Expand Down Expand Up @@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx

indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]

return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
Expand Down Expand Up @@ -149,6 +151,7 @@ def from_lora_tensors(
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
Expand All @@ -171,6 +174,7 @@ def from_lora_tensors(
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
Expand Down Expand Up @@ -295,11 +299,10 @@ def __init__(
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.indices_len: List[Optional[int]] = [None] * 4

self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
Expand All @@ -312,7 +315,7 @@ def __init__(
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self

Expand Down Expand Up @@ -370,7 +373,7 @@ def deactivate_lora(self, lora_id: int) -> bool:
return True
return False

def _add_lora(self, lora: LoRAModel) -> bool:
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora

Expand Down Expand Up @@ -418,7 +421,7 @@ def list_loras(self) -> Dict[int, LoRAModel]:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)

def remove_all_loras(self) -> bool:
def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
Expand Down Expand Up @@ -467,6 +470,7 @@ def create_dummy_lora(
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
Expand Down Expand Up @@ -500,7 +504,7 @@ def create_dummy_lora(
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
Expand Down Expand Up @@ -538,7 +542,7 @@ def _register_packed_modules(self, module_full_name: str) -> None:

def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
Expand All @@ -557,12 +561,12 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:

class LoRALRUCache(LRUCache[LoRAModel]):

def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn

def _on_remove(self, key: Hashable, value: LoRAModel):
def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
Expand Down
Loading
Loading