diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 47e40e015239a..2143315d2a078 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi RUN python3 -m pip install -U \ - 'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index c2a40000aab4b..b19c6ddec7948 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \ # These packages will be in rocketce eventually RUN --mount=type=cache,target=/root/.cache/pip \ pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ - 'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ torch==2.3.1 \ -r requirements-cpu.txt \ xformers uvloop==0.20.0 diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js index dac40ca2cfe75..18b502c786e1d 100644 --- a/docs/source/_static/custom.js +++ b/docs/source/_static/custom.js @@ -8,7 +8,9 @@ document.addEventListener("DOMContentLoaded", function () { script.setAttribute("version", "stable"); script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); // cmd-j or ctrl-j to open the widget. script.setAttribute("runllm-name", "vLLM"); - script.setAttribute("runllm-position", "TOP_RIGHT"); + script.setAttribute("runllm-position", "BOTTOM_RIGHT"); + script.setAttribute("runllm-position-y", "20%"); + script.setAttribute("runllm-position-x", "3%"); script.setAttribute("runllm-assistant-id", "207"); script.async = true; diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 6bf170b164fb8..69530fd778c55 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -62,7 +62,7 @@ Build from source .. code-block:: console $ pip install --upgrade pip - $ pip install cmake>=3.26,<=3.30 wheel packaging ninja "setuptools-scm>=8" numpy + $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu - Finally, build and install vLLM CPU backend: diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 060599680be25..77bf550601346 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -20,6 +20,10 @@ Hangs loading a model from disk If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow. It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory. +.. note:: + + To isolate the model downloading and loading issue, you can use the ``--load-format dummy`` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck. + Model is too large ---------------------------------------- If the model is too large to fit in a single GPU, you might want to `consider tensor parallelism `_ to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `this example `_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. diff --git a/pyproject.toml b/pyproject.toml index 3be401daa44c7..3c8c46cc8621e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] # Should be mirrored in requirements-build.txt requires = [ - "cmake>=3.26,<=3.30", + "cmake>=3.26", "ninja", "packaging", "setuptools>=61", diff --git a/requirements-build.txt b/requirements-build.txt index 64b92861df25d..fec01caaf25ef 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,5 +1,5 @@ # Should be mirrored in pyproject.toml -cmake>=3.26,<=3.30 +cmake>=3.26 ninja packaging setuptools>=61 diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 94a3225dcf479..f9a0770804e55 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -2,7 +2,7 @@ -r requirements-common.txt # Dependencies for TPU -cmake>=3.26,<=3.30 +cmake>=3.26 ninja packaging setuptools-scm>=8 diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 479cb4bb18484..e41295792283f 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -2,7 +2,7 @@ -r requirements-common.txt ray >= 2.9 -cmake>=3.26,<=3.30 +cmake>=3.26 ninja packaging setuptools-scm>=8 diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 5d77d8abb4718..50444d3abfaf2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -43,12 +43,15 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) if rank <= 2: - pg2 = StatelessProcessGroup.create( - init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) + pg2 = StatelessProcessGroup.create(host="127.0.0.1", + port=port2, + rank=rank, + world_size=3) data = torch.tensor([rank]) data = pg1.broadcast_obj(data, src=2) assert data.item() == 2 @@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2): torch.cuda.set_device(rank) - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) pynccl1 = PyNcclCommunicator(pg1, device=rank) pynccl1.disabled = False if rank <= 2: - pg2 = StatelessProcessGroup.create( - init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) + pg2 = StatelessProcessGroup.create(host="127.0.0.1", + port=port2, + rank=rank, + world_size=3) pynccl2 = PyNcclCommunicator(pg2, device=rank) pynccl2.disabled = False data = torch.tensor([rank]).cuda() @@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) if rank == 2: @@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) data = pg1.all_gather_obj(rank) @@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): pg1.barrier() -# TODO: investigate why this test is flaky. It hangs during initialization. -@pytest.mark.skip("Skip the test because it is flaky.") @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize( "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 8451aac33acc4..b3692b594326a 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -27,6 +27,9 @@ def make_request() -> EngineCoreRequest: request_id=uuid.uuid4(), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, + mm_data=None, + mm_placeholders=None, + mm_processor_kwargs=None, sampling_params=SamplingParams(), eos_token_id=None, arrival_time=time.time(), diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index d582101a1164f..7b241bf836a0e 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -29,6 +29,9 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: request_id=str(uuid.uuid4()), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, + mm_data=None, + mm_placeholders=None, + mm_processor_kwargs=None, sampling_params=params, eos_token_id=None, arrival_time=time.time(), diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a77b41322f376..dcfcb848cbe06 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -9,7 +9,7 @@ from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -from torch.distributed.rendezvous import rendezvous +from torch.distributed import TCPStore import vllm.envs as envs from vllm.logger import init_logger @@ -97,7 +97,6 @@ class StatelessProcessGroup: group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ - prefix: str rank: int world_size: int store: torch._C._distributed_c10d.Store @@ -127,7 +126,7 @@ def __post_init__(self): def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" self.expire_data() - key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}" + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" self.store.set(key, pickle.dumps(obj)) self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) @@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( self.store.get( - f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}" - )) + f"send_to/{self.rank}/{self.recv_src_counter[src]}")) self.recv_src_counter[src] += 1 return obj @@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """ if self.rank == src: self.expire_data() - key = (f"{self.prefix}/broadcast_from/{src}/" + key = (f"broadcast_from/{src}/" f"{self.broadcast_send_counter}") self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: - key = (f"{self.prefix}/broadcast_from/{src}/" + key = (f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}") recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 @@ -194,7 +192,8 @@ def barrier(self): @staticmethod def create( - init_method: str, + host: str, + port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600, @@ -214,15 +213,14 @@ def create( can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa - from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT - timeout = _DEFAULT_PG_TIMEOUT - - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) - store.set_timeout(timeout) + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) return StatelessProcessGroup( - prefix=init_method, rank=rank, world_size=world_size, store=store, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 808c3936b6c35..428483bdb29cb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -317,7 +317,7 @@ def process_input_socket(self, input_path: str): # Msgpack serialization decoding. decoder_add_req = PickleEncoder() - decoder_abort_req = msgpack.Decoder(list[str]) + decoder_abort_req = PickleEncoder() with self.make_socket(input_path, zmq.constants.PULL) as socket: while True: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2c40853742ac9..db676e2819bf4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -404,14 +404,17 @@ def execute_model( def load_model(self) -> None: if self.use_cuda_graph: - # FIXME(woosuk): Currently, we do not use inductor to reduce the - # compilation time and any potential issues with the inductor. - os.environ["VLLM_CUSTOM_OPS"] = "all" + # NOTE(woosuk): Currently, we use inductor because the piecewise + # CUDA graphs do not work properly with the custom CUDA kernels. + # FIXME(woosuk): Disable inductor to reduce the compilation time + # and avoid any potential issues with the inductor. + os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], - use_inductor=False, + use_inductor=True, + enable_fusion=False, )) logger.info("Starting to load model %s...", self.model_config.model)