Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legalize aten.repeat_interleave.Tensor for torch-mlir #44

Merged
merged 10 commits into from
May 31, 2023
3 changes: 3 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@
"ScatterValueFloatModule_basic",
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
"ScatterValueIntModule_basic",
# ERROR: Unsupported: dynamic shape operator: aten.repeat_interleave.Tensor
"RepeatInterleaveModule_basic",
}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -1284,4 +1286,5 @@
"ChunkListUnpackUnevenDynamic_Module_basic",
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",
"RepeatInterleaveModule_basic",
}
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7570,6 +7570,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

def Torch_AtenRepeatInterleaveTensorOp : Torch_Op<"aten.repeat_interleave.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$repeats,
AnyTorchOptionalIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRepeatInterleaveTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenRepeatInterleaveTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement,
ReadOnly
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6669,6 +6669,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.Uninitialized : !torch.int\n"
" %1 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" %4 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %4 : !torch.int\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8396,6 +8411,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
for i in range(tensor_dim):
out.append(self[i] * repeats[i + leading_rank])
return out

def aten〇repeat_interleave〇Tensor〡shape(repeats: List[int], output_size: Optional[int] = None) -> List[int]:
assert output_size is not None
return [output_size]

def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)
Expand Down Expand Up @@ -1717,6 +1721,10 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇repeat_interleave〇Tensor〡dtype(repeats_rank_dtype: Tuple[int, int], output_size: Optional[int] = None) -> int:
repeats_rank, repeats_dtype = repeats_rank_dtype
return repeats_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]))
def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"RepeatInterleaveModule_basic",
}

# TODO: Delete once torch 2.1.0 is released
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,27 @@ def RepeatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================
class RepeatInterleaveModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4], torch.int, True),
])
def forward(self, x):
z = torch.ops.aten.repeat_interleave(x, output_size=10)
y = torch.ops.aten.repeat_interleave(x)
return z, y


@register_test_case(module_factory=lambda: RepeatInterleaveModule())
def RepeatInterleaveModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([3, 1, 2, 4], dtype=torch.int))

# ==============================================================================


class ExpandModule(torch.nn.Module):
Expand Down