Skip to content

Commit

Permalink
Fix a failing case
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Jun 29, 2022
1 parent b6b9805 commit 1f030db
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
26 changes: 25 additions & 1 deletion lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,33 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
indices);
};

auto expandGroupsOut = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

SmallVector<int64_t> outShape(inShape.begin(), inShape.end());
outShape.insert(outShape.begin() + 1,
(inShape[1] == kUnknownSize ? kUnknownSize
: inShape[1] / groupSize));
outShape[2] = groupSize;

SmallVector<ReassociationIndices> indices;
for (auto i = 0; i <= (long)inShape.size(); i++) {
if (i == 1) {
indices.push_back({i, ++i});
continue;
}
indices.push_back({i});
}

auto retType = inType.clone(outShape);
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
indices);
};

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandGroups(weight, 0);
Value outputTensorExpanded = expandGroups(outputTensor, 1);
Value outputTensorExpanded = expandGroupsOut(outputTensor);

// TODO: add 1D and 3D case
conv = rewriter
Expand Down
12 changes: 6 additions & 6 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,20 +268,20 @@ def __init__(self):
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([1, 32, 4, 4], torch.float32, True),
([32, 4, 1, 1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
stride=[1, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=3)
groups=8)

@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 9, 10, 10), torch.randn(3, 3, 2, 2))
module.forward(torch.ones(1, 32, 4, 4), torch.randn(32, 4, 1, 1))

0 comments on commit 1f030db

Please sign in to comment.