diff --git a/rfcs/20231017-collective-broadcast.md b/rfcs/20231017-collective-broadcast.md
index 559d6426496..d9682936e2f 100644
--- a/rfcs/20231017-collective-broadcast.md
+++ b/rfcs/20231017-collective-broadcast.md
@@ -2,7 +2,7 @@
Status: Review
Initial version: 10/17/20223
-Last updated: 10/17/2023
+Last updated: 11/1/2023
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809)
## Motivation
@@ -30,49 +30,44 @@ Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper]
Within each process group in the StableHLO process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
-`result` tensor. The first process in each of the `replica_groups` is treated as the
-source process. For any process `n` not in any replica group, that process will just return
-zeros.
+`result` tensor. For any process `n` not in any replica group, that process will just
+return zeros.
The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:
-* `cross_replica(replica_groups)`
- if `channel_id <= 0 and use_global_device_ids = false`.
-* `cross_replica_and_partition(replica_groups)`
- if `channel_id > 0 and use_global_device_ids = false`.
-* `flattened_ids(replica_groups)`
- if `channel_id > 0 and use_global_device_ids = true`.
+* `cross_replica(replica_groups)` if `channel_id <= 0`.
+* `cross_partition(replica_groups)` if `channel_id > 0`.
-Afterwards, within each `process_group`:
+Afterwards, `result@process` is given by:
+
+* `operand@process_groups[i, 0]` if there exists an `i` such that
+ the process is in `process_groups[i]`.
+* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))`
+ if process is never found in `process_groups`.
-* `result@process[result_index] = operand@process[0]`
#### Inputs
| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
-| (I1) | `operand` | tensor | (C5) |
-| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) |
-| (I3) | `channel_id` | constant of type `si64` | (C4) |
-| (I4) | `use_global_device_ids` | constant of type `i1` | (C4) |
+| (I1) | `operand` | tensor | (C3) |
+| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C2) |
+| (I3) | `channel_id` | constant of type `si64` | |
#### Outputs
| Name | Type | Constraints |
|----------|--------|-------------|
-| `result` | tensor | (C5) |
+| `result` | tensor | (C3) |
#### Constraints
-* (C1) `is_unique(replica_groups)`.
-* (C2) `size(replica_groups)` is defined as:
+* (C1) is_unique(replica_groups).
+* (C2) `0 <= replica_groups < N` where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
- * `num_replicas` if `cross_replica_and_partition` is used.
- * `num_processes` if `flattened_ids` is used.
-* (C3) `0 <= replica_groups < size(replica_groups)`.
-* (C4) If `use_global_device_ids = true`, then `channel_id > 0`.
-* (C5) `type(result) = type(operand)`.
+ * `num_partitions` if `cross_partition` is used.
+* (C3) `type(result) = type(operand)`.
#### Examples
@@ -85,9 +80,8 @@ Afterwards, within each `process_group`:
// %operand@(0, 3): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 0], [1, 3]]> : tensor<2x2xi64>,
- // channel_id = 0
- channel_handle = #stablehlo.channel_handle
- // use_global_device_ids = false
+ // channel_id = 1
+ channel_handle = #stablehlo.channel_handle
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[5, 6]]
// %result@(0, 1): [[3, 4]]
@@ -105,9 +99,8 @@ Example with missing processes
// %operand@(0, 3): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
- // channel_id = 0
- channel_handle = #stablehlo.channel_handle
- // use_global_device_ids = false
+ // channel_id = 1
+ channel_handle = #stablehlo.channel_handle
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(0, 1): [[5, 6]]