-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Eliminate parallel worker per-step task scheduling overhead #4894
Conversation
This is failing the distributed spec decoding test - TP spec decoding was added since the original PR and I need to look more closely at the flow for it and how to best integrate these changes with that. |
we found nccl broadcasting overhead is very big now for high tp, and this PR is very important for reducing the gap. I will take a look at the PR tmrw! |
Thanks @rkooo567! This PR actually doesn't reduce any nccl broadcast overhead, it just eliminates most of the non-torch.distributed RPCs which turn out to be much more significant (see the measurements in the original PR). We can/should additionally reduce the amount of nccl broadcasting done, #4844 is a first step, I'm working on the next steps now. I'll try to address the above-mentioned spec decoding TP issue with this PR today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm wonder if there's a cleaner way to stop the execution loop... (seems like too much impl details are leaked) unfortunately, I couldn't come up with a very good idea, so right now it is probably the best to document things more verbosely.
return await self._driver_execute_model_async(execute_model_req) | ||
|
||
async def stop_remote_worker_execution_loop_async(self) -> None: | ||
if self.parallel_worker_tasks is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe assert instead? If this is None, doesn't that mean the state is kind of screwed up?
Btw, I also talked with @cadedaniel. I think this optimization makes spec decoding pretty complicated (and I feel like PP as well). @njhill I am curious if there's a way to disable this when spec decoding is enabled for the short term in a clean way? |
+1 for PP @rkooo567 |
Thanks for the great review comments @rkooo567!
@rkooo567 @andoorve this would definitely be an option but would also come at the cost of at least some additional complexity. I haven't yet dug in enough but I'm hopeful about adapting it to work with spec decoding too. That could be misplaced optimism and I'm less sure about PP. I'll first look at those and if the complexity is too great then changing it to work only in non- spec decoding / PP case could be the backup plan. I was actually hoping this could be a stepping-stone to elimination of the "control plane" RPCs altogether. The workers could remain in a permanent torch.distributed broadcast loop, with keepalives when idle if timeouts are an issue. |
@njhill This is a great point, I also committed a change to PyTorch in anticipation of these kinds of optimizations (Replace control-place RPC with torch distributed) for PP in the future: pytorch/pytorch@b96b1e8 It's a real pain to get multiple sends/recvs on a single rank though, so this is definitely a more future item once we get the basic functionality of PP working. |
@njhill btw, this was the major bottleneck for us now, so lmk if there's any way I can help accelerating the PR merge!! |
@rkooo567 about to update now, will push within next couple of hours hopefully! |
@njhill let me know when it is ready to review again! |
Thanks @rkooo567, I made some updates but had to deal with other things for the remainder of the day. If I don't get a chance to finish debugging tonight I'll do it first thing in the morning (pacific time) |
Approving spec decode changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if tests pass!
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: this can hang forever if the loop is not finished properly. Should we allow timeout here and kill workers if timeout is reached with an exception? (theoretically, I think if it takes > 30s, I think there's something wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rkooo567 I will check but I'm not sure that this is necessary. At least in the async case there's already an overall per-step timeout that would cover this I think.
Hey @njhill, did you get a chance to look at this? I feel like this would cause quite a few changes to PP in its current form.
|
@rkooo567 OK the spec decoding thing was a tiny fix. And I have added some more detail to the code comments per your latest review. Thanks @cadedaniel for reviewing too, I was going to ping you to ask for review of that part once the tests were passing :) @andoorve I can help to look at how we can adapt it to work with PP but it sounds like there is some urgency to get this PR merged first (for the upcoming release?) |
Hi @njhill, Got it, don't want to delay this PR, but just wanted to see if there was any way to loosen this assumption (some mechanism or fallback to per-step task scheduling) that only the rank 0 worker is sending over RPC and receiving results. The main idea of having other workers (non-drivers) in a task loop is thankfully compatible with PP, but it's just this assumption that there's only 1 parallel group that could cause quite a bit of friction. Happy to chat offline as well. |
@njhill thank you so much for the quick fix! |
@andoorve I will look more closely at the PP PR tomorrow (and we could chat too if you're free then). Based on what you said and a high level understanding/assumption about how that works I'm fairly confident we can solve it without too much effort. |
This PR replaces #3763.
Common logic is handled in the
DistributedGPUExecutor
superclass and used by both Ray and Mutliprocessing executors.