Skip to content

Commit

Permalink
[Core] Implement sharded state loader (vllm-project#4690)
Browse files Browse the repository at this point in the history
Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
aurickq and WoosukKwon authored May 16, 2024
1 parent 6643d72 commit 5d94657
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 0 deletions.
75 changes: 75 additions & 0 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save
Then, the model can be loaded with
llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")
parser.add_argument("--file-pattern",
type=str,
help="string pattern of saved filenames")
parser.add_argument("--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
if engine_args.enable_lora:
raise ValueError("Saving with enable_lora=True is not supported!")
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output,
pattern=args.file_pattern,
max_size=args.max_file_size)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(os.path.join(model_path, file),
os.path.join(args.output, file))
else:
shutil.copy(os.path.join(model_path, file), args.output)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
90 changes: 90 additions & 0 deletions tests/test_sharded_state_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import shutil
from tempfile import TemporaryDirectory

import pytest
import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.model_executor.model_loader.loader import ShardedStateLoader

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
seed=0,
max_tokens=256,
ignore_eos=True,
)


def test_filter_subtensors():
state_dict = {
"a": torch.empty(2),
"b": torch.empty((2, 4)),
"c": torch.empty((2, 4, 8)),
}
state_dict.update({
"x": state_dict["b"],
"y": state_dict["c"][1, 2, :],
"z": state_dict["c"][1, :, 4],
})
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items():
assert tensor.equal(state_dict[key])


@pytest.mark.parametrize("enable_lora", [False, True])
def test_sharded_state_loader(enable_lora):
weights_patterns = ("*.bin", "*.pt", "*.safetensors")

with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
cache_dir=cache_dir)

llm = LLM(
model=input_dir,
worker_use_ray=True,
gpu_memory_utilization=0.3,
)

# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
del llm.llm_engine.model_executor

llm_before = LLM(
model=input_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
)
gen_before = llm_before.generate(prompts, sampling_params)
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
del llm_before.llm_engine.model_executor

llm_after = LLM(
model=output_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
load_format="sharded_state",
)
gen_after = llm_after.generate(prompts, sampling_params)
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
del llm_after.llm_engine.model_executor

assert out_before == out_after
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"


@dataclass
Expand Down
11 changes: 11 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _run_workers(
self,
Expand Down
148 changes: 148 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: SIM117
import collections
import copy
import glob
import os
Expand Down Expand Up @@ -366,6 +367,150 @@ def load_model(self, *, model_config: ModelConfig,
cache_config)


class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""

DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")

@staticmethod
def _filter_subtensors(
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups = collections.defaultdict(list)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))

def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

result = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
model_config.model,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s", tensor.shape,
key, param_shape)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file

from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: Dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

return DefaultModelLoader(load_config)
14 changes: 14 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ def load_model(self) -> None:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)

def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
Expand Down
12 changes: 12 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)

@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
Expand Down

0 comments on commit 5d94657

Please sign in to comment.