Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support Scalar in StableHLO #43

Closed
yaochengji opened this issue Aug 25, 2022 · 6 comments
Closed

[Feature Request] Support Scalar in StableHLO #43

yaochengji opened this issue Aug 25, 2022 · 6 comments
Assignees
Labels

Comments

@yaochengji
Copy link
Contributor

Request description

There could be some scalar computation in deep learning model.

As I know, currently tensorflow and onnx will create operations of arithmetic dialect for scalar computation when converting to mhlo dialect.

It will be more systematic if stablehlo could support scalar. After this, there should be only stablehlo and shape dialect after converting frontend dialect to stablehlo.

Additional context

No response

@burmako
Copy link
Contributor

burmako commented Sep 26, 2022

Hi @yaochengji! I've looked into what dialects JAX, ONNX-MLIR, Torch-MLIR and TensorFlow are using alongside with MHLO at the moment of writing:

  • JAX: chlo, func, sparse_tensor (ml_program is imported but is not currently used).
  • ONNX-MLIR: func, shape.
  • TensorFlow: arith, chlo, func, shape.
  • Torch-MLIR: arith, func, chlo, tensor, torch_c.

Overall, in addition to func and shape, there are arith, chlo, sparse_tensor and tensor dialects which are used alongside mhlo. (Depending on how you count, we may decide to ignore chlo, because it can be lowered to mhlo and shape).

Given that it's not just arith but also a few other dialects, can you elaborate a bit more on the motivation for your feature request? Are you looking for conceptual minimalism (i.e. being able to say "our frontend/backend contract is just X and Y dialects" and feel good about that)? Or for comprehensive documentation for the entire surface of the contract? Or maybe compatibility guarantees for the entirety of the contract? Something else?

@yaochengji
Copy link
Contributor Author

Hi @burmako ,

Basically I'm looking for conceptual minimalism. In our compiler stack, there are several backends. Some backends only use mhlo as an input IR and then it will be directly translated out to a non-MLIR format. Therefore we need to implement several translation of mhlo (maybe along with other dialects) to other backends' IR.

For shape dialect, currently the most use case is static shape model, so there should be no shape dialect.
For tensor dialect, considering the business models ( more than 1000) we've met currently, we don't find any of them needs tensor.
For func dialect, it's fine because it only has a few ops.

In summary currently we find some models have mhlo along with only arith. It will be much more concise if mhlo could also represent scalar computation so we don't need arith at this level.

@burmako
Copy link
Contributor

burmako commented Nov 15, 2022

I think that you might be interested in the latest revision of the compatibility RFC.

More specifically, the RFC proposes the following: "For the purposes of compatibility guarantees, we define StableHLO programs as programs that only include ops, attributes, and types from the StableHLO opset. There will be several additions to the StableHLO opset to accomplish this goal of providing a self-contained opset that satisfies current use cases".

Furthermore, it continues with: "As a stop-gap provision, we propose adding AddIOp, CmpIOp, ConstantIOp, DivSIOp, ExtSIOp, IndexCastOp, MaxSIOp, MinSIOp, MulIOp, SelectOp, SubIOp, TruncIOp, CastOp, DimOp, FromElementsOp to the StableHLO opset to accommodate how MHLO programs look like today. These ops may be deprecated by the Dynamism RFC".

Can you take a look at the full list of the ops in the RFC and comment on whether we missed any? Overall, will this RFC address your feature request?

@yaochengji
Copy link
Contributor Author

Based on the current production model we have, the additional scalar opset in the RFC is enough.

And I'm curious about how the opset is chosen? There're more ops in mhlo dialect that can support scalar operand: sin, cos, exp, etc. Why are they not chosen?

@burmako
Copy link
Contributor

burmako commented Nov 15, 2022

"And I'm curious about how the opset is chosen?". The proposed scalar ops mentioned above all come from the use case of representing dynamically-shaped programs in StableHLO. Basically, we want all the shape computation logic to be representable in StableHLO.

"There're more ops in mhlo dialect that can support scalar operand: sin, cos, exp, etc. Why are they not chosen?". Can you elaborate on the use case for this? So far, as we've been discussing the opset with various stakeholders, shape computations is the only area where scalar computations have come up (by "scalar" here I mean values like f32 not tensor<f32> - the latter is already supported in StableHLO). If we missed something, we'd love to learn more!

Also @GleasonK for visibility, since Kevin is driving the compatibility RFC, and the idea of a self-contained StableHLO opset is central to it.

@yaochengji
Copy link
Contributor Author

yaochengji commented Nov 15, 2022

I wrongly thought that tf dialect could support scalar type for regular operation such as sin, exp. After a simple experiment just now I found that scalar type is changed to tensor type during translating to tf dialect from tf graph format.

shape computations is the only area where scalar computations have come up

Yes, it's also the only situation we've ever encountered, typically during lowering frontend dialect (like tf dialect) to mhlo dialect. The current additional opset should be enough for us. Thanks @burmako and @GleasonK .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants