diff --git a/rfcs/20231017-collective-broadcast.md b/rfcs/20231017-collective-broadcast.md index 559d6426496..d9682936e2f 100644 --- a/rfcs/20231017-collective-broadcast.md +++ b/rfcs/20231017-collective-broadcast.md @@ -2,7 +2,7 @@ Status: Review
Initial version: 10/17/20223
-Last updated: 10/17/2023
+Last updated: 11/1/2023
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809) ## Motivation @@ -30,49 +30,44 @@ Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper] Within each process group in the StableHLO process grid, send the value of the `operand` tensor from the source process to the target processes and produce a -`result` tensor. The first process in each of the `replica_groups` is treated as the -source process. For any process `n` not in any replica group, that process will just return -zeros. +`result` tensor. For any process `n` not in any replica group, that process will just +return zeros. The operation splits the StableHLO process grid into `process_groups` which is defined as follows: -* `cross_replica(replica_groups)` - if `channel_id <= 0 and use_global_device_ids = false`. -* `cross_replica_and_partition(replica_groups)` - if `channel_id > 0 and use_global_device_ids = false`. -* `flattened_ids(replica_groups)` - if `channel_id > 0 and use_global_device_ids = true`. +* `cross_replica(replica_groups)` if `channel_id <= 0`. +* `cross_partition(replica_groups)` if `channel_id > 0`. -Afterwards, within each `process_group`: +Afterwards, `result@process` is given by: + +* `operand@process_groups[i, 0]` if there exists an `i` such that + the process is in `process_groups[i]`. +* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))` + if process is never found in `process_groups`. -* `result@process[result_index] = operand@process[0]` #### Inputs | Label | Name | Type | Constraints | |-------|-------------------------|------------------------------------------------------------------|-------------| -| (I1) | `operand` | tensor | (C5) | -| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) | -| (I3) | `channel_id` | constant of type `si64` | (C4) | -| (I4) | `use_global_device_ids` | constant of type `i1` | (C4) | +| (I1) | `operand` | tensor | (C3) | +| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C2) | +| (I3) | `channel_id` | constant of type `si64` | | #### Outputs | Name | Type | Constraints | |----------|--------|-------------| -| `result` | tensor | (C5) | +| `result` | tensor | (C3) | #### Constraints -* (C1) `is_unique(replica_groups)`. -* (C2) `size(replica_groups)` is defined as: +* (C1) is_unique(replica_groups). +* (C2) `0 <= replica_groups < N` where `N` is defined as: * `num_replicas` if `cross_replica` is used. - * `num_replicas` if `cross_replica_and_partition` is used. - * `num_processes` if `flattened_ids` is used. -* (C3) `0 <= replica_groups < size(replica_groups)`. -* (C4) If `use_global_device_ids = true`, then `channel_id > 0`. -* (C5) `type(result) = type(operand)`. + * `num_partitions` if `cross_partition` is used. +* (C3) `type(result) = type(operand)`. #### Examples @@ -85,9 +80,8 @@ Afterwards, within each `process_group`: // %operand@(0, 3): [[7, 8]] %result = "stablehlo.collective_broadcast"(%operand) { replica_groups = dense<[[2, 0], [1, 3]]> : tensor<2x2xi64>, - // channel_id = 0 - channel_handle = #stablehlo.channel_handle - // use_global_device_ids = false + // channel_id = 1 + channel_handle = #stablehlo.channel_handle } : (tensor1x2xi64>) -> tensor<1x2xi64> // %result@(0, 0): [[5, 6]] // %result@(0, 1): [[3, 4]] @@ -105,9 +99,8 @@ Example with missing processes // %operand@(0, 3): [[7, 8]] %result = "stablehlo.collective_broadcast"(%operand) { replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, - // channel_id = 0 - channel_handle = #stablehlo.channel_handle - // use_global_device_ids = false + // channel_id = 1 + channel_handle = #stablehlo.channel_handle } : (tensor1x2xi64>) -> tensor<1x2xi64> // %result@(0, 0): [[0, 0]] // %result@(0, 1): [[5, 6]]