Skip to content

Commit

Permalink
[pt2] add metas for avg_pool3d and avg_pool3d_backward (pytorch#1…
Browse files Browse the repository at this point in the history
…03392)

Pull Request resolved: pytorch#103392
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jun 13, 2023
1 parent 8dc6001 commit 4a76fb4
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 4 deletions.
3 changes: 0 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,7 +2809,6 @@ def forward(self, x):
xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
xfail('nn.functional.bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
Expand Down Expand Up @@ -3054,10 +3053,8 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
torch.nn.LocalResponseNorm, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.AvgPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.MaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.MaxPool3d, # torch._subclasses.fake_tensor.UnsupportedOperatorException:
# aten.max_pool3d_with_indices.default
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Expand Down
313 changes: 313 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,172 @@ def meta_avg_pool2d_backward(
)


@register_meta(aten.avg_pool3d)
@out_wrapper()
def meta_avg_pool3d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
kT = kernel_size[0]
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]

check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
dT = kT if not stride else stride[0]
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])

check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
padT = padding[0]
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]

check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)

check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)

nbatch = input.size(0)
nslices = input.size(-4)
itime = input.size(-3)
iheight = input.size(-2)
iwidth = input.size(-1)

otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)

pool3d_shape_check(
input,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
padT,
padH,
padW,
1,
1,
1,
itime,
iheight,
iwidth,
otime,
oheight,
owidth,
"avg_pool3d()",
check_input_size=True,
)

if input.ndim == 4:
return input.new_empty((nslices, otime, oheight, owidth))
else:
return input.new_empty((nbatch, nslices, otime, oheight, owidth))


@register_meta(aten.avg_pool3d_backward)
@out_wrapper()
def meta_avg_pool3d_backward(
grad_output,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
):
check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
kT = kernel_size[0]
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]

check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
dT = kT if not stride else stride[0]
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])

check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
padT = padding[0]
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]

check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)

check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)

nslices = input.size(-4)
itime = input.size(-3)
iheight = input.size(-2)
iwidth = input.size(-1)

otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)

avg_pool3d_backward_shape_check(
input,
grad_output,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
padT,
padH,
padW,
itime,
iheight,
iwidth,
otime_for_shape_check,
oheight_for_shape_check,
owidth_for_shape_check,
"avg_pool3d_backward()",
)

return input.new_empty(input.shape)


@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
check(
Expand Down Expand Up @@ -2546,6 +2712,153 @@ def pool2d_shape_check(
)


def pool3d_shape_check(
input: Tensor,
nslices: int,
kT: int,
kH: int,
kW: int,
dT: int,
dH: int,
dW: int,
pT: int,
pH: int,
pW: int,
dilationT: int,
dilationH: int,
dilationW: int,
itime: int,
iheight: int,
iwidth: int,
otime: int,
oheight: int,
owidth: int,
fn_name: str,
check_input_size: bool = False,
):
ndim = input.ndim

check(
kT > 0 and kW > 0 and kH > 0,
lambda: (
f"kernel size should be greater than zero, but got "
f"kT: {kT}, kH: {kH}, kW: {kW}"
),
)
check(
dT > 0 and dW > 0 and dH > 0,
lambda: (
f"stride should be greater than zero, but got "
f"dT: {dT}, dH: {dH}, dW: {dW}"
),
)
check(
dilationT > 0 and dilationW > 0 and dilationH > 0,
lambda: (
f"dilation should be greater than zero, but got "
f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
),
)

check(
ndim in (4, 5),
lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
)

for i in range(ndim):
if ndim == 5 and i == 0:
# size of batch-dim can be 0.
continue
check(
input.size(i) > 0,
lambda: (
f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
f" but input has a shape of {input.shape}"
f" and non-batch dimension {input.size(i)} has length zero!"
),
)

if check_input_size: # AveragePool3d
check(
itime >= kT and iheight >= kH and iwidth >= kW,
lambda: (
f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
),
)

check(
kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
lambda: (
f"pad should be smaller than or equal to half of kernel size, but got "
f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
),
)

check(
otime >= 1 and owidth >= 1 and oheight >= 1,
lambda: (
f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
f"Output size is too small"
),
)


def avg_pool3d_backward_shape_check(
input: Tensor,
grad_output: Tensor,
nslices: int,
kT: int,
kH: int,
kW: int,
dT: int,
dH: int,
dW: int,
pT: int,
pH: int,
pW: int,
itime: int,
iheight: int,
iwidth: int,
otime: int,
oheight: int,
owidth: int,
fn_name: str,
):
ndim = input.ndim

pool3d_shape_check(
input,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
pT,
pH,
pW,
1,
1,
1,
itime,
iheight,
iwidth,
otime,
oheight,
owidth,
fn_name,
True,
)

check_dim_size(grad_output, ndim, ndim - 4, nslices)
check_dim_size(grad_output, ndim, ndim - 3, otime)
check_dim_size(grad_output, ndim, ndim - 2, oheight)
check_dim_size(grad_output, ndim, ndim - 1, owidth)


def max_pool2d_checks_and_compute_shape(
input, kernel_size, stride, padding, dilation, ceil_mode
):
Expand Down

0 comments on commit 4a76fb4

Please sign in to comment.