Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Implement sharded state loader #4690

Merged
merged 22 commits into from
May 16, 2024
75 changes: 75 additions & 0 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")
parser.add_argument("--file-pattern",
type=str,
help="string pattern of saved filenames")
parser.add_argument("--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
if engine_args.enable_lora:
raise ValueError("Saving with enable_lora=True is not supported!")
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output,
pattern=args.file_pattern,
max_size=args.max_file_size)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(os.path.join(model_path, file),
os.path.join(args.output, file))
else:
shutil.copy(os.path.join(model_path, file), args.output)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
90 changes: 90 additions & 0 deletions tests/test_sharded_state_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import shutil
from tempfile import TemporaryDirectory

import pytest
import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.model_executor.model_loader.loader import ShardedStateLoader

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
seed=0,
max_tokens=256,
ignore_eos=True,
)


def test_filter_subtensors():
state_dict = {
"a": torch.empty(2),
"b": torch.empty((2, 4)),
"c": torch.empty((2, 4, 8)),
}
state_dict.update({
"x": state_dict["b"],
"y": state_dict["c"][1, 2, :],
"z": state_dict["c"][1, :, 4],
})
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items():
assert tensor.equal(state_dict[key])


@pytest.mark.parametrize("enable_lora", [False, True])
def test_sharded_state_loader(enable_lora):
weights_patterns = ("*.bin", "*.pt", "*.safetensors")

with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
cache_dir=cache_dir)

llm = LLM(
model=input_dir,
worker_use_ray=True,
gpu_memory_utilization=0.3,
)

# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
del llm.llm_engine.model_executor

llm_before = LLM(
model=input_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
)
gen_before = llm_before.generate(prompts, sampling_params)
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
del llm_before.llm_engine.model_executor

llm_after = LLM(
model=output_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
load_format="sharded_state",
)
gen_after = llm_after.generate(prompts, sampling_params)
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
del llm_after.llm_engine.model_executor

assert out_before == out_after
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"


@dataclass
Expand Down
11 changes: 11 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _run_workers(
self,
Expand Down
148 changes: 148 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: SIM117
import collections
import copy
import glob
import os
Expand Down Expand Up @@ -366,6 +367,150 @@ def load_model(self, *, model_config: ModelConfig,
cache_config)


class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""

DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")

@staticmethod
def _filter_subtensors(
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups = collections.defaultdict(list)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))

def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

result = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
model_config.model,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s", tensor.shape,
key, param_shape)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file

from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: Dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

return DefaultModelLoader(load_config)
14 changes: 14 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ def load_model(self) -> None:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)

def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
Expand Down
12 changes: 12 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)

@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
Expand Down
Loading