Skip to content

Commit

Permalink
[Mypy] Typing lora folder (#4337)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Apr 25, 2024
1 parent f4bc4de commit b5b4a39
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 70 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ jobs:
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
Expand All @@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml


CODESPELL_EXCLUDES=(
Expand Down
35 changes: 22 additions & 13 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]

def create_lora_weights(
self,
Expand Down Expand Up @@ -233,9 +235,10 @@ def create_lora_weights(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down Expand Up @@ -267,6 +270,7 @@ def set_lora(
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

def set_mapping(
Expand Down Expand Up @@ -343,11 +347,12 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)

self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2]

# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
Expand Down Expand Up @@ -475,8 +480,9 @@ def create_lora_weights(
device=self.device,
) for _ in range(n_slices))

self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -690,7 +696,8 @@ def create_lora_weights(
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# lazily initialized.
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -814,8 +821,9 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down Expand Up @@ -991,9 +999,10 @@ def create_lora_weights(
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor

def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
Expand Down
28 changes: 17 additions & 11 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
Expand All @@ -108,17 +108,20 @@ def __init__(
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]

@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
Expand All @@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj

def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self

@property
Expand Down
64 changes: 34 additions & 30 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type

import safetensors.torch
import torch
Expand Down Expand Up @@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx

indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]

return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
Expand Down Expand Up @@ -149,6 +151,7 @@ def from_lora_tensors(
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
Expand All @@ -171,6 +174,7 @@ def from_lora_tensors(
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
Expand Down Expand Up @@ -295,11 +299,10 @@ def __init__(
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.indices_len: List[Optional[int]] = [None] * 4

self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
Expand All @@ -312,7 +315,7 @@ def __init__(
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self

Expand Down Expand Up @@ -370,7 +373,7 @@ def deactivate_lora(self, lora_id: int) -> bool:
return True
return False

def _add_lora(self, lora: LoRAModel) -> bool:
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora

Expand Down Expand Up @@ -418,7 +421,7 @@ def list_loras(self) -> Dict[int, LoRAModel]:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)

def remove_all_loras(self) -> bool:
def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
Expand Down Expand Up @@ -467,6 +470,7 @@ def create_dummy_lora(
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
Expand Down Expand Up @@ -500,7 +504,7 @@ def create_dummy_lora(
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
Expand Down Expand Up @@ -538,7 +542,7 @@ def _register_packed_modules(self, module_full_name: str) -> None:

def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
Expand All @@ -557,12 +561,12 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:

class LoRALRUCache(LRUCache[LoRAModel]):

def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn

def _on_remove(self, key: Hashable, value: LoRAModel):
def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
Expand Down
Loading

0 comments on commit b5b4a39

Please sign in to comment.