Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops(max_pool3d and max_pool3d_with_indices) | feat(atenlib) #618

Merged
merged 9 commits into from
Apr 11, 2023
317 changes: 183 additions & 134 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,66 +668,76 @@ def aten_max_pool1d_with_indices(
raise NotImplementedError()


@torch_op("aten::max_pool2d", trace_only=True)
def aten_max_pool2d(
self: TFloatOrUInt8,
def _adjust_attributes_of_max_pool(
expand_size: int,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""

# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 2

# The dilations should be [x, y]
if isinstance(dilation, int): # x -> [x, x]
stride: Sequence[int],
padding: Sequence[int],
dilation: Sequence[int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
if isinstance(dilation, int):
dilations = [dilation] * expand_size
else: # already [x, y]
else:
dilations = dilation

# The kernel_shape should be [x, y]
if isinstance(kernel_size, int): # x -> [x, x]
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else: # assert(len(kernel_size)==2), already [x, y]
else:
kernel_shape = kernel_size

# The pads should be [w, x, y, z]
if isinstance(padding, int): # w -> [w, w, w, w]
if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
elif len(padding) == 1:
pads = padding * expand_size * 2
elif len(padding) == 2:
pads = padding * expand_size
else:
pads = padding

# The strides should be [x, y]
if isinstance(stride, int): # x -> [x, x]
if isinstance(stride, int):
strides = [stride] * expand_size
elif stride is None:
strides = kernel_shape
else:
strides = stride

return _aten_max_pool2d_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode)
return (kernel_shape, strides, pads, dilations)


@torch_op("aten::max_pool2d", trace_only=True)
def aten_max_pool2d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""

# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
expand_size = 2

kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)

@torch_op("aten::max_pool2d", private=True)
def _aten_max_pool2d_onnx(

@torch_op("internal::max_pool", private=True)
def _aten_max_pool_onnx(
self: TFloatOrUInt8,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
) -> TFloatOrUInt8:
self_rank = op.Size(op.Shape(self))
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

pool_result, _ = op.MaxPool(
Expand All @@ -739,122 +749,65 @@ def _aten_max_pool2d_onnx(
strides=strides,
)

if self_rank == 3:
if self_rank == unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))

return pool_result


@torch_op("aten::max_pool2d_with_indices", trace_only=True)
def aten_max_pool2d_with_indices(
@torch_op("aten::max_pool3d", trace_only=True)
def aten_max_pool3d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> Tuple[TFloatOrUInt8, INT64]:
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 2

# The dilations should be [x, y]
if isinstance(dilation, int): # x -> [x, x]
dilations = [dilation] * expand_size
else: # already [x, y]
dilations = dilation

# The kernel_shape should be [x, y]
if isinstance(kernel_size, int): # x -> [x, x]
kernel_shape = [kernel_size] * expand_size
else: # assert(len(kernel_size)==2), already [x, y]
kernel_shape = kernel_size
) -> TFloatOrUInt8:
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""

# The pads should be [w, x, y, z]
if isinstance(padding, int): # w -> [w, w, w, w]
pads = [padding] * expand_size * 2
elif len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 3

# The strides should be [x, y]
if isinstance(stride, int): # x -> [x, x]
strides = [stride] * expand_size
elif stride is None:
strides = kernel_shape
else:
strides = stride

return _aten_max_pool2d_with_indices_onnx(
self, expand_size, kernel_shape, strides, pads, dilations, ceil_mode
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 4)


@torch_op("aten::max_pool2d_with_indices", private=True)
def _aten_max_pool2d_with_indices_onnx(
@torch_op("aten::max_pool2d_with_indices", trace_only=True)
def aten_max_pool2d_with_indices(
self: TFloatOrUInt8,
expand_size: INT64,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank = op.Size(op.Shape(self))
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

pool_result, indices = op.MaxPool(
self,
ceil_mode=ceil_mode,
dilations=dilations,
kernel_shape=kernel_shape,
pads=pads,
strides=strides,
)
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

if self_rank == 3:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
expand_size = 2

# Torch use relative position number for the second Channel data
# If align, need reduce size(Channel)
# e.g. [[8,3,10],[30,32,23]]-[0,18] -> [[8,3,10],[12,14,5]]
# 18 = H x W = 3 x 6
batches = op.Shape(self, start=0, end=1)
channels = op.Shape(self, start=1, end=2)
end = batches * channels
offset = op.Range(0, end, 1)
data_shape = op.Shape(self, start=2)
data_size = op.ReduceProd(data_shape)
offset = offset * data_size
new_shape = op.Expand(
op.Constant(value_ints=[1]), op.Reshape(expand_size, op.Constant(value_ints=[-1]))
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
new_shape = op.Concat(batches, channels, new_shape, axis=0)
offset = op.Reshape(offset, new_shape)
indices = indices - offset
if self_rank == 3:
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
return pool_result, indices


def aten_max_pool3d(
self: TensorType,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
padding: Sequence[int] = (0, 0, 0),
dilation: Sequence[int] = (1, 1, 1),
ceil_mode: bool = False,
) -> TensorType:
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""

raise NotImplementedError()
return _aten_max_pool_with_indices_onnx(
self,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
3,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)


def aten_max_pool2d_with_indices_backward(
Expand All @@ -872,17 +825,113 @@ def aten_max_pool2d_with_indices_backward(
raise NotImplementedError()


@torch_op("aten::max_pool3d_with_indices", trace_only=True)
def aten_max_pool3d_with_indices(
self: TensorType,
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
padding: Sequence[int] = (0, 0, 0),
dilation: Sequence[int] = (1, 1, 1),
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> tuple[TensorType, TensorType]:
) -> Tuple[TFloatOrUInt8, INT64]:
"""max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

raise NotImplementedError()
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 3

kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_with_indices_onnx(
self,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
4,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)


@torch_op("internal::max_pool_with_indices", private=True)
def _aten_max_pool_with_indices_onnx(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
dilation: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
n_dims_one: Sequence[int],
n_dims_zero: Sequence[int],
n_dims_axes: Sequence[int],
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank = op.Size(op.Shape(self))
if self_rank == unbatched_rank:
self = op.Unsqueeze(self, axes=0)

pool_result, indices = op.MaxPool(
self,
ceil_mode=ceil_mode,
dilations=dilation,
kernel_shape=kernel_size,
pads=padding,
strides=stride,
)

# Simple but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by PyTorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor where each dimension has values of indices within
# the dimension it is in.
# For Maxpool1d(kernel=1,stride=1,return_indices=True), with the input torch.ones(1,2,2).
# The computed indices are the following:
# output indices pytorch :
# [[0,1],
# [0,1]]
# output indices onnx:
# [[0,1],
# [2,3]]
# The purpose was to convert the indices from one format to the other to be able to match the results.
# So flattened_indices will have the value of each index and will be equal to :
# [[0,1],
# [2,3]]
# Then call Slice to get the first value of each line (so 0 and 2).
# And the subtraction executes :
# [[0-0,1-0],
# [2-2,3-2]]
# So indices results to the expected output which is :
# [[0,1],
# [0,1]]
# For more information :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! This makes a lot of sense

# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
_, flatten_indices = op.MaxPool(
self, dilations=dilation, kernel_shape=n_dims_one, strides=n_dims_one
)

ends = op.Constant(value_ints=n_dims_one)
starts = op.Constant(value_ints=n_dims_zero)
axes = op.Constant(value_ints=n_dims_axes)

delta = op.Slice(flatten_indices, axes=axes, starts=starts, ends=ends)
indices = op.Sub(indices, delta)

if self_rank == unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))

return (pool_result, indices)


def aten_max_pool3d_with_indices_backward(
Expand Down
Loading