-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Phi-3-mini does not work when using Ray #6607
Comments
does |
cc @rkooo567 @richardliaw for the ray error. |
I feel like I have seen this before, and it may have been fixed in the latest version. have you tried 0.5.2? |
Oh dang I didn't even realize it came out. Let me upgrade and report back. Edit: |
Yes. I didn't explicitly set it in the original notebook because, as I understand it, |
I assume it is the same issue as #4286. @baughmann is there a repro I can try? |
maybe we need more fundamental solution for this case |
@rkooo567 Thanks for looking into this. I put the content of the Jupyter notebook that I used to produce the error in the OP, but I'll also put it here for your convenience :) The model used is just the official one. from vllm import AsyncEngineArgs, AsyncLLMEngine
# model source: https://huggingface.co/microsoft/Phi-3-mini-128k-instruct
# this config works
mp_args = AsyncEngineArgs(
model="../../models/microsoft/Phi-3-mini-128k-instruct",
trust_remote_code=True,
distributed_executor_backend="mp",
max_model_len=8000, # limit mem utilization for this example
disable_sliding_window=True, # needed in order to use flash-attn
)
# this config does not work. it just sits at
# "INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at..."
# The actor dies with a `ray.exceptions.RaySystemError: System error: No module named 'transformers_modules'`
ray_args = AsyncEngineArgs(
model="../../models/microsoft/Phi-3-mini-128k-instruct",
trust_remote_code=True,
max_model_len=8000,
engine_use_ray=True,
distributed_executor_backend="ray",
)
# engine = AsyncLLMEngine.from_engine_args(mp_args)
engine = AsyncLLMEngine.from_engine_args(ray_args) |
Hello. I've been investigating the same error but in the context of multi-node inference with Ray. I created #6751 which fixes the issue for me. Perhaps my fix will help in this scenario as well. I attempted to reproduce the error raised here using the code examples in this issue, but was unable to (using the latest vLLM code); the |
@tjohnson31415 In my case I'm running only on a single node. And oh, wow, the notebook didn't give you problems using the ray args? How strange |
Yeah, that makes it interesting. In my understanding, #4286 should be the most recent fix for the single node case, but that fix looks like it was included in release v0.4.1... Some other thoughts/questions:
|
Let me also take a look at it quickly. I am a little busy by other high priority task from our end. |
@tjohnson31415 What I'll do is make a minimal conda env with a minimal reqs etc. and then perform the troubleshooting on that. I will upload and post that repo when I'm able, but it may not be until later today |
@tjohnson31415 @rkooo567 Here is a repo with conda for you all. Also created a basic readme for your convenience. @tjohnson31415 Here's the update regarding your suggestions:
Yes, it does. As I understand it, with
Yes, I'm afraid so. That's not surprising though as I run Jupyter notebooks in the same exact development environment I'm writing my application in. For me, it's just a quick and easy way to experiment with specific parts of my application.
Yes, it does. I added a note about this in the readme of the repo I posted. I see a |
@baughmann Ah, thanks for creating the repro-repo! I didn't realize that
To make the workers executing the model use Ray, The error occurs because the Ray worker spawned for the engine loop with The current way that the (non-engine) Ray workers handle this is that the But quickest fix is |
@tjohnson31415 That most certainly did it! Thank you for the detailed explanation, that makes a lot of sense! However, I would still expect feature parity among the supported models. Should we leave this ticket open, even though there is that workaround? |
I am having this issue as well, and the workaround works. I am also curious as to when this will be implemented into the engine? If there is an open branch or fork, can someone link it here? EDIT: NVM I found it! Thank you all! |
Just to note it here, there is a new RFC to remove If the RFC is accepted, a fix for this issue may not be relevant for very long. |
It would appear mp exec is now also affected in 0.6.1.post2
|
workaround for me was to remove '.' from the model name (path), before instantiating the engine. e.g. "weights/Phi3.5mini-instruct" -> "weights/Phi35mini-instruct" (on top of disabling Ray) |
@nightflight-dk it makes sense. the name might be used for import, and |
Your current environment
🐛 Describe the bug
Feel free to use this gist with a minimal Jupyter notebook.
When attempting to load any Phi-3 mini/small model using the
AsyncLLMEngine
and specifyingray
as the distributed backend, Ray throws a:A
pip list
in my main project showsalthough it sounds like this is likely not a bug with my project.
I highly encourage you to look at the Jupyter notebook, but for completeness, here's how I'm trying to load the model:
Additionally, here's the full System log from the dead actor:
Also, thank you guys for such a great library. Its very easy and fun to use and bugs like this are few and far between 😄
Edit: I've also tried this with 0.5.2 and the 0.5.3 prerelease per @rkooo567 's question
The text was updated successfully, but these errors were encountered: