Skip to content

Commit

Permalink
Add ops(max_pool3d and max_pool3d_with_indices) | feat(atenlib) (#618)
Browse files Browse the repository at this point in the history
1. Add max_pool3d and max_pool3d_with_indices ops.
2. Refactor the code so max_pool2d and max_pool3d could share the same
helper function. And move them to nn.py file.
3. Refactor the code so max_pool2d_with_indices and max_pool3d_indices
could share the same helper function. And move them to nn.py file.
4. Add tests.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
fatcat-z and justinchuby authored Apr 11, 2023
1 parent b132176 commit 5eafe2a
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 147 deletions.
317 changes: 183 additions & 134 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,66 +668,76 @@ def aten_max_pool1d_with_indices(
raise NotImplementedError()


@torch_op("aten::max_pool2d", trace_only=True)
def aten_max_pool2d(
self: TFloatOrUInt8,
def _adjust_attributes_of_max_pool(
expand_size: int,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""

# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 2

# The dilations should be [x, y]
if isinstance(dilation, int): # x -> [x, x]
stride: Sequence[int],
padding: Sequence[int],
dilation: Sequence[int],
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
if isinstance(dilation, int):
dilations = [dilation] * expand_size
else: # already [x, y]
else:
dilations = dilation

# The kernel_shape should be [x, y]
if isinstance(kernel_size, int): # x -> [x, x]
if isinstance(kernel_size, int):
kernel_shape = [kernel_size] * expand_size
else: # assert(len(kernel_size)==2), already [x, y]
else:
kernel_shape = kernel_size

# The pads should be [w, x, y, z]
if isinstance(padding, int): # w -> [w, w, w, w]
if isinstance(padding, int):
pads = [padding] * expand_size * 2
elif len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
elif len(padding) == 1:
pads = padding * expand_size * 2
elif len(padding) == 2:
pads = padding * expand_size
else:
pads = padding

# The strides should be [x, y]
if isinstance(stride, int): # x -> [x, x]
if isinstance(stride, int):
strides = [stride] * expand_size
elif stride is None:
strides = kernel_shape
else:
strides = stride

return _aten_max_pool2d_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode)
return (kernel_shape, strides, pads, dilations)


@torch_op("aten::max_pool2d", trace_only=True)
def aten_max_pool2d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> TFloatOrUInt8:
"""max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor"""

# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
expand_size = 2

kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)

@torch_op("aten::max_pool2d", private=True)
def _aten_max_pool2d_onnx(

@torch_op("internal::max_pool", private=True)
def _aten_max_pool_onnx(
self: TFloatOrUInt8,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
) -> TFloatOrUInt8:
self_rank = op.Size(op.Shape(self))
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

pool_result, _ = op.MaxPool(
Expand All @@ -739,122 +749,65 @@ def _aten_max_pool2d_onnx(
strides=strides,
)

if self_rank == 3:
if self_rank == unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))

return pool_result


@torch_op("aten::max_pool2d_with_indices", trace_only=True)
def aten_max_pool2d_with_indices(
@torch_op("aten::max_pool3d", trace_only=True)
def aten_max_pool3d(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> Tuple[TFloatOrUInt8, INT64]:
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

# Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly
# But ONNX needs pair number [x,y] to specify on each side explicitly
# For pool3d, this number should be 3
expand_size = 2

# The dilations should be [x, y]
if isinstance(dilation, int): # x -> [x, x]
dilations = [dilation] * expand_size
else: # already [x, y]
dilations = dilation

# The kernel_shape should be [x, y]
if isinstance(kernel_size, int): # x -> [x, x]
kernel_shape = [kernel_size] * expand_size
else: # assert(len(kernel_size)==2), already [x, y]
kernel_shape = kernel_size
) -> TFloatOrUInt8:
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""

# The pads should be [w, x, y, z]
if isinstance(padding, int): # w -> [w, w, w, w]
pads = [padding] * expand_size * 2
elif len(padding) == 1: # [w] -> [w, w, w, w]
pads = padding * 4
elif len(padding) == 2: # [w, x] -> [w, x, w, x]
pads = padding * 2
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 3

# The strides should be [x, y]
if isinstance(stride, int): # x -> [x, x]
strides = [stride] * expand_size
elif stride is None:
strides = kernel_shape
else:
strides = stride

return _aten_max_pool2d_with_indices_onnx(
self, expand_size, kernel_shape, strides, pads, dilations, ceil_mode
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 4)


@torch_op("aten::max_pool2d_with_indices", private=True)
def _aten_max_pool2d_with_indices_onnx(
@torch_op("aten::max_pool2d_with_indices", trace_only=True)
def aten_max_pool2d_with_indices(
self: TFloatOrUInt8,
expand_size: INT64,
kernel_shape: Sequence[int],
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
ceil_mode: bool,
kernel_size: Sequence[int],
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank = op.Size(op.Shape(self))
if self_rank == 3: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

pool_result, indices = op.MaxPool(
self,
ceil_mode=ceil_mode,
dilations=dilations,
kernel_shape=kernel_shape,
pads=pads,
strides=strides,
)
"""max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

if self_rank == 3:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a pair of number [x,y] on each side explicitly.
expand_size = 2

# Torch use relative position number for the second Channel data
# If align, need reduce size(Channel)
# e.g. [[8,3,10],[30,32,23]]-[0,18] -> [[8,3,10],[12,14,5]]
# 18 = H x W = 3 x 6
batches = op.Shape(self, start=0, end=1)
channels = op.Shape(self, start=1, end=2)
end = batches * channels
offset = op.Range(0, end, 1)
data_shape = op.Shape(self, start=2)
data_size = op.ReduceProd(data_shape)
offset = offset * data_size
new_shape = op.Expand(
op.Constant(value_ints=[1]), op.Reshape(expand_size, op.Constant(value_ints=[-1]))
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)
new_shape = op.Concat(batches, channels, new_shape, axis=0)
offset = op.Reshape(offset, new_shape)
indices = indices - offset
if self_rank == 3:
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
return pool_result, indices


def aten_max_pool3d(
self: TensorType,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
padding: Sequence[int] = (0, 0, 0),
dilation: Sequence[int] = (1, 1, 1),
ceil_mode: bool = False,
) -> TensorType:
"""max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor"""

raise NotImplementedError()
return _aten_max_pool_with_indices_onnx(
self,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
3,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)


def aten_max_pool2d_with_indices_backward(
Expand All @@ -872,17 +825,113 @@ def aten_max_pool2d_with_indices_backward(
raise NotImplementedError()


@torch_op("aten::max_pool3d_with_indices", trace_only=True)
def aten_max_pool3d_with_indices(
self: TensorType,
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Optional[Sequence[int]] = None,
padding: Sequence[int] = (0, 0, 0),
dilation: Sequence[int] = (1, 1, 1),
stride: Sequence[int] = (),
padding: Sequence[int] = (0, 0),
dilation: Sequence[int] = (1, 1),
ceil_mode: bool = False,
) -> tuple[TensorType, TensorType]:
) -> Tuple[TFloatOrUInt8, INT64]:
"""max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)"""

raise NotImplementedError()
# Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly.
# But ONNX needs to specify a tuple of three ints for all sides explicitly.
expand_size = 3

kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
expand_size, kernel_size, stride, padding, dilation
)

return _aten_max_pool_with_indices_onnx(
self,
kernel_shape,
strides,
pads,
dilations,
ceil_mode,
4,
([1] * expand_size),
([0] * expand_size),
([2 + i for i in range(expand_size)]),
)


@torch_op("internal::max_pool_with_indices", private=True)
def _aten_max_pool_with_indices_onnx(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
dilation: Sequence[int],
ceil_mode: bool,
unbatched_rank: int,
n_dims_one: Sequence[int],
n_dims_zero: Sequence[int],
n_dims_axes: Sequence[int],
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank = op.Size(op.Shape(self))
if self_rank == unbatched_rank:
self = op.Unsqueeze(self, axes=0)

pool_result, indices = op.MaxPool(
self,
ceil_mode=ceil_mode,
dilations=dilation,
kernel_shape=kernel_size,
pads=padding,
strides=stride,
)

# Simple but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by PyTorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and subtract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor where each dimension has values of indices within
# the dimension it is in.
# For Maxpool1d(kernel=1,stride=1,return_indices=True), with the input torch.ones(1,2,2).
# The computed indices are the following:
# output indices pytorch :
# [[0,1],
# [0,1]]
# output indices onnx:
# [[0,1],
# [2,3]]
# The purpose was to convert the indices from one format to the other to be able to match the results.
# So flattened_indices will have the value of each index and will be equal to :
# [[0,1],
# [2,3]]
# Then call Slice to get the first value of each line (so 0 and 2).
# And the subtraction executes :
# [[0-0,1-0],
# [2-2,3-2]]
# So indices results to the expected output which is :
# [[0,1],
# [0,1]]
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
_, flatten_indices = op.MaxPool(
self, dilations=dilation, kernel_shape=n_dims_one, strides=n_dims_one
)

ends = op.Constant(value_ints=n_dims_one)
starts = op.Constant(value_ints=n_dims_zero)
axes = op.Constant(value_ints=n_dims_axes)

delta = op.Slice(flatten_indices, axes=axes, starts=starts, ends=ends)
indices = op.Sub(indices, delta)

if self_rank == unbatched_rank:
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))

return (pool_result, indices)


def aten_max_pool3d_with_indices_backward(
Expand Down
Loading

0 comments on commit 5eafe2a

Please sign in to comment.