Skip to content

Commit

Permalink
Split verifiers and shape functions for ops with regions (#401)
Browse files Browse the repository at this point in the history
This is a pure refactor PR which fixes
#400 (deprecated) and
#623 (new)

As in the #623 , "we only
need this for ops with regions", a full list of ops with region extract
from
https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td
includes **11 ops** in total:

| Op | what's done |
| --- | --- |
| AllReduceOp | No change (already split) |
| CaseOp |  No change (region is indispensable for type inference)  |
| IfOp |  No change (region is indispensable for type inference)  |
| MapOp |  No change (region is indispensable for type inference)  |
| ReduceOp |  Split |
| ReduceScatterOp | No change (Type Inference implementation on hold see
#725) |
| ReduceWindowOp |  Split |
| ScatterOp | No change (already split) |
| SelectAndScatterOp |  No change (already split) |
| SortOp |  Split |
| WhileOp |  Split |

The ideal split is that verifiers contain as almost all verifications,
and the shape functions are simple as possible, but note:
1. `IfOp/CaseOp/MapOp`: We need info from region(s) to infer the return
type, so an init function without regions is always invalid and should
not exist. No change for them in this PR.
2. As both verifier & shape function need verification of the
inputs/attrs, so we need put them in a separate utils functions.
`ReduceOp`: introduce new util `verifyReduceOpInputsAndInferShape()`
`ReduceWindow`: introduce new util
`verifyReduceWindowOpInputsAndInferWindow()`
In each op, verifier does (1) call this new util function (2) verify
region
shape function: (1) call this new util function (2) generate inferred
type from the intermediate result from (1)
3. Besides, the verification logic needs further fix see
#394, but this is out of
scope of this PR.
  • Loading branch information
Xin Zhou authored Dec 10, 2022
1 parent 63813b3 commit 6ee83d2
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 202 deletions.
31 changes: 25 additions & 6 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2945,7 +2945,14 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents(
location, adaptor.getInputs(), adaptor.getInitValues(),
adaptor.getWindowDimensions(), adaptor.getWindowStrides(),
adaptor.getBaseDilations(), adaptor.getWindowDilations(),
adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes);
adaptor.getPadding(), inferredReturnShapes);
}

LogicalResult ReduceWindowOp::verify() {
return hlo::verifyReduceWindowOp(getLoc(), getInputs(), getInitValues(),
getWindowDimensions(), getWindowStrides(),
getBaseDilations(), getWindowDilations(),
getPadding(), getBody());
}

// Get the operation used for reduction applied to `result_index`th result. Its
Expand Down Expand Up @@ -3305,7 +3312,12 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
ReduceOp::Adaptor adaptor(operands, attributes, regions);
return hlo::inferReduceOp(location, adaptor.getInputs(),
adaptor.getInitValues(), adaptor.getDimensions(),
adaptor.getBody(), inferredReturnShapes);
inferredReturnShapes);
}

LogicalResult ReduceOp::verify() {
return hlo::verifyReduceOp(getLoc(), getInputs(), getInitValues(),
getDimensions(), getBody());
}

LogicalResult ReduceOp::reifyReturnTypeShapes(
Expand Down Expand Up @@ -3807,8 +3819,12 @@ LogicalResult SortOp::inferReturnTypeComponents(
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SortOp::Adaptor adaptor(operands, attributes, regions);
return hlo::inferSortOp(location, adaptor.getInputs(), adaptor.getDimension(),
adaptor.getComparator(), inferredReturnShapes);
return hlo::inferSortOp(location, adaptor.getInputs(), inferredReturnShapes);
}

LogicalResult SortOp::verify() {
return hlo::verifySortOp(getLoc(), getInputs(), getDimension(),
getComparator());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4377,8 +4393,11 @@ LogicalResult WhileOp::inferReturnTypes(
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
WhileOp::Adaptor adaptor(operands, attributes, regions);
return hlo::inferWhileOp(location, adaptor.getOperand(), adaptor.getCond(),
adaptor.getBody(), inferredReturnTypes);
return hlo::inferWhileOp(location, adaptor.getOperand(), inferredReturnTypes);
}

LogicalResult WhileOp::verify() {
return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody());
}

/// Print a `while` op.
Expand Down
8 changes: 8 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,8 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [

let results = (outs Variadic<HLO_TensorOrToken>);

let hasVerifier = 1;

let extraClassDeclaration = [{
// Method of OpAsmOpInterface used during custom printing to name the block
// arguments in the nested regions. We name both the condition and the body
Expand Down Expand Up @@ -1452,6 +1454,8 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [

let hasCustomAssemblyFormat = 1;

let hasVerifier = 1;

// TODO(hinsu): Verify that the attached body arguments and results are
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body);
Expand Down Expand Up @@ -2620,6 +2624,8 @@ def StableHLO_SortOp : StableHLO_Op<"sort",
let builders = [
OpBuilder<(ins "ValueRange":$inputs, CArg<"int64_t", "-1">:$dimension,
CArg<"bool", "false">:$is_stable)>];

let hasVerifier = 1;
}

def StableHLO_ReverseOp: StableHLO_Op<"reverse",
Expand Down Expand Up @@ -2810,6 +2816,8 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [

let regions = (region SizedRegion<1>:$body);

let hasVerifier = 1;

// Builder for non-variadic version of the operation.
let builders = [
OpBuilder<(ins "Type":$result_type, "Value":$operand,
Expand Down
Loading

0 comments on commit 6ee83d2

Please sign in to comment.