Skip to content

Commit

Permalink
Add non-unit groups to aten::convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 2, 2022
1 parent 76c9766 commit 8215edc
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 38 deletions.
204 changes: 170 additions & 34 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,33 +663,44 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");

Value N = getDimOp(rewriter, loc, input, 0);
Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
for (size_t i = 2; i < inRank; i++)
inDims.push_back(getDimOp(rewriter, loc, input, i));
Value F = getDimOp(rewriter, loc, weight, 0);
Value weightBatch = getDimOp(rewriter, loc, weight, 0);
Value weightChannels = getDimOp(rewriter, loc, weight, 1);
SmallVector<Value> weightDims;
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Guard unused values (transposed, groups)
int64_t group_size;
if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) ||
group_size != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: only group size of 1 supported");
// Guard unused values (transposed)
bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: only non-transposed convolution supported");

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);
// Checks for valid group size
int64_t groupSize;
if (!matchPattern(op.groups(), m_TorchConstantInt(&groupSize)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.groups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value inputValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, c0,
rewriter.create<arith::RemSIOp>(loc, toValidate, groups));
rewriter.create<cf::AssertOp>(loc, inputValid,
rewriter.getStringAttr(err));
};
validate(inChannels,
"invalid: groups must divide input channel size evenly.");
validate(weightBatch,
"invalid: groups must divide weight batch size evenly.");

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
Expand All @@ -698,7 +709,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

SmallVector<Value> outDims{N, F};
SmallVector<Value> outDims{inBatch, weightBatch};
for (size_t i = 0; i < inRank - 2; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
Expand All @@ -708,12 +719,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);

Value bias = adaptor.bias();
Value biasInitTensor;
Value outputTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
Expand All @@ -727,27 +738,152 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
outputTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
}

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);

// TODO: add 1D and 3D case
Value conv =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
biasInitTensor, stridesAttr, dilationAttr)
.getResult(0);
Value inputStride =
rewriter.create<arith::FloorDivSIOp>(loc, inChannels, groups);
Value weightStride =
rewriter.create<arith::FloorDivSIOp>(loc, weightBatch, groups);

SmallVector<Value> zeroOffsets(inRank, rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(0)));
SmallVector<Value> unitStrides(inRank, rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(1)));
SmallVector<Value> outDimSlice(outDims);
outDimSlice[1] = weightStride;
SmallVector<Value> inputSliceSizes{inBatch, inputStride};
inputSliceSizes.append(inDims);
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
weightSliceSizes.append(weightDims);

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.append(paddingInts);

// Pad inputSlice
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);

Value conv;
if (groupSize == 1) {
// TODO: add 1D and 3D case
conv =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
} else {
// Special depthwise case
auto inShape = input.getType().cast<RankedTensorType>().getShape();
auto weightShape = weight.getType().cast<RankedTensorType>().getShape();
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
// Collapse weight shape
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{
(weightShape[0] == kUnknownSize ? kUnknownSize
: weightShape[0] * weightShape[1]),
weightShape[2], weightShape[3]};
Type collapsedType = RankedTensorType::get(collapsedShape, elementType);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);

conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

// Grouped case, use the grouped conv linalg op

auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

SmallVector<int64_t> outShape;
for (auto i = 0; i < (long)inShape.size(); i++) {
if (i == 1) {
outShape.push_back(groupSize);
}
if (i == (long)dim) {
outShape.push_back(inShape[i] == kUnknownSize
? kUnknownSize
: inShape[i] / groupSize);
} else {
outShape.push_back(inShape[i]);
}
}

SmallVector<ReassociationIndices> indices;
for (auto i = 0; i <= (long)inShape.size(); i++) {
if (i == (long)dim) {
indices.push_back({i, ++i});
continue;
}
indices.push_back({i});
}

auto retType = inType.clone(outShape);
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
indices);
};

auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

SmallVector<int64_t> outShape{
groupSize, (inShape[0] == kUnknownSize ? kUnknownSize
: inShape[0] / groupSize)};
outShape.append(inShape.begin() + 1, inShape.end());

SmallVector<ReassociationIndices> indices{{0, 1}};
for (auto i = 2; i <= (long)inShape.size(); i++)
indices.push_back({i});

auto retType = inType.clone(outShape);
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
indices);
};

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>(
loc, outputTensorExpanded.getType(),
ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr)
.getResult(0);

SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices);
}

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
Expand Down
6 changes: 2 additions & 4 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"TableBatchEmbeddingModule_basic",
"MobilenetV2Module_basic",
"MobilenetV3Module_basic",
"Convolution3DModule_basic",
"Convolution1DModule_basic",
"ConvolutionModule3D_basic",
"ConvolutionModule1D_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
}
Expand Down
25 changes: 25 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,28 @@ def forward(self, inputVec, weight):
@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class ConvolutionModule2DGroups(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=4)

@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3))

0 comments on commit 8215edc

Please sign in to comment.