-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Support Scalar in StableHLO #43
Comments
Hi @yaochengji! I've looked into what dialects JAX, ONNX-MLIR, Torch-MLIR and TensorFlow are using alongside with MHLO at the moment of writing:
Overall, in addition to Given that it's not just |
Hi @burmako , Basically I'm looking for conceptual minimalism. In our compiler stack, there are several backends. Some backends only use For In summary currently we find some models have |
I think that you might be interested in the latest revision of the compatibility RFC. More specifically, the RFC proposes the following: "For the purposes of compatibility guarantees, we define StableHLO programs as programs that only include ops, attributes, and types from the StableHLO opset. There will be several additions to the StableHLO opset to accomplish this goal of providing a self-contained opset that satisfies current use cases". Furthermore, it continues with: "As a stop-gap provision, we propose adding Can you take a look at the full list of the ops in the RFC and comment on whether we missed any? Overall, will this RFC address your feature request? |
Based on the current production model we have, the additional scalar opset in the RFC is enough. And I'm curious about how the opset is chosen? There're more ops in mhlo dialect that can support scalar operand: sin, cos, exp, etc. Why are they not chosen? |
"And I'm curious about how the opset is chosen?". The proposed scalar ops mentioned above all come from the use case of representing dynamically-shaped programs in StableHLO. Basically, we want all the shape computation logic to be representable in StableHLO. "There're more ops in mhlo dialect that can support scalar operand: sin, cos, exp, etc. Why are they not chosen?". Can you elaborate on the use case for this? So far, as we've been discussing the opset with various stakeholders, shape computations is the only area where scalar computations have come up (by "scalar" here I mean values like Also @GleasonK for visibility, since Kevin is driving the compatibility RFC, and the idea of a self-contained StableHLO opset is central to it. |
I wrongly thought that tf dialect could support scalar type for regular operation such as sin, exp. After a simple experiment just now I found that scalar type is changed to tensor type during translating to tf dialect from tf graph format.
Yes, it's also the only situation we've ever encountered, typically during lowering frontend dialect (like tf dialect) to mhlo dialect. The current additional opset should be enough for us. Thanks @burmako and @GleasonK . |
Request description
There could be some scalar computation in deep learning model.
As I know, currently tensorflow and onnx will create operations of arithmetic dialect for scalar computation when converting to mhlo dialect.
It will be more systematic if stablehlo could support scalar. After this, there should be only stablehlo and shape dialect after converting frontend dialect to stablehlo.
Additional context
No response
The text was updated successfully, but these errors were encountered: