Skip to content

Commit

Permalink
Workaround PyTorch IPC handle issue (opendatahub-io#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenkaidu authored Aug 28, 2024
1 parent 07b6b14 commit 056baed
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,20 @@ def capture(self):
self.register_graph_buffers()

def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
if is_hip():
# _share_cuda_() doesn't accept meta buffer not allocated from
# PyTorch cache allocator, use direct HIP call to get IPC handle
handle = custom_ar.get_meta_buffer_ipc_handle(inp)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
else:
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)

def _gather_ipc_meta(self, shard_data):
Expand Down

0 comments on commit 056baed

Please sign in to comment.