Skip to content

Commit

Permalink
[xFasterTransformer] Add xfastertransformer support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Duyi-Wang committed May 22, 2024
1 parent 11ce842 commit 513598e
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 184 deletions.
48 changes: 48 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
This is a fork of vLLM to support xfastertransformer backend. This version is based on official vllm `v0.4.2`.
## Notice
🎉🎉🎉***Continuous batching is supported.*** 🎇🎇🎇
- Distributed is not support yet.(WIP)
- BeamSearch is not support yet.(WIP)
- LORA is not support yet.(WIP)

## Install
### From PyPI
`pip install vllm-xft`

### From Source
`python3 setup.py bdist_wheel --verbose`

## Usage
### Python offline
```
python examples/offline_inference_xfastertransformer.py
```
### Serving(OpenAI Compatible Server)
```shell
python -m vllm.entrypoints.openai.api_server \
--model /data/llama-2-7b-chat-cpu \
--tokenizer /data/llama-2-7b-chat-hf \
--dtype fp16 \
--kv-cache-dtype fp16 \
--served-model-name xft \
--port 8000 \
--trust-remote-code \
```
- `--max-num-batched-tokens`: max batched token, default value is max(MAX_SEQ_LEN_OF_MODEL, 2048).
- `--max-num-seqs`: max seqs batch, default is 256.

More Arguments please refer to [vllm official docs](https://docs.vllm.ai/en/latest/models/engine_args.html)

### Query example
```shell
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "xft",
"prompt": "San Francisco is a",
"max_tokens": 512,
"temperature": 0
}'
```


<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
Expand Down
5 changes: 3 additions & 2 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.0+cpu
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
# torch == 2.3.0+cpu
# triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
xfastertransformer > 1.6.0
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def find_version(filepath: str) -> str:

def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))
return version

if _is_cuda():
cuda_version = str(get_nvcc_cuda_version())
Expand Down Expand Up @@ -352,6 +353,7 @@ def _read_requirements(filename: str) -> List[str]:
else:
resolved_requirements.append(line)
return resolved_requirements
return _read_requirements("requirements-cpu.txt")

if _is_cuda():
requirements = _read_requirements("requirements-cuda.txt")
Expand Down Expand Up @@ -420,10 +422,10 @@ def _read_requirements(filename: str) -> List[str]:
"tests*")),
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
# ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer==2.9.0"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
# cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)
29 changes: 19 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@ class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
model: Name or path of the xfastertransformer model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
dtype: Data type for model weights and activations.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
Expand Down Expand Up @@ -117,11 +115,18 @@ def __init__(
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,

import os
if not os.path.exists(model):
raise RuntimeError("Path of xFasterTransformer model doesn't exists.")
if not os.path.exists(tokenizer):
raise RuntimeError("Path of tokenizer doesn't exists.")

self.hf_config = get_config(self.tokenizer, trust_remote_code, revision,
code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.dtype = dtype
# self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.served_model_name = get_served_model_name(model,
Expand Down Expand Up @@ -347,8 +352,8 @@ def __init__(
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()
# self._verify_args()
# self._verify_cache_dtype()

# Will be set after profiling.
self.num_gpu_blocks = None
Expand Down Expand Up @@ -495,7 +500,7 @@ def __post_init__(self):
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
self._verify_load_format()
# self._verify_load_format()

def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
Expand Down Expand Up @@ -662,6 +667,10 @@ def _verify_args(self) -> None:
class DeviceConfig:

def __init__(self, device: str = "auto") -> None:
self.device = torch.device("cpu")
self.device_type = "cpu"
return

if device == "auto":
# Automated device type detection
if is_neuron():
Expand Down
88 changes: 51 additions & 37 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,19 @@ def __init__(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)

BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config.
use_v2_block_manager else "v1")

# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
# BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
# version="v2" if self.scheduler_config.
# use_v2_block_manager else "v1")

# # Create the block space manager.
# self.block_manager = BlockSpaceManagerImpl(
# block_size=self.cache_config.block_size,
# num_gpu_blocks=self.cache_config.num_gpu_blocks,
# num_cpu_blocks=self.cache_config.num_cpu_blocks,
# sliding_window=self.cache_config.sliding_window,
# enable_caching=self.cache_config.enable_prefix_caching)

self.block_manager = None

# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
Expand Down Expand Up @@ -445,7 +447,7 @@ def _schedule_running(
swapped_out.append(seq_group)
break
else:
self._append_slots(seq_group, blocks_to_copy)
# self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
Expand Down Expand Up @@ -523,7 +525,8 @@ def _schedule_swapped(
seq_group = swapped_queue[0]

# If the sequence group cannot be swapped in, stop.
alloc_status = self.block_manager.can_swap_in(seq_group)
# alloc_status = self.block_manager.can_swap_in(seq_group)
alloc_status = AllocStatus.OK
if alloc_status == AllocStatus.LATER:
break
elif alloc_status == AllocStatus.NEVER:
Expand Down Expand Up @@ -656,7 +659,8 @@ def _schedule_prefills(
continue

# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(seq_group)
# can_allocate = self.block_manager.can_allocate(seq_group)
can_allocate = AllocStatus.OK
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
Expand Down Expand Up @@ -891,6 +895,7 @@ def _schedule_chunked_prefill(self):

def _schedule(self) -> SchedulerOutputs:
"""Schedule queued requests."""
return self._schedule_default()
if self.scheduler_config.chunked_prefill_enabled:
return self._schedule_chunked_prefill()
else:
Expand All @@ -900,6 +905,7 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
return True
# It is True only for testing case to trigger artificial preemption.
if (self.enable_artificial_preemption
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
Expand Down Expand Up @@ -938,26 +944,27 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
# block_tables[seq_id] = self.block_manager.get_block_table(seq)
# self.block_manager.access_all_blocks_in_seq(seq, now)

common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
# common_computed_block_nums = (
# self.block_manager.get_common_computed_block_ids(
# seq_group.get_seqs(status=SequenceStatus.RUNNING)))
common_computed_block_nums = 0

do_sample = True
if seq_group.is_prefill():
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
seqs[0].data.get_len()):
do_sample = False
# if seq_group.is_prefill():
# seqs = seq_group.get_seqs()
# # Prefill has only 1 sequence.
# assert len(seqs) == 1
# # In the next iteration, all prompt tokens are not computed.
# # It means the prefill is chunked, and we don't need sampling.
# # NOTE: We use get_len instead of get_prompt_len because when
# # a sequence is preempted, prefill includes previous generated
# # output tokens.
# if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
# seqs[0].data.get_len()):
# do_sample = False

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
Expand Down Expand Up @@ -986,9 +993,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
# for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
# self.block_manager.mark_blocks_as_computed(
# scheduled_seq_group.seq_group)

return seq_group_metadata_list, scheduler_outputs

Expand All @@ -997,14 +1004,21 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:

def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
# self.block_manager.free(seq)
pass

def free_finished_seq_groups(self) -> None:
def free_finished_seq_groups(self) -> List[int]:
free_xft_seq_ids = []
for seq_group in self.running:
if seq_group.is_finished():
for seq in seq_group.seqs_dict.values():
free_xft_seq_ids.append(seq.data.xft_ids)
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
return free_xft_seq_ids

def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
# self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING

Expand Down
30 changes: 22 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ def nullable_str(val: str):
return None
return val

DTYPE_LIST = [
"fp16",
"bf16",
"int8",
"w8a8",
"int4",
"nf4",
"bf16_fp16",
"bf16_int8",
"bf16_w8a8",
"bf16_int4",
"bf16_nf4",
"w8a8_int8",
"w8a8_int4",
"w8a8_nf4",
]

@dataclass
class EngineArgs:
Expand All @@ -27,9 +43,9 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
load_format: str = 'xft'
dtype: str = 'bf16'
kv_cache_dtype: str = 'fp16'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
Expand Down Expand Up @@ -153,7 +169,7 @@ def add_cli_args(
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
'xft'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
Expand All @@ -172,9 +188,7 @@ def add_cli_args(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
choices=DTYPE_LIST,
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
Expand All @@ -186,7 +200,7 @@ def add_cli_args(
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['fp16', 'int8'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
Expand Down
Loading

0 comments on commit 513598e

Please sign in to comment.