forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Implement sharded state loader (vllm-project#4690)
Co-authored-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
6643d72
commit 5d94657
Showing
7 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
Saves each worker's model state dict directly to a checkpoint, which enables a | ||
fast load path for large tensor-parallel models where each worker only needs to | ||
read its own shard rather than the entire checkpoint. | ||
Example usage: | ||
python save_sharded_state.py \ | ||
--model /path/to/load \ | ||
--quantization deepspeedfp \ | ||
--tensor-parallel-size 8 \ | ||
--output /path/to/save | ||
Then, the model can be loaded with | ||
llm = LLM( | ||
model="/path/to/save", | ||
load_format="sharded_state", | ||
quantization="deepspeedfp", | ||
tensor_parallel_size=8, | ||
) | ||
""" | ||
import argparse | ||
import dataclasses | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
from vllm import LLM, EngineArgs | ||
|
||
parser = argparse.ArgumentParser() | ||
EngineArgs.add_cli_args(parser) | ||
parser.add_argument("--output", | ||
"-o", | ||
required=True, | ||
type=str, | ||
help="path to output checkpoint") | ||
parser.add_argument("--file-pattern", | ||
type=str, | ||
help="string pattern of saved filenames") | ||
parser.add_argument("--max-file-size", | ||
type=str, | ||
default=5 * 1024**3, | ||
help="max size (in bytes) of each safetensors file") | ||
|
||
|
||
def main(args): | ||
engine_args = EngineArgs.from_cli_args(args) | ||
if engine_args.enable_lora: | ||
raise ValueError("Saving with enable_lora=True is not supported!") | ||
model_path = engine_args.model | ||
if not Path(model_path).is_dir(): | ||
raise ValueError("model path must be a local directory") | ||
# Create LLM instance from arguments | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
# Prepare output directory | ||
Path(args.output).mkdir(exist_ok=True) | ||
# Dump worker states to output directory | ||
model_executor = llm.llm_engine.model_executor | ||
model_executor.save_sharded_state(path=args.output, | ||
pattern=args.file_pattern, | ||
max_size=args.max_file_size) | ||
# Copy metadata files to output directory | ||
for file in os.listdir(model_path): | ||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): | ||
if os.path.isdir(os.path.join(model_path, file)): | ||
shutil.copytree(os.path.join(model_path, file), | ||
os.path.join(args.output, file)) | ||
else: | ||
shutil.copy(os.path.join(model_path, file), args.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
import shutil | ||
from tempfile import TemporaryDirectory | ||
|
||
import pytest | ||
import torch | ||
from huggingface_hub import snapshot_download | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.model_executor.model_loader.loader import ShardedStateLoader | ||
|
||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
# Create a sampling params object. | ||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
seed=0, | ||
max_tokens=256, | ||
ignore_eos=True, | ||
) | ||
|
||
|
||
def test_filter_subtensors(): | ||
state_dict = { | ||
"a": torch.empty(2), | ||
"b": torch.empty((2, 4)), | ||
"c": torch.empty((2, 4, 8)), | ||
} | ||
state_dict.update({ | ||
"x": state_dict["b"], | ||
"y": state_dict["c"][1, 2, :], | ||
"z": state_dict["c"][1, :, 4], | ||
}) | ||
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) | ||
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") | ||
for key, tensor in filtered_state_dict.items(): | ||
assert tensor.equal(state_dict[key]) | ||
|
||
|
||
@pytest.mark.parametrize("enable_lora", [False, True]) | ||
def test_sharded_state_loader(enable_lora): | ||
weights_patterns = ("*.bin", "*.pt", "*.safetensors") | ||
|
||
with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: | ||
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", | ||
cache_dir=cache_dir) | ||
|
||
llm = LLM( | ||
model=input_dir, | ||
worker_use_ray=True, | ||
gpu_memory_utilization=0.3, | ||
) | ||
|
||
# Dump worker states to output directory | ||
model_executor = llm.llm_engine.model_executor | ||
model_executor.save_sharded_state(path=output_dir) | ||
# Copy metadata files to output directory | ||
for file in os.listdir(input_dir): | ||
if not any(file.endswith(ext) for ext in weights_patterns): | ||
shutil.copy(f"{input_dir}/{file}", output_dir) | ||
del llm.llm_engine.model_executor | ||
|
||
llm_before = LLM( | ||
model=input_dir, | ||
worker_use_ray=True, | ||
enable_lora=enable_lora, | ||
gpu_memory_utilization=0.3, | ||
) | ||
gen_before = llm_before.generate(prompts, sampling_params) | ||
out_before = [gen.outputs[0].__dict__ for gen in gen_before] | ||
del llm_before.llm_engine.model_executor | ||
|
||
llm_after = LLM( | ||
model=output_dir, | ||
worker_use_ray=True, | ||
enable_lora=enable_lora, | ||
gpu_memory_utilization=0.3, | ||
load_format="sharded_state", | ||
) | ||
gen_after = llm_after.generate(prompts, sampling_params) | ||
out_after = [gen.outputs[0].__dict__ for gen in gen_after] | ||
del llm_after.llm_engine.model_executor | ||
|
||
assert out_before == out_after |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters