Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Multiprocessing executor for single-node multi-GPU deployment #3466

Closed
wants to merge 8 commits into from

Conversation

njhill
Copy link
Member

@njhill njhill commented Mar 18, 2024

ray is a powerful platform for general purpose distributed computing but potentially overkill for the specific requirements of realtime synchronized inferencing between GPUs on a single node.

We would prefer to have a "lightweight" option without the ray dependency for non-ray cluster environments. This also helps with production security compliance.

With the changes in this PR, ray will continue to be used for parallel workers if it's installed, otherwise vanilla python multiprocessing is used. It can also be overridden with --no-worker-use-ray.

Worker processes are shut down when the LLMEngine is garbage collected.

This PR was co-authored by @sahilsuneja1.


This reworks the original PR #2898 to plug into the new distributed executor abstraction.

I've introduced a MultiGPUExecutor abstract superclass shared between the ray and vanilla multiprocessing implementations.

@njhill
Copy link
Member Author

njhill commented Mar 21, 2024

@zhuohan123 @simon-mo this one should be ready to go! 🙏

@@ -322,19 +213,17 @@ def _run_workers(
method)(*driver_args, **driver_kwargs)

# Get the results of the ray workers.
if self.workers:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if removed because workers should always be non-empty here (i.e. if TP > 1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this if is indeed redundant. However, workers can indeed be empty since the ray GPU executor can be run with TP=1

@njhill njhill force-pushed the ray-optional2 branch 2 times, most recently from 7232261 to 1169213 Compare March 28, 2024 21:23
setup.py Outdated Show resolved Hide resolved
return prompt_token_ids


@pytest.mark.skip("Requires multiple GPUs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo I've actually removed the new test now and just parameterized the existing distributed test to include ray and non-ray.

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test some of the functionality from this file without LLM?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK these tests added in local_worker_tests.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deferring the review of the interface change to @zhuohan123

@njhill njhill force-pushed the ray-optional2 branch 2 times, most recently from efd61b2 to fb15723 Compare March 29, 2024 17:22
@njhill
Copy link
Member Author

njhill commented Mar 30, 2024

Thanks for the review @simon-mo! I've addressed all of your comments

@njhill
Copy link
Member Author

njhill commented Mar 31, 2024

@zhuohan123 @simon-mo @WoosukKwon I just tried some cursory performance comparisons, wasn't expecting the difference to be so significant. Surprisingly Ray doesn't appear to give any latency benefits over single GPU for the config I tried.

Using 80GB A100s, with llama-2-7b openai completion API. Single request with 5 input tokens, 2000 generated tokens. I repeated each test request multiple times, results were very consistent.

Time (sec) Difference
TP=1 24.2 0
TP=2 using Ray (main or this PR) 24.2 0
TP=2 without Ray (this PR with --no-worker-use-ray) 17.0 -30%

@ywang96
Copy link
Member

ywang96 commented Mar 31, 2024

@njhill Wow! 30% is quite a bit (albeit serving llama-7b over 2 A100-80G probably doesn't really make sense in practice).

I will do some testing on this branch on parallel with serving benchmark and report back as well

@njhill
Copy link
Member Author

njhill commented Mar 31, 2024

Thanks @ywang96, yes I'm sure this will be smaller in relative terms for larger models. But not bad given performance improvement was not the purpose of this PR.

@ywang96
Copy link
Member

ywang96 commented Apr 1, 2024

@njhill I did some preliminary testing on H100 TP2 with mistralai/Mistral-7B-Instruct-v0.1 and there's definitly some speedup (not as much as 30% since this is running on sharegpt).

Server launch command:

python -m vllm.entrypoints.openai.api_server \
        --model mistralai/Mistral-7B-Instruct-v0.1 \
        --swap-space 16 \
        --disable-log-requests \
        --tensor-parallel-size 2 \
        --no-worker-use-ray #comment if use ray

Benchmark command:

python benchmarks/benchmark_serving.py \
        --backend vllm \
        --model mistralai/Mistral-7B-Instruct-v0.1 \
        --dataset-name sharegpt \
        --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
        --request-rate 1 \
        --num-prompts 100

With Ray workers:

---------------Time to First Token----------------
Mean TTFT (ms):                          25.42     
Median TTFT (ms):                        31.51     
P99 TTFT (ms):                           49.27     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.21      
Median TPOT (ms):                        8.20      
P99 TPOT (ms):                           10.59     

This PR:

---------------Time to First Token----------------
Mean TTFT (ms):                          22.31     
Median TTFT (ms):                        27.61     
P99 TTFT (ms):                           36.08     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.75      
Median TPOT (ms):                        7.79      
P99 TPOT (ms):                           9.59      

@njhill
Copy link
Member Author

njhill commented Apr 1, 2024

Thanks @ywang96, that's great! 5-6% lower TPOT still nice to have! I am doing some spot tests on TP=4 llama-70b too.

@njhill
Copy link
Member Author

njhill commented Apr 1, 2024

For llama-2-70b with single request 5 input / 1000 output the times I got are 32.3 before, 30.8 after i.e. 4-5% speedup.

@ywang96
Copy link
Member

ywang96 commented Apr 1, 2024

For llama-2-70b with single request 5 input / 1000 output the times I got are 32.3 before, 30.8 after i.e. 4-5% speedup.

I will test on A100-80G with Mixtral TP4 and TP8 just to see if 4-5% is likely the average speedup we get in general.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If the mp based approach is faster, it makes sense to change the default I think. OOC, what's the delta between this vs ray default (for a single node case) except the performance? I assume it supports logging prefix, so maybe just the debugger and the ray dashboard?

@@ -0,0 +1,242 @@
import asyncio
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: this file is probably very specific to mp based executor. Should we create executor/util and move it to there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567 yes maybe the naming is bad but that's what it's meant for, equivalent to ray_utils.py which is in the same place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's probably move this file to executor/? And also move ray_utils to executor/?

njhill and others added 8 commits April 15, 2024 13:16
ray is a powerful platform for general purpose distributed computing but potentially overkill for the specific requirements of realtime synchronized inferencing between GPUs on a single node.

We would prefer to have a "lightweight" option without the ray dependency for non-ray cluster environments. This also helps with production security compliance.

With the changes in this PR, ray will continue to be used for parallel workers if it's installed, otherwise vanilla python multiprocessing is used. It can also be overridden with --no-worker-use-ray.

Worker processes are shut down when the LLMEngine is garbage collected.

Co-authored-by: Sahil Suneja <[email protected]>
Instead of adding equivalent new test
@mpjlu
Copy link

mpjlu commented Apr 16, 2024

I tested baichuan 13B TP=2 with/without cudagraph on A100, the performance of the PR and main branch is the same. Do you test 13B models?

@njhill
Copy link
Member Author

njhill commented Apr 18, 2024

what's the delta between this vs ray default (for a single node case) except the performance? I assume it supports logging prefix, so maybe just the debugger and the ray dashboard?

@rkooo567 yes I think that's all correct.

@youkaichao
Copy link
Member

My suggestions:

  1. try if torch rpc works and if it has benefit. that modification should be small, and the usage is quite similar to ray, while the underlying implementation is multiprocessing.
  2. tear down the PR in small sizes, so that we can digest it step-by-step.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay Nick! Please find my comments below. I think in general this is a good feature that we should have. However, I don't quite understand what's going on in vllm/engine/local_worker_utils.py. Let's chat offline about this.

setup.py Outdated
Comment on lines 321 to 335
def get_ray_requirement() -> Optional[Dict[str, List[str]]]:
if _is_neuron():
return None
return {"ray": ["ray >= 2.9"]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have different requirements.txt for different backends. Is this function necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to make ray an optional extra, independent of the backend. Ray is currently supported for some backends and not others, and if it is supported the version requirement can be different, hence the if's here.

return prompt_token_ids


@pytest.mark.skip("Requires multiple GPUs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a condition like

@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")

and put this test in distributed test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123 I think somehow you weren't reviewing the latest version of the PR, but the last time this branch was updated was a week ago.

These tests no longer require GPUs, they test the mutliprocessing mechanics in isolation. The existing distributed test is additionally parameterized to test both ray and non-ray.

Comment on lines +487 to +519
if self.worker_use_ray is None:
ray_found = importlib.util.find_spec("ray") is not None
self.worker_use_ray = ray_found and self.world_size > 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check whether ray is successfully imported in vllm/engine/ray_utils.py?

Comment on lines +25 to +27
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
mp_method = os.getenv("MULTIPROC_METHOD", "fork")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we always use spawn here? I don't think there will be cases when fork will be better.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I think fork might be faster, I can test to see if it makes non-negligible difference.

'automatically set when using more than 1 GPU')
parser.add_argument(
'--worker-use-ray',
action=argparse.BooleanOptionalAction,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123 it allows having a default to a boolean arg of None so that we can differentiate between not set and explicitly set to false.

It looks like this may have been introduced in python 3.9 though, so I guess may need to be changed anyhow.

@@ -0,0 +1,242 @@
import asyncio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's probably move this file to executor/? And also move ray_utils to executor/?

@@ -322,19 +213,17 @@ def _run_workers(
method)(*driver_args, **driver_kwargs)

# Get the results of the ray workers.
if self.workers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this if is indeed redundant. However, workers can indeed be empty since the ray GPU executor can be run with TP=1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor looks really good! Thanks for the work!

self.result_handler.close()


class LocalWorkerVllm(mp.Process):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we choose to inherit mp.Process instead of just creating a mp.Process instance in the class? I'm not very familiar but is this a standard practice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually didn't write the first iteration of this PR ... I think it's fairly standard but probably more common/clearer to wrap an instance instead, I'll change this.

_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)

del self.tasks # Not used in forked process
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to explicitly delete this field here? Can we just not include this field in __init__?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ runs in the main/parent process (since that's where this class gets constructed), the object state will be copied to the forked/spawned process when start() is called. This task map is only used in the copy of this object in the main process - it's deleted here to make that clearer.

This will probably need to change though if we change to not subclass multiprocessing.Process.

@nivibilla
Copy link

nivibilla commented Apr 30, 2024

Just a note for this. I use Ray to do multi node batch inference with vllm. (On a 8x8*A10) And with models that fit in a single GPU it worked perfectly but trying to initialize tensor parallel models with ray, within a ray instance doesn't work. I think this solution is the only way to do multi node batch Inference with ray orchestrating the nodes. And the multiprocessing for tensor parallel 70b inside the worker node.

Thank you for this PR! Hope it gets merged soon.

@jacobthebanana
Copy link
Contributor

My team at work has been looking for a way to do efficient autoregressive generation during LLM fine-tuning. We'd like to tap into the efficiency of vLLM, but so far haven't been able to run torch FSDP alongside vLLM on the same set of GPUs. The changes proposed in this pull request have resolved our issue. Thanks again Nick for the great work, and I'd love to see this pull request being merged very soon.

@njhill
Copy link
Member Author

njhill commented May 8, 2024

@nivibilla @jacobthebanana I had some interrupts in the last few days but will make sure this lands this week (not this PR but the ones that replaced it).

@njhill
Copy link
Member Author

njhill commented May 15, 2024

This was broken into smaller PRs which have now all been merged, see #4539.

@njhill njhill closed this May 15, 2024
@njhill njhill deleted the ray-optional2 branch May 15, 2024 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.