-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider moving "Not in HLO" ops to CHLO #3
Comments
BroadcastOp can be lowered to |
In retrospect, this ticket was a bit ahead of its time:
As a result, we'll be focusing on completing the StableHLO v1.0 spec in the next few months, and by then we expect that all the blockers mentioned above will be resolved. Until then, I'll be unassigning this ticket. |
A proposal to remove redundant operations from StableHLO before long-term compatibility guarantees go into place. High level summary: - Remove `CreateTokenOp`, `TraceOp`, `BroadcastOp`, `DotOp`, `UnaryEinsumOp`, `RealDynamicSliceOp`. - Enhance `DynamicSliceOp`. - Move `CrossReplicaSumOp` to CHLO. - Hopefully remove/move to CHLO (need feedback) `MapOp`, `RngOp`, `EinsumOp`, `TorchIndexSelectOp`, `GetTupleElementOp`, `tuple` and `tuple` type. OpenXLA Discuss post: https://groups.google.com/a/openxla.org/g/openxla-discuss/c/sBAkvnd2bcA Related tickets: #2176, #3
In https://discourse.llvm.org/t/rfc-proposal-for-a-high-level-ml-dialect-in-mlir/64249/46, we did a survey of MHLO ops (almost of that applies to StableHLO, since we've bootstrapped it from MHLO) and identified the following ops which are decomposable to other MHLO ops:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
. Let's move these ops to CHLO.Note: In MHLO,
broadcast
anddot
are referring to less general versions ofbroadcast_in_dim
anddot_general
. In HLO,broadcast
anddot
are used to refer to the more general versions, so we have a bit of confusing naming here which this move will also help to solve.The text was updated successfully, but these errors were encountered: