Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verification of ReduceOp and ReduceWindowOp is not strict enough. #394

Open
sdasgup3 opened this issue Oct 26, 2022 · 4 comments
Open

Verification of ReduceOp and ReduceWindowOp is not strict enough. #394

sdasgup3 opened this issue Oct 26, 2022 · 4 comments

Comments

@sdasgup3
Copy link
Member

Request description

func.func @variadic_reduce_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x4xi32>, 
    %arg2: tensor<?x5xi32>, %arg3: tensor<f32>,
    %arg4: tensor<i32>, %arg5: tensor<i32>
    ) -> (tensor<f32>, tensor<i32>, tensor<i32>) {
  %0:3 = mhlo.reduce(%arg0 init: %arg3),
                    (%arg1 init: %arg4),
                    (%arg2 init: %arg5) across dimensions = [0, 1] {someattr}
   : (tensor<?x?xf32>, tensor<?x4xi32>, tensor<?x5xi32>,
      tensor<f32>, tensor<i32>, tensor<i32>
   ) -> (tensor<f32>, tensor<i32>, tensor<i32>)
     reducer(%arg6: tensor<f32>, %arg9: tensor<f32>) 
            (%arg7: tensor<i32>, %arg10: tensor<i32>)
            (%arg8: tensor<i32>, %arg11: tensor<i32>)  {
      %1 = mhlo.add %arg6, %arg9 : tensor<f32>
      %2 = mhlo.add %arg7, %arg10 : tensor<i32>
      %3 = mhlo.add %arg8, %arg11 : tensor<i32>
      mhlo.return %1, %2, %3 : tensor<f32>, tensor<i32>, tensor<i32>
    }
    return %0#0, %0#1, %0#2 : tensor<f32>, tensor<i32>, tensor<i32>
  }

https://source.corp.google.com/piper///depot/google3/third_party/stablehlo/stablehlo/dialect/TypeInference.cpp;rcl=482658831;l=914

input shapes tensor<?x4xi32>, tensor<?x5xi32> are incompatible, but verifier allows it, because it checks the compatibility with the first argument that has type tensor<?x?xf32>.

I think it should infer the "most static" shape first and then try to check whether the arguments are compatible with it. It would be nice to enforce arguments to have the same shape by inserting casts to this inferred type.

@sdasgup3
Copy link
Member Author

sdasgup3 commented Oct 26, 2022

It would be nice to enforce arguments to have the same shape by inserting casts to this inferred type.

While re-thinking on that:
StableHLO framework, by itself, will not be adding any extra casts to enforce those constraints. However, "inserting casts" seems like a nice alternative to adding verification checks.

@ghpvnist
Copy link
Member

There is a similar issue with ReduceWindowOp

// P2.
if (!allInputsUnranked) {
for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
inputArgTypes[inputIdx]))) {
return emitOptionalError(
location, "expects all inputs to have compatible shapes. Shape at",
" input-index ", inputIdx,
" is not compatible with shape at input-index ", rankedInputIdx);
}
}
}

@ghpvnist ghpvnist changed the title Verification of ReduceOp is not strict enough. Verification of ReduceOp and ReduceWindowOp is not strict enough. Oct 26, 2022
@burmako burmako added the Migrate to MHLO PR that needs to be migrated to MLIR-HLO label Oct 28, 2022
@burmako burmako added Dynamism and removed Migrate to MHLO PR that needs to be migrated to MLIR-HLO labels Nov 8, 2022
zhouxin913 pushed a commit that referenced this issue Dec 10, 2022
This is a pure refactor PR which fixes
#400 (deprecated) and
#623 (new)

As in the #623 , "we only
need this for ops with regions", a full list of ops with region extract
from
https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td
includes **11 ops** in total:

| Op | what's done |
| --- | --- |
| AllReduceOp | No change (already split) |
| CaseOp |  No change (region is indispensable for type inference)  |
| IfOp |  No change (region is indispensable for type inference)  |
| MapOp |  No change (region is indispensable for type inference)  |
| ReduceOp |  Split |
| ReduceScatterOp | No change (Type Inference implementation on hold see
#725) |
| ReduceWindowOp |  Split |
| ScatterOp | No change (already split) |
| SelectAndScatterOp |  No change (already split) |
| SortOp |  Split |
| WhileOp |  Split |

The ideal split is that verifiers contain as almost all verifications,
and the shape functions are simple as possible, but note:
1. `IfOp/CaseOp/MapOp`: We need info from region(s) to infer the return
type, so an init function without regions is always invalid and should
not exist. No change for them in this PR.
2. As both verifier & shape function need verification of the
inputs/attrs, so we need put them in a separate utils functions.
`ReduceOp`: introduce new util `verifyReduceOpInputsAndInferShape()`
`ReduceWindow`: introduce new util
`verifyReduceWindowOpInputsAndInferWindow()`
In each op, verifier does (1) call this new util function (2) verify
region
shape function: (1) call this new util function (2) generate inferred
type from the intermediate result from (1)
3. Besides, the verification logic needs further fix see
#394, but this is out of
scope of this PR.
@zhouxin913
Copy link
Contributor

This problem become a little more complicate after we add bounds support in #737: we also need verify bounds are compatible with each other: a.bounds >= b.dim, final bound = min(every bound), etc.f

GleasonK pushed a commit to GleasonK/stablehlo that referenced this issue Dec 13, 2022
This is a pure refactor PR which fixes
openxla#400 (deprecated) and
openxla#623 (new)

As in the openxla#623 , "we only
need this for ops with regions", a full list of ops with region extract
from
https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/StablehloOps.td
includes **11 ops** in total:

| Op | what's done |
| --- | --- |
| AllReduceOp | No change (already split) |
| CaseOp |  No change (region is indispensable for type inference)  |
| IfOp |  No change (region is indispensable for type inference)  |
| MapOp |  No change (region is indispensable for type inference)  |
| ReduceOp |  Split |
| ReduceScatterOp | No change (Type Inference implementation on hold see
openxla#725) |
| ReduceWindowOp |  Split |
| ScatterOp | No change (already split) |
| SelectAndScatterOp |  No change (already split) |
| SortOp |  Split |
| WhileOp |  Split |

The ideal split is that verifiers contain as almost all verifications,
and the shape functions are simple as possible, but note:
1. `IfOp/CaseOp/MapOp`: We need info from region(s) to infer the return
type, so an init function without regions is always invalid and should
not exist. No change for them in this PR.
2. As both verifier & shape function need verification of the
inputs/attrs, so we need put them in a separate utils functions.
`ReduceOp`: introduce new util `verifyReduceOpInputsAndInferShape()`
`ReduceWindow`: introduce new util
`verifyReduceWindowOpInputsAndInferWindow()`
In each op, verifier does (1) call this new util function (2) verify
region
shape function: (1) call this new util function (2) generate inferred
type from the intermediate result from (1)
3. Besides, the verification logic needs further fix see
openxla#394, but this is out of
scope of this PR.
@sdasgup3
Copy link
Member Author

sdasgup3 commented Dec 22, 2022

Adding more ops which involves matching shapes and need to deal with the same issue. As previously discussed it would make sense to collect the size and bound for each dimension individually and use that for verification/type inference.

  • ConcatenateOp: shape of operands must match except for the concatenate dimension.
  • IfOp/CaseOp: The output shape of the regions should match.
  • MapOp: All inputs to have the same shape.
  • ScatterOp: All inputs have the same shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

No branches or pull requests

5 participants