Skip to content

Commit

Permalink
Add some docstrings for remote DMAs and semaphore barriers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627037991
  • Loading branch information
jax authors committed Apr 22, 2024
1 parent b79f3b7 commit 667a0c1
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,24 @@ def async_copy(src_ref, dst_ref, sem):

def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
device_id_type: DeviceIdType = DeviceIdType.MESH):
"""Creates a description of a remote copy operation.
Copies data from src_ref on the current device to dst_ref on the device
specified by device_id. Both semaphores should be waited on using the
descriptor on both source and target devices.
Note that device_id can also refer to the current device.
Args:
src_ref: The source Reference.
dst_ref: The destination Reference.
send_sem: The semaphore on the source device.
recv_sem: The semaphore on the destination device.
device_id: The device id of the destination device.
device_id_type: The type of the device id.
Returns:
An AsyncCopyDescriptor.
"""
src_ref, src_indexers = _get_ref_and_indexers(src_ref)
send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem)
dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
Expand Down Expand Up @@ -576,4 +594,24 @@ def _get_barrier_semaphore_abstract_eval():
)

def get_barrier_semaphore():
"""Returns a barrier semaphore.
This function returns a barrier semaphore based on the collective_id of the
current pallas kernel.
It's very important that the semaphore is wait-ed back down to 0, or else the
semaphores will become corrupted.
It's also very important that the collective_id is different for each pallas
kernel with communication. E.g. if you have two pallas kernels, one that syncs
across the X axis of the device mesh and the second that syncs across the Y
axis, they must have different collective_ids.
However it is legal for two kernels that perform the same synchronization
pattern (e.g. only communicating with neighbours on the same mesh axis)
to share a collective_id. However, if in doubt, prefer not sharing
collective_ids, as doing so incorrectly can lead to silent data corruption or
crashes.
Note that re-using the same collective_id doesn't guarantee that the same
semaphore is provided by XLA.
"""
return get_barrier_semaphore_p.bind()

0 comments on commit 667a0c1

Please sign in to comment.