diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index ccb876a05ce1..c2a2c8cd3349 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -747,9 +747,33 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { indices); }; + auto expandGroupsOut = [&](Value tensor) { + auto inType = tensor.getType().cast(); + auto inShape = inType.getShape(); + + SmallVector outShape(inShape.begin(), inShape.end()); + outShape.insert(outShape.begin() + 1, + (inShape[1] == kUnknownSize ? kUnknownSize + : inShape[1] / groupSize)); + outShape[2] = groupSize; + + SmallVector indices; + for (auto i = 0; i <= (long)inShape.size(); i++) { + if (i == 1) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(outShape); + return rewriter.create(loc, retType, tensor, + indices); + }; + Value paddedInputExpanded = expandGroups(paddedInput, 1); Value weightExpanded = expandGroups(weight, 0); - Value outputTensorExpanded = expandGroups(outputTensor, 1); + Value outputTensorExpanded = expandGroupsOut(outputTensor); // TODO: add 1D and 3D case conv = rewriter diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 832a4e3314e8..cc48b8c9ef75 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -268,20 +268,20 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), + ([1, 32, 4, 4], torch.float32, True), + ([32, 4, 1, 1], torch.float32, True), ]) def forward(self, inputVec, weight): return torch.ops.aten.convolution(inputVec, weight, bias=None, - stride=[3, 3], - padding=[2, 2], + stride=[1, 1], + padding=[1, 1], dilation=[1, 1], transposed=False, output_padding=[0, 0], - groups=3) + groups=8) @register_test_case(module_factory=lambda: ConvolutionModule2DGroups()) def ConvolutionModule2DGroups_basic(module, tu: TestUtils): - module.forward(torch.randn(3, 9, 10, 10), torch.randn(3, 3, 2, 2)) + module.forward(torch.ones(1, 32, 4, 4), torch.randn(32, 4, 1, 1))