From edf725ef42b9bc7bc1dada691a3988b3c0038e33 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 12 Sep 2024 19:07:11 +0800 Subject: [PATCH] [Torch] add AtenAsStridedOp in torch dialect (#3706) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 7 ++++++ lib/Dialect/Torch/Utils/Utils.cpp | 11 ++++---- .../build_tools/abstract_interp_lib_gen.py | 7 ++++++ .../build_tools/torch_ods_gen.py | 1 + 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7591493f86ab..12907c9a649e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13195,6 +13195,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ }]; } +def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenAsStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 545fdee26836..27a2f1e2c7af 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10002,6 +10002,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" " %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -12297,6 +12300,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index eb8b37502efc..988df760d4cb 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -247,11 +247,12 @@ bool Torch::isViewLikeOp(Operation *op) { // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value // semantics. - return isa List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) +def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: + return size + def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]: return self, self @@ -3377,6 +3380,10 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0 self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9318ab6f2db0..2421fda24161 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -968,6 +968,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")