Skip to content

Commit

Permalink
Add qwen config and and input config simplification (#1190)
Browse files Browse the repository at this point in the history
* remove max input/output in favor of max_seq_len

* add qwen

* bump version

* revert bump

* update config in test

* fix tests

* bump version

* bump briton image

* remove unused config

* better client side error msg + validation

* fix test
  • Loading branch information
joostinyi authored Oct 28, 2024
1 parent 9889e04 commit 467b7a1
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.45rc009"
version = "0.9.45rc018"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
11 changes: 5 additions & 6 deletions truss/config/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ class TrussTRTLLMModel(str, Enum):
MISTRAL = "mistral"
DEEPSEEK = "deepseek"
WHISPER = "whisper"
QWEN = "qwen"


class TrussTRTLLMQuantizationType(str, Enum):
NO_QUANT = "no_quant"
WEIGHTS_ONLY_INT8 = "weights_int8"
WEIGHTS_KV_INT8 = "weights_kv_int8"
WEIGHTS_ONLY_INT4 = "weights_int4"
WEIGHTS_KV_INT4 = "weights_kv_int4"
WEIGHTS_INT4_KV_INT8 = "weights_int4_kv_int8"
SMOOTH_QUANT = "smooth_quant"
FP8 = "fp8"
FP8_KV = "fp8_kv"
Expand Down Expand Up @@ -58,10 +59,9 @@ class CheckpointRepository(BaseModel):

class TrussTRTLLMBuildConfiguration(BaseModel):
base_model: TrussTRTLLMModel
max_input_len: int
max_output_len: int
max_batch_size: int
max_num_tokens: Optional[int] = None
max_seq_len: int
max_batch_size: Optional[int] = 256
max_num_tokens: Optional[int] = 8192
max_beam_width: int = 1
max_prompt_embedding_table_size: int = 0
checkpoint_repository: CheckpointRepository
Expand All @@ -75,7 +75,6 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
plugin_configuration: TrussTRTLLMPluginConfiguration = (
TrussTRTLLMPluginConfiguration()
)
use_fused_mlp: bool = False
kv_cache_free_gpu_mem_fraction: float = 0.9
num_builder_gpus: Optional[int] = None
enable_chunked_context: bool = False
Expand Down
2 changes: 1 addition & 1 deletion truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.11"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = [
"grpcio==1.62.3",
Expand Down
11 changes: 0 additions & 11 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
DEFAULT_BUNDLED_PACKAGES_DIR,
)

tensor_parallel_count = (
config.trt_llm.build.tensor_parallel_count # type: ignore[union-attr]
if config.trt_llm.build is not None
else config.trt_llm.serve.tensor_parallel_count # type: ignore[union-attr]
)

if tensor_parallel_count != config.resources.accelerator.count:
raise ValueError(
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)

config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY

if not is_audio_model:
Expand Down
3 changes: 1 addition & 2 deletions truss/test_data/test_trt_llm_truss/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ resources:
use_gpu: True
trt_llm:
build:
max_input_len: 1000
max_seq_len: 1000
max_batch_size: 1
max_beam_width: 1
max_output_len: 1000
base_model: llama
checkpoint_repository:
repo: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Expand Down
3 changes: 1 addition & 2 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,7 @@ def modify_handle(h: TrussHandle):
content["trt_llm"] = {
"build": {
"base_model": "llama",
"max_input_len": 1024,
"max_output_len": 1024,
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
Expand Down
9 changes: 7 additions & 2 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ def default_config() -> Dict[str, Any]:
@pytest.fixture
def trtllm_config(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
}
trtllm_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_input_len": 1024,
"max_output_len": 1024,
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
Expand Down
3 changes: 2 additions & 1 deletion truss/trt_llm/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _verify_has_class_init_arg(source: str, class_name: str, arg_name: str):
raise ValidationError(
(
"Model class `__init__` method is required to have `trt_llm` as an argument. Please add that argument.\n "
"Or if you want to use the automatically generated model class then remove the `model.py` file."
"Or if you want to use the automatically generated model class then remove the `model.py` file.\n "
"Refer to https://docs.baseten.co/performance/engine-builder-customization for details on engine object usage."
)
)

Expand Down
13 changes: 10 additions & 3 deletions truss/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def to_dict(self, verbose: bool = True):
def clone(self):
return TrussConfig.from_dict(self.to_dict())

def _validate_quant_format_and_accelerator_for_trt_llm_builder(self) -> None:
def _validate_accelerator_for_trt_llm_builder(self) -> None:
if self.trt_llm and self.trt_llm.build:
if (
self.trt_llm.build.quantization_type
Expand All @@ -665,9 +665,16 @@ def _validate_quant_format_and_accelerator_for_trt_llm_builder(self) -> None:
] and self.resources.accelerator.accelerator not in [
Accelerator.H100,
Accelerator.H100_40GB,
Accelerator.L4,
]:
raise ValueError(
"FP8 quantization is only supported on H100 accelerators"
"FP8 quantization is only supported on L4 and H100 accelerators"
)
tensor_parallel_count = self.trt_llm.build.tensor_parallel_count

if tensor_parallel_count != self.resources.accelerator.count:
raise ValueError(
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)

def validate(self):
Expand All @@ -692,7 +699,7 @@ def validate(self):
raise ValueError(
"Please ensure that only one of `requirements` and `requirements_file` is specified"
)
self._validate_quant_format_and_accelerator_for_trt_llm_builder()
self._validate_accelerator_for_trt_llm_builder()


def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]:
Expand Down

0 comments on commit 467b7a1

Please sign in to comment.