Skip to content

Commit

Permalink
[TOSA] Fix type for non-FP32 bias.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jun 7, 2023
1 parent 71aee8f commit f32ce3b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@
"Conv1dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Convolution2DStaticModule_basic",
"Convolution2DBFloat16_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
Expand Down Expand Up @@ -1004,6 +1005,7 @@
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"Convolution2DStaticModule_basic",
"Convolution2DBFloat16_basic",
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
"TypeAsSameModule_basic",
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,14 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
{static_cast<int32_t>(weightShape[0])})
.value();

// If bias ElementType is different from inputElemTy, create a cast to it.
auto biasTy = bias.getType().cast<RankedTensorType>();
if (biasTy.getElementType() != inputElemTy) {
bias = rewriter.create<tosa::CastOp>(
bias.getLoc(),
RankedTensorType::get(biasTy.getShape(), inputElemTy), bias);
}
}
} else {
if (!bias.getType().cast<RankedTensorType>())
Expand Down
28 changes: 28 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)

# ==============================================================================

class Convolution2DBFloat16(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 32, 4, 4], torch.float32, True),
([32, 32, 3, 3], torch.float32, True),
])
def forward(self, x, weight):
return torch.ops.aten.convolution(x,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1)

@register_test_case(module_factory=lambda: Convolution2DBFloat16())
def Convolution2DBFloat16_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 32, 3, 3))

# ==============================================================================

class Conv2dWithPaddingModule(torch.nn.Module):

Expand Down

0 comments on commit f32ce3b

Please sign in to comment.