Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp #1159

Merged
merged 1 commit into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@
"StdDimBiasedModule_basic",
"StdDimKeepDimFalseModule_basic",
"StdDimKeepDimTrueModule_basic",
"StdDimEmptyDimModule_basic",
"StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
Expand Down
8 changes: 4 additions & 4 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4049,10 +4049,10 @@ def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)`";
let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim
);
Expand Down Expand Up @@ -4099,10 +4099,10 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)`";
let summary = "Generated op for `aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim
);
Expand Down
50 changes: 41 additions & 9 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5564,12 +5564,19 @@ module {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
Expand All @@ -5580,7 +5587,8 @@ module {
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<int>
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
Expand Down Expand Up @@ -5620,11 +5628,35 @@ module {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ def aten〇mean(self: List[int], dtype: Optional[int] = None) -> List[int]:
def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
return []

def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if len(dim)==0:
def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)

Expand All @@ -502,7 +502,9 @@ def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correctio
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
return []

def aten〇std〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)

def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)")
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)")
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
Expand Down
70 changes: 68 additions & 2 deletions python/torch_mlir_e2e_test/test_suite/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,50 @@ def StdDimBiasedModule_basic(module, tu: TestUtils):
# ==============================================================================


class StdDimEmptyDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=[], keepdim=False)


@register_test_case(module_factory=lambda: StdDimEmptyDimModule())
def StdDimEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class StdDimNoneDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=None, keepdim=False)


@register_test_case(module_factory=lambda: StdDimNoneDimModule())
def StdDimNoneDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class VarDimModule(torch.nn.Module):

def __init__(self):
Expand Down Expand Up @@ -416,7 +460,7 @@ def __init__(self):
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=0, unbiased=False, keepdim=True)
return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True)


@register_test_case(module_factory=lambda: VarDimBiasedModule())
Expand All @@ -438,7 +482,7 @@ def __init__(self):
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=0, keepdim=True)
return torch.ops.aten.var(x, dim=(0,), keepdim=True)


@register_test_case(module_factory=lambda: VarDimSingleDimModule())
Expand Down Expand Up @@ -537,6 +581,28 @@ def VarDimEmptyDimModule_basic(module, tu: TestUtils):
# ==============================================================================


class VarDimNoneDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, keepdim=False)


@register_test_case(module_factory=lambda: VarDimNoneDimModule())
def VarDimNoneDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class VarCorrectionModule(torch.nn.Module):

def __init__(self):
Expand Down