Skip to content

Commit

Permalink
Fuse op when the shape is not static (onnx#2577)
Browse files Browse the repository at this point in the history
* transformation

Signed-off-by: chentong319 <[email protected]>

* condition

Signed-off-by: chentong319 <[email protected]>

* test

Signed-off-by: chentong319 <[email protected]>

* format

Signed-off-by: chentong319 <[email protected]>

* more test

Signed-off-by: chentong319 <[email protected]>

* fix

Signed-off-by: chentong319 <[email protected]>

---------

Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 authored Oct 30, 2023
1 parent 0f021bc commit faa42c8
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 62 deletions.
139 changes: 77 additions & 62 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,16 +1580,18 @@ typedef mlir::Value (*EmitScalarFunc)(mlir::ConversionPatternRewriter &rewriter,
class OpFusionHelper {
public:
// Constructor
OpFusionHelper(
mlir::ConversionPatternRewriter &rewriter, mlir::Operation *rootOp)
: rootOp(rootOp), rewriter(rewriter), fusibleOps(), fuseEmitFuctions() {}
OpFusionHelper(mlir::ConversionPatternRewriter &rewriter,
mlir::Operation *rootOp, DimAnalysis *dimAnalysis)
: rootOp(rootOp), rewriter(rewriter), dimAnalysis(dimAnalysis),
fusibleOps(), fuseEmitFuctions() {}

// Fusion should not break any control dependence
static bool isControlFlowValidForFusion(Operation *useOp, Operation *defOp);

// Check whether the inputs of the useOp are valid for useOp to be fused
// with the defOp. The defOp defines one of useOp's inputs.
static bool areInputsValidForFusion(Operation *useOp, Operation *defOp);
static bool areInputsValidForFusion(
Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis);

// Check whether the op is fusible along the use-def chain from the defOp.
// If true, record the op and its scalar op.
Expand All @@ -1607,13 +1609,18 @@ class OpFusionHelper {
// For example, comparison ops, and cast Op.
MemRefType getOutputType(MemRefType outputType);

Value emitFuseOps(Value producerResult, ValueRange loopInd = {});
// Generate the code for the ops to be fused
// procedureResult is the scalar value from producer
// alloc is used to get the tensor for the producer, which is required by
// by the shape helper.
Value emitFuseOps(Value producerResult, Value alloc, ValueRange loopInd = {});

void replaceOrEraseONNXOps(Value alloc);

private:
mlir::Operation *rootOp;
mlir::ConversionPatternRewriter &rewriter;
DimAnalysis *dimAnalysis;
llvm::SmallVector<mlir::Operation *, 2> fusibleOps;
llvm::SmallVector<EmitScalarFunc, 2> fuseEmitFuctions;
}; // End of OpFusionHelper Declaration
Expand All @@ -1623,10 +1630,11 @@ class OpFusionHelper {
template <typename T>
bool enqueueFusibleOpImpl(Operation *useOp, Operation *defOp,
SmallVector<Operation *, 2> &fusibleOps,
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions) {
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions,
DimAnalysis *dimAnalysis) {
if (isa<T>(useOp)) {
if (OpFusionHelper::isControlFlowValidForFusion(useOp, defOp) &&
OpFusionHelper::areInputsValidForFusion(useOp, defOp)) {
OpFusionHelper::areInputsValidForFusion(useOp, defOp, dimAnalysis)) {
fusibleOps.emplace_back(useOp);
fuseEmitFunctions.emplace_back(emitScalarOpFor<T>);
return true;
Expand All @@ -1639,23 +1647,26 @@ bool enqueueFusibleOpImpl(Operation *useOp, Operation *defOp,
template <typename T = void, class... Ts>
bool enqueueFusibleOp(Operation *useOp, Operation *defOp,
SmallVector<Operation *, 2> &fusibleOps,
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions);
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions,
DimAnalysis *dimAnalysis);

template <typename T, class... Ts>
bool enqueueFusibleOp(Operation *useOp, Operation *defOp,
SmallVector<Operation *, 2> &fusibleOps,
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions) {
if (enqueueFusibleOpImpl<T>(useOp, defOp, fusibleOps, fuseEmitFunctions)) {
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions,
DimAnalysis *dimAnalysis) {
if (enqueueFusibleOpImpl<T>(
useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis))
return true;
} else {
return enqueueFusibleOp<Ts...>(useOp, defOp, fusibleOps, fuseEmitFunctions);
}
return enqueueFusibleOp<Ts...>(
useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis);
}

template <>
bool enqueueFusibleOp(Operation *useOp, Operation *defOp,
SmallVector<Operation *, 2> &fusibleOps,
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions) {
SmallVector<EmitScalarFunc, 2> &fuseEmitFunctions,
DimAnalysis *dimAnalysis) {
return false;
}

Expand Down Expand Up @@ -1689,7 +1700,7 @@ bool OpFusionHelper::checkFusibleOp(Operation *useOp, Operation *defOp,
mlir::ONNXAddOp, mlir::ONNXAndOp, mlir::ONNXDivOp, mlir::ONNXMaxOp,
mlir::ONNXMeanOp, mlir::ONNXMinOp, mlir::ONNXMulOp, mlir::ONNXOrOp,
mlir::ONNXSubOp, mlir::ONNXSumOp, mlir::ONNXXorOp>(
useOp, defOp, fusibleOps, fuseEmitFunctions);
useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis);
}

// Only operations are in the same block are allowed to fuse.
Expand Down Expand Up @@ -1735,29 +1746,20 @@ bool OpFusionHelper::isControlFlowValidForFusion(
// assumed the canonicalization has hoisted all constant to the beginning of the
// function by fold function.
bool OpFusionHelper::areInputsValidForFusion(
Operation *useOp, Operation *defOp) {
Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis) {
// Elementwise unary operation is always fusible
if (useOp->getOperands().size() == 1)
return true;

// To fuse Elementwise op with more one operands with the producer,
// the shape of the output the user Op has to have the same size
// output as that of the producer Op. Here dimension expansion with size
// 1 is allowed. Refer to hasNoBroadcast() definition.
// ToFix: This PR simply check static shape and does not use symbolic
// shape analysis and BroadcastShapeHelper
// Some discussion can be found at
// https://github.com/onnx/onnx-mlir/issues/2199

if (!hasStaticShape(defOp->getResults()[0].getType()))
return false;

ArrayRef<int64_t> defShape = getShape(defOp->getResults()[0].getType());
ArrayRef<int64_t> useShape = getShape(useOp->getResults()[0].getType());
Type defOutputType = defOp->getResultTypes()[0];
Type useOutputType = useOp->getResultTypes()[0];
ArrayRef<int64_t> defShape = getShape(defOutputType);
ArrayRef<int64_t> useShape = getShape(useOutputType);
if (defShape != useShape) {
return false;
}

// Check the inputs in the useOp
for (size_t i = 0; i < useOp->getOperands().size(); i++) {
// Only input from block argument and constant is allowed,
// if the input does not come from the defining Op
Expand All @@ -1769,12 +1771,29 @@ bool OpFusionHelper::areInputsValidForFusion(
return false;
}
}
}

// Check whether this shape of the defOp is the same as the shape of
// the output of use op. If true, the iteration space from the defOp is
// sufficient for the element-wise operation for the useOp, even if
// MDBroadcast occurs in the useOp.
// Otherwise, the loop nest should be defined according to the tensor with
// larger space.

// First check the rank
if (getRank(defOutputType) != getRank(useOutputType))
return false;

// ToFix: This restriction can be relaxed if ShapeHelper utility is used
// to generate load in future.
if (!hasStaticShape(useOp->getOperand(i).getType()))
if (dimAnalysis) {
if (!dimAnalysis->sameShape(defOp->getResult(0), useOp->getResult(0)))
return false;
ArrayRef<int64_t> inputShape = getShape(useOp->getOperand(i).getType());
} else {
// If there is no dimAnalysis, check the simplest case.
// Static and the same shape
if (!hasStaticShape(useOutputType))
return false;

ArrayRef<int64_t> inputShape = getShape(useOutputType);
if (inputShape != defShape)
return false;
}
Expand Down Expand Up @@ -1819,7 +1838,8 @@ MemRefType OpFusionHelper::getOutputType(MemRefType outputType) {
}

// Emit fusion Ops
Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) {
Value OpFusionHelper::emitFuseOps(
Value defOpResult, Value alloc, ValueRange loopInd) {
if (isFusibleListEmpty())
return defOpResult;

Expand All @@ -1835,33 +1855,30 @@ Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) {
MDBuilder create(rewriter, loc);
Type currentElementType = getElementType(useOp->getResults()[0].getType());

// Prepare Values for EmitScalarOpFor<T>
SmallVector<Value, 2> inputValues;
// ToFix: expect to use new utility for this purpose
// There is an issue to fix: cannot getRemappedValue for the Value that is
// currently handling: the defOp.
// Otherwise, runtime error: "null operand found" caused by
// just calling the function without using the result!
#if 0
// useOperands is used for ShapeHelper and load op.
// getRemappedValue is needed for load op.
SmallVector<Value, 4> useOperands;
for (auto oper : useOp->getOperands()) {
if (oper.getDefiningOp() != defOp)
useOperands.emplace_back(rewriter.getRemappedValue(oper));
else
// load will not needed because of useOpResult.
// This value is only needed by shape helper.
useOperands.emplace_back(alloc);
}
LogicalResult res =
rewriter.getRemappedValues(useOp->getOperands(), useOperands);
assert(succeeded(res) && "Could not remap value for rewriter");
// Use shape helper to generate load index
ONNXBroadcastOpShapeHelper shapeHelper(
useOp, useOperands, &create.krnlIE, nullptr, false);
#endif
shapeHelper.computeShapeAndAssertOnFailure();

// Prepare Values for EmitScalarOpFor<T>
SmallVector<Value, 2> inputValues;
for (size_t i = 0; i < useOp->getOperands().size(); i++) {
Value inputValue = useOp->getOperand(i);
Operation *inputOp = inputValue.getDefiningOp();
if (inputOp == defOp) {
inputValues.emplace_back(defOpResult);
} else {
// ToFix: expect to use new utility to handle any broadcast cases
#if 0
IndexExprScope innerScope(create.krnl, shapeHelper.getScope());
SmallVector<IndexExpr, 4> outputAccessExprs;
getIndexExprList<DimIndexExpr>(loopInd, outputAccessExprs);
Expand All @@ -1870,16 +1887,13 @@ Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) {
inputValue, i, outputAccessExprs, loadAccessExprs, true);
assert(succeeded(res) && "Could not compute access indices");
Value load = create.krnl.loadIE(useOperands[i], loadAccessExprs);
#endif
// The shape is guaranteed to be the same.
Value load =
create.krnl.load(rewriter.getRemappedValue(inputValue), loopInd);
inputValues.emplace_back(load);
}
}
defOpResult =
emitScalar(rewriter, loc, useOp, currentElementType, inputValues);
defOp = useOp;
alloc = defOp->getResult(0);
}
return defOpResult;
}
Expand Down Expand Up @@ -1997,7 +2011,7 @@ struct ONNXElementwiseUnaryOpLowering
LLVM_DEBUG(llvm::dbgs() << " scalar execution\n");

// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op);
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

Expand Down Expand Up @@ -2034,7 +2048,7 @@ struct ONNXElementwiseUnaryOpLowering
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
rewriter, loc, op, elementType, args);
loweredOpResult =
opFusionHelper.emitFuseOps(loweredOpResult, loopInd);
opFusionHelper.emitFuseOps(loweredOpResult, alloc, loopInd);
// Store result in the resulting array.
createKrnl.store(loweredOpResult, alloc, loopInd);
});
Expand All @@ -2055,7 +2069,7 @@ struct ONNXElementwiseUnaryOpLowering
}
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
rewriter, loc, op, elementType, args);
loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult);
loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult, alloc);
// Store result in the resulting array.
create.krnl.store(loweredOpResult, alloc);
}
Expand Down Expand Up @@ -2165,7 +2179,7 @@ struct ONNXElementwiseBinaryOpLowering
LLVM_DEBUG(llvm::dbgs() << " scalar execution\n");

// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op);
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

Expand Down Expand Up @@ -2209,7 +2223,7 @@ struct ONNXElementwiseBinaryOpLowering
Value result = emitScalarOpFor<ElementwiseBinaryOp>(
rewriter, loc, op, outputElementType, {lhs, rhs});

result = opFusionHelper.emitFuseOps(result, loopInd);
result = opFusionHelper.emitFuseOps(result, alloc, loopInd);
// Store result in the resulting array.
createKrnl.store(result, alloc, loopInd);
});
Expand All @@ -2221,7 +2235,7 @@ struct ONNXElementwiseBinaryOpLowering
Value result = emitScalarOpFor<ElementwiseBinaryOp>(
rewriter, loc, op, outputElementType, {lhs, rhs});

result = opFusionHelper.emitFuseOps(result);
result = opFusionHelper.emitFuseOps(result, alloc);
// Store result in the resulting array.
create.krnl.store(result, alloc);
}
Expand Down Expand Up @@ -2328,7 +2342,7 @@ struct ONNXElementwiseVariadicOpLowering
LLVM_DEBUG(llvm::dbgs() << " scalar execution\n");

// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op);
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

Expand Down Expand Up @@ -2378,7 +2392,8 @@ struct ONNXElementwiseVariadicOpLowering

Value finalResult = emitPostProcessingFor<ElementwiseVariadicOp>(
rewriter, loc, op, outputElementType, accumulated);
finalResult = opFusionHelper.emitFuseOps(finalResult, loopInd);
finalResult =
opFusionHelper.emitFuseOps(finalResult, alloc, loopInd);
// Store result in the resulting array.
createKrnl.storeIE(finalResult, alloc, outputAccessExprs);
});
Expand All @@ -2395,7 +2410,7 @@ struct ONNXElementwiseVariadicOpLowering
}
Value finalResult = emitPostProcessingFor<ElementwiseVariadicOp>(
rewriter, loc, op, outputElementType, accumulated);
finalResult = opFusionHelper.emitFuseOps(finalResult);
finalResult = opFusionHelper.emitFuseOps(finalResult, alloc);
// Store result in the resulting array.
create.krnl.store(finalResult, alloc);
}
Expand Down
Loading

0 comments on commit faa42c8

Please sign in to comment.