From f32ce3b2ef2c72476476f66380fa5e174816a454 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Wed, 7 Jun 2023 17:48:49 +0000 Subject: [PATCH] [TOSA] Fix type for non-FP32 bias. --- e2e_testing/xfail_sets.py | 2 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 8 ++++++ python/torch_mlir_e2e_test/test_suite/conv.py | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 2fcd790fd150..6fe68bdd5d02 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -607,6 +607,7 @@ "Conv1dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Convolution2DStaticModule_basic", + "Convolution2DBFloat16_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", @@ -1004,6 +1005,7 @@ "ElementwiseNeIntScalarModule_basic", "ElementwiseNeFloatTensorModule_basic", "Convolution2DStaticModule_basic", + "Convolution2DBFloat16_basic", "ElementwiseNegModule_basic", "TestMultipleTensorReturn_basic", "TypeAsSameModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index aff9b637273c..e3c4b3d85d1f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1927,6 +1927,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( bias = tosa::getConstTensor(rewriter, op, zeroVec, {static_cast(weightShape[0])}) .value(); + + // If bias ElementType is different from inputElemTy, create a cast to it. + auto biasTy = bias.getType().cast(); + if (biasTy.getElementType() != inputElemTy) { + bias = rewriter.create( + bias.getLoc(), + RankedTensorType::get(biasTy.getShape(), inputElemTy), bias); + } } } else { if (!bias.getType().cast()) diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index c20916c14982..8be6dd5f3c09 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -99,6 +99,34 @@ def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils): t = tu.rand(5, 2, 10, 20) module.forward(t) +# ============================================================================== + +class Convolution2DBFloat16(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 32, 4, 4], torch.float32, True), + ([32, 32, 3, 3], torch.float32, True), + ]) + def forward(self, x, weight): + return torch.ops.aten.convolution(x, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1) + +@register_test_case(module_factory=lambda: Convolution2DBFloat16()) +def Convolution2DBFloat16_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 32, 3, 3)) + +# ============================================================================== class Conv2dWithPaddingModule(torch.nn.Module):