Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bootstrap StableHLO from CHLO/MHLO #1

Merged
merged 4 commits into from
Aug 17, 2022
Merged

Bootstrap StableHLO from CHLO/MHLO #1

merged 4 commits into from
Aug 17, 2022

Conversation

burmako
Copy link
Contributor

@burmako burmako commented Aug 9, 2022

Motivation

Recent discussions
highlight an acute need for stability of interchange dialects in between
ML frameworks and ML compilers in the opensource community.

In a Discourse post,
silvasean@ called out a potential solution - something called "shallow dialects"
that producers could vendor into their repositories and upgrade with
well-defined backward compatibility windows.

I think this presents a great opportunity for StableHLO: to start as shallow MHLO
which will bootstrap us from a well-understood baseline and will enable us to
provide a service to the community right away - backward compatibility
guarantees for MHLO (and its sister dialect CHLO as well).

This doesn't mean that the scope of StableHLO should be limited by the scope of
CHLO/MHLO. Bootstrapping from CHLO/MHLO is just the beginning of the journey,
and there are a lot of ideas to explore to evolve StableHLO beyond that.
(See "Future work" below for some of these ideas).

Summary

All of the CHLO ops are forked into CHLO ops in this repository.

Most of the MHLO ops are forked into StableHLO ops in this repository
(except for the 8 ops which are private to XLA, i.e. don't seem to be created by
existing producers and are only created inside XLA itself: add_dependency,
bitcast, copy, domain, fusion, partition_id, tuple,
xla.rng_get_and_update_state).
For further details, see the categorization of MHLO ops at:
https://discourse.llvm.org/t/rfc-proposal-for-a-high-level-ml-dialect-in-mlir/64249/46.

These forks are "shallow", i.e. they include only essential
functionality for producers. All folders and canonicalizers are dropped
(except for ConstantOp folders which are necessary to satisfy the
ConstantLike trait). No passes are forked, except for a pass which helps
test shape inference.

This uses MLIR-HLO at revision eb0a4d2ae01e9a39dfa079648303d9f3576b096a:
tensorflow/mlir-hlo@eb0a4d2
If we decide to proceed with this pull request, we can have a discussion which
commit from MLIR-HLO to fork from (e.g. we might want to use the MLIR-HLO HEAD
at that time).

Details

  1. hlo_ops_base.{h,td} are forked into Base.{h,td} and include only the
    functionality which is shared between CHLO and StableHLO.

  2. hlo_ops_base_attrs.td is forked into StablehloAttrs.td because it's
    not used in CHLO. hlo_ops_base_enums.td is forked into a combination of
    ChloEnums.td and StablehloEnums.td. ComparisonType and
    ComparisonDirection now exist in both CHLO and StableHLO, to avoid
    a dependency from the former on the latter.

  3. hlo_ops_common.h is merged into StablehloOps.h. The "common" part
    is referring to this functionality shared between MHLO and LMHLO
    dialects, but LMHLO already depends on MHLO anyway, so there's no need
    to have a separate abstraction for this.

  4. All dependencies from CHLO to MHLO are severed (via introducing
    some extra functions in Base.h, creating chlo.constant and forking the
    enums described above from MHLO to CHLO), so that CHLO and StableHLO are
    two independent dialects now.

  5. I also think it will also be a good idea to move 10 decomposable ops:
    batch_norm_grad, batch_norm_inference, batch_norm_training, broadcast,
    create_token, cross-replica-sum, dot, einsum, torch_index_select,
    unary_einsum - into CHLO, but that is left for future work.

Future work

A) As the next step, I propose that we vendor this repository into
MLIR-HLO (with the plan to regularly integrate changes into MLIR-HLO)
and introduce a conversion from StableHLO to MHLO in order to connect
StableHLO producers and MHLO consumers.

B) That would enable MHLO producers to consider experimenting with StableHLO
by changing the library they are depending on (from MHLO to StableHLO) and
changing mhlo. and mhlo:: references to stablehlo. and stablehlo::.
We're talking with JAX, TF/XLA and Torch-MLIR teams about these experiments.

C) Speaking of compatibility guarantees, I propose that we initially provide
them at the boundary between StableHLO and MHLO (so that StableHLO provides
stability, and MHLO can keep evolving mostly freely). Potential MLIR upstream
functionality along the lines of [this patch](https://reviews.llvm.org/D117761]
would enable StableHLO to be versioned.

D) In the longer term, StableHLO work can be structured into three workstreams:

  • Workstream 1: Stable version of HLO/MHLO.
    Specification, test suite, reference implementation - ETA: H2 2022.
  • Workstream 2: Evolution beyond what's currently in HLO/MHLO.
    Ongoing work on dynamism, sparsity, quantization and extensibility -
    ETA: H2 2022.
  • Workstream 3: Adoption of StableHLO.
    Support for ML frameworks (TensorFlow, JAX, PyTorch) and ML compilers
    (XLA and IREE) - ETA: H2 2022.

E) This repository is just getting bootstrapped. In the future, work items will
be tracked via Issues / Projects and perhaps other documentation.

Motivation
----------

[Recent discussions](https://discourse.llvm.org/t/coordinate-llvm-commits-for-different-project/63990)
highlight an acute need for stability of interchange dialects in between
ML frameworks and ML compilers in the opensource community.

In [a Discourse post](https://discourse.llvm.org/t/coordinate-llvm-commits-for-different-project/63990/7),
silvasean@ called out a potential solution - something called "shallow dialects"
that producers could vendor into their repositories and upgrade with
well-defined backward compatibility windows.

I think this presents a great opportunity for StableHLO: to start as shallow MHLO
which will bootstrap us from a well-understood baseline and will enable us to
provide a service to the community right away - backward compatibility
guarantees for MHLO (and its sister dialect CHLO as well).

This doesn't mean that the scope of StableHLO should be limited by the scope of
CHLO/MHLO. Bootstrapping from CHLO/MHLO is just the beginning of the journey,
and there are a lot of ideas to explore to evolve StableHLO beyond that.
(See "Future work" below for some of these ideas).

Summary
-------

All of the CHLO ops are forked into CHLO ops in this repository.

Most of the MHLO ops are forked into StableHLO ops in this repository
(except for the 8 ops which are private to XLA, i.e. don't seem to be created by
existing producers and are only created inside XLA itself: `add_dependency`,
`bitcast`, `copy`, `domain`, `fusion`, `partition_id`, `tuple`,
`xla.rng_get_and_update_state`).
For further details, see the categorization of MHLO ops at:
https://discourse.llvm.org/t/rfc-proposal-for-a-high-level-ml-dialect-in-mlir/64249/46.

These forks are "shallow", i.e. they include only essential
functionality for producers. All folders and canonicalizers are dropped
(except for `ConstantOp` folders which are necessary to satisfy the
`ConstantLike` trait). No passes are forked, except for a pass which helps
test shape inference.

This uses MLIR-HLO at revision eb0a4d2ae01e9a39dfa079648303d9f3576b096a:
tensorflow/mlir-hlo@eb0a4d2
If we decide to proceed with this pull request, we can have a discussion which
commit from MLIR-HLO to fork from (e.g. we might want to use the MLIR-HLO HEAD
at that time).

Details
-------

1) `hlo_ops_base.{h,td}` are forked into `Base.{h,td}` and include only the
functionality which is shared between CHLO and StableHLO.

2) `hlo_ops_base_attrs.td` is forked into `StablehloAttrs.td` because it's
not used in CHLO. `hlo_ops_base_enums.td` is forked into a combination of
`ChloEnums.td` and `StablehloEnums.td`. `ComparisonType` and
`ComparisonDirection` now exist in both CHLO and StableHLO, to avoid
a dependency from the former on the latter.

3) `hlo_ops_common.h` is merged into `StablehloOps.h`. The "common" part
is referring to this functionality shared between MHLO and LMHLO
dialects, but LMHLO already depends on MHLO anyway, so there's no need
to have a separate abstraction for this.

4) All dependencies from CHLO to MHLO are severed (via introducing
some extra functions in `Base.h`, creating `chlo.constant` and forking the
enums described above from MHLO to CHLO), so that CHLO and StableHLO are
two independent dialects now.

5) I also think it will also be a good idea to move 10 decomposable ops:
`batch_norm_grad`, `batch_norm_inference`, `batch_norm_training, broadcast`,
`create_token`, `cross-replica-sum`, `dot`, `einsum`, `torch_index_select`,
`unary_einsum` - into CHLO, but that is left for future work.

Future work
-----------

A) As the next step, I propose that we vendor this repository into
MLIR-HLO (with the plan to regularly integrate changes into MLIR-HLO)
and introduce a conversion from StableHLO to MHLO in order to connect
StableHLO producers and MHLO consumers.

B) That would enable MHLO producers to consider experimenting with StableHLO
by changing the library they are depending on (from MHLO to StableHLO) and
changing `mhlo.` and `mhlo::` references to `stablehlo.` and `stablehlo::`.
We're talking with JAX, TF/XLA and Torch-MLIR teams about these experiments.

C) Speaking of compatibility guarantees, I propose that we initially provide
them at the boundary between StableHLO and MHLO (so that StableHLO provides
stability, and MHLO can keep evolving mostly freely). Potential MLIR upstream
functionality along the lines of [this patch](https://reviews.llvm.org/D117761]
would enable StableHLO to be versioned.

D) In the longer term, StableHLO work can be structured into three workstreams:
  * Workstream #1: Stable version of HLO/MHLO.
    Specification, test suite, reference implementation - ETA: H2 2022.
  * Workstream #2: Evolution beyond what's currently in HLO/MHLO.
    Ongoing work on dynamism, sparsity, quantization and extensibility -
    ETA: H2 2022.
  * Workstream #3: Adoption of StableHLO.
    Support for ML frameworks (TensorFlow, JAX, PyTorch) and ML compilers
    (XLA and IREE) - ETA: H2 2022.

E) This repository is just getting bootstrapped. In the future, work items will
be tracked via Issues / Projects and perhaps other documentation.
def StableHLO_Dialect : Dialect {
let name = "stablehlo";
let cppNamespace = "::mlir::stablehlo";

Copy link
Member

@GleasonK GleasonK Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a let description = [{ ... }]; in case we plan to generate documentation from TD at some point.

Could use the repo description:

Suggested change
let description = [{
StableHLO is an operation set that expresses ML computations. It has been
originally bootstrapped from the MHLO dialect and enhances it with additional
functionality, including serialization and versioning, to be used as a portability
layer between ML frameworks and ML compilers.
}];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! That's a nice catch. We have a description for the StableHLO dialect in its current dummy state, but I accidentally deleted it when bootstrapping from MHLO. Will fix later today.

@burmako burmako merged commit d6c918d into openxla:main Aug 17, 2022
This was referenced Aug 17, 2022
@burmako burmako self-assigned this Oct 2, 2022
@burmako burmako added the RFC label Oct 2, 2022
@burmako burmako mentioned this pull request Dec 19, 2022
burmako pushed a commit that referenced this pull request Feb 8, 2023
I recently noticed this code when reviewing #849, and I'm not sure why
we need it there.

This seems like a pretty strong statement about a fundamental role of
the Tensor dialect in the workings of the StableHLO dialect, and I don't
think we have established that yet.

It would seem that we've inherited this from MHLO when bootstrapping
StableHLO (#1), but I don't think I understand the reasoning on the MHLO
side either. This change was introduced as part of an LLVM integrate
(tensorflow/mlir-hlo@ba0346b),
and the commit description doesn't go into detail about motivation.

Given that, I propose to revert this in the StableHLO dialect and see
what happens. All tests in this repository are passing, but maybe we'll
learn more after downstream integrations.
atondwal pushed a commit to atondwal/stablehlo that referenced this pull request Mar 3, 2023
)

I recently noticed this code when reviewing openxla#849, and I'm not sure why
we need it there.

This seems like a pretty strong statement about a fundamental role of
the Tensor dialect in the workings of the StableHLO dialect, and I don't
think we have established that yet.

It would seem that we've inherited this from MHLO when bootstrapping
StableHLO (openxla#1), but I don't think I understand the reasoning on the MHLO
side either. This change was introduced as part of an LLVM integrate
(tensorflow/mlir-hlo@ba0346b),
and the commit description doesn't go into detail about motivation.

Given that, I propose to revert this in the StableHLO dialect and see
what happens. All tests in this repository are passing, but maybe we'll
learn more after downstream integrations.
abhigunj referenced this pull request in abhigunj/stablehlo Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants