Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage]: How to use Multi-instance in Vllm? (Model replication on multiple GPUs) #6155

Open
KimMinSang96 opened this issue Jul 5, 2024 · 12 comments
Labels
usage How to use vllm

Comments

@KimMinSang96
Copy link

KimMinSang96 commented Jul 5, 2024

I would like to use techniques such as Multi-instance Support supported by the tensorrt-llm backend. In the documentation, I can see that multiple models are served using modes like Leader mode and Orchestrator mode. Does vLLM support this functionality separately? Or should I implement it similarly to the tensorrt-llm backend?

Here is for reference url : https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#leader-mode

@KimMinSang96 KimMinSang96 added the usage How to use vllm label Jul 5, 2024
@stas00
Copy link
Contributor

stas00 commented Aug 1, 2024

It works fine with the online mode - you just create multiple servers (even reusing the same gpus!), but indeed it doesn't work with the offline mode. Here is an example on a 8x H100 node

from vllm import LLM, SamplingParams

import multiprocessing

def main():

    llm1 = LLM(
        model="meta-llama/Meta-Llama-3-8B-Instruct",
        tensor_parallel_size=8,
        gpu_memory_utilization=0.65,
    )

    llm2 = LLM(
        model="microsoft/phi-1_5",
        tensor_parallel_size=8,
        gpu_memory_utilization=0.25,
    )


    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
    ]

    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    outputs = llm1.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    outputs = llm2.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
    multiprocessing.freeze_support()
    main()

and then:

VLLM_WORKER_MULTIPROC_METHOD=spawn python offline-2models-2.py

and it hangs while initializing the 2nd model:

INFO 08-01 01:10:37 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='microsoft/phi-1_5', speculative_config=None, tokenizer='microsoft/phi-1_5', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=microsoft/phi-1_5, use_v2_block_manager=False, enable_prefix_caching=False)
(VllmWorkerProcess pid=3445429) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445426) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445431) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445427) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445432) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445428) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=3445430) INFO 08-01 01:10:45 multiproc_worker_utils.py:215] Worker ready; awaiting tasks

@stas00
Copy link
Contributor

stas00 commented Aug 1, 2024

The problem seems to be in some internal state that is not being isolated, even if I do:

    llm1 = LLM(
        model="meta-llama/Meta-Llama-3-8B-Instruct",
        tensor_parallel_size=8,
        gpu_memory_utilization=0.65,
    )

    del llm1

    llm2 = LLM(
        model="microsoft/phi-1_5",
        tensor_parallel_size=8,
        gpu_memory_utilization=0.25,
    )

it still hangs in the init of the 2nd model. While this del would be impractical for what we are trying to do, this demonstrates that vllm isn't capable of handling multi-models in the offline mode which is a pity.

@njhill
Copy link
Member

njhill commented Aug 1, 2024

@stas00 at least the latter case I have been debugging and will open a fix today. Can see if it also works with concurrent llms but I expect there mat be additional isolation changes needed for that.

@stas00
Copy link
Contributor

stas00 commented Aug 1, 2024

Thanks a lot for working on that, @njhill - that will help with disagrregation type of offline use of vllm.

@mces89
Copy link

mces89 commented Sep 4, 2024

@stas00 I wonder if it's possible to create multiple servers in the same gpu if the gpu memory is not an issue?

@stas00
Copy link
Contributor

stas00 commented Sep 4, 2024

with online setup yes it'd work, but this is an offline recipe

please read #6155 (comment)

@russellb
Copy link
Collaborator

@stas00 did the patch from njhill fix the issue you raised?

@stas00
Copy link
Contributor

stas00 commented Oct 16, 2024

@njhill didn't post the update since his last note #6155 (comment), so I wasn't able to validate it - do you have the PR link - would be happy to re-test.

@njhill
Copy link
Member

njhill commented Oct 17, 2024

@russellb @stas00 sorry yes the fixes that I was referring to above were done in #8492. Hopefully your example above should work cleanly now.

@stas00
Copy link
Contributor

stas00 commented Oct 17, 2024

Thank you for the update, @njhill - there was a progress and it no longer hangs in the init of the 2nd model, but now it hangs in:

INFO 10-17 22:54:59 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='microsoft/phi-1_5', speculative_config=None, tokenizer='microsoft/phi-1_5', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=microsoft/phi-1_5, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
WARNING 10-17 22:55:02 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
(VllmWorkerProcess pid=2977383) INFO 10-17 22:55:06 multiproc_worker_utils.py:215] Worker ready; awaiting tasks

tb:

Thread 2977044 (idle): "MainThread"
    wait (threading.py:320)
    wait (threading.py:607)
    get (vllm/executor/multiproc_worker_utils.py:51)
    <listcomp> (vllm/executor/multiproc_gpu_executor.py:196)
    _run_workers (vllm/executor/multiproc_gpu_executor.py:196)
    _init_executor (vllm/executor/multiproc_gpu_executor.py:110)
    __init__ (vllm/executor/executor_base.py:47)
    __init__ (vllm/executor/distributed_gpu_executor.py:26)
    __init__ (vllm/engine/llm_engine.py:334)
    from_engine_args (vllm/engine/llm_engine.py:573)
    __init__ (vllm/entrypoints/llm.py:177)
    main (offline-2models-2.py:30)
    <module> (offline-2models-2.py:46)
Thread 2977298 (idle): "Thread-3"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 2977356 (idle): "Thread-4"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 2977384 (idle): "Thread-5"
    _recv (multiprocessing/connection.py:379)
    _recv_bytes (multiprocessing/connection.py:414)
    recv_bytes (multiprocessing/connection.py:216)
    get (multiprocessing/queues.py:103)
    run (vllm/executor/multiproc_worker_utils.py:80)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 2977385 (idle): "Thread-6"
    select (selectors.py:416)
    wait (multiprocessing/connection.py:931)
    run (vllm/executor/multiproc_worker_utils.py:106)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 2977386 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)

I'm running this example #6155 (comment) (same problem w/ or w/o del llm1 before creating llm2).

0.6.3post1 here.

@njhill
Copy link
Member

njhill commented Oct 17, 2024

😢 thanks @stas00, we can keep this open and I'll try to get to it soon.

@stas00
Copy link
Contributor

stas00 commented Oct 17, 2024

Thank you, @njhill - my example can serve as a repro.

And of course we want to have more than one llm object co-exist and not run those sequentially only (i.e. w/o del llm1 before instantiating llm2) - this opens a possibility for some interesting workflows where multiple models are used at once from the same program.

For example one small model could be used to augment the prompt and then the larger model could do the normal generation using the extended prompt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usage How to use vllm
Projects
None yet
Development

No branches or pull requests

5 participants