Skip to content

Commit

Permalink
Torch: Fold RuntimeAssertOp when condition is true (llvm#2198)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd authored and gpetters94 committed Jul 7, 2023
1 parent 76e1a90 commit f82a7d4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
"TypeAsSameModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmDynamicModule_basic",
"BaddbmmStaticModule_basic",
Expand Down Expand Up @@ -1065,6 +1066,7 @@
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"PrimsSqueezeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"MoveDimIntModule_basic",
"MoveDimIntNegativeIndexModule_basic",
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
let results = (outs
);
let assemblyFormat = "$condition `,` $message attr-dict";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,27 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// RuntimeAssertOp
//===----------------------------------------------------------------------===//

void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](RuntimeAssertOp op, PatternRewriter &rewriter) {
bool value;
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&value)))
return failure();

if (value) {
rewriter.eraseOp(op);
return success();
}
// Even if the condition is statically false, the assert might never be
// executed.
return failure();
});
}

//===----------------------------------------------------------------------===//
// DerefineOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.i
return %0, %1, %2, %3 : !torch.int, !torch.int, !torch.int, !torch.int
}

// CHECK-LABEL: func.func @torch.runtime.assert
// CHECK-NEXT: return
func.func @torch.runtime.assert() {
%true = torch.constant.bool true
torch.runtime.assert %true, "msg"
return
}

// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
Expand Down

0 comments on commit f82a7d4

Please sign in to comment.