diff --git a/README.md b/README.md index 524d027137aba..4db3548a29140 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,51 @@ +This is a fork of vLLM to support xfastertransformer backend. This version is based on official vllm `v0.4.2`. +## Notice +🎉🎉🎉***Continuous batching is supported.*** 🎇🎇🎇 +- Distributed is not support yet.(WIP) +- BeamSearch is not support yet.(WIP) +- LORA is not support yet.(WIP) + +## Install +### From PyPI +`pip install vllm-xft` + +### From Source +`python3 setup.py bdist_wheel --verbose` + +## Usage +### Python offline +``` +python examples/offline_inference_xfastertransformer.py +``` +### Serving(OpenAI Compatible Server) +```shell +python -m vllm.entrypoints.openai.api_server \ + --model /data/llama-2-7b-chat-cpu \ + --tokenizer /data/llama-2-7b-chat-hf \ + --dtype fp16 \ + --kv-cache-dtype fp16 \ + --served-model-name xft \ + --port 8000 \ + --trust-remote-code \ +``` +- `--max-num-batched-tokens`: max batched token, default value is max(MAX_SEQ_LEN_OF_MODEL, 2048). +- `--max-num-seqs`: max seqs batch, default is 256. + +More Arguments please refer to [vllm official docs](https://docs.vllm.ai/en/latest/models/engine_args.html) + +### Query example +```shell + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "xft", + "prompt": "San Francisco is a", + "max_tokens": 512, + "temperature": 0 + }' +``` + +

diff --git a/requirements-cpu.txt b/requirements-cpu.txt index b739642d8d344..94fa15b30a792 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,6 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.3.0+cpu -triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file +# torch == 2.3.0+cpu +# triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. +xfastertransformer > 1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 11b0ef4646cf5..0180dac944d73 100644 --- a/setup.py +++ b/setup.py @@ -304,6 +304,7 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) + return version if _is_cuda(): cuda_version = str(get_nvcc_cuda_version()) @@ -352,6 +353,7 @@ def _read_requirements(filename: str) -> List[str]: else: resolved_requirements.append(line) return resolved_requirements + return _read_requirements("requirements-cpu.txt") if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") @@ -420,10 +422,10 @@ def _read_requirements(filename: str) -> List[str]: "tests*")), python_requires=">=3.8", install_requires=get_requirements(), - ext_modules=ext_modules, + # ext_modules=ext_modules, extras_require={ "tensorizer": ["tensorizer==2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, + # cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/vllm/config.py b/vllm/config.py index 13bb294591725..4c1de8646abdc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -30,7 +30,7 @@ class ModelConfig: """Configuration for the model. Args: - model: Name or path of the huggingface model to use. + model: Name or path of the xfastertransformer model to use. It is also used as the content for `model_name` tag in metrics output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. @@ -38,9 +38,7 @@ class ModelConfig: available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. - dtype: Data type for model weights and activations. The "auto" option - will use FP16 precision for FP32 and FP16 models, and BF16 precision - for BF16 models. + dtype: Data type for model weights and activations. seed: Random seed for reproducibility. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default @@ -117,11 +115,18 @@ def __init__( or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init - - self.hf_config = get_config(self.model, trust_remote_code, revision, + + import os + if not os.path.exists(model): + raise RuntimeError("Path of xFasterTransformer model doesn't exists.") + if not os.path.exists(tokenizer): + raise RuntimeError("Path of tokenizer doesn't exists.") + + self.hf_config = get_config(self.tokenizer, trust_remote_code, revision, code_revision) self.hf_text_config = get_hf_text_config(self.hf_config) - self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.dtype = dtype + # self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) self.served_model_name = get_served_model_name(model, @@ -347,8 +352,8 @@ def __init__( self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching - self._verify_args() - self._verify_cache_dtype() + # self._verify_args() + # self._verify_cache_dtype() # Will be set after profiling. self.num_gpu_blocks = None @@ -495,7 +500,7 @@ def __post_init__(self): if isinstance(model_loader_extra_config, str): self.model_loader_extra_config = json.loads( model_loader_extra_config) - self._verify_load_format() + # self._verify_load_format() def _verify_load_format(self) -> None: if not isinstance(self.load_format, str): @@ -662,6 +667,10 @@ def _verify_args(self) -> None: class DeviceConfig: def __init__(self, device: str = "auto") -> None: + self.device = torch.device("cpu") + self.device_type = "cpu" + return + if device == "auto": # Automated device type detection if is_neuron(): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a9e0b05b8db67..5801c1f7737c8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -271,17 +271,19 @@ def __init__( self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=self.cache_config.num_gpu_blocks, - num_cpu_blocks=self.cache_config.num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + # BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + # version="v2" if self.scheduler_config. + # use_v2_block_manager else "v1") + + # # Create the block space manager. + # self.block_manager = BlockSpaceManagerImpl( + # block_size=self.cache_config.block_size, + # num_gpu_blocks=self.cache_config.num_gpu_blocks, + # num_cpu_blocks=self.cache_config.num_cpu_blocks, + # sliding_window=self.cache_config.sliding_window, + # enable_caching=self.cache_config.enable_prefix_caching) + + self.block_manager = None # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -445,7 +447,7 @@ def _schedule_running( swapped_out.append(seq_group) break else: - self._append_slots(seq_group, blocks_to_copy) + # self._append_slots(seq_group, blocks_to_copy) is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( @@ -523,7 +525,8 @@ def _schedule_swapped( seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - alloc_status = self.block_manager.can_swap_in(seq_group) + # alloc_status = self.block_manager.can_swap_in(seq_group) + alloc_status = AllocStatus.OK if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -656,7 +659,8 @@ def _schedule_prefills( continue # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) + # can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = AllocStatus.OK if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -891,6 +895,7 @@ def _schedule_chunked_prefill(self): def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" + return self._schedule_default() if self.scheduler_config.chunked_prefill_enabled: return self._schedule_chunked_prefill() else: @@ -900,6 +905,7 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ + return True # It is True only for testing case to trigger artificial preemption. if (self.enable_artificial_preemption and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB @@ -938,26 +944,27 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) + # block_tables[seq_id] = self.block_manager.get_block_table(seq) + # self.block_manager.access_all_blocks_in_seq(seq, now) - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) + # common_computed_block_nums = ( + # self.block_manager.get_common_computed_block_ids( + # seq_group.get_seqs(status=SequenceStatus.RUNNING))) + common_computed_block_nums = 0 do_sample = True - if seq_group.is_prefill(): - seqs = seq_group.get_seqs() - # Prefill has only 1 sequence. - assert len(seqs) == 1 - # In the next iteration, all prompt tokens are not computed. - # It means the prefill is chunked, and we don't need sampling. - # NOTE: We use get_len instead of get_prompt_len because when - # a sequence is preempted, prefill includes previous generated - # output tokens. - if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < - seqs[0].data.get_len()): - do_sample = False + # if seq_group.is_prefill(): + # seqs = seq_group.get_seqs() + # # Prefill has only 1 sequence. + # assert len(seqs) == 1 + # # In the next iteration, all prompt tokens are not computed. + # # It means the prefill is chunked, and we don't need sampling. + # # NOTE: We use get_len instead of get_prompt_len because when + # # a sequence is preempted, prefill includes previous generated + # # output tokens. + # if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + # seqs[0].data.get_len()): + # do_sample = False # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. @@ -986,9 +993,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution # will crash the vLLM instance / will not retry. - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group) + # for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + # self.block_manager.mark_blocks_as_computed( + # scheduled_seq_group.seq_group) return seq_group_metadata_list, scheduler_outputs @@ -997,14 +1004,21 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None: """Free a sequence from a block table.""" - self.block_manager.free(seq) + # self.block_manager.free(seq) + pass - def free_finished_seq_groups(self) -> None: + def free_finished_seq_groups(self) -> List[int]: + free_xft_seq_ids = [] + for seq_group in self.running: + if seq_group.is_finished(): + for seq in seq_group.seqs_dict.values(): + free_xft_seq_ids.append(seq.data.xft_ids) self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) + return free_xft_seq_ids def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) + # self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8245eb307f7..2794ebbcb8803 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,22 @@ def nullable_str(val: str): return None return val +DTYPE_LIST = [ + "fp16", + "bf16", + "int8", + "w8a8", + "int4", + "nf4", + "bf16_fp16", + "bf16_int8", + "bf16_w8a8", + "bf16_int4", + "bf16_nf4", + "w8a8_int8", + "w8a8_int4", + "w8a8_nf4", +] @dataclass class EngineArgs: @@ -27,9 +43,9 @@ class EngineArgs: tokenizer_mode: str = 'auto' trust_remote_code: bool = False download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + load_format: str = 'xft' + dtype: str = 'bf16' + kv_cache_dtype: str = 'fp16' quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None @@ -153,7 +169,7 @@ def add_cli_args( type=str, default=EngineArgs.load_format, choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + 'xft' ], help='The format of the model weights to load.\n\n' '* "auto" will try to load the weights in the safetensors format ' @@ -172,9 +188,7 @@ def add_cli_args( '--dtype', type=str, default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], + choices=DTYPE_LIST, help='Data type for model weights and activations.\n\n' '* "auto" will use FP16 precision for FP32 and FP16 models, and ' 'BF16 precision for BF16 models.\n' @@ -186,7 +200,7 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['fp16', 'int8'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9f72a0d11974f..b7de816ce454a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -11,7 +11,7 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.executor.ray_utils import initialize_ray_cluster, ray +# from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -345,7 +345,10 @@ def from_engine_args( # Create the engine configs. engine_config = engine_args.create_engine_config() - if engine_config.device_config.device_type == "neuron": + if True: + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "cpu": @@ -424,21 +427,21 @@ def start_background_loop(self) -> None: def _init_engine(self, *args, **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: - if not self.engine_use_ray: - engine_class = self._engine_class - elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(self._engine_class).remote - else: - # FIXME(woosuk): This is a bit hacky. Be careful when changing the - # order of the arguments. - cache_config = kwargs["cache_config"] - parallel_config = kwargs["parallel_config"] - if parallel_config.tensor_parallel_size == 1: - num_gpus = cache_config.gpu_memory_utilization - else: - num_gpus = 1 - engine_class = ray.remote(num_gpus=num_gpus)( - self._engine_class).remote + engine_class = self._engine_class + # if not self.engine_use_ray: + # elif self.worker_use_ray: + # engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + # else: + # # FIXME(woosuk): This is a bit hacky. Be careful when changing the + # # order of the arguments. + # cache_config = kwargs["cache_config"] + # parallel_config = kwargs["parallel_config"] + # if parallel_config.tensor_parallel_size == 1: + # num_gpus = cache_config.gpu_memory_utilization + # else: + # num_gpus = 1 + # engine_class = ray.remote(num_gpus=num_gpus)( + # self._engine_class).remote return engine_class(*args, **kwargs) async def engine_step(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9938b045ba2b..3b1798097aafe 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -169,7 +169,7 @@ def __init__( load_config=load_config, ) - self._initialize_kv_caches() + # self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -259,7 +259,7 @@ def _initialize_kv_caches(self) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + # self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) @classmethod def from_engine_args( @@ -272,7 +272,10 @@ def from_engine_args( engine_config = engine_args.create_engine_config() # Initialize the cluster and specify the executor class. - if engine_config.device_config.device_type == "neuron": + if True: + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor elif engine_config.device_config.device_type == "cpu": @@ -516,7 +519,9 @@ def _process_model_outputs( self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. - self.scheduler.free_finished_seq_groups() + free_xft_seq_ids = self.scheduler.free_finished_seq_groups() + self.model_executor.free_xft_cache(free_xft_seq_ids) + # Create the outputs. request_outputs: List[RequestOutput] = [] @@ -636,16 +641,16 @@ def _get_stats( num_waiting_sys = len(self.scheduler.waiting) # KV Cache Usage in % - num_total_gpu = self.cache_config.num_gpu_blocks - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + # num_total_gpu = self.cache_config.num_gpu_blocks + # num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() + gpu_cache_usage_sys = 0. - num_total_cpu = self.cache_config.num_cpu_blocks + # num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. - if num_total_cpu > 0: - num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( - ) - cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + # if num_total_cpu > 0: + # num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( + # ) + # cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) # Iteration stats num_prompt_tokens_iter = 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3ed660e183360..a6b4e841c7fa6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -85,7 +85,7 @@ def __init__( skip_tokenizer_init: bool = False, trust_remote_code: bool = False, tensor_parallel_size: int = 1, - dtype: str = "auto", + dtype: str = "fp16", quantization: Optional[str] = None, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index a2212459f034e..4cc14467b6af6 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -19,8 +19,9 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" assert self.lora_config is None, "cpu backend doesn't support LoRA" - self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config(self.cache_config) + # self.model_config = _verify_and_get_model_config(self.model_config) + # self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.cache_config.enable_prefix_caching = False self.scheduler_config = _verify_and_get_scheduler_config( self.scheduler_config) @@ -50,7 +51,7 @@ def _init_worker(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) - self.driver_worker.init_device() + # self.driver_worker.init_device() self.driver_worker.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: @@ -91,6 +92,9 @@ def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. return + + def free_xft_cache(self, xft_seq_ids:List[int]) -> bool: + return self.driver_worker.free_xft_cache(xft_seq_ids) class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 08aa58999b1ec..78f45f21b5cd8 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -98,6 +98,10 @@ def shutdown(self) -> None: def __del__(self): self.shutdown() + + def free_xft_cache(self, xft_seq_ids:List[int]) -> bool: + return False + class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/sequence.py b/vllm/sequence.py index f2939eff7959b..10f7adb87734a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -124,6 +124,7 @@ def __init__( # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL + self.xft_ids = -1 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) diff --git a/vllm/utils.py b/vllm/utils.py index b06c8508757c5..63f4a35ccb225 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -122,11 +122,13 @@ def clear(self): def is_hip() -> bool: + return False return torch.version.hip is not None @lru_cache(maxsize=None) def is_cpu() -> bool: + return True from importlib.metadata import PackageNotFoundError, version try: return "cpu" in version("vllm") @@ -136,6 +138,7 @@ def is_cpu() -> bool: @lru_cache(maxsize=None) def is_neuron() -> bool: + return False try: import transformers_neuronx except ImportError: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..39bb472ec18c0 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -13,6 +13,9 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad +from vllm.model_executor.layers.sampler import Sampler +import xfastertransformer + logger = init_logger(__name__) _PAD_SLOT_ID = -1 @@ -29,7 +32,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], - kv_cache_dtype: Optional[str] = "auto", + kv_cache_dtype: Optional[str] = "fp16", is_driver_worker: bool = False, *args, **kwargs, @@ -58,26 +61,31 @@ def __init__( self.model_config.dtype if model_config is not None else None) # Lazy initialization. - self.model: nn.Module # Set after init_Model + self.model: xfastertransformer.AutoModel # Set after init_Model self.block_size: int # Set after initial profiling. def load_model(self) -> None: - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=self.vision_language_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + # self.model = get_model( + # model_config=self.model_config, + # load_config=self.load_config, + # device_config=self.device_config, + # vision_language_config=self.vision_language_config, + # lora_config=self.lora_config, + # parallel_config=self.parallel_config, + # scheduler_config=self.scheduler_config) + self.model = xfastertransformer.AutoModel.from_pretrained( + self.model_config.model, self.model_config.dtype, self.kv_cache_dtype + ) + # self.model = None + self.sampler = Sampler() def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + ) -> Tuple[List[List[int]], torch.Tensor, AttentionMetadata, List[int], Optional[torch.Tensor]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] + input_tokens = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] @@ -95,7 +103,7 @@ def _prepare_prompt( seq_len = len(prompt_tokens) seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids + input_tokens.append(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt @@ -107,26 +115,26 @@ def _prepare_prompt( seq_group_metadata.multi_modal_data.data) # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] + # block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) + # start_idx = 0 + # if self.sliding_window is not None: + # start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue + # for i in range(computed_len, seq_len): + # if i < start_idx: + # slot_mapping.append(_PAD_SLOT_ID) + # continue - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + # block_number = block_table[i // + # self.block_size] # type: ignore + # block_offset = i % self.block_size # type: ignore + # slot = block_number * self.block_size + block_offset + # slot_mapping.append(slot) if multi_modal_input_list: assert self.vision_language_config, ( @@ -139,9 +147,9 @@ def _prepare_prompt( num_prompt_tokens = len(input_tokens) - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore + # input_tokens = torch.tensor(input_tokens, + # dtype=torch.long, + # device=self.device) # type: ignore input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore @@ -169,13 +177,14 @@ def _prepare_prompt( def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[int], AttentionMetadata]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] + xft_seq_ids: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -187,6 +196,7 @@ def _prepare_decode( seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) + xft_seq_ids.append(seq_data.xft_ids) seq_len = seq_data.get_len() position = seq_len - 1 @@ -196,23 +206,23 @@ def _prepare_decode( seq_len, self.sliding_window) seq_lens.append(seq_len) - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + # block_table = seq_group_metadata.block_tables[seq_id] + # block_number = block_table[position // self.block_size] + # block_offset = position % self.block_size + # slot = block_number * self.block_size + block_offset + # slot_mapping.append(slot) - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + # if self.sliding_window is not None: + # sliding_window_blocks = (self.sliding_window // + # self.block_size) + # block_table = block_table[-sliding_window_blocks:] + # block_tables.append(block_table) max_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, - device=self.device) + device=self.device).unsqueeze(1) input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) @@ -223,15 +233,15 @@ def _prepare_decode( dtype=torch.int, device=self.device) - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + # max_block_table_len = max( + # len(block_table) for block_table in block_tables) + # block_tables = make_tensor_with_pad( + # block_tables, + # max_len=max_block_table_len, + # pad=0, + # dtype=torch.int, + # device=self.device, + # ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, @@ -244,21 +254,25 @@ def _prepare_decode( num_prefills=0, prefill_metadata=None, decode_metadata=None, - block_tables=block_tables, + block_tables=torch.tensor([]), kv_cache_dtype=self.kv_cache_dtype, ) return ( input_tokens, input_positions, + xft_seq_ids, attn_metadata, ) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + ) -> Tuple[List[List[int]], torch.Tensor, Optional[List[int]], Optional[List[int]], AttentionMetadata, SamplingMetadata, Optional[torch.Tensor]]: multi_modal_input = None + # xft_seq_ids is None for prompts and xft_max_lens is None for decodes + xft_seq_ids = None + xft_max_lens = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -269,7 +283,7 @@ def prepare_input_tensors( multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, + (input_tokens, input_positions, xft_seq_ids, attn_metadata) = self._prepare_decode(seq_group_metadata_list) seq_lens = [] sampling_metadata = SamplingMetadata.prepare( @@ -281,15 +295,19 @@ def prepare_input_tensors( seq_lens, self.device, pin_memory=False) + if is_prompt: + xft_max_lens = [] + for i in range(len(sampling_metadata.seq_groups)): + xft_max_lens.append(sampling_metadata.seq_groups[i].sampling_params.max_tokens + seq_lens[i]) # Broadcast the metadata. - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - } - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) + # metadata_dict = { + # "input_tokens": input_tokens, + # "input_positions": input_positions, + # "selected_token_indices": + # sampling_metadata.selected_token_indices, + # } + # metadata_dict.update(attn_metadata.asdict_zerocopy()) + # broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") @@ -306,41 +324,34 @@ def prepare_input_tensors( generators=None, ) - return (input_tokens, input_positions, attn_metadata, + return (input_tokens, input_positions, xft_seq_ids, xft_max_lens, attn_metadata, sampling_metadata, multi_modal_input) - @torch.inference_mode() + # @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], - kv_caches: List[torch.Tensor], + kv_caches: Optional[List[torch.Tensor]], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, + (input_tokens, input_positions, xft_seq_ids, xft_max_lens, attn_metadata, sampling_metadata, multi_modal_input ) = self.prepare_input_tensors(seq_group_metadata_list) - model_executable = self.model - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - } - if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + xft_seq_ids = self.model.set_input_cb(input_tokens, xft_seq_ids, xft_max_lens).tolist() - hidden_states = model_executable(**execute_model_kwargs) + if seq_group_metadata_list[0].is_prompt: + for i in range(len(xft_seq_ids)): + seq_id = list(seq_group_metadata_list[i].seq_data.keys())[0] + seq_group_metadata_list[i].seq_data[seq_id].xft_ids = xft_seq_ids[i] # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return None + logits = self.model.forward_cb() # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) + output = self.sampler(logits, sampling_metadata) return output + + def free_xft_cache(self, xft_seq_ids:List[int]) -> bool: + return self.model.free_seqs( + torch.tensor(xft_seq_ids, dtype=torch.long, device=self.device) + ) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4420d4cc9e12f..417d0e6920dea 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -264,31 +264,31 @@ def execute_model( else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list - if self.is_driver_worker: + if True: assert seq_group_metadata_list is not None num_seq_groups: int = len(seq_group_metadata_list) assert execute_model_req is not None blocks_to_copy = execute_model_req.blocks_to_copy assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_copy": execute_model_req.blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) + # data: Dict[str, Any] = { + # "num_seq_groups": num_seq_groups, + # "blocks_to_copy": execute_model_req.blocks_to_copy, + # } + # broadcast_tensor_dict(data, src=0) else: data = broadcast_tensor_dict(src=0) num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] - self.cache_copy(blocks_to_copy) + # self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: return [] output = self.model_runner.execute_model(seq_group_metadata_list, - self.cpu_cache) + None) # CPU worker only supports single-step execution. return [output] @@ -319,3 +319,6 @@ def get_cache_block_size_bytes(self) -> int: return CPUCacheEngine.get_cache_block_size( self.cache_config.block_size, self.cache_config.cache_dtype, self.model_config, self.parallel_config) + + def free_xft_cache(self, xft_seq_ids:List[int]) -> bool: + return self.model_runner.free_xft_cache(xft_seq_ids) \ No newline at end of file