Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legalize aten.repeat_interleave.Tensor for torch-mlir #44

Merged
merged 10 commits into from
May 31, 2023
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"RepeatInterleaveModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
Expand Down Expand Up @@ -1008,6 +1009,7 @@
"FullModuleFloat2D_basic",
"ElementwiseAbsModule_basic",
"RepeatModule_basic",
"RepeatInterleaveModule_basic",
"TensorsSplitTensorModule_basic",
"TensorsSplitTensorNegativeDimModule_basic",
"TensorsSplitTensorLastSmallerModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7401,6 +7401,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

def Torch_AtenRepeatInterleaveTensorOp : Torch_Op<"aten.repeat_interleave.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$repeats,
AnyTorchOptionalIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRepeatInterleaveTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenRepeatInterleaveTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement,
ReadOnly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
for i in range(tensor_dim):
out.append(self[i] * repeats[i + leading_rank])
return out

def aten〇repeat_interleave〇Tensor〡shape(repeats: List[int], output_size: Optional[int] = None) -> List[int]:
return repeats
flemairen6 marked this conversation as resolved.
Show resolved Hide resolved

def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)
Expand Down Expand Up @@ -1695,6 +1698,10 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇repeat_interleave〇Tensor〡dtype(repeats_rank_dtype: Tuple[int, int], output_size: Optional[int] = None) -> int:
repeats_rank, repeats_dtype = repeats_rank_dtype
return repeats_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]))
def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,26 @@ def RepeatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================
class RepeatInterleaveModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2, 4], torch.int, True),
])
def forward(self, x):
x = torch.ops.aten.repeat_interleave(x, output_size=10)
flemairen6 marked this conversation as resolved.
Show resolved Hide resolved
return x


@register_test_case(module_factory=lambda: RepeatInterleaveModule())
def RepeatInterleaveModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([3, 1, 2, 4], dtype=torch.int))

# ==============================================================================


class ExpandModule(torch.nn.Module):
Expand Down