Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-unit groups to aten::convolution #858

Merged
merged 1 commit into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 170 additions & 34 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,33 +663,44 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");

Value N = getDimOp(rewriter, loc, input, 0);
Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
for (size_t i = 2; i < inRank; i++)
inDims.push_back(getDimOp(rewriter, loc, input, i));
Value F = getDimOp(rewriter, loc, weight, 0);
Value weightBatch = getDimOp(rewriter, loc, weight, 0);
Value weightChannels = getDimOp(rewriter, loc, weight, 1);
SmallVector<Value> weightDims;
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Guard unused values (transposed, groups)
int64_t group_size;
if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) ||
group_size != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: only group size of 1 supported");
// Guard unused values (transposed)
bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: only non-transposed convolution supported");

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);
// Checks for valid group size
int64_t groupSize;
if (!matchPattern(op.groups(), m_TorchConstantInt(&groupSize)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.groups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value inputValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, c0,
rewriter.create<arith::RemSIOp>(loc, toValidate, groups));
rewriter.create<cf::AssertOp>(loc, inputValid,
rewriter.getStringAttr(err));
};
validate(inChannels,
"invalid: groups must divide input channel size evenly.");
validate(weightBatch,
"invalid: groups must divide weight batch size evenly.");

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
Expand All @@ -698,7 +709,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

SmallVector<Value> outDims{N, F};
SmallVector<Value> outDims{inBatch, weightBatch};
for (size_t i = 0; i < inRank - 2; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
Expand All @@ -708,12 +719,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);

Value bias = adaptor.bias();
Value biasInitTensor;
Value outputTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
Expand All @@ -727,27 +738,152 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
outputTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
}

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);

// TODO: add 1D and 3D case
Value conv =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
biasInitTensor, stridesAttr, dilationAttr)
.getResult(0);
Value inputStride =
rewriter.create<arith::FloorDivSIOp>(loc, inChannels, groups);
Value weightStride =
rewriter.create<arith::FloorDivSIOp>(loc, weightBatch, groups);

SmallVector<Value> zeroOffsets(inRank, rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(0)));
SmallVector<Value> unitStrides(inRank, rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(1)));
SmallVector<Value> outDimSlice(outDims);
outDimSlice[1] = weightStride;
SmallVector<Value> inputSliceSizes{inBatch, inputStride};
inputSliceSizes.append(inDims);
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
weightSliceSizes.append(weightDims);

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.append(paddingInts);

gpetters94 marked this conversation as resolved.
Show resolved Hide resolved
// Pad inputSlice
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);

Value conv;
if (groupSize == 1) {
// TODO: add 1D and 3D case
conv =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
} else {
// Special depthwise case
auto inShape = input.getType().cast<RankedTensorType>().getShape();
auto weightShape = weight.getType().cast<RankedTensorType>().getShape();
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
// Collapse weight shape
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
SmallVector<int64_t> collapsedShape{
(weightShape[0] == kUnknownSize ? kUnknownSize
: weightShape[0] * weightShape[1]),
weightShape[2], weightShape[3]};
Type collapsedType = RankedTensorType::get(collapsedShape, elementType);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, weight, collapsedDims);

conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

// Grouped case, use the grouped conv linalg op

auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

SmallVector<int64_t> outShape;
for (auto i = 0; i < (long)inShape.size(); i++) {
if (i == 1) {
outShape.push_back(groupSize);
}
if (i == (long)dim) {
outShape.push_back(inShape[i] == kUnknownSize
? kUnknownSize
: inShape[i] / groupSize);
} else {
outShape.push_back(inShape[i]);
}
}

SmallVector<ReassociationIndices> indices;
for (auto i = 0; i <= (long)inShape.size(); i++) {
if (i == (long)dim) {
indices.push_back({i, ++i});
continue;
}
indices.push_back({i});
}

auto retType = inType.clone(outShape);
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
indices);
};

auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

SmallVector<int64_t> outShape{
groupSize, (inShape[0] == kUnknownSize ? kUnknownSize
: inShape[0] / groupSize)};
outShape.append(inShape.begin() + 1, inShape.end());

SmallVector<ReassociationIndices> indices{{0, 1}};
for (auto i = 2; i <= (long)inShape.size(); i++)
indices.push_back({i});

auto retType = inType.clone(outShape);
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
indices);
};

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>(
loc, outputTensorExpanded.getType(),
ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr)
.getResult(0);

SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices);
}

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
Expand Down
2 changes: 0 additions & 2 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"TableBatchEmbeddingModule_basic",
"MobilenetV2Module_basic",
"MobilenetV3Module_basic",
"Convolution3DModule_basic",
"Convolution1DModule_basic",
"MaxPool2dWith3dInputModule_basic",
Expand Down
25 changes: 25 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,28 @@ def forward(self, inputVec, weight):
@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class ConvolutionModule2DGroups(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=4)

@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3))