Skip to content

Commit

Permalink
feat: adds a folder for torch.aten.broadcast_to operation. (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost authored May 24, 2023
1 parent a950b6a commit 827a091
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
Expand Down Expand Up @@ -970,6 +971,7 @@
"ReduceSumUnsignedIntModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"SliceStaticModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeDtypeFloatModule_basic",
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3274,6 +3274,13 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(

SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape()));
// Result dimension -1 means not changing the size of that dimension.
// Adjust it by assigning its inputShape.
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) {
auto index = shape.index();
if (resultShape[index] == -1)
resultShape[index] = shape.value();
}
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) {
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,27 @@ def BroadcastZeroRankInputStaticModule_basic(module, tu: TestUtils):

# ==============================================================================

class BroadcastListConstructWithMinusOneModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 8], torch.float32, True),
([3, 1, 8], torch.float32, True),
])
def forward(self, x, y):
y = torch.broadcast_to(y, [-1, -1, -1])
return torch.ops.aten.sub(x, y)


@register_test_case(module_factory=lambda: BroadcastListConstructWithMinusOneModule())
def BroadcastListConstructWithMinusOneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8), tu.rand(3, 1, 8))

# ==============================================================================

class RollModule(torch.nn.Module):

Expand Down

0 comments on commit 827a091

Please sign in to comment.