Skip to content

Commit

Permalink
[Core] Add generic typing to LRUCache (#3511)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Mar 20, 2024
1 parent 9474e89 commit 4ad521d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)

import safetensors.torch
import torch
Expand Down Expand Up @@ -535,14 +535,14 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
replacement_loras)


class LoRALRUCache(LRUCache):
class LoRALRUCache(LRUCache[LoRAModel]):

def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn

def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
Expand Down
19 changes: 13 additions & 6 deletions vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,34 @@ def get_max_input_len(self,
pass

@abstractmethod
def encode(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

@abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
6 changes: 2 additions & 4 deletions vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None

def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
Expand Down
20 changes: 11 additions & 9 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
import gc
from platform import uname
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Generic
from packaging.version import parse, Version

import psutil
Expand Down Expand Up @@ -53,10 +53,10 @@ def reset(self) -> None:
self.counter = 0


class LRUCache:
class LRUCache(Generic[T]):

def __init__(self, capacity: int):
self.cache = OrderedDict()
self.cache = OrderedDict[Hashable, T]()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
Expand All @@ -65,10 +65,10 @@ def __contains__(self, key: Hashable) -> bool:
def __len__(self) -> int:
return len(self.cache)

def __getitem__(self, key: Hashable) -> Any:
def __getitem__(self, key: Hashable) -> T:
return self.get(key)

def __setitem__(self, key: Hashable, value: Any) -> None:
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)

def __delitem__(self, key: Hashable) -> None:
Expand All @@ -77,20 +77,22 @@ def __delitem__(self, key: Hashable) -> None:
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)

def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
return value

def put(self, key: Hashable, value: Any) -> None:
def put(self, key: Hashable, value: T) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()

def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: T):
pass

def remove_oldest(self):
Expand All @@ -103,7 +105,7 @@ def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()

def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
Expand Down

0 comments on commit 4ad521d

Please sign in to comment.