Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider moving "Not in HLO" ops to CHLO #3

Open
burmako opened this issue Aug 17, 2022 · 3 comments
Open

Consider moving "Not in HLO" ops to CHLO #3

burmako opened this issue Aug 17, 2022 · 3 comments
Assignees

Comments

@burmako
Copy link
Contributor

burmako commented Aug 17, 2022

In https://discourse.llvm.org/t/rfc-proposal-for-a-high-level-ml-dialect-in-mlir/64249/46, we did a survey of MHLO ops (almost of that applies to StableHLO, since we've bootstrapped it from MHLO) and identified the following ops which are decomposable to other MHLO ops: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum. Let's move these ops to CHLO.

Note: In MHLO, broadcast and dot are referring to less general versions of broadcast_in_dim and dot_general. In HLO, broadcast and dot are used to refer to the more general versions, so we have a bit of confusing naming here which this move will also help to solve.

@subhankarshah
Copy link
Member

BroadcastOp can be lowered to stablehlo.broadcast_in_dim for static shaped tensors and stablehlo.dynamic_broadcast_in_dim for ranked dynamic shaped tensors, however, It is not possible to decompose broadcastOp into any other stablehlo op for unranked tensors. We wait for the release of dynamism RFC to decide upon the migration of broadcastOp.

@burmako burmako assigned burmako and unassigned subhankarshah Sep 28, 2022
@burmako
Copy link
Contributor Author

burmako commented Sep 30, 2022

In retrospect, this ticket was a bit ahead of its time:

  1. The fact that we haven't established a baseline specification prevented us from effectively planning the move. For example, as discussed above, it turned out that we cannot move broadcast to CHLO, because unranked broadcasts cannot be expressed via broadcast_in_dim.
  2. Our compatibility guarantees haven't yet been finalized (see StableHLO Compatibility RFC v2 #115), so it's unclear what exactly this move will entail. E.g. are we going to keep both chlo.broadcast and stablehlo.broadcast around? For how long? How do we migrate users?
  3. We don't yet have an RFC process to reason about changes to the StableHLO opset. Introduce StableHLO evolution process #196 was only introduced a few days ago and will still need some additional work.

As a result, we'll be focusing on completing the StableHLO v1.0 spec in the next few months, and by then we expect that all the blockers mentioned above will be resolved. Until then, I'll be unassigning this ticket.

@burmako burmako removed their assignment Sep 30, 2022
@burmako burmako added CHLO and removed Spec labels Nov 24, 2022
@burmako burmako changed the title Move decomposable StableHLO ops to CHLO Consider moving "Not in HLO" ops to CHLO Nov 25, 2022
@burmako
Copy link
Contributor Author

burmako commented Nov 25, 2022

Refactored this ticket to focus on ops which don't have equivalents in the HLO opset (not counting dynamism ops which are handled separately in #8). BatchNorm ops which were previously part of this ticket are now discussed in #603.

@burmako burmako moved this to Todo in Frontend contract Apr 23, 2023
@GleasonK GleasonK added the Spec label Apr 8, 2024
@GleasonK GleasonK self-assigned this Apr 12, 2024
@GleasonK GleasonK moved this from Ready to In progress in StableHLO v1.0 Release May 6, 2024
@GleasonK GleasonK moved this from Todo to In Progress in Frontend contract May 6, 2024
GleasonK added a commit that referenced this issue May 13, 2024
A proposal to remove redundant operations from StableHLO before
long-term compatibility guarantees go into place.

High level summary:
- Remove `CreateTokenOp`, `TraceOp`, `BroadcastOp`, `DotOp`,
`UnaryEinsumOp`, `RealDynamicSliceOp`.
- Enhance `DynamicSliceOp`.
- Move `CrossReplicaSumOp` to CHLO.
- Hopefully remove/move to CHLO (need feedback) `MapOp`, `RngOp`,
`EinsumOp`, `TorchIndexSelectOp`, `GetTupleElementOp`, `tuple` and `tuple` type.

OpenXLA Discuss post:
https://groups.google.com/a/openxla.org/g/openxla-discuss/c/sBAkvnd2bcA

Related tickets: #2176, #3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Status: In progress
Development

No branches or pull requests

3 participants