Skip to content

Commit

Permalink
[core][distributed] fix custom allreduce in pytorch 2.5 (vllm-project…
Browse files Browse the repository at this point in the history
…#9815)

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
youkaichao authored and NickLucche committed Oct 31, 2024
1 parent e4817da commit 543400f
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,20 @@ def capture(self):

def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
handle = data[1]
# https://github.com/pytorch/pytorch/pull/130890 changes
# the binary format of the ipc handle
# it starts from pytorch 2.5
if len(handle) > 64:
assert len(handle) == 66
# only support SHAREABLE_HANDLE_VERSION = 1
assert int(handle[0]) == 1
# only support SHAREABLE_CUDA_MALLOC = 'c'
assert handle[1] == ord("c")
handle = handle[2:]
# TODO: support expandable segment
shard_data = (
data[1], # ipc handle to base ptr
handle, # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
Expand Down

0 comments on commit 543400f

Please sign in to comment.