Skip to content

Commit

Permalink
[V1] Add missing tokenizer options for Detokenizer (vllm-project#10288
Browse files Browse the repository at this point in the history
)

Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Nov 13, 2024
1 parent d909acf commit bb7991a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 9 additions & 2 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,17 @@ def _get_next_output_text(self, finished: bool, delta: bool) -> str:

class Detokenizer:

def __init__(self, tokenizer_name: str):
def __init__(self,
tokenizer_name: str,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: Optional[str] = None):
# TODO: once we support LoRA, we should should pass the tokenizer
# here. We currently have two copies (this + in the LLMEngine).
self.tokenizer = get_tokenizer(tokenizer_name)
self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=revision)

# Request id -> IncrementalDetokenizer
self.request_states: Dict[str, IncrementalDetokenizer] = {}
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def __init__(
input_registry)

# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer)
self.detokenizer = Detokenizer(
tokenizer_name=vllm_config.model_config.tokenizer,
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
trust_remote_code=vllm_config.model_config.trust_remote_code,
revision=vllm_config.model_config.tokenizer_revision,
)

# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
Expand Down

0 comments on commit bb7991a

Please sign in to comment.