diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 238da256b7cdc..6fe07b69b3203 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import math import os import re -from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) +from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type) import safetensors.torch import torch @@ -535,14 +535,14 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: replacement_loras) -class LoRALRUCache(LRUCache): +class LoRALRUCache(LRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], None]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: LoRAModel): logger.debug(f"Removing LoRA. int id: {key}") self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 99518a606fabe..3cce96e06d1a0 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -22,27 +22,34 @@ def get_max_input_len(self, pass @abstractmethod - def encode(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod - async def encode_async(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod def get_lora_tokenizer( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass @abstractmethod async def get_lora_tokenizer_async( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 3af1334cb5ede..ec20d0fb713a4 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -21,10 +21,8 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None + self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( + capacity=max_num_seqs) if enable_lora else None def ping(self) -> bool: """Check if the tokenizer group is alive.""" diff --git a/vllm/utils.py b/vllm/utils.py index 7c73062e809f3..8fa372b5f7f09 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,7 +5,7 @@ import uuid import gc from platform import uname -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Generic from packaging.version import parse, Version import psutil @@ -53,10 +53,10 @@ def reset(self) -> None: self.counter = 0 -class LRUCache: +class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache = OrderedDict[Hashable, T]() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: @@ -65,10 +65,10 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> Any: + def __getitem__(self, key: Hashable) -> T: return self.get(key) - def __setitem__(self, key: Hashable, value: Any) -> None: + def __setitem__(self, key: Hashable, value: T) -> None: self.put(key, value) def __delitem__(self, key: Hashable) -> None: @@ -77,7 +77,9 @@ def __delitem__(self, key: Hashable) -> None: def touch(self, key: Hashable) -> None: self.cache.move_to_end(key) - def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + def get(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) @@ -85,12 +87,12 @@ def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: value = default_value return value - def put(self, key: Hashable, value: Any) -> None: + def put(self, key: Hashable, value: T) -> None: self.cache[key] = value self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: T): pass def remove_oldest(self): @@ -103,7 +105,7 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: run_on_remove = key in self.cache value = self.cache.pop(key, default_value) if run_on_remove: