Skip to content

Commit

Permalink
feat: allow configuration of the max soft prompt length (opendatahub-…
Browse files Browse the repository at this point in the history
…io#33)

Instead of defaulting to a hard-coded 256, the default soft prompt
length is now 50% of the max sequence length.
The env var MAX_PROMPT_PREFIX_LENGTH can be used to override this
default if desired


Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: TRAVIS JOHNSON <[email protected]>
  • Loading branch information
joerunde and tjohnson31415 authored Feb 21, 2024
1 parent ac1f655 commit 1f4cfbe
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(
model_path, AutoModelForCausalLM, dtype, quantize, model_config, max_sequence_length
)

super(CausalLM, self).__init__(inference_engine, dtype)
super(CausalLM, self).__init__(inference_engine, dtype, max_sequence_length)

if self.model.config.pad_token_id is not None:
self.tokenizer.pad_token_id = self.model.config.pad_token_id
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __init__(
model_path, auto_model_class, dtype, quantize, model_config, max_sequence_length
)

super(FlashCausalLM, self).__init__(inference_engine, dtype)
super(FlashCausalLM, self).__init__(inference_engine, dtype, max_sequence_length)
self.use_position_ids = True

if self.model.config.pad_token_id is not None:
Expand Down
26 changes: 21 additions & 5 deletions server/text_generation_server/models/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import math
import os
import types

Expand All @@ -19,9 +20,6 @@

B = TypeVar("B", bound=Batch)

# TODO make configurable, possibly based on configured max seq length
MAX_PROMPT_PREFIX_LENGTH = 256

CUDA_PAD_TO_MULT_OF_8 = os.getenv("CUDA_PAD_TO_MULT_OF_8", "true").lower() != "false"
PT2_COMPILE = os.getenv("PT2_COMPILE", "false").lower() != "false"

Expand All @@ -33,7 +31,7 @@


class Model(ABC):
def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype, max_seq_length: Optional[int] = None):
self.engine = engine
self.config, self.tokenizer, self.model = engine.get_components()
self.device = engine.get_device()
Expand All @@ -50,6 +48,24 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):

if prompt_prefix_supported:
# Set up prefix cache

if max_seq_length is None:
# shouldn't be None, but just in case since the parameter is passed through as Optional
max_seq_length = 2048

# default value to 50% of the max sequence length
max_prompt_prefix_length = math.ceil(max_seq_length * 0.5)
if (max_prompt_prefix_env_var := os.getenv("MAX_PROMPT_PREFIX_LENGTH")):
try:
max_prompt_prefix_env_var = int(max_prompt_prefix_env_var)
except ValueError as exc:
raise ValueError("Invalid value for MAX_PROMPT_PREFIX_LENGTH") from exc

if max_prompt_prefix_env_var > max_seq_length - 1:
raise ValueError(f"Value for the MAX_PROMPT_PREFIX_LENGTH ({max_prompt_prefix_env_var}) cannot be larger than the max sequence length - 1 ({max_seq_length - 1})")

max_prompt_prefix_length = max_prompt_prefix_env_var

decoder_start_token_id = self.model.config.decoder_start_token_id
if decoder_start_token_id is None:
decoder_start_token_id = self.tokenizer.bos_token_id
Expand All @@ -65,7 +81,7 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
self.prefix_cache = PrefixCache(
device=self.device,
dtype=dtype,
max_length=MAX_PROMPT_PREFIX_LENGTH,
max_length=max_prompt_prefix_length,
encoder_decoder=self.model.config.is_encoder_decoder,
return_zero=return_zero,
decoder_start_tok_embedding=self.word_embeddings(
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def __init__(
inference_engine = get_inference_engine_class(deployment_framework)(
model_path, AutoModelForSeq2SeqLM, dtype, quantize, model_config, max_sequence_length
)
super(Seq2SeqLM, self).__init__(inference_engine, dtype)
super(Seq2SeqLM, self).__init__(inference_engine, dtype, max_sequence_length)

bos_token_id = self.model.config.decoder_start_token_id
if bos_token_id is None:
Expand Down

0 comments on commit 1f4cfbe

Please sign in to comment.