From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001 From: Sanger Steel <sangersteel@gmail.com> Date: Sat, 13 Apr 2024 20:13:01 -0400 Subject: [PATCH] [Frontend] [Core] feat: Add model loading using `tensorizer` (#3476) --- .buildkite/test-pipeline.yaml | 3 + docs/source/conf.py | 1 + docs/source/models/engine_args.rst | 3 +- examples/tensorize_vllm_model.py | 254 ++++++++++++++ requirements-cpu.txt | 2 +- requirements-dev.txt | 1 + setup.py | 3 + tests/tensorizer/__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 245 ++++++++++++++ tests/tensorizer/test_tensorizer.py | 302 +++++++++++++++++ vllm/config.py | 74 +++- vllm/engine/arg_utils.py | 45 ++- vllm/engine/llm_engine.py | 8 +- vllm/executor/gpu_executor.py | 23 +- vllm/executor/ray_gpu_executor.py | 6 +- vllm/model_executor/model_loader.py | 61 +++- vllm/model_executor/tensorizer_loader.py | 319 ++++++++++++++++++ vllm/model_executor/weight_utils.py | 34 +- vllm/worker/model_runner.py | 9 +- vllm/worker/worker.py | 9 +- 20 files changed, 1351 insertions(+), 51 deletions(-) create mode 100644 examples/tensorize_vllm_model.py create mode 100644 tests/tensorizer/__init__.py create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py create mode 100644 tests/tensorizer/test_tensorizer.py create mode 100644 vllm/model_executor/tensorizer_loader.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8d7d6304cf12e..aa4582bbda0c7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -91,6 +91,9 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 +- label: Tensorizer Test + command: apt-get install curl libsodium23 && pytest -v -s tensorizer + - label: Metrics Test command: pytest -v -s metrics diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a8c365ffb3bb..19cc8557a7541 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,7 @@ "vllm._C", "numpy", "tqdm", + "tensorizer", ] for mock_target in autodoc_mock_imports: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d8a7ac72e0175..886a806934c04 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM: Directory to download and load the weights, default to the default cache dir of huggingface. -.. option:: --load-format {auto,pt,safetensors,npcache,dummy} +.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer} The format of the model weights to load. @@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py new file mode 100644 index 0000000000000..3c20a38c7f726 --- /dev/null +++ b/examples/tensorize_vllm_model.py @@ -0,0 +1,254 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True, + help="The directory to serialize the model to. " + "This can be a local directory or S3 URI. The path to where the " + "tensors are saved is a combination of the supplied `dir` and model " + "reference ID. For instance, if `dir` is the serialized directory, " + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " + "where `suffix` is given by `--suffix` or a random UUID if not " + "provided.") + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 36d20bc9473ea..5779b38b24e69 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. +triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 96dfda6faf00f..1317e51b2dd11 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,6 +14,7 @@ types-setuptools # testing pytest +tensorizer==2.9.0a0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 9f0814e9f3bff..813321efe796d 100644 --- a/setup.py +++ b/setup.py @@ -405,6 +405,9 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, + extras_require={ + "optional": ["tensorizer==2.9.0a1"], + }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py new file mode 100644 index 0000000000000..d0be08329fd64 --- /dev/null +++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py @@ -0,0 +1,245 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True) + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py new file mode 100644 index 0000000000000..2ab893e95da9c --- /dev/null +++ b/tests/tensorizer/test_tensorizer.py @@ -0,0 +1,302 @@ +import gc +import subprocess +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from tests.entrypoints.test_openai_server import ServerRunner +from vllm import SamplingParams +from vllm.config import TensorizerConfig +from vllm.model_executor.tensorizer_loader import ( + EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, + load_with_tensorizer, open_stream) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +model_ref = "facebook/opt-125m" + + +def is_curl_installed(): + try: + subprocess.check_call(['curl', '--version']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +@pytest.fixture(autouse=True) +def tensorizer_config(): + config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + return config + + +@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +def test_load_with_tensorizer(mock_agent, tensorizer_config): + mock_linear_method = MagicMock() + mock_agent_instance = mock_agent.return_value + mock_agent_instance.deserialize.return_value = MagicMock() + + result = load_with_tensorizer(tensorizer_config, + linear_method=mock_linear_method) + + mock_agent.assert_called_once_with(tensorizer_config, + linear_method=mock_linear_method) + mock_agent_instance.deserialize.assert_called_once() + assert result == mock_agent_instance.deserialize.return_value + + +def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = True + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is True + + +def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = False + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is False + + +def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_can_deserialize_s3(vllm_runner): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + loaded_hf_model = vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ) + + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + + assert deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_deserialized_encrypted_vllm_model_has_same_outputs( + vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + with open_stream(key_path, "wb+") as stream: + stream.write(encryption_params.key) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True) + + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, + tmp_path): + hf_model = hf_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + del hf_model + gc.collect() + torch.cuda.empty_cache() + loaded_hf_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False) + + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) + + assert outputs == deserialized_outputs + + +def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): + from huggingface_hub import snapshot_download + + from examples.multilora_inference import (create_test_prompts, + process_requests) + + model_ref = "meta-llama/Llama-2-7b-hf" + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + + # Serialize model before deserializing and binding LoRA adapters + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner( + model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=True, + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=50, + max_model_len=1000, + ) + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + + assert loaded_vllm_model + + +def test_load_without_tensorizer_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, tensorizer_uri="test") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorize_vllm_model(tmp_path): + # Test serialize command + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + # Test deserialize command + deserialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + path_to_tensors + ] + result = subprocess.run(deserialize_args, capture_output=True, text=True) + assert result.returncode == 0, (f"Deserialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_openai_apiserver_with_tensorizer(tmp_path): + ## Serialize model + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + ## Start OpenAI API server + openai_args = [ + "--model", model_ref, "--dtype", "float16", "--load-format", + "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", + "--port", "8000" + ] + + server = ServerRunner.remote(openai_args) + + print("Server ready.") + assert server.ready.remote() + + +def test_raise_value_error_on_invalid_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, + load_format="safetensors", + tensorizer_uri="test") + + +def test_tensorizer_with_tp(vllm_runner): + with pytest.raises(ValueError): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + tensor_parallel_size=2, + ) + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorizer_warn_quant(tmp_path): + model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", + "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + assert 'PerformanceWarning' in result.stderr diff --git a/vllm/config.py b/vllm/config.py index bbda4ecf3cc56..dce2944b2ee8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,8 @@ import enum +import io import json import os +import typing from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union @@ -16,6 +18,8 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.model_executor.tensorizer_loader import TensorizerArgs + logger = init_logger(__name__) _GB = 1 << 30 @@ -139,13 +143,14 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" + "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" ] rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " + "'dummy'.") if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in supported_load_format @@ -882,6 +887,65 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[torch.nn.Module] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Union[str, torch.dtype] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + from vllm.model_executor.tensorizer_loader import TensorizerArgs + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config) -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + from vllm.model_executor.tensorizer_loader import ( + tensorizer_warning) + tensorizer_warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + if (model_config.load_format != "tensorizer" + and self.tensorizer_uri is not None): + raise ValueError( + "A tensorizer uri was passed for tensorizer loading, but the " + f"load format was set to {model_config.load_format}. " + "Please set the load format to 'tensorizer' to use " + f"tensorizer args.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1029,6 +1093,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1036,6 +1101,11 @@ def __post_init__(self): self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) + self.tensorizer_config.verify_with_model_config(self.model_config) + if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index daefddc01b431..831a03be65f61 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,12 +1,15 @@ import argparse import dataclasses +import io +import os from dataclasses import dataclass -from typing import Optional +from typing import BinaryIO, Optional, Union from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + SpeculativeConfig, TensorizerConfig, + TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -58,12 +61,22 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 + # Tensorizer configuration parameters + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] = None + vllm_tensorized: bool = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -135,7 +148,9 @@ def add_cli_args( '--load-format', type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + ], help='The format of the model weights to load. ' '"auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' @@ -145,7 +160,10 @@ def add_cli_args( '"npcache" will load the weights in pytorch format and store ' 'a numpy cache to speed up the loading. ' '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + 'which is mainly for profiling.' + '"tensorizer" will load the weights using tensorizer from CoreWeave' + 'which assumes tensorizer_uri is set to the location of the ' + 'serialized weights.') parser.add_argument( '--dtype', type=str, @@ -403,6 +421,7 @@ def add_cli_args( default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') + parser = TensorizerArgs.add_cli_args(parser) return parser @classmethod @@ -465,6 +484,17 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + tensorizer_config = TensorizerConfig( + tensorizer_uri=self.tensorizer_uri, + vllm_tensorized=self.vllm_tensorized, + verify_hash=self.verify_hash, + num_readers=self.num_readers, + encryption_keyfile=self.encryption_keyfile, + s3_access_key_id=self.s3_access_key_id, + s3_secret_access_key=self.s3_secret_access_key, + s3_endpoint=self.s3_endpoint, + ) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -488,7 +518,8 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, - speculative_config=speculative_config) + speculative_config=speculative_config, + tensorizer_config=tensorizer_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a91629a630591..8c37c5a9d6ee9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +74,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -110,6 +111,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -125,6 +127,7 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + tensorizer_config=tensorizer_config, ) self._initialize_kv_caches() @@ -264,6 +267,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index f20221a0b941a..30577ecf62faa 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -2,7 +2,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,17 +15,14 @@ class GPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config @@ -33,6 +30,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for GPU backend" @@ -61,6 +59,7 @@ def _init_worker(self): distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b937693c92257..28dc3e0db312a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -42,6 +42,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -50,6 +51,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for RayGPU backend." @@ -171,6 +173,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, + tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,6 +190,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0f..c70ca48bca70a 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -3,11 +3,14 @@ from typing import Tuple, Type import torch -import torch.nn as nn +from torch import nn from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llava import LlavaForConditionalGeneration +from vllm.model_executor.tensorizer_loader import ( + ParameterizedLoadFormat, is_vllm_serialized_tensorizer, + load_with_tensorizer) from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) vision_language_config = kwargs.get("vision_language_config", None) + tensorizer_config = kwargs.get("tensorizer_config", None) model_class = _get_model_architecture(model_config)[0] # Get the (maybe quantized) linear method. @@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): - model = model_class(model_config.hf_config, linear_method, - lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - else: - if model_class not in _VISION_MODEL_CLASSES: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config, - vision_language_config, linear_method) + if (model_config.load_format == "tensorizer" + and is_vllm_serialized_tensorizer(tensorizer_config)): + extra_kwargs["linear_method"] = linear_method + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + model = model_class(config=model_config.hf_config, + linear_method=linear_method, + **extra_kwargs) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + if model_config.load_format == "tensorizer": + # Provide a dynamic load format for `model.load_weights` + # to retain tensorizer args from CLI. + model_config.load_format = ParameterizedLoadFormat( + model_config.load_format) + model_config.load_format.params = ( + tensorizer_config._construct_tensorizer_args()) + + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py new file mode 100644 index 0000000000000..ed3ad9e2ffa15 --- /dev/null +++ b/vllm/model_executor/tensorizer_loader.py @@ -0,0 +1,319 @@ +import argparse +import dataclasses +import io +import os +import time +import typing +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.config import TensorizerConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +tensorizer_load_fail = False + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) +except ImportError: + tensorizer_load_fail = True + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor' +] + +logger = init_logger(__name__) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +def tensorizer_warning(message: str): + return warnings.warn(message, category=PerformanceWarning, stacklevel=2) + + +def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: + if tensorizer_config is None: + return False + return tensorizer_config.vllm_tensorized + + +class ParameterizedLoadFormat(str): + __slots__ = "params" + + +class PerformanceWarning(UserWarning): + + def __str__(self): + return (f"{super().__str__()}" + " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" + " environment variable to hide this)") + + +if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() + not in ("", "0", "n", "no", "off", "disable")): + warnings.simplefilter("ignore", category=PerformanceWarning) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is 1. This greatly increases + performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = (self.s3_access_key_id + or os.environ.get("S3_ACCESS_KEY_ID")) or None + self.s3_secret_access_key = ( + self.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY")) or None + self.s3_endpoint = (self.s3_endpoint + or os.environ.get("S3_ENDPOINT_URL")) or None + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + # Omitting self.dtype and self.device as this behaves weirdly + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Tensorizer CLI arguments""" + + # Create the argument group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + '--load-format=tensorizer')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=1, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + group.add_argument( + "--vllm-tensorized", + action="store_true", + help="If enabled, indicates that the serialized model is a vLLM " + "model. This is used to determine the behavior of the " + "TensorDeserializer when loading tensors from a " + "serialized model.") + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + linear_method: LinearMethodBase, **extra_kwargs): + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("linear_method", None) is not None: + self.linear_method = extra_kwargs["linear_method"] + else: + self.linear_method = linear_method + self.model = self._init_model() + + if tensorizer_load_fail: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`.") + + def _init_model(self): + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + linear_method=self.linear_method, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + # Lazy load the tensors from S3 into the model. + start = time.perf_counter() + with open_stream( + self.tensorizer_args.tensorizer_uri, + mode="rb", + **self.tensorizer_args.stream_params, + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info(f"Deserialized {total_bytes_str} in " + f"{end - start:0.2f}s, {per_second}/s") + logger.info(f"Memory usage before: {before_mem}") + logger.info(f"Memory usage after: {after_mem}") + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + return self.model.eval() diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 0961478930d74..08425604f0511 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union import filelock import huggingface_hub.constants @@ -161,7 +161,8 @@ def prepare_hf_model_weights( revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) \ + and load_format != "tensorizer" use_safetensors = False # Some quantized models use .pt files for storing the weights. if load_format == "auto": @@ -173,13 +174,15 @@ def prepare_hf_model_weights( allow_patterns = ["*.pt"] elif load_format == "npcache": allow_patterns = ["*.bin"] + elif load_format == "tensorizer": + allow_patterns = ["*.tensors"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] - if not is_local: + if not is_local and load_format != "tensorizer": # Before we download we look at that is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -224,6 +227,9 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] + if load_format == "tensorizer": + return hf_folder, hf_weights_files, use_safetensors + if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -234,7 +240,7 @@ def prepare_hf_model_weights( def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto", + load_format: Union[Tuple, str] = "auto", revision: Optional[str] = None, fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -277,6 +283,26 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) + elif load_format == "tensorizer": + from vllm.model_executor.tensorizer_loader import (TensorDeserializer, + open_stream, + tensorizer_warning) + tensorizer_args = load_format.params + tensorizer_warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state elif use_safetensors: for st_file in hf_weights_files: with safe_open(st_file, framework="pt") as f: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47ad8f0c9b78b..7dbe14ead0976 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) + SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -111,11 +112,13 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -158,7 +161,9 @@ def load_model(self) -> None: lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + tensorizer_config=self.tensorizer_config, + ) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3f0b2fd83f3e5..82491c6df6616 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -42,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -53,6 +55,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,7 +73,9 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + tensorizer_config=tensorizer_config, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine = None