From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001
From: Sanger Steel <sangersteel@gmail.com>
Date: Sat, 13 Apr 2024 20:13:01 -0400
Subject: [PATCH] [Frontend] [Core] feat: Add model loading using `tensorizer`
 (#3476)

---
 .buildkite/test-pipeline.yaml                 |   3 +
 docs/source/conf.py                           |   1 +
 docs/source/models/engine_args.rst            |   3 +-
 examples/tensorize_vllm_model.py              | 254 ++++++++++++++
 requirements-cpu.txt                          |   2 +-
 requirements-dev.txt                          |   1 +
 setup.py                                      |   3 +
 tests/tensorizer/__init__.py                  |   0
 .../tensorize_vllm_model_for_testing.py       | 245 ++++++++++++++
 tests/tensorizer/test_tensorizer.py           | 302 +++++++++++++++++
 vllm/config.py                                |  74 +++-
 vllm/engine/arg_utils.py                      |  45 ++-
 vllm/engine/llm_engine.py                     |   8 +-
 vllm/executor/gpu_executor.py                 |  23 +-
 vllm/executor/ray_gpu_executor.py             |   6 +-
 vllm/model_executor/model_loader.py           |  61 +++-
 vllm/model_executor/tensorizer_loader.py      | 319 ++++++++++++++++++
 vllm/model_executor/weight_utils.py           |  34 +-
 vllm/worker/model_runner.py                   |   9 +-
 vllm/worker/worker.py                         |   9 +-
 20 files changed, 1351 insertions(+), 51 deletions(-)
 create mode 100644 examples/tensorize_vllm_model.py
 create mode 100644 tests/tensorizer/__init__.py
 create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py
 create mode 100644 tests/tensorizer/test_tensorizer.py
 create mode 100644 vllm/model_executor/tensorizer_loader.py

diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 8d7d6304cf12e..aa4582bbda0c7 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -91,6 +91,9 @@ steps:
   command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
   parallelism: 4
 
+- label: Tensorizer Test
+  command: apt-get install curl libsodium23 && pytest -v -s tensorizer
+
 - label: Metrics Test
   command: pytest -v -s metrics
 
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 7a8c365ffb3bb..19cc8557a7541 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -83,6 +83,7 @@
     "vllm._C",
     "numpy",
     "tqdm",
+    "tensorizer",
 ]
 
 for mock_target in autodoc_mock_imports:
diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst
index d8a7ac72e0175..886a806934c04 100644
--- a/docs/source/models/engine_args.rst
+++ b/docs/source/models/engine_args.rst
@@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
 
     Directory to download and load the weights, default to the default cache dir of huggingface.
 
-.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
+.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
 
     The format of the model weights to load.
 
@@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
     * "safetensors" will load the weights in the safetensors format.
     * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
     * "dummy" will initialize the weights with random values, mainly for profiling.
+    * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
 
 .. option:: --dtype {auto,half,float16,bfloat16,float,float32}
 
diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py
new file mode 100644
index 0000000000000..3c20a38c7f726
--- /dev/null
+++ b/examples/tensorize_vllm_model.py
@@ -0,0 +1,254 @@
+import argparse
+import dataclasses
+import os
+import time
+import uuid
+from functools import partial
+from typing import Type
+
+import torch
+import torch.nn as nn
+from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
+                        TensorSerializer, stream_io)
+from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
+from transformers import AutoConfig, PretrainedConfig
+
+from vllm.distributed import initialize_model_parallel
+from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.models import ModelRegistry
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
+# yapf conflicts with isort for this docstring
+# yapf: disable
+"""
+tensorize_vllm_model.py is a script that can be used to serialize and 
+deserialize vLLM models. These models can be loaded using tensorizer directly 
+to the GPU extremely quickly. Tensor encryption and decryption is also 
+supported, although libsodium must be installed to use it. Install
+vllm with tensorizer support using `pip install vllm[tensorizer]`.
+
+To serialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+   --model EleutherAI/gpt-j-6B \
+   --dtype float16 \
+   serialize \
+   --serialized-directory s3://my-bucket/ \
+   --suffix vllm
+   
+Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
+and saves it to your S3 bucket. A local directory can also be used.
+
+You can also encrypt the model weights with a randomly-generated key by 
+providing a `--keyfile` argument.
+
+To deserialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+   --model EleutherAI/gpt-j-6B \
+   --dtype float16 \
+   deserialize \
+   --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
+
+Which downloads the model tensors from your S3 bucket and deserializes them.
+To provide S3 credentials, you can provide `--s3-access-key-id` and 
+`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
+the OpenAI entrypoint, as arguments for LLM(), or as environment variables
+in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
+
+
+You can also provide a `--keyfile` argument to decrypt the model weights if 
+they were serialized with encryption.
+
+For more information on the available arguments, run 
+`python tensorize_vllm_model.py --help`.
+"""
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="An example script that can be used to serialize and "
+        "deserialize vLLM models. These models "
+        "can be loaded using tensorizer directly to the GPU "
+        "extremely quickly. Tensor encryption and decryption is "
+        "also supported, although libsodium must be installed to "
+        "use it.")
+    parser = EngineArgs.add_cli_args(parser)
+    subparsers = parser.add_subparsers(dest='command')
+
+    serialize_parser = subparsers.add_parser(
+        'serialize', help="Serialize a model to `--serialized-directory`")
+
+    serialize_parser.add_argument(
+        "--suffix",
+        type=str,
+        required=False,
+        help=(
+            "The suffix to append to the serialized model directory, which is "
+            "used to construct the location of the serialized model tensors, "
+            "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
+            "`--suffix` is `v1`, the serialized model tensors will be "
+            "saved to "
+            "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
+            "If none is provided, a random UUID will be used."))
+    serialize_parser.add_argument(
+        "--serialized-directory",
+        type=str,
+        required=True,
+        help="The directory to serialize the model to. "
+        "This can be a local directory or S3 URI. The path to where the "
+        "tensors are saved is a combination of the supplied `dir` and model "
+        "reference ID. For instance, if `dir` is the serialized directory, "
+        "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
+        "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
+        "where `suffix` is given by `--suffix` or a random UUID if not "
+        "provided.")
+
+    serialize_parser.add_argument(
+        "--keyfile",
+        type=str,
+        required=False,
+        help=("Encrypt the model weights with a randomly-generated binary key,"
+              " and save the key at this path"))
+
+    deserialize_parser = subparsers.add_parser(
+        'deserialize',
+        help=("Deserialize a model from `--path-to-tensors`"
+              " to verify it can be loaded and used."))
+
+    deserialize_parser.add_argument(
+        "--path-to-tensors",
+        type=str,
+        required=True,
+        help="The local path or S3 URI to the model tensors to deserialize. ")
+
+    deserialize_parser.add_argument(
+        "--keyfile",
+        type=str,
+        required=False,
+        help=("Path to a binary key to use to decrypt the model weights,"
+              " if the model was serialized with encryption"))
+
+    return parser.parse_args()
+
+
+def make_model_contiguous(model):
+    # Ensure tensors are saved in memory contiguously
+    for param in model.parameters():
+        param.data = param.data.contiguous()
+
+
+def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+    architectures = getattr(config, "architectures", [])
+    for arch in architectures:
+        model_cls = ModelRegistry.load_model_cls(arch)
+        if model_cls is not None:
+            return model_cls
+    raise ValueError(
+        f"Model architectures {architectures} are not supported for now. "
+        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def serialize():
+
+    eng_args_dict = {f.name: getattr(args, f.name) for f in
+                     dataclasses.fields(EngineArgs)}
+    engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
+    engine = LLMEngine.from_engine_args(engine_args)
+
+    model = (engine.model_executor.driver_worker.
+             model_runner.model)
+
+    encryption_params = EncryptionParams.random() if keyfile else None
+    if keyfile:
+        with _write_stream(keyfile) as stream:
+            stream.write(encryption_params.key)
+
+    with _write_stream(model_path) as stream:
+        serializer = TensorSerializer(stream, encryption=encryption_params)
+        serializer.write_module(model)
+        serializer.close()
+
+    print("Serialization complete. Model tensors saved to", model_path)
+    if keyfile:
+        print("Key saved to", keyfile)
+
+
+def deserialize():
+    config = AutoConfig.from_pretrained(model_ref)
+
+    with no_init_or_tensor():
+        model_class = _get_vllm_model_architecture(config)
+        model = model_class(config)
+
+    before_mem = get_mem_usage()
+    start = time.time()
+
+    if keyfile:
+        with _read_stream(keyfile) as stream:
+            key = stream.read()
+            decryption_params = DecryptionParams.from_key(key)
+            tensorizer_args.deserializer_params['encryption'] = \
+                decryption_params
+
+    with (_read_stream(model_path)) as stream, TensorDeserializer(
+            stream, **tensorizer_args.deserializer_params) as deserializer:
+        deserializer.load_into_module(model)
+        end = time.time()
+
+    # Brag about how fast we are.
+    total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+    duration = end - start
+    per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+    after_mem = get_mem_usage()
+    print(
+        f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
+    )
+    print(f"Memory usage before: {before_mem}")
+    print(f"Memory usage after: {after_mem}")
+
+    return model
+
+
+args = parse_args()
+
+s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
+                    or None)
+s3_secret_access_key = (args.s3_secret_access_key
+                        or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
+
+s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
+
+_read_stream, _write_stream = (partial(
+    stream_io.open_stream,
+    mode=mode,
+    s3_access_key_id=s3_access_key_id,
+    s3_secret_access_key=s3_secret_access_key,
+    s3_endpoint=s3_endpoint,
+) for mode in ("rb", "wb+"))
+
+model_ref = args.model
+
+model_name = model_ref.split("/")[1]
+
+os.environ["MASTER_ADDR"] = "127.0.0.1"
+os.environ["MASTER_PORT"] = "8080"
+
+torch.distributed.init_process_group(world_size=1, rank=0)
+initialize_model_parallel()
+
+keyfile = args.keyfile if args.keyfile else None
+
+if args.command == "serialize":
+    input_dir = args.serialized_directory.rstrip('/')
+    suffix = args.suffix if args.suffix else uuid.uuid4().hex
+    base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
+    model_path = f"{base_path}/model.tensors"
+    serialize()
+elif args.command == "deserialize":
+    tensorizer_args = TensorizerArgs.from_cli_args(args)
+    model_path = args.path_to_tensors
+    deserialize()
+else:
+    raise ValueError("Either serialize or deserialize must be specified.")
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
index 36d20bc9473ea..5779b38b24e69 100644
--- a/requirements-cpu.txt
+++ b/requirements-cpu.txt
@@ -3,4 +3,4 @@
 
 # Dependencies for x86_64 CPUs
 torch == 2.2.1+cpu
-triton >= 2.1.0  # FIXME(woosuk): This is a hack to avoid import error.
+triton >= 2.1.0  # FIXME(woosuk): This is a hack to avoid import error.
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 96dfda6faf00f..1317e51b2dd11 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -14,6 +14,7 @@ types-setuptools
 
 # testing
 pytest
+tensorizer==2.9.0a0
 pytest-forked
 pytest-asyncio
 pytest-rerunfailures
diff --git a/setup.py b/setup.py
index 9f0814e9f3bff..813321efe796d 100644
--- a/setup.py
+++ b/setup.py
@@ -405,6 +405,9 @@ def _read_requirements(filename: str) -> List[str]:
     python_requires=">=3.8",
     install_requires=get_requirements(),
     ext_modules=ext_modules,
+    extras_require={
+        "optional": ["tensorizer==2.9.0a1"],
+    },
     cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
     package_data=package_data,
 )
diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py
new file mode 100644
index 0000000000000..d0be08329fd64
--- /dev/null
+++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py
@@ -0,0 +1,245 @@
+import argparse
+import dataclasses
+import os
+import time
+import uuid
+from functools import partial
+from typing import Type
+
+import torch
+import torch.nn as nn
+from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
+                        TensorSerializer, stream_io)
+from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
+from transformers import AutoConfig, PretrainedConfig
+
+from vllm.distributed import initialize_model_parallel
+from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.models import ModelRegistry
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
+# yapf conflicts with isort for this docstring
+# yapf: disable
+"""
+tensorize_vllm_model.py is a script that can be used to serialize and 
+deserialize vLLM models. These models can be loaded using tensorizer directly 
+to the GPU extremely quickly. Tensor encryption and decryption is also 
+supported, although libsodium must be installed to use it. Install
+vllm with tensorizer support using `pip install vllm[tensorizer]`.
+
+To serialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+   --model EleutherAI/gpt-j-6B \
+   --dtype float16 \
+   serialize \
+   --serialized-directory s3://my-bucket/ \
+   --suffix vllm
+
+Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
+and saves it to your S3 bucket. A local directory can also be used.
+
+You can also encrypt the model weights with a randomly-generated key by 
+providing a `--keyfile` argument.
+
+To deserialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+   --model EleutherAI/gpt-j-6B \
+   --dtype float16 \
+   deserialize \
+   --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
+
+Which downloads the model tensors from your S3 bucket and deserializes them.
+To provide S3 credentials, you can provide `--s3-access-key-id` and 
+`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
+the OpenAI entrypoint, as arguments for LLM(), or as environment variables
+in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
+
+
+You can also provide a `--keyfile` argument to decrypt the model weights if 
+they were serialized with encryption.
+
+For more information on the available arguments, run 
+`python tensorize_vllm_model.py --help`.
+"""
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="An example script that can be used to serialize and "
+                    "deserialize vLLM models. These models "
+                    "can be loaded using tensorizer directly to the GPU "
+                    "extremely quickly. Tensor encryption and decryption is "
+                    "also supported, although libsodium must be installed to "
+                    "use it.")
+    parser = EngineArgs.add_cli_args(parser)
+    subparsers = parser.add_subparsers(dest='command')
+
+    serialize_parser = subparsers.add_parser(
+        'serialize', help="Serialize a model to `--serialized-directory`")
+
+    serialize_parser.add_argument(
+        "--suffix",
+        type=str,
+        required=False,
+        help=(
+            "The suffix to append to the serialized model directory, which is "
+            "used to construct the location of the serialized model tensors, "
+            "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
+            "`--suffix` is `v1`, the serialized model tensors will be "
+            "saved to "
+            "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
+            "If none is provided, a random UUID will be used."))
+    serialize_parser.add_argument(
+        "--serialized-directory",
+        type=str,
+        required=True)
+
+    serialize_parser.add_argument(
+        "--keyfile",
+        type=str,
+        required=False,
+        help=("Encrypt the model weights with a randomly-generated binary key,"
+              " and save the key at this path"))
+
+    deserialize_parser = subparsers.add_parser(
+        'deserialize',
+        help=("Deserialize a model from `--path-to-tensors`"
+              " to verify it can be loaded and used."))
+
+    deserialize_parser.add_argument(
+        "--path-to-tensors",
+        type=str,
+        required=True,
+        help="The local path or S3 URI to the model tensors to deserialize. ")
+
+    deserialize_parser.add_argument(
+        "--keyfile",
+        type=str,
+        required=False,
+        help=("Path to a binary key to use to decrypt the model weights,"
+              " if the model was serialized with encryption"))
+
+    return parser.parse_args()
+
+
+def make_model_contiguous(model):
+    # Ensure tensors are saved in memory contiguously
+    for param in model.parameters():
+        param.data = param.data.contiguous()
+
+
+def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+    architectures = getattr(config, "architectures", [])
+    for arch in architectures:
+        model_cls = ModelRegistry.load_model_cls(arch)
+        if model_cls is not None:
+            return model_cls
+    raise ValueError(
+        f"Model architectures {architectures} are not supported for now. "
+        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def serialize():
+    eng_args_dict = {f.name: getattr(args, f.name) for f in
+                     dataclasses.fields(EngineArgs)}
+    engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
+    engine = LLMEngine.from_engine_args(engine_args)
+
+    model = (engine.model_executor.driver_worker.
+             model_runner.model)
+
+    encryption_params = EncryptionParams.random() if keyfile else None
+    if keyfile:
+        with _write_stream(keyfile) as stream:
+            stream.write(encryption_params.key)
+
+    with _write_stream(model_path) as stream:
+        serializer = TensorSerializer(stream, encryption=encryption_params)
+        serializer.write_module(model)
+        serializer.close()
+
+    print("Serialization complete. Model tensors saved to", model_path)
+    if keyfile:
+        print("Key saved to", keyfile)
+
+
+def deserialize():
+    config = AutoConfig.from_pretrained(model_ref)
+
+    with no_init_or_tensor():
+        model_class = _get_vllm_model_architecture(config)
+        model = model_class(config)
+
+    before_mem = get_mem_usage()
+    start = time.time()
+
+    if keyfile:
+        with _read_stream(keyfile) as stream:
+            key = stream.read()
+            decryption_params = DecryptionParams.from_key(key)
+            tensorizer_args.deserializer_params['encryption'] = \
+                decryption_params
+
+    with (_read_stream(model_path)) as stream, TensorDeserializer(
+            stream, **tensorizer_args.deserializer_params) as deserializer:
+        deserializer.load_into_module(model)
+        end = time.time()
+
+    # Brag about how fast we are.
+    total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+    duration = end - start
+    per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+    after_mem = get_mem_usage()
+    print(
+        f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
+    )
+    print(f"Memory usage before: {before_mem}")
+    print(f"Memory usage after: {after_mem}")
+
+    return model
+
+
+args = parse_args()
+
+s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
+                    or None)
+s3_secret_access_key = (args.s3_secret_access_key
+                        or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
+
+s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
+
+_read_stream, _write_stream = (partial(
+    stream_io.open_stream,
+    mode=mode,
+    s3_access_key_id=s3_access_key_id,
+    s3_secret_access_key=s3_secret_access_key,
+    s3_endpoint=s3_endpoint,
+) for mode in ("rb", "wb+"))
+
+model_ref = args.model
+
+model_name = model_ref.split("/")[1]
+
+os.environ["MASTER_ADDR"] = "127.0.0.1"
+os.environ["MASTER_PORT"] = "8080"
+
+torch.distributed.init_process_group(world_size=1, rank=0)
+initialize_model_parallel()
+
+keyfile = args.keyfile if args.keyfile else None
+
+if args.command == "serialize":
+    input_dir = args.serialized_directory.rstrip('/')
+    suffix = args.suffix if args.suffix else uuid.uuid4().hex
+    base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
+    model_path = f"{base_path}/model.tensors"
+    serialize()
+elif args.command == "deserialize":
+    tensorizer_args = TensorizerArgs.from_cli_args(args)
+    model_path = args.path_to_tensors
+    deserialize()
+else:
+    raise ValueError("Either serialize or deserialize must be specified.")
diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py
new file mode 100644
index 0000000000000..2ab893e95da9c
--- /dev/null
+++ b/tests/tensorizer/test_tensorizer.py
@@ -0,0 +1,302 @@
+import gc
+import subprocess
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from tests.entrypoints.test_openai_server import ServerRunner
+from vllm import SamplingParams
+from vllm.config import TensorizerConfig
+from vllm.model_executor.tensorizer_loader import (
+    EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
+    load_with_tensorizer, open_stream)
+
+prompts = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+]
+# Create a sampling params object.
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
+
+model_ref = "facebook/opt-125m"
+
+
+def is_curl_installed():
+    try:
+        subprocess.check_call(['curl', '--version'])
+        return True
+    except (subprocess.CalledProcessError, FileNotFoundError):
+        return False
+
+
+@pytest.fixture(autouse=True)
+def tensorizer_config():
+    config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
+    return config
+
+
+@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
+def test_load_with_tensorizer(mock_agent, tensorizer_config):
+    mock_linear_method = MagicMock()
+    mock_agent_instance = mock_agent.return_value
+    mock_agent_instance.deserialize.return_value = MagicMock()
+
+    result = load_with_tensorizer(tensorizer_config,
+                                  linear_method=mock_linear_method)
+
+    mock_agent.assert_called_once_with(tensorizer_config,
+                                       linear_method=mock_linear_method)
+    mock_agent_instance.deserialize.assert_called_once()
+    assert result == mock_agent_instance.deserialize.return_value
+
+
+def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
+    tensorizer_config.vllm_tensorized = True
+
+    result = is_vllm_serialized_tensorizer(tensorizer_config)
+
+    assert result is True
+
+
+def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
+    tensorizer_config.vllm_tensorized = False
+
+    result = is_vllm_serialized_tensorizer(tensorizer_config)
+
+    assert result is False
+
+
+def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
+    vllm_model = vllm_runner(model_ref)
+    model_path = tmp_path / (model_ref + ".tensors")
+    outputs = vllm_model.generate(prompts, sampling_params)
+    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+             model_runner.model)
+    with open_stream(model_path, "wb+") as stream:
+        serializer = TensorSerializer(stream)
+        serializer.write_module(model)
+    del vllm_model, model
+    gc.collect()
+    torch.cuda.empty_cache()
+    loaded_vllm_model = vllm_runner(model_ref,
+                                    load_format="tensorizer",
+                                    tensorizer_uri=model_path,
+                                    num_readers=1,
+                                    vllm_tensorized=True)
+    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
+
+    # Assumes SamplingParams being seeded ensures the outputs are deterministic
+    assert outputs == deserialized_outputs
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_can_deserialize_s3(vllm_runner):
+    model_ref = "EleutherAI/pythia-1.4b"
+    tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
+
+    loaded_hf_model = vllm_runner(
+        model_ref,
+        tensorizer_uri=tensorized_path,
+        load_format="tensorizer",
+        num_readers=1,
+        vllm_tensorized=False,
+        s3_endpoint="object.ord1.coreweave.com",
+    )
+
+    deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
+
+    assert deserialized_outputs
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_deserialized_encrypted_vllm_model_has_same_outputs(
+        vllm_runner, tmp_path):
+    vllm_model = vllm_runner(model_ref)
+    model_path = tmp_path / (model_ref + ".tensors")
+    key_path = tmp_path / (model_ref + ".key")
+    outputs = vllm_model.generate(prompts, sampling_params)
+    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+             model_runner.model)
+
+    encryption_params = EncryptionParams.random()
+    with open_stream(model_path, "wb+") as stream:
+        serializer = TensorSerializer(stream, encryption=encryption_params)
+        serializer.write_module(model)
+    with open_stream(key_path, "wb+") as stream:
+        stream.write(encryption_params.key)
+    del vllm_model, model
+    gc.collect()
+    torch.cuda.empty_cache()
+    loaded_vllm_model = vllm_runner(model_ref,
+                                    tensorizer_uri=model_path,
+                                    load_format="tensorizer",
+                                    encryption_keyfile=key_path,
+                                    num_readers=1,
+                                    vllm_tensorized=True)
+
+    deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
+
+    # Assumes SamplingParams being seeded ensures the outputs are deterministic
+    assert outputs == deserialized_outputs
+
+
+def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
+                                                tmp_path):
+    hf_model = hf_runner(model_ref)
+    model_path = tmp_path / (model_ref + ".tensors")
+    max_tokens = 50
+    outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
+    with open_stream(model_path, "wb+") as stream:
+        serializer = TensorSerializer(stream)
+        serializer.write_module(hf_model.model)
+    del hf_model
+    gc.collect()
+    torch.cuda.empty_cache()
+    loaded_hf_model = vllm_runner(model_ref,
+                                  tensorizer_uri=model_path,
+                                  load_format="tensorizer",
+                                  num_readers=1,
+                                  vllm_tensorized=False)
+
+    deserialized_outputs = loaded_hf_model.generate_greedy(
+        prompts, max_tokens=max_tokens)
+
+    assert outputs == deserialized_outputs
+
+
+def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
+    from huggingface_hub import snapshot_download
+
+    from examples.multilora_inference import (create_test_prompts,
+                                              process_requests)
+
+    model_ref = "meta-llama/Llama-2-7b-hf"
+    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
+    test_prompts = create_test_prompts(lora_path)
+
+    # Serialize model before deserializing and binding LoRA adapters
+    vllm_model = vllm_runner(model_ref, )
+    model_path = tmp_path / (model_ref + ".tensors")
+    model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+             model_runner.model)
+    with open_stream(model_path, "wb+") as stream:
+        serializer = TensorSerializer(stream)
+        serializer.write_module(model)
+    del vllm_model, model
+    gc.collect()
+    torch.cuda.empty_cache()
+    loaded_vllm_model = vllm_runner(
+        model_ref,
+        tensorizer_uri=model_path,
+        load_format="tensorizer",
+        num_readers=1,
+        vllm_tensorized=True,
+        enable_lora=True,
+        max_loras=1,
+        max_lora_rank=8,
+        max_cpu_loras=2,
+        max_num_seqs=50,
+        max_model_len=1000,
+    )
+    process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
+
+    assert loaded_vllm_model
+
+
+def test_load_without_tensorizer_load_format(vllm_runner):
+    with pytest.raises(ValueError):
+        vllm_runner(model_ref, tensorizer_uri="test")
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_tensorize_vllm_model(tmp_path):
+    # Test serialize command
+    serialize_args = [
+        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+        model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
+        tmp_path, "--suffix", "tests"
+    ]
+    result = subprocess.run(serialize_args, capture_output=True, text=True)
+    print(result.stdout)  # Print the output of the serialize command
+
+    assert result.returncode == 0, (f"Serialize command failed with output:"
+                                    f"\n{result.stdout}\n{result.stderr}")
+
+    path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
+
+    # Test deserialize command
+    deserialize_args = [
+        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+        model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
+        path_to_tensors
+    ]
+    result = subprocess.run(deserialize_args, capture_output=True, text=True)
+    assert result.returncode == 0, (f"Deserialize command failed with output:"
+                                    f"\n{result.stdout}\n{result.stderr}")
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_openai_apiserver_with_tensorizer(tmp_path):
+    ## Serialize model
+    serialize_args = [
+        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+        model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
+        tmp_path, "--suffix", "tests"
+    ]
+    result = subprocess.run(serialize_args, capture_output=True, text=True)
+    print(result.stdout)  # Print the output of the serialize command
+
+    assert result.returncode == 0, (f"Serialize command failed with output:"
+                                    f"\n{result.stdout}\n{result.stderr}")
+
+    path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
+
+    ## Start OpenAI API server
+    openai_args = [
+        "--model", model_ref, "--dtype", "float16", "--load-format",
+        "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
+        "--port", "8000"
+    ]
+
+    server = ServerRunner.remote(openai_args)
+
+    print("Server ready.")
+    assert server.ready.remote()
+
+
+def test_raise_value_error_on_invalid_load_format(vllm_runner):
+    with pytest.raises(ValueError):
+        vllm_runner(model_ref,
+                    load_format="safetensors",
+                    tensorizer_uri="test")
+
+
+def test_tensorizer_with_tp(vllm_runner):
+    with pytest.raises(ValueError):
+        model_ref = "EleutherAI/pythia-1.4b"
+        tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
+
+        vllm_runner(
+            model_ref,
+            tensorizer_uri=tensorized_path,
+            load_format="tensorizer",
+            num_readers=1,
+            vllm_tensorized=False,
+            s3_endpoint="object.ord1.coreweave.com",
+            tensor_parallel_size=2,
+        )
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_tensorizer_warn_quant(tmp_path):
+    model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
+    serialize_args = [
+        "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+        model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
+        "serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
+    ]
+    result = subprocess.run(serialize_args, capture_output=True, text=True)
+    assert 'PerformanceWarning' in result.stderr
diff --git a/vllm/config.py b/vllm/config.py
index bbda4ecf3cc56..dce2944b2ee8a 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1,6 +1,8 @@
 import enum
+import io
 import json
 import os
+import typing
 from dataclasses import dataclass, fields
 from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
 
@@ -16,6 +18,8 @@
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
 
+    from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
 logger = init_logger(__name__)
 
 _GB = 1 << 30
@@ -139,13 +143,14 @@ def __init__(
     def _verify_load_format(self) -> None:
         load_format = self.load_format.lower()
         supported_load_format = [
-            "auto", "pt", "safetensors", "npcache", "dummy"
+            "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
         ]
         rocm_not_supported_load_format: List[str] = []
         if load_format not in supported_load_format:
             raise ValueError(
                 f"Unknown load format: {self.load_format}. Must be one of "
-                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
+                "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
+                "'dummy'.")
         if is_hip() and load_format in rocm_not_supported_load_format:
             rocm_supported_load_format = [
                 f for f in supported_load_format
@@ -882,6 +887,65 @@ def get_image_input_enum_type(
                              f"{[x.name for x in cls.ImageInputType]}.") from e
 
 
+@dataclass
+class TensorizerConfig:
+    tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+                          str, bytes, os.PathLike, int]
+    vllm_tensorized: bool
+    verify_hash: Optional[bool] = False
+    num_readers: Optional[int] = 1
+    encryption_keyfile: Optional[str] = None
+    s3_access_key_id: Optional[str] = None
+    s3_secret_access_key: Optional[str] = None
+    s3_endpoint: Optional[str] = None
+    model_class: Optional[torch.nn.Module] = None
+    hf_config: Optional[PretrainedConfig] = None
+    dtype: Union[str, torch.dtype] = None
+
+    def _construct_tensorizer_args(self) -> "TensorizerArgs":
+        from vllm.model_executor.tensorizer_loader import TensorizerArgs
+        tensorizer_args = {
+            "tensorizer_uri": self.tensorizer_uri,
+            "vllm_tensorized": self.vllm_tensorized,
+            "verify_hash": self.verify_hash,
+            "num_readers": self.num_readers,
+            "encryption_keyfile": self.encryption_keyfile,
+            "s3_access_key_id": self.s3_access_key_id,
+            "s3_secret_access_key": self.s3_secret_access_key,
+            "s3_endpoint": self.s3_endpoint,
+        }
+        return TensorizerArgs(**tensorizer_args)
+
+    def verify_with_parallel_config(
+        self,
+        parallel_config: "ParallelConfig",
+    ) -> None:
+        if (parallel_config.tensor_parallel_size > 1
+                and self.tensorizer_uri is not None):
+            raise ValueError(
+                "Loading to multiple GPUs is not currently supported with "
+                "vLLM-serialized models. Please set tensor_parallel_size=1."
+                " or use a non-vLLM-serialized model, such as a "
+                "serialized Hugging Face `PretrainedModel`.")
+
+    def verify_with_model_config(self, model_config) -> None:
+        if (model_config.quantization is not None
+                and self.tensorizer_uri is not None):
+            from vllm.model_executor.tensorizer_loader import (
+                tensorizer_warning)
+            tensorizer_warning(
+                "Loading a model using Tensorizer with quantization on vLLM"
+                " is unstable and may lead to errors.")
+
+        if (model_config.load_format != "tensorizer"
+                and self.tensorizer_uri is not None):
+            raise ValueError(
+                "A tensorizer uri was passed for tensorizer loading, but the "
+                f"load format was set to {model_config.load_format}. "
+                "Please set the load format to 'tensorizer' to use "
+                f"tensorizer args.")
+
+
 _STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.float16,
     "float16": torch.float16,
@@ -1029,6 +1093,7 @@ class EngineConfig:
     lora_config: Optional[LoRAConfig]
     vision_language_config: Optional[VisionLanguageConfig]
     speculative_config: Optional[SpeculativeConfig]
+    tensorizer_config: Optional[TensorizerConfig]
 
     def __post_init__(self):
         """Verify configs are valid & consistent with each other.
@@ -1036,6 +1101,11 @@ def __post_init__(self):
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
 
+        if self.tensorizer_config:
+            self.tensorizer_config.verify_with_parallel_config(
+                self.parallel_config)
+            self.tensorizer_config.verify_with_model_config(self.model_config)
+
         if self.lora_config:
             self.lora_config.verify_with_model_config(self.model_config)
             self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index daefddc01b431..831a03be65f61 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1,12 +1,15 @@
 import argparse
 import dataclasses
+import io
+import os
 from dataclasses import dataclass
-from typing import Optional
+from typing import BinaryIO, Optional, Union
 
 from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
                          ModelConfig, ParallelConfig, SchedulerConfig,
-                         SpeculativeConfig, TokenizerPoolConfig,
-                         VisionLanguageConfig)
+                         SpeculativeConfig, TensorizerConfig,
+                         TokenizerPoolConfig, VisionLanguageConfig)
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
 from vllm.utils import str_to_int_tuple
 
 
@@ -58,12 +61,22 @@ class EngineArgs:
     num_gpu_blocks_override: Optional[int] = None
     num_lookahead_slots: int = 0
 
+    # Tensorizer configuration parameters
+    tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
+                          bytes, os.PathLike, int] = None
+    vllm_tensorized: bool = False
+    verify_hash: Optional[bool] = False
+    num_readers: Optional[int] = 1
+    encryption_keyfile: Optional[str] = None
+    s3_access_key_id: Optional[str] = None
+    s3_secret_access_key: Optional[str] = None
+    s3_endpoint: Optional[str] = None
+
     # Related to Vision-language models such as llava
     image_input_type: Optional[str] = None
     image_token_id: Optional[int] = None
     image_input_shape: Optional[str] = None
     image_feature_size: Optional[int] = None
-
     scheduler_delay_factor: float = 0.0
     enable_chunked_prefill: bool = False
 
@@ -135,7 +148,9 @@ def add_cli_args(
             '--load-format',
             type=str,
             default=EngineArgs.load_format,
-            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
+            choices=[
+                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
+            ],
             help='The format of the model weights to load. '
             '"auto" will try to load the weights in the safetensors format '
             'and fall back to the pytorch bin format if safetensors format '
@@ -145,7 +160,10 @@ def add_cli_args(
             '"npcache" will load the weights in pytorch format and store '
             'a numpy cache to speed up the loading. '
             '"dummy" will initialize the weights with random values, '
-            'which is mainly for profiling.')
+            'which is mainly for profiling.'
+            '"tensorizer" will load the weights using tensorizer from CoreWeave'
+            'which assumes tensorizer_uri is set to the location of the '
+            'serialized weights.')
         parser.add_argument(
             '--dtype',
             type=str,
@@ -403,6 +421,7 @@ def add_cli_args(
             default=None,
             help='The number of speculative tokens to sample from '
             'the draft model in speculative decoding')
+        parser = TensorizerArgs.add_cli_args(parser)
         return parser
 
     @classmethod
@@ -465,6 +484,17 @@ def create_engine_config(self, ) -> EngineConfig:
             max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
             and self.max_cpu_loras > 0 else None) if self.enable_lora else None
 
+        tensorizer_config = TensorizerConfig(
+            tensorizer_uri=self.tensorizer_uri,
+            vllm_tensorized=self.vllm_tensorized,
+            verify_hash=self.verify_hash,
+            num_readers=self.num_readers,
+            encryption_keyfile=self.encryption_keyfile,
+            s3_access_key_id=self.s3_access_key_id,
+            s3_secret_access_key=self.s3_secret_access_key,
+            s3_endpoint=self.s3_endpoint,
+        )
+
         if self.image_input_type:
             if (not self.image_token_id or not self.image_input_shape
                     or not self.image_feature_size):
@@ -488,7 +518,8 @@ def create_engine_config(self, ) -> EngineConfig:
                             device_config=device_config,
                             lora_config=lora_config,
                             vision_language_config=vision_language_config,
-                            speculative_config=speculative_config)
+                            speculative_config=speculative_config,
+                            tensorizer_config=tensorizer_config)
 
 
 @dataclass
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index a91629a630591..8c37c5a9d6ee9 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -6,7 +6,7 @@
 import vllm
 from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
                          ParallelConfig, SchedulerConfig, SpeculativeConfig,
-                         VisionLanguageConfig)
+                         TensorizerConfig, VisionLanguageConfig)
 from vllm.core.scheduler import Scheduler, SchedulerOutputs
 from vllm.engine.arg_utils import EngineArgs
 from vllm.engine.metrics import StatLogger, Stats
@@ -74,6 +74,7 @@ def __init__(
         lora_config: Optional[LoRAConfig],
         vision_language_config: Optional[VisionLanguageConfig],
         speculative_config: Optional[SpeculativeConfig],
+        tensorizer_config: Optional[TensorizerConfig],
         executor_class: Type[ExecutorBase],
         log_stats: bool,
         usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -110,6 +111,7 @@ def __init__(
         self.scheduler_config = scheduler_config
         self.device_config = device_config
         self.speculative_config = speculative_config
+        self.tensorizer_config = tensorizer_config
         self.log_stats = log_stats
 
         self._init_tokenizer()
@@ -125,6 +127,7 @@ def __init__(
             lora_config=lora_config,
             vision_language_config=vision_language_config,
             speculative_config=speculative_config,
+            tensorizer_config=tensorizer_config,
         )
 
         self._initialize_kv_caches()
@@ -264,6 +267,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
+        if self.tensorizer_config:
+            self.tensorizer_config.verify_with_parallel_config(
+                self.parallel_config)
         if self.lora_config:
             self.lora_config.verify_with_model_config(self.model_config)
             self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index f20221a0b941a..30577ecf62faa 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -2,7 +2,7 @@
 
 from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
                          ParallelConfig, SchedulerConfig, SpeculativeConfig,
-                         VisionLanguageConfig)
+                         TensorizerConfig, VisionLanguageConfig)
 from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from vllm.logger import init_logger
 from vllm.lora.request import LoRARequest
@@ -15,17 +15,14 @@
 
 class GPUExecutor(ExecutorBase):
 
-    def __init__(
-        self,
-        model_config: ModelConfig,
-        cache_config: CacheConfig,
-        parallel_config: ParallelConfig,
-        scheduler_config: SchedulerConfig,
-        device_config: DeviceConfig,
-        lora_config: Optional[LoRAConfig],
-        vision_language_config: Optional[VisionLanguageConfig],
-        speculative_config: Optional[SpeculativeConfig],
-    ) -> None:
+    def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
+                 parallel_config: ParallelConfig,
+                 scheduler_config: SchedulerConfig,
+                 device_config: DeviceConfig,
+                 lora_config: Optional[LoRAConfig],
+                 vision_language_config: Optional[VisionLanguageConfig],
+                 speculative_config: Optional[SpeculativeConfig],
+                 tensorizer_config: Optional[TensorizerConfig]) -> None:
         self.model_config = model_config
         self.cache_config = cache_config
         self.lora_config = lora_config
@@ -33,6 +30,7 @@ def __init__(
         self.scheduler_config = scheduler_config
         self.device_config = device_config
         self.vision_language_config = vision_language_config
+        self.tensorizer_config = tensorizer_config
 
         assert (not speculative_config
                 ), "Speculative decoding not yet supported for GPU backend"
@@ -61,6 +59,7 @@ def _init_worker(self):
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,
+            tensorizer_config=self.tensorizer_config,
             is_driver_worker=True,
         )
         self.driver_worker.init_device()
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index b937693c92257..28dc3e0db312a 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -7,7 +7,7 @@
 
 from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
                          ParallelConfig, SchedulerConfig, SpeculativeConfig,
-                         VisionLanguageConfig)
+                         TensorizerConfig, VisionLanguageConfig)
 from vllm.engine.ray_utils import RayWorkerVllm, ray
 from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from vllm.logger import init_logger
@@ -42,6 +42,7 @@ def __init__(
         lora_config: Optional[LoRAConfig],
         vision_language_config: Optional[VisionLanguageConfig],
         speculative_config: Optional[SpeculativeConfig],
+        tensorizer_config: Optional[TensorizerConfig],
     ) -> None:
         self.model_config = model_config
         self.cache_config = cache_config
@@ -50,6 +51,7 @@ def __init__(
         self.scheduler_config = scheduler_config
         self.device_config = device_config
         self.vision_language_config = vision_language_config
+        self.tensorizer_config = tensorizer_config
         assert (not speculative_config
                 ), "Speculative decoding not yet supported for RayGPU backend."
 
@@ -171,6 +173,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
                     distributed_init_method=distributed_init_method,
                     lora_config=lora_config,
                     vision_language_config=vision_language_config,
+                    tensorizer_config=self.tensorizer_config,
                 ))
 
         # Initialize the driver worker with the Worker class.
@@ -187,6 +190,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,
+            tensorizer_config=self.tensorizer_config,
             is_driver_worker=True,
         )
 
diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py
index 2745dbd89ab0f..c70ca48bca70a 100644
--- a/vllm/model_executor/model_loader.py
+++ b/vllm/model_executor/model_loader.py
@@ -3,11 +3,14 @@
 from typing import Tuple, Type
 
 import torch
-import torch.nn as nn
+from torch import nn
 
 from vllm.config import DeviceConfig, ModelConfig
 from vllm.model_executor.models import ModelRegistry
 from vllm.model_executor.models.llava import LlavaForConditionalGeneration
+from vllm.model_executor.tensorizer_loader import (
+    ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
+    load_with_tensorizer)
 from vllm.model_executor.weight_utils import (get_quant_config,
                                               initialize_dummy_weights)
 
@@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
               **kwargs) -> nn.Module:
     lora_config = kwargs.get("lora_config", None)
     vision_language_config = kwargs.get("vision_language_config", None)
+    tensorizer_config = kwargs.get("tensorizer_config", None)
     model_class = _get_model_architecture(model_config)[0]
 
     # Get the (maybe quantized) linear method.
@@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
                 f"{model_config.dtype} is not supported for quantization "
                 f"method {model_config.quantization}. Supported dtypes: "
                 f"{supported_dtypes}")
+
         linear_method = quant_config.get_linear_method()
 
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # The weights will be initialized as empty tensors.
+        extra_kwargs = {}
+        if hasattr(model_class, "supported_lora_modules"):
+            extra_kwargs["lora_config"] = lora_config
+        elif lora_config:
+            raise ValueError(
+                f"Model {model_class.__name__} does not support LoRA, "
+                "but LoRA is enabled. Support for this model may "
+                "be added in the future. If this is important to you, "
+                "please open an issue on github.")
+        elif model_class in _VISION_MODEL_CLASSES:
+            extra_kwargs["vision_language_config"] = vision_language_config
+
         with torch.device(device_config.device):
-            if hasattr(model_class, "supported_lora_modules"):
-                model = model_class(model_config.hf_config, linear_method,
-                                    lora_config)
-            elif lora_config:
-                raise ValueError(
-                    f"Model {model_class.__name__} does not support LoRA, "
-                    "but LoRA is enabled. Support for this model may "
-                    "be added in the future. If this is important to you, "
-                    "please open an issue on github.")
-            else:
-                if model_class not in _VISION_MODEL_CLASSES:
-                    model = model_class(model_config.hf_config, linear_method)
-                else:
-                    model = model_class(model_config.hf_config,
-                                        vision_language_config, linear_method)
+            if (model_config.load_format == "tensorizer"
+                    and is_vllm_serialized_tensorizer(tensorizer_config)):
+                extra_kwargs["linear_method"] = linear_method
+                tensorizer_config.model_class = model_class
+                tensorizer_config.hf_config = model_config.hf_config
+                tensorizer_config.dtype = model_config.dtype
+                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
+                return model.eval()
+            model = model_class(config=model_config.hf_config,
+                                linear_method=linear_method,
+                                **extra_kwargs)
         if model_config.load_format == "dummy":
             # NOTE(woosuk): For accurate performance evaluation, we assign
             # random values to the weights.
             initialize_dummy_weights(model)
         else:
             # Load the weights from the cached or downloaded files.
-            model.load_weights(model_config.model, model_config.download_dir,
-                               model_config.load_format, model_config.revision)
+            if model_config.load_format == "tensorizer":
+                # Provide a dynamic load format for `model.load_weights`
+                # to retain tensorizer args from CLI.
+                model_config.load_format = ParameterizedLoadFormat(
+                    model_config.load_format)
+                model_config.load_format.params = (
+                    tensorizer_config._construct_tensorizer_args())
+
+            model.load_weights(
+                model_config.model,
+                model_config.download_dir,
+                model_config.load_format,
+                model_config.revision,
+            )
     return model.eval()
diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py
new file mode 100644
index 0000000000000..ed3ad9e2ffa15
--- /dev/null
+++ b/vllm/model_executor/tensorizer_loader.py
@@ -0,0 +1,319 @@
+import argparse
+import dataclasses
+import io
+import os
+import time
+import typing
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from vllm.config import TensorizerConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import LinearMethodBase
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding)
+
+tensorizer_load_fail = False
+
+try:
+    from tensorizer import (DecryptionParams, EncryptionParams,
+                            TensorDeserializer, TensorSerializer)
+    from tensorizer.stream_io import open_stream
+    from tensorizer.utils import (convert_bytes, get_mem_usage,
+                                  no_init_or_tensor)
+except ImportError:
+    tensorizer_load_fail = True
+
+__all__ = [
+    'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
+    'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
+    'no_init_or_tensor'
+]
+
+logger = init_logger(__name__)
+
+
+def load_with_tensorizer(tensorizer_config: TensorizerConfig,
+                         **extra_kwargs) -> nn.Module:
+    tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
+    return tensorizer.deserialize()
+
+
+def tensorizer_warning(message: str):
+    return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
+
+
+def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
+    if tensorizer_config is None:
+        return False
+    return tensorizer_config.vllm_tensorized
+
+
+class ParameterizedLoadFormat(str):
+    __slots__ = "params"
+
+
+class PerformanceWarning(UserWarning):
+
+    def __str__(self):
+        return (f"{super().__str__()}"
+                " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
+                " environment variable to hide this)")
+
+
+if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
+        not in ("", "0", "n", "no", "off", "disable")):
+    warnings.simplefilter("ignore", category=PerformanceWarning)
+
+
+@dataclass
+class TensorizerArgs:
+    tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+                          str, bytes, os.PathLike, int]
+    vllm_tensorized: bool
+    verify_hash: Optional[bool] = False
+    num_readers: Optional[int] = 1
+    encryption_keyfile: Optional[str] = None
+    s3_access_key_id: Optional[str] = None
+    s3_secret_access_key: Optional[str] = None
+    s3_endpoint: Optional[str] = None
+    """
+  Args for the TensorizerAgent class. These are used to configure the behavior 
+  of the TensorDeserializer when loading tensors from a serialized model.
+  
+  Args:
+      tensorizer_uri: Path to serialized model tensors. Can be a local file 
+          path or a S3 URI.
+      vllm_tensorized: If True, indicates that the serialized model is a 
+          vLLM model. This is used to determine the behavior of the 
+          TensorDeserializer when loading tensors from a serialized model.
+          It is far faster to deserialize a vLLM model as it utilizes
+          tensorizer's optimized GPU loading.
+      verify_hash: If True, the hashes of each tensor will be verified against 
+          the hashes stored in the metadata. A `HashMismatchError` will be 
+          raised if any of the hashes do not match.
+      num_readers: Controls how many threads are allowed to read concurrently
+          from the source file. Default is 1. This greatly increases
+          performance.
+      encryption_keyfile: File path to a binary file containing a  
+          binary key to use for decryption. `None` (the default) means 
+          no decryption. See the example script in 
+          examples/tensorize_vllm_model.py. 
+      s3_access_key_id: The access key for the S3 bucket. Can also be set via
+          the S3_ACCESS_KEY_ID environment variable.
+      s3_secret_access_key: The secret access key for the S3 bucket. Can also
+          be set via the S3_SECRET_ACCESS_KEY environment variable.
+      s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
+          S3_ENDPOINT_URL environment variable.
+  """
+
+    def __post_init__(self):
+        self.file_obj = self.tensorizer_uri
+        self.s3_access_key_id = (self.s3_access_key_id
+                                 or os.environ.get("S3_ACCESS_KEY_ID")) or None
+        self.s3_secret_access_key = (
+            self.s3_secret_access_key
+            or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
+        self.s3_endpoint = (self.s3_endpoint
+                            or os.environ.get("S3_ENDPOINT_URL")) or None
+        self.stream_params = {
+            "s3_access_key_id": self.s3_access_key_id,
+            "s3_secret_access_key": self.s3_secret_access_key,
+            "s3_endpoint": self.s3_endpoint,
+        }
+
+        # Omitting self.dtype and self.device as this behaves weirdly
+        self.deserializer_params = {
+            "verify_hash": self.verify_hash,
+            "encryption": self.encryption_keyfile,
+            "num_readers": self.num_readers
+        }
+        if self.encryption_keyfile:
+            with open_stream(
+                    self.encryption_keyfile,
+                    **self.stream_params,
+            ) as stream:
+                key = stream.read()
+                decryption_params = DecryptionParams.from_key(key)
+                self.deserializer_params['encryption'] = decryption_params
+
+    def add_cli_args(
+            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+        """Tensorizer CLI arguments"""
+
+        # Create the argument group
+        group = parser.add_argument_group(
+            'tensorizer options',
+            description=('Options for configuring the behavior of the'
+                         ' tensorizer deserializer when '
+                         '--load-format=tensorizer'))
+
+        group.add_argument(
+            "--tensorizer-uri",
+            help="Path to serialized model tensors. Can be a local file path,"
+            " or an HTTP(S) or S3 URI.",
+        )
+        group.add_argument(
+            "--verify-hash",
+            action="store_true",
+            help="If enabled, the hashes of each tensor will be verified"
+            " against the hashes stored in the file metadata. An exception"
+            " will be raised if any of the hashes do not match.",
+        )
+        group.add_argument(
+            "--encryption-keyfile",
+            default=None,
+            help="The file path to a binary file containing a binary key to "
+            "use for decryption. Can be a file path or S3 network URI.")
+        group.add_argument(
+            "--num-readers",
+            default=1,
+            type=int,
+            help="Controls how many threads are allowed to read concurrently "
+            "from the source file.")
+        group.add_argument(
+            "--s3-access-key-id",
+            default=None,
+            help="The access key for the S3 bucket. Can also be set via the "
+            "S3_ACCESS_KEY_ID environment variable.",
+        )
+        group.add_argument(
+            "--s3-secret-access-key",
+            default=None,
+            help="The secret access key for the S3 bucket. Can also be set via "
+            "the S3_SECRET_ACCESS_KEY environment variable.",
+        )
+        group.add_argument(
+            "--s3-endpoint",
+            default=None,
+            help="The endpoint for the S3 bucket. Can also be set via the "
+            "S3_ENDPOINT_URL environment variable.",
+        )
+        group.add_argument(
+            "--vllm-tensorized",
+            action="store_true",
+            help="If enabled, indicates that the serialized model is a vLLM "
+            "model. This is used to determine the behavior of the "
+            "TensorDeserializer when loading tensors from a "
+            "serialized model.")
+
+        return parser
+
+    @classmethod
+    def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
+        # Get the list of attributes of this dataclass.
+        attrs = [attr.name for attr in dataclasses.fields(cls)]
+        # Set the attributes from the parsed arguments.
+        tensorizer_args = cls(**{
+            attr: getattr(args, attr)
+            for attr in attrs if hasattr(args, attr)
+        })
+        return tensorizer_args
+
+
+class TensorizerAgent:
+    """
+    A class for performing tensorizer deserializations specifically for
+    vLLM models using plaid_mode. Uses TensorizerArgs to configure the
+    behavior of the TensorDeserializer when loading tensors from a serialized
+    model. For deserializations of HuggingFace models, TensorDeserializer is
+    instead used as an iterator directly in the func hf_model_weights_iterator
+    in vllm/model_executor/weight_utils.py
+    """
+
+    def __init__(self, tensorizer_config: TensorizerConfig,
+                 linear_method: LinearMethodBase, **extra_kwargs):
+        self.tensorizer_config = tensorizer_config
+        self.tensorizer_args = (
+            self.tensorizer_config._construct_tensorizer_args())
+        self.extra_kwargs = extra_kwargs
+        if extra_kwargs.get("linear_method", None) is not None:
+            self.linear_method = extra_kwargs["linear_method"]
+        else:
+            self.linear_method = linear_method
+        self.model = self._init_model()
+
+        if tensorizer_load_fail:
+            raise ImportError(
+                "Tensorizer is not installed. Please install tensorizer "
+                "to use this feature with `pip install vllm[tensorizer]`.")
+
+    def _init_model(self):
+        model_args = self.tensorizer_config.hf_config
+        model_args.torch_dtype = self.tensorizer_config.dtype
+        with no_init_or_tensor():
+            return self.tensorizer_config.model_class(
+                config=model_args,
+                linear_method=self.linear_method,
+                **self.extra_kwargs)
+
+    def _resize_lora_embeddings(self):
+        """Modify LoRA embedding layers to use bigger tensors
+        to allow for adapter added tokens."""
+        for child in self.model.modules():
+            if (isinstance(child, VocabParallelEmbedding)
+                    and child.weight.shape[0] <
+                    child.num_embeddings_per_partition):
+                new_weight = torch.empty(child.num_embeddings_per_partition,
+                                         child.embedding_dim,
+                                         dtype=child.weight.dtype,
+                                         device=child.weight.device)
+                new_weight[:child.weight.shape[0]].copy_(child.weight.data)
+                new_weight[child.weight.shape[0]:].fill_(0)
+                child.weight.data = new_weight
+
+    def _check_tensors_on_meta_device(self):
+        for tensor in self.model.state_dict().values():
+            if tensor.device.type == 'meta':
+                raise ValueError(
+                    "The serialized model contains tensors on the meta device,"
+                    " indicating that some tensors were not loaded properly."
+                    " Please check that the parameters of the model being"
+                    " specified match that of the serialized model, such as"
+                    " its quantization.")
+
+    def deserialize(self):
+        """
+        Deserialize the model using the TensorDeserializer. This method is
+        specifically for vLLM models using tensorizer's plaid_mode.
+
+        The deserializer makes use of tensorizer_args.stream_params
+        to configure the behavior of the stream when loading tensors from a
+        serialized model. The deserializer_params are used to configure the
+        behavior of the TensorDeserializer when loading tensors themselves.
+        Documentation on these params can be found in TensorizerArgs
+
+        Returns:
+            nn.Module: The deserialized model.
+        """
+        before_mem = get_mem_usage()
+        # Lazy load the tensors from S3 into the model.
+        start = time.perf_counter()
+        with open_stream(
+                self.tensorizer_args.tensorizer_uri,
+                mode="rb",
+                **self.tensorizer_args.stream_params,
+        ) as stream, TensorDeserializer(
+                stream,
+                dtype=self.tensorizer_config.dtype,
+                **self.tensorizer_args.deserializer_params) as deserializer:
+            deserializer.load_into_module(self.model)
+            end = time.perf_counter()
+
+        total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+        duration = end - start
+        per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+        after_mem = get_mem_usage()
+        deserializer.close()
+        logger.info(f"Deserialized {total_bytes_str} in "
+                    f"{end - start:0.2f}s, {per_second}/s")
+        logger.info(f"Memory usage before: {before_mem}")
+        logger.info(f"Memory usage after: {after_mem}")
+
+        self._check_tensors_on_meta_device()
+        self._resize_lora_embeddings()
+        return self.model.eval()
diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py
index 0961478930d74..08425604f0511 100644
--- a/vllm/model_executor/weight_utils.py
+++ b/vllm/model_executor/weight_utils.py
@@ -5,7 +5,7 @@
 import json
 import os
 from collections import defaultdict
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
 
 import filelock
 import huggingface_hub.constants
@@ -161,7 +161,8 @@ def prepare_hf_model_weights(
     revision: Optional[str] = None,
 ) -> Tuple[str, List[str], bool]:
     # Download model weights from huggingface.
-    is_local = os.path.isdir(model_name_or_path)
+    is_local = os.path.isdir(model_name_or_path) \
+               and load_format != "tensorizer"
     use_safetensors = False
     # Some quantized models use .pt files for storing the weights.
     if load_format == "auto":
@@ -173,13 +174,15 @@ def prepare_hf_model_weights(
         allow_patterns = ["*.pt"]
     elif load_format == "npcache":
         allow_patterns = ["*.bin"]
+    elif load_format == "tensorizer":
+        allow_patterns = ["*.tensors"]
     else:
         raise ValueError(f"Unknown load_format: {load_format}")
 
     if fall_back_to_pt:
         allow_patterns += ["*.pt"]
 
-    if not is_local:
+    if not is_local and load_format != "tensorizer":
         # Before we download we look at that is available:
         fs = HfFileSystem()
         file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
@@ -224,6 +227,9 @@ def prepare_hf_model_weights(
             if not any(f.endswith(x) for x in blacklist)
         ]
 
+    if load_format == "tensorizer":
+        return hf_folder, hf_weights_files, use_safetensors
+
     if len(hf_weights_files) == 0:
         raise RuntimeError(
             f"Cannot find any model weights with `{model_name_or_path}`")
@@ -234,7 +240,7 @@ def prepare_hf_model_weights(
 def hf_model_weights_iterator(
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
-    load_format: str = "auto",
+    load_format: Union[Tuple, str] = "auto",
     revision: Optional[str] = None,
     fall_back_to_pt: Optional[bool] = True,
 ) -> Iterator[Tuple[str, torch.Tensor]]:
@@ -277,6 +283,26 @@ def hf_model_weights_iterator(
             with open(param_path, "rb") as f:
                 param = np.load(f)
             yield name, torch.from_numpy(param)
+    elif load_format == "tensorizer":
+        from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
+                                                           open_stream,
+                                                           tensorizer_warning)
+        tensorizer_args = load_format.params
+        tensorizer_warning(
+            "Deserializing HuggingFace models is not optimized for "
+            "loading on vLLM, as tensorizer is forced to load to CPU. "
+            "Consider deserializing a vLLM model instead for faster "
+            "load times. See the examples/tensorize_vllm_model.py example "
+            "script for serializing vLLM models.")
+
+        deserializer_args = tensorizer_args.deserializer_params
+        stream_params = tensorizer_args.stream_params
+        stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
+        with TensorDeserializer(stream, **deserializer_args,
+                                device="cpu") as state:
+            for name, param in state.items():
+                yield name, param
+        del state
     elif use_safetensors:
         for st_file in hf_weights_files:
             with safe_open(st_file, framework="pt") as f:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 47ad8f0c9b78b..7dbe14ead0976 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -10,7 +10,8 @@
 from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                             get_attn_backend)
 from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
-                         SchedulerConfig, VisionLanguageConfig)
+                         SchedulerConfig, TensorizerConfig,
+                         VisionLanguageConfig)
 from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
 from vllm.distributed.device_communicators import (custom_all_reduce,
                                                    pynccl_utils)
@@ -111,11 +112,13 @@ def __init__(
         kv_cache_dtype: Optional[str] = "auto",
         is_driver_worker: bool = False,
         vision_language_config: Optional[VisionLanguageConfig] = None,
+        tensorizer_config: Optional[TensorizerConfig] = None,
     ):
         self.model_config = model_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
         self.lora_config = lora_config
+        self.tensorizer_config = tensorizer_config
         self.is_driver_worker = is_driver_worker
 
         # model_config can be None in tests/samplers/test_sampler.py.
@@ -158,7 +161,9 @@ def load_model(self) -> None:
                 lora_config=self.lora_config,
                 vision_language_config=self.vision_language_config,
                 parallel_config=self.parallel_config,
-                scheduler_config=self.scheduler_config)
+                scheduler_config=self.scheduler_config,
+                tensorizer_config=self.tensorizer_config,
+            )
 
         self.model_memory_usage = m.consumed_memory
         logger.info(f"Loading model weights took "
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 3f0b2fd83f3e5..82491c6df6616 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -7,7 +7,8 @@
 import torch.distributed
 
 from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
-                         ParallelConfig, SchedulerConfig, VisionLanguageConfig)
+                         ParallelConfig, SchedulerConfig, TensorizerConfig,
+                         VisionLanguageConfig)
 from vllm.distributed import (broadcast_tensor_dict,
                               ensure_model_parallel_initialized,
                               init_distributed_environment)
@@ -42,6 +43,7 @@ def __init__(
         distributed_init_method: str,
         lora_config: Optional[LoRAConfig] = None,
         vision_language_config: Optional[VisionLanguageConfig] = None,
+        tensorizer_config: Optional[TensorizerConfig] = None,
         is_driver_worker: bool = False,
     ) -> None:
         self.model_config = model_config
@@ -53,6 +55,7 @@ def __init__(
         self.rank = rank
         self.distributed_init_method = distributed_init_method
         self.lora_config = lora_config
+        self.tensorizer_config = tensorizer_config
         self.is_driver_worker = is_driver_worker
         if self.is_driver_worker:
             assert self.rank == 0, "The driver worker must have rank 0."
@@ -70,7 +73,9 @@ def __init__(
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
             is_driver_worker=is_driver_worker,
-            vision_language_config=vision_language_config)
+            vision_language_config=vision_language_config,
+            tensorizer_config=tensorizer_config,
+        )
         # Uninitialized cache engine. Will be initialized by
         # initialize_cache.
         self.cache_engine = None