From 9bb7eff780d7bc0d9ad5b563859435cb00f76d39 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 26 Apr 2024 04:13:50 +0900 Subject: [PATCH] [Mypy] Typing lora folder (#4337) --- .github/workflows/mypy.yaml | 7 ++-- format.sh | 2 +- vllm/lora/layers.py | 35 ++++++++++++-------- vllm/lora/lora.py | 28 +++++++++------- vllm/lora/models.py | 64 ++++++++++++++++++++----------------- vllm/lora/worker_manager.py | 21 ++++++------ vllm/worker/model_runner.py | 4 +-- 7 files changed, 91 insertions(+), 70 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 9f1855696e20a..089c7d18ad6f2 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -33,8 +33,6 @@ jobs: - name: Mypy run: | mypy vllm/attention --config-file pyproject.toml - # TODO(sang): Fix nested dir - mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml @@ -44,8 +42,9 @@ jobs: mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + # TODO(sang): Fix nested dir mypy vllm/model_executor/*.py --config-file pyproject.toml - # TODO(sang): Fix nested dir - # mypy vllm/lora/*.py --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index bd2e9e89e1806..4ac1842daef0a 100755 --- a/format.sh +++ b/format.sh @@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml -# mypy vllm/lora/*.py --config-file pyproject.toml +mypy vllm/lora --config-file pyproject.toml CODESPELL_EXCLUDES=( diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index aac86351b15e1..98e74168002c4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( self, @@ -233,9 +235,10 @@ def create_lora_weights( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None - self.embeddings_indices = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.embeddings_indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -267,6 +270,7 @@ def set_lora( self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[2] )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( @@ -343,11 +347,12 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None self.output_dim = self.lora_b_stacked.shape[2] + # lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 @@ -475,8 +480,9 @@ def create_lora_weights( device=self.device, ) for _ in range(n_slices)) - self.indices: Optional[torch.Tensor] = None self.output_dim = self.lora_b_stacked[0].shape[2] + # Lazily initialized. + self.indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -690,7 +696,8 @@ def create_lora_weights( self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # lazily initialized. + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -814,8 +821,9 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # Lazily initialized + self.indices: torch.Tensor + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -991,9 +999,10 @@ def create_lora_weights( dtype=self.dtype, device=self.device, ) - self.indices = None - self.indices_padded = None - self.indices_len = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.indices_padded: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index fefad16700fe3..d7794aa7cd35c 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -97,9 +97,9 @@ def __init__( self, module_name: str, rank: int, - lora_alphas: List[int], - lora_a: List[torch.Tensor], - lora_b: List[torch.Tensor], + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], scaling: Optional[List[float]] = None, ) -> None: super().__init__( @@ -108,17 +108,20 @@ def __init__( lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - scaling=scaling, + scaling=scaling, # type: ignore embeddings_tensor=None, ) self.lora_alphas = lora_alphas if scaling is None: - self.scaling = [ - lora_alpha / self.rank for lora_alpha in self.lora_alphas + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas ] @classmethod - def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + def pack( + cls, loras: List[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": """Pack a list of LoRAs into a single LoRA. If LoRA is None, it signifies that the submodule does not have a LoRA. @@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - scaling=[1 if lora is not None else None for lora in loras]) + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) return obj def optimize(self) -> "PackedLoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" for i in range(len(self.lora_b)): - if self.scaling[i] == 1 or self.lora_b[i] is None: + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore continue - self.lora_b[i] *= self.scaling[i] - self.scaling[i] = 1 + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore return self @property diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6bb9fee27d535..c249497a4d893 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,7 @@ import math import os import re -from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type import safetensors.torch import torch @@ -53,44 +53,46 @@ def convert_mapping( embeddings. indices_len: List of lengths of the above tensors. """ - indices = list(mapping.index_mapping).copy() - embedding_indices = indices.copy() - lora_indices = indices.copy() - prompt_mapping = [ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None - for i in range(len(indices)): + for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(indices[i]) - if indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if indices[i] > 0 else 0 - indices[i] = i + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + index_mapping_indices[i] = i lora_indices[i] = lora_idx - indices = torch.tensor([indices, lora_indices, embedding_indices], - dtype=torch.long, - device="cuda") - prompt_mapping = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) + indices = torch.tensor( + [index_mapping_indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size) ]) embeddings_indices[embeddings_indices == -1] = max_loras - 1 base_indices = indices[1] - sampler_indices = prompt_mapping + sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded = ( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (sampler_indices_padded * len(sampler_indices_padded))) - indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1]) + indices_len = [ + base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] + ] return (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) @@ -149,6 +151,7 @@ def from_lora_tensors( if module_name not in loras: lora_embeddings_tensor = None if embeddings: + assert embedding_modules is not None embeddings_module = next( (k for k in embedding_modules if k in module_name), None) @@ -171,6 +174,7 @@ def from_lora_tensors( else: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() + assert embedding_padding_modules is not None if any(name in module_name for name in embedding_padding_modules ) and target_embedding_padding is not None: @@ -295,11 +299,10 @@ def __init__( self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.offsets = [] # 4 is the number of indicies tensors defined above # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices - self.indices_len = [None] * 4 + self.indices_len: List[Optional[int]] = [None] * 4 self.model: nn.Module = model if hasattr(self.model, "supported_lora_modules"): @@ -312,7 +315,7 @@ def __init__( self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. self._active_loras: Dict[int, None] = {} - self._last_mapping = None + self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -370,7 +373,7 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False - def _add_lora(self, lora: LoRAModel) -> bool: + def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora @@ -418,7 +421,7 @@ def list_loras(self) -> Dict[int, LoRAModel]: def get_lora(self, lora_id: int) -> Optional[LoRAModel]: return self._registered_loras.get(lora_id, None) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): """Remove all LoRAModels from the manager.""" self._registered_loras.clear() self.lora_index_to_id = [None] * self.lora_slots @@ -467,6 +470,7 @@ def create_dummy_lora( continue parts = module_name.split(".") if module_name not in self.packed_modules: + assert embedding_modules is not None if parts[-1] in embedding_modules: input_dim = (module.base_layer.org_vocab_size + self.lora_config.lora_extra_vocab_size if @@ -500,7 +504,7 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras = [] + subloras: List[Optional["LoRALayerWeights"]] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -538,7 +542,7 @@ def _register_packed_modules(self, module_full_name: str) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras = [] + replacement_loras: List[Optional[LoRALayerWeights]] = [] has_replacement = False for r in new_module_names: lora = lora_model.get_lora(r) @@ -557,12 +561,12 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: class LoRALRUCache(LRUCache[LoRAModel]): - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], - None]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: LoRAModel): + def _on_remove(self, key: int, value: LoRAModel): logger.debug(f"Removing LoRA. int id: {key}") self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 5356b79537b05..ec3c10c591a18 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Dict, List, Set, Type import torch @@ -37,7 +37,7 @@ def create_lora_manager( ... @abstractmethod - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: ... @@ -54,7 +54,7 @@ def remove_lora(self, lora_id: int) -> bool: ... @abstractmethod - def remove_all_loras(self) -> bool: + def remove_all_loras(self): ... @abstractmethod @@ -81,10 +81,11 @@ def __init__( embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, ): - self._lora_manager: Optional[LoRAModelManager] = None self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + # Lazily initialized by create_lora_manager. + self._lora_manager: LoRAModelManager super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device) @@ -104,7 +105,7 @@ def create_lora_manager( lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, ) - self._lora_manager: LoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model def set_active_loras(self, lora_requests: Set[LoRARequest], @@ -188,7 +189,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self._lora_manager.remove_lora(lora_id) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): self._lora_manager.remove_all_loras() def list_loras(self) -> Set[int]: @@ -217,10 +218,10 @@ def create_lora_manager( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager: LRUCacheLoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request @@ -237,12 +238,14 @@ def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id not in self.list_loras(): # Remove before we load the new lora to save memory if len(self._lora_manager) + 1 > self._lora_manager.capacity: + assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) self._lora_manager.remove_oldest_lora() lora = self._load_lora(lora_request) loaded = self._lora_manager.add_lora(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + loaded = self._lora_manager.get_lora( + lora_request.lora_int_id) is not None self._lora_manager.activate_lora(lora_request.lora_int_id) return loaded diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 13e20d8524a1b..65996f1710a8a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -930,10 +930,10 @@ def profile_run(self) -> None: torch.cuda.synchronize() return - def remove_all_loras(self) -> bool: + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_loras() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: