Skip to content

Commit

Permalink
Make initialization of tokenizer and detokenizer optional (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#3748)

Co-authored-by: Yun Ding <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
3 people authored and jimpang committed Apr 25, 2024
1 parent 9565682 commit 552b2d2
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 12 deletions.
23 changes: 23 additions & 0 deletions tests/engine/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class ModelConfig:
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
"""

def __init__(
Expand All @@ -85,6 +87,7 @@ def __init__(
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -99,14 +102,16 @@ def __init__(
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self._verify_tokenizer_mode()
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()

Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
Expand Down Expand Up @@ -93,6 +94,10 @@ def add_cli_args(
type=str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer')
parser.add_argument(
'--revision',
type=str,
Expand Down Expand Up @@ -453,7 +458,7 @@ def create_engine_config(self, ) -> EngineConfig:
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
self.max_logprobs, self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
29 changes: 21 additions & 8 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
Expand Down Expand Up @@ -132,8 +133,14 @@ def __init__(
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats

self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
Expand Down Expand Up @@ -187,9 +194,10 @@ def __init__(
parallel_config.disable_custom_all_reduce,
})

# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()

# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
Expand Down Expand Up @@ -296,7 +304,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)

def _verify_args(self) -> None:
Expand Down Expand Up @@ -393,8 +401,13 @@ def add_request(
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
if prompt_logprobs is not None and \
seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
Expand Down Expand Up @@ -105,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
if seq_group.sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. Expect valid prompt_token_ids and None for prompt
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
Expand Down Expand Up @@ -76,6 +79,7 @@ def __init__(
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
Expand All @@ -96,6 +100,7 @@ def __init__(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
Expand Down Expand Up @@ -160,6 +165,10 @@ def generate(
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
Expand Down

0 comments on commit 552b2d2

Please sign in to comment.