From 694a32e557f04406b6f0a0b28c6fd8b6589c3226 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 12 Jul 2022 16:55:29 -0400 Subject: [PATCH 1/3] Blacklist _convolution op in LTC --- build_tools/autogen_ltc_backend.yaml | 3 ++- e2e_testing/torchscript/xfail_sets.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 37115fa98e4a..b8a8905e06c8 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -1,13 +1,14 @@ blacklist: # List of unsupported ops in LTC autogen because of some error +- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here - empty_like # Error: TODO add support for type BaseType(name=) - index.Tensor # Error: TODO not sure if there are other valid types to handle here - index_put # Error: TODO not sure if there are other valid types to handle here - index_put_ # Error: TODO not sure if there are other valid types to handle here -- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here - stack # Error: TODO not sure if there are other valid types to handle here # Additional ops which autogen is supported for but don't compile yet +- _convolution - detach - item - size diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 3d5d44aaf43a..513145b42637 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -175,6 +175,11 @@ } LTC_XFAIL_SET = { + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", From 540449ad53133e2ef931845d3a564e3985bf61a5 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 13 Jul 2022 09:03:40 -0400 Subject: [PATCH 2/3] Removed duplicate Torch_AtenSelectScatterOp instance from autogen .td --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ------------------- .../jit_ir/build_tools/torch_ods_gen.py | 1 - 2 files changed, 27 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f9230e36085c..484a4b75501b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5293,32 +5293,6 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ }]; } -def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$src, - Torch_IntType:$dim, - Torch_IntType:$index - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenSelectScatterOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index c3722bb355e1..9052e6eecb67 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -437,7 +437,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") - emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::sum : (Tensor, int?) -> (Tensor)") From 9062130cfc0a4eb82e7eb41b1d805b07a6c26009 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 13 Jul 2022 10:12:39 -0400 Subject: [PATCH 3/3] Removed duplicate Torch_AtenSliceScatterOp instance from autogen .td --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 ------------------- .../jit_ir/build_tools/torch_ods_gen.py | 1 - 2 files changed, 29 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 484a4b75501b..a9bba698ea2f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5748,34 +5748,6 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ }]; } -def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$src, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$start, - AnyTorchOptionalIntType:$end, - Torch_IntType:$step - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenSliceScatterOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 9052e6eecb67..8ccba234e29d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -455,7 +455,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") - emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")