Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mhlo: migrate conversion to stablehlo #1840

Merged
merged 3 commits into from
Feb 2, 2023
Merged

mhlo: migrate conversion to stablehlo #1840

merged 3 commits into from
Feb 2, 2023

Conversation

ashay
Copy link
Collaborator

@ashay ashay commented Feb 1, 2023

This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
convert-torch-to-stablehlo pass and so that they check for StableHLO
operations.

@burmako
Copy link

burmako commented Feb 1, 2023

@ashay Oh that's a good catch! When we forked StableHLO from MHLO, we did a survey of frontends and looked into which MHLO ops are used and which ops aren't. Only the former ops made the cut, and copy was in the latter category. Sounds like we have missed the usage of copy in Torch-MLIR!

mhlo.copy is an interesting op. It was added a while ago - I imagine it was done for the reasons of parity with kCopy in HLO - but it didn't model the fullness of HLO's semantics, and became kind of a no-op (MHLO semantics involves immutable tensors, so a copy is just a no-op). Then, at a later point, this was recognized, and a folder was added which unconditionally replaced mhlo.copy with its operand.

I see that at the moment mhlo.copy is only used in one place in Torch-MLIR. I wonder if it's possible to just use its operand instead of creating this op in the first place:

INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
.

lib/Conversion/TorchToMhlo/Basic.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToMhlo/TorchToMhlo.cpp Outdated Show resolved Hide resolved
@silvasean
Copy link
Contributor

Nit: possibly in a separate PR, we can rename all the Mhlo in file paths / directories to Stablehlo (note h is lowercase officially in the dialect name).

Also when possible, updating the VerifyStablehloBackendContract pass to reject Mhlo would be good (in some sense that is the true milestone of completion of this workstream).

And if we can get sign-off from @burmako that it is indeed checking the correct set of ops: https://github.com/llvm/torch-mlir/pull/1840/files#diff-d8eb5bebe72fb7e5c0d49bd9c2234d6ac74e44f250d38e7f0d7af0fee52530d7 -- two questions:

  • Is CHLO part of the StableHLO set?
  • Is shape::ShapeOfOp part of StableHLO set?

@ashay
Copy link
Collaborator Author

ashay commented Feb 1, 2023

Nit: possibly in a separate PR, we can rename all the Mhlo in file paths / directories to Stablehlo (note h is lowercase officially in the dialect name).

Fixed in this PR.

Also when possible, updating the VerifyStablehloBackendContract pass to reject Mhlo would be good (in some sense that is the true milestone of completion of this workstream).

Since the TorchToMhlo pipeline ultimately lowers to Mhlo, I don't think we can get rid of the VerifyMhloBackendContract pass. So I added the VerifyStablehloBackendContract pass after the torch-to-stablehlo conversion while retaining the existing VerifyMhloBackendContract pass.,

ashay added 2 commits February 1, 2023 13:44
This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
Since the dialect translation previous generated MHLO operations, the
names of many files, functions, and passes referred to the MHLO dialect.
This patch renames them to refer to Stablehlo.
@silvasean
Copy link
Contributor

silvasean commented Feb 1, 2023

Since the TorchToMhlo pipeline ultimately lowers to Mhlo, I don't think we can get rid of the VerifyMhloBackendContract pass. So I added the VerifyStablehloBackendContract pass after the torch-to-stablehlo conversion while retaining the existing VerifyMhloBackendContract pass.,

It should lower to StableHLO and not Mhlo. The later lowering to MHLO is really just used as part of e2e testing since only mhlo has a path to linalg. Other than that incidental aspect MHLO is irrelevant to Torch-MLIR. (StableHLO has a path to MHLO, but that is not our project's responsibility)

Concretely we would put lowering stablehlo->mhlo here:

"builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))",

(and that whole directory structure also needs to be renamed to stablehlo, as well as the e2e test config)

@burmako
Copy link

burmako commented Feb 2, 2023

As far as VerifyStablehloBackendContract goes, how about we populate it with the contents of the current VerifyMhloBackendContract with the only change being mhlo::MhloDialect replaced with stablehlo::StablehloDialect, and then over time work on shrinking this contract?

Here's the rationale for this proposal: StableHLO is ready to serve as a drop-in replacement for MHLO in the role of an interface between ML frameworks and ML compilers. It supports everything that MHLO supports at the interface level, and then some, but some future work still remains, namely:

  1. We are planning to replace func.func and func.return in StableHLO programs with StableHLO-specific ops (Spec and implement Module/Func/Call/Return ops openxla/stablehlo#425).
  2. There is no standard way of writing dynamically-shaped programs at the moment. The spec doesn't take a position on this at the moment of writing, because it's ongoing work (Write a dynamism RFC openxla/stablehlo#8).
  3. CHLO is not part of the spec at the moment of writing, and I'm starting to think that maybe it shouldn't be (Create a plan for CHLO openxla/stablehlo#602).

Given that, the existing pieces of VerifyMhloBackendContract make sense:

  • target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes) corresponds to 1).
  • target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes), target.addLegalDialect<tensor::TensorDialect>(), target.addLegalDialect<arith::ArithDialect>() correspond to 2).
  • target.addLegalDialect<chlo::ChloDialect>() correspond to 3).

As we're working through these open questions, I expect that we'll be able to eventually shrink the backend contract. For example, one potential future could involve: a) keeping builtin.module, b) introducing stablehlo.func and using stablehlo.return for all returns, c) vendoring all shape computations into the StableHLO dialect, obviating the need for ops from arith, shape and tensor dialects in StableHLO programs, d) deconstructing CHLO to StableHLO. In that future, the contract would shrink to allowing just the ModuleOp and the StableHLO dialect.

What do you think?

@ashay ashay requested a review from silvasean February 2, 2023 06:51
@silvasean
Copy link
Contributor

silvasean commented Feb 2, 2023

Thanks Eugene, this makes sense to me! Looking forward to collaborating to tighten this up!

Copy link
Contributor

@silvasean silvasean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but let's wait for one of the owners of StableHLO support in Torch-MLIR approve as well.

Copy link
Collaborator

@tanyokwok tanyokwok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ashay ashay merged commit 711646d into llvm:main Feb 2, 2023
@ashay ashay deleted the ashay/migrate-to-stablehlo branch February 2, 2023 13:29
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request May 10, 2023
This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants