Skip to content

Commit

Permalink
[Frontend] [Core] perf: Automatically detect vLLM-tensorized model, u…
Browse files Browse the repository at this point in the history
…pdate `tensorizer` to version 2.9.0 (#4208)
  • Loading branch information
sangstar authored May 13, 2024
1 parent 0fca3cd commit 8bc68e1
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 523 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ steps:
mirror_hardwares: [amd]
commands:
# install aws cli for llava_example.py
- pip install awscli
# install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- python3 offline_inference.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors

- label: Kernels Test %N
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
Expand Down
200 changes: 81 additions & 119 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import argparse
import dataclasses
import json
import os
import time
import uuid
from functools import partial
from typing import Type

import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from tensorizer import stream_io

from vllm.distributed import initialize_model_parallel
from vllm import LLM
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
serialize_vllm_model)

# yapf conflicts with isort for this docstring
# yapf: disable
Expand All @@ -27,25 +24,25 @@
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
or locally. Tensor encryption and decryption is also supported, although
libsodium must be installed to use it. Install vllm with tensorizer support
using `pip install vllm[tensorizer]`.
using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
https://github.com/coreweave/tensorizer
To serialize a model, install vLLM from source, then run something
like this from the root level of this repository:
python -m examples.tensorize_vllm_model \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
--model facebook/opt-125m \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
--serialized-directory s3://my-bucket \
--suffix v1
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used. This
assumes your S3 credentials are specified as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
script.
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and
`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide
`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint`
as CLI args to this script.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
Expand All @@ -57,7 +54,7 @@
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
Expand All @@ -71,26 +68,30 @@
`python -m examples.tensorize_vllm_model deserialize --help`.
Once a model is serialized, it can be used to load the model when running the
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
the `--tensorizer-uri` CLI argument that is functionally the same as the
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
signify that the model to be deserialized is a vLLM model, rather than a
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
in the same inference server, albeit without the speed optimizations. To
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
to provide the path to the keyfile used to encrypt the model weights. For
information on all the arguments that can be used to configure tensorizer's
deserialization, check out the tensorizer options argument group in the
`vllm/entrypoints/openai/api_server.py` script with `--help`.
Tensorizer can also be invoked with the `LLM` class directly to load models:
Once a model is serialized, tensorizer can be invoked with the `LLM` class
directly to load models:
llm = LLM(model="facebook/opt-125m",
load_format="tensorizer",
tensorizer_uri=path_to_opt_tensors,
num_readers=3,
vllm_tensorized=True)
model_loader_extra_config=TensorizerConfig(
tensorizer_uri = path_to_tensors,
num_readers=3,
)
)
A serialized model can be used during model loading for the vLLM OpenAI
inference server. `model_loader_extra_config` is exposed as the CLI arg
`--model-loader-extra-config`, and accepts a JSON string literal of the
TensorizerConfig arguments desired.
In order to see all of the available arguments usable to configure
loading with tensorizer that are given to `TensorizerConfig`, run:
`python -m examples.tensorize_vllm_model deserialize --help`
under the `tensorizer options` section. These can also be used for
deserialization in this example script, although `--tensorizer-uri` and
`--path-to-tensors` are functionally the same in this case.
"""


Expand Down Expand Up @@ -158,95 +159,35 @@ def parse_args():
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))

return parser.parse_args()


def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()


def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")


def serialize():

eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
TensorizerArgs.add_cli_args(deserialize_parser)

model = (engine.model_executor.driver_worker.
model_runner.model)

encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)

with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
return parser.parse_args()

print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)


def deserialize():
config = AutoConfig.from_pretrained(model_ref)

with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)

before_mem = get_mem_usage()
start = time.time()

if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params

with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()

# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
llm = LLM(model=args.model,
load_format="tensorizer",
model_loader_extra_config=tensorizer_config
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return llm

return model


args = parse_args()

s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}

_read_stream, _write_stream = (partial(
stream_io.open_stream,
Expand All @@ -263,20 +204,41 @@ def deserialize():
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"

torch.distributed.init_process_group(world_size=1, rank=0)
init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel()

keyfile = args.keyfile if args.keyfile else None


if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None

if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)

input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
**credentials)
serialize_vllm_model(engine, tensorizer_config, keyfile)
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ types-setuptools

# testing
pytest
tensorizer==2.9.0
tensorizer>=2.9.0
pytest-forked
pytest-asyncio
pytest-rerunfailures
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _read_requirements(filename: str) -> List[str]:
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer==2.9.0"],
"tensorizer": ["tensorizer>=2.9.0"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
Expand Down
Loading

0 comments on commit 8bc68e1

Please sign in to comment.