Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split verifiers and shape functions for ops with regions #623

Closed
burmako opened this issue Nov 25, 2022 · 1 comment · Fixed by #401
Closed

Split verifiers and shape functions for ops with regions #623

burmako opened this issue Nov 25, 2022 · 1 comment · Fixed by #401

Comments

@burmako
Copy link
Contributor

burmako commented Nov 25, 2022

In Q3, when we started to systematically implement type inference for StableHLO/MHLO ops, we decided to merge verifiers and shape functions because that significantly simplified the resulting code. Oftentimes, verifiers and shape functions need to compute exactly the same intermediate values (including the inferred type which verifiers would want to compare the actual type against), and sharing the code reduced duplication considerably.

Things were going well until one day we ported some of the new shape functions from StableHLO to MHLO, and some of the MHLO tests started failing. Here's one of the passes mhlo_canonicalize_reduction.cc which was crashing:

  auto newOp =
      b.create<ReduceOp>(loc, newOperands, op.getInitValues(), attr);
  newOp.getBody().takeBody(op.getBody());

Previously, this code was using a hand-written result type-less builder for ReduceOp. When porting shape functions to MHLO, we removed this builder, so the code started using the autogenerated builder (incidentally, TableGen used to not generate result type-less builders for ops with regions, and that was changed right when we were making our change: https://reviews.llvm.org/D136232 which made things hilariously confusing):

void ReduceOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange inputs, ::mlir::ValueRange init_values, ::mlir::DenseIntElementsAttr dimensions) {
  odsState.addOperands(inputs);
  odsState.addOperands(init_values);
  odsState.addAttribute(getDimensionsAttrName(odsState.name), dimensions);
  (void)odsState.addRegion();

  ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
  if (::mlir::succeeded(ReduceOp::inferReturnTypes(odsBuilder.getContext(),
                odsState.location, odsState.operands,
                odsState.attributes.getDictionary(odsState.getContext()),
                odsState.regions, inferredReturnTypes)))
    odsState.addTypes(inferredReturnTypes);
  else
    ::llvm::report_fatal_error("Failed to infer result type(s).");
}

When the producer code and the generated builder are put next to each other, the issue becomes clear (although it wasn't evident at all until we encountered it!). The original call to the builder doesn't include the region, so odsState.regions passed to the shape function contains just an empty odsState.addRegion(), which makes the verifier crash.

Fixing this crash to gracefully fail the verifier for empty regions is easy. However, this won't help the builder at all - it will still fail, whereas we want it to succeed, i.e. we want to be able to create ReduceOps with empty bodies without explicitly providing the result type. In order to enable that, we need to split the verifier and the shape function for ReduceOp (#400).

Originally, we thought that we'll need to split all verifiers and shape functions, but then I realized that we only need this for ops with regions. Other ops are created with all their operands and attributes provided right away, so running the verifier at that point shouldn't be a problem (it might've been a problem if we had ops that need multiple steps of initialization, but we don't have such ops except for ops with regions).

@zhouxin913
Copy link
Contributor

zhouxin913 commented Dec 7, 2022

As in the description, "we only need this for ops with regions", a full list of ops with region extract from https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td is includes 11 ops in total:

AllReduceOp 
CaseOp
IfOp
MapOp
ReduceOp
ReduceScatterOp
ReduceWindowOp
ScatterOp
SelectAndScatterOp
SortOp
WhileOp

@zhouxin913 zhouxin913 linked a pull request Dec 8, 2022 that will close this issue
zhouxin913 pushed a commit that referenced this issue Dec 10, 2022
This is a pure refactor PR which fixes
#400 (deprecated) and
#623 (new)

As in the #623 , "we only
need this for ops with regions", a full list of ops with region extract
from
https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td
includes **11 ops** in total:

| Op | what's done |
| --- | --- |
| AllReduceOp | No change (already split) |
| CaseOp |  No change (region is indispensable for type inference)  |
| IfOp |  No change (region is indispensable for type inference)  |
| MapOp |  No change (region is indispensable for type inference)  |
| ReduceOp |  Split |
| ReduceScatterOp | No change (Type Inference implementation on hold see
#725) |
| ReduceWindowOp |  Split |
| ScatterOp | No change (already split) |
| SelectAndScatterOp |  No change (already split) |
| SortOp |  Split |
| WhileOp |  Split |

The ideal split is that verifiers contain as almost all verifications,
and the shape functions are simple as possible, but note:
1. `IfOp/CaseOp/MapOp`: We need info from region(s) to infer the return
type, so an init function without regions is always invalid and should
not exist. No change for them in this PR.
2. As both verifier & shape function need verification of the
inputs/attrs, so we need put them in a separate utils functions.
`ReduceOp`: introduce new util `verifyReduceOpInputsAndInferShape()`
`ReduceWindow`: introduce new util
`verifyReduceWindowOpInputsAndInferWindow()`
In each op, verifier does (1) call this new util function (2) verify
region
shape function: (1) call this new util function (2) generate inferred
type from the intermediate result from (1)
3. Besides, the verification logic needs further fix see
#394, but this is out of
scope of this PR.
GleasonK pushed a commit to GleasonK/stablehlo that referenced this issue Dec 13, 2022
This is a pure refactor PR which fixes
openxla#400 (deprecated) and
openxla#623 (new)

As in the openxla#623 , "we only
need this for ops with regions", a full list of ops with region extract
from
https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td
includes **11 ops** in total:

| Op | what's done |
| --- | --- |
| AllReduceOp | No change (already split) |
| CaseOp |  No change (region is indispensable for type inference)  |
| IfOp |  No change (region is indispensable for type inference)  |
| MapOp |  No change (region is indispensable for type inference)  |
| ReduceOp |  Split |
| ReduceScatterOp | No change (Type Inference implementation on hold see
openxla#725) |
| ReduceWindowOp |  Split |
| ScatterOp | No change (already split) |
| SelectAndScatterOp |  No change (already split) |
| SortOp |  Split |
| WhileOp |  Split |

The ideal split is that verifiers contain as almost all verifications,
and the shape functions are simple as possible, but note:
1. `IfOp/CaseOp/MapOp`: We need info from region(s) to infer the return
type, so an init function without regions is always invalid and should
not exist. No change for them in this PR.
2. As both verifier & shape function need verification of the
inputs/attrs, so we need put them in a separate utils functions.
`ReduceOp`: introduce new util `verifyReduceOpInputsAndInferShape()`
`ReduceWindow`: introduce new util
`verifyReduceWindowOpInputsAndInferWindow()`
In each op, verifier does (1) call this new util function (2) verify
region
shape function: (1) call this new util function (2) generate inferred
type from the intermediate result from (1)
3. Besides, the verification logic needs further fix see
openxla#394, but this is out of
scope of this PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants