Skip to content

Commit

Permalink
Updates based on new semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
chaserileyroberts committed Nov 1, 2023
1 parent 87521c0 commit 059d932
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions rfcs/20231017-collective-broadcast.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Status: Review<br/>
Initial version: 10/17/20223<br/>
Last updated: 10/17/2023<br/>
Last updated: 11/1/2023<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809)

## Motivation
Expand Down Expand Up @@ -30,49 +30,44 @@ Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper]

Within each process group in the StableHLO process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor. The first process in each of the `replica_groups` is treated as the
source process. For any process `n` not in any replica group, that process will just return
zeros.
`result` tensor. For any process `n` not in any replica group, that process will just
return zeros.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:

* `cross_replica(replica_groups)`
if `channel_id <= 0 and use_global_device_ids = false`.
* `cross_replica_and_partition(replica_groups)`
if `channel_id > 0 and use_global_device_ids = false`.
* `flattened_ids(replica_groups)`
if `channel_id > 0 and use_global_device_ids = true`.
* `cross_replica(replica_groups)` if `channel_id <= 0`.
* `cross_partition(replica_groups)` if `channel_id > 0`.

Afterwards, within each `process_group`:
Afterwards, `result@process` is given by:

* `operand@process_groups[i, 0]` if there exists an `i` such that
the process is in `process_groups[i]`.
* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))`
if process is never found in `process_groups`.

* `result@process[result_index] = operand@process[0]`

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor | (C5) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) |
| (I3) | `channel_id` | constant of type `si64` | (C4) |
| (I4) | `use_global_device_ids` | constant of type `i1` | (C4) |
| (I1) | `operand` | tensor | (C3) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C2) |
| (I3) | `channel_id` | constant of type `si64` | |

#### Outputs

| Name | Type | Constraints |
|----------|--------|-------------|
| `result` | tensor | (C5) |
| `result` | tensor | (C3) |

#### Constraints

* (C1) `is_unique(replica_groups)`.
* (C2) `size(replica_groups)` is defined as:
* (C1) is_unique(replica_groups).
* (C2) `0 <= replica_groups < N` where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used.
* (C3) `0 <= replica_groups < size(replica_groups)`.
* (C4) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C5) `type(result) = type(operand)`.
* `num_partitions` if `cross_partition` is used.
* (C3) `type(result) = type(operand)`.

#### Examples

Expand All @@ -85,9 +80,8 @@ Afterwards, within each `process_group`:
// %operand@(0, 3): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 0], [1, 3]]> : tensor<2x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
// channel_id = 1
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[5, 6]]
// %result@(0, 1): [[3, 4]]
Expand All @@ -105,9 +99,8 @@ Example with missing processes
// %operand@(0, 3): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
// channel_id = 1
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(0, 1): [[5, 6]]
Expand Down

0 comments on commit 059d932

Please sign in to comment.