Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine the lowering of onnx.concat #1778

Merged
merged 7 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/Conversion/ONNXToKrnl/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
assert(succeeded(shapecomputed) && "Could not compute output shape");

auto axis = concatOp.axis();
assert(axis >= 0 && "negative axis is supposed to have been normalized");
unsigned int inputNum = operands.size();

// Convert the output type to MemRefType.
Expand All @@ -57,16 +58,28 @@ struct ONNXConcatOpLowering : public ConversionPattern {
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);

// Creates loops, one for each input.
// Since the each input should have same size for each dimension(except
// axis), we will try to make the loop upper bound the same for futher
// optimization. Difference may come from constant vs. dynamic, or dynamic
// dim of different inputs.
KrnlBuilder createKrnl(rewriter, loc);
SmallVector<IndexExpr, 4> commonUB(shapeHelper.dimsForOutput());
// IndexExprScope IEScope(&rewriter, loc);
IndexExpr accumulatedOffset = LiteralIndexExpr(0);
for (unsigned int i = 0; i < inputNum; ++i) {
// Since the acculatedOffsetValue will be used in a nested IndexExprScope,
// we get the Value of this IndexExpr and pass it as a symbol
Value accumulatedOffsetValue = accumulatedOffset.getValue();
OpBuilder::InsertionGuard insertGuard(rewriter);
// Create loop.
ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));
MemRefBoundsIndexCapture bounds(operands[i]);
SmallVector<IndexExpr, 4> ubs;
bounds.getDimList(ubs);
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
// For each input, only the dimension 'axis' is different
commonUB[axis] = ubs[axis];
createKrnl.iterateIE(loopDef, loopDef, lbs, commonUB,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Indices for the read and write.
SmallVector<Value, 4> readIndices, writeIndices;
Expand All @@ -76,17 +89,18 @@ struct ONNXConcatOpLowering : public ConversionPattern {
else {
IndexExprScope IEScope(&rewriter, loc);
IndexExpr writeOffset = DimIndexExpr(loopInd[r]);
for (unsigned int j = 0; j < i; j++) {
MemRefBoundsIndexCapture operandJBounds(operands[j]);
writeOffset = writeOffset + operandJBounds.getDim(r);
}
IndexExpr accumulatedOffsetIE =
SymbolIndexExpr(accumulatedOffsetValue);
writeOffset = writeOffset + accumulatedOffsetIE;
writeIndices.emplace_back(writeOffset.getValue());
}
}
// Insert copy.
Value loadData = createKrnl.load(operands[i], loopInd);
createKrnl.store(loadData, alloc, writeIndices);
});
MemRefBoundsIndexCapture operandJBounds(operands[i]);
accumulatedOffset = accumulatedOffset + operandJBounds.getDim(axis);
}
rewriter.replaceOp(op, alloc);
return success();
Expand Down
33 changes: 25 additions & 8 deletions src/Dialect/ONNX/ShapeInference/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,36 @@ LogicalResult ONNXConcatOpShapeHelper::computeShape(
if (axisIndex < 0)
axisIndex += commonRank;

// For Concat Op, the size of each dimension of inputs should be the same,
// except for concatenated dimension. To simplify the result, constant
// size is used if there is one. Otherwise, the dimension of the last
// input tensor (implementation dependent) is used for the output tensor.
DimsExpr outputDims(commonRank);
IndexExpr cumulativeAxisSize = LiteralIndexExpr(0);
SmallVector<bool, 4> isConstant(commonRank, false);
for (unsigned i = 0; i < numInputs; ++i) {
Value currentInput = operandAdaptor.inputs()[i];
MemRefBoundsIndexCapture currInputBounds(currentInput);
DimIndexExpr currentSize(currInputBounds.getDim(axisIndex));
cumulativeAxisSize = cumulativeAxisSize + currentSize;
for (unsigned dim = 0; dim < commonRank; dim++) {
if (dim == axisIndex) {
DimIndexExpr currentSize(currInputBounds.getDim(axisIndex));
cumulativeAxisSize = cumulativeAxisSize + currentSize;
} else {
if (!isConstant[dim]) {
if (currInputBounds.getDim(dim).isLiteral()) {
// The size of current dimension of current input is a constant
outputDims[dim] = currInputBounds.getDim(dim);
isConstant[dim] = true;
} else if (i == numInputs - 1) {
// If no constant dimension found for all the inputs, use the
// dynamic size of the last input.
outputDims[dim] = currInputBounds.getDim(dim);
}
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it easier to initialize outputDims with the first input's dims? Then accumulate the dim along the axis and replace other dims if there is a constant. By that way, we don't need the isConstant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

}

DimsExpr outputDims(commonRank);
MemRefBoundsIndexCapture firstInputBounds(firstInput);
for (unsigned i = 0; i < commonRank; i++)
outputDims[i] =
(i == axisIndex) ? cumulativeAxisSize : firstInputBounds.getDim(i);
outputDims[axisIndex] = cumulativeAxisSize;

setOutputDims(outputDims);
return success();
Expand Down
Loading