From 5eafe2ac111dfdc883b9259c61c8f1471d0ade98 Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Wed, 12 Apr 2023 01:24:24 +0800 Subject: [PATCH] Add ops(max_pool3d and max_pool3d_with_indices) | feat(atenlib) (#618) 1. Add max_pool3d and max_pool3d_with_indices ops. 2. Refactor the code so max_pool2d and max_pool3d could share the same helper function. And move them to nn.py file. 3. Refactor the code so max_pool2d_with_indices and max_pool3d_indices could share the same helper function. And move them to nn.py file. 4. Add tests. --------- Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_aten/ops/nn.py | 317 ++++++++++-------- .../function_libs/torch_aten/extra_opinfo.py | 40 ++- .../torch_aten/ops_correctness_test.py | 33 +- 3 files changed, 243 insertions(+), 147 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 15a8944bf..d5753fe9f 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -668,66 +668,76 @@ def aten_max_pool1d_with_indices( raise NotImplementedError() -@torch_op("aten::max_pool2d", trace_only=True) -def aten_max_pool2d( - self: TFloatOrUInt8, +def _adjust_attributes_of_max_pool( + expand_size: int, kernel_size: Sequence[int], - stride: Sequence[int] = (), - padding: Sequence[int] = (0, 0), - dilation: Sequence[int] = (1, 1), - ceil_mode: bool = False, -) -> TFloatOrUInt8: - """max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor""" - - # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly - # But ONNX needs pair number [x,y] to specify on each side explicitly - # For pool3d, this number should be 3 - expand_size = 2 - - # The dilations should be [x, y] - if isinstance(dilation, int): # x -> [x, x] + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: + if isinstance(dilation, int): dilations = [dilation] * expand_size - else: # already [x, y] + else: dilations = dilation - # The kernel_shape should be [x, y] - if isinstance(kernel_size, int): # x -> [x, x] + if isinstance(kernel_size, int): kernel_shape = [kernel_size] * expand_size - else: # assert(len(kernel_size)==2), already [x, y] + else: kernel_shape = kernel_size - # The pads should be [w, x, y, z] - if isinstance(padding, int): # w -> [w, w, w, w] + if isinstance(padding, int): pads = [padding] * expand_size * 2 - elif len(padding) == 1: # [w] -> [w, w, w, w] - pads = padding * 4 - elif len(padding) == 2: # [w, x] -> [w, x, w, x] - pads = padding * 2 - else: # assert len(padding) == 4, already [w, x, y, z] + elif len(padding) == 1: + pads = padding * expand_size * 2 + elif len(padding) == 2: + pads = padding * expand_size + else: pads = padding - # The strides should be [x, y] - if isinstance(stride, int): # x -> [x, x] + if isinstance(stride, int): strides = [stride] * expand_size elif stride is None: strides = kernel_shape else: strides = stride - return _aten_max_pool2d_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode) + return (kernel_shape, strides, pads, dilations) + + +@torch_op("aten::max_pool2d", trace_only=True) +def aten_max_pool2d( + self: TFloatOrUInt8, + kernel_size: Sequence[int], + stride: Sequence[int] = (), + padding: Sequence[int] = (0, 0), + dilation: Sequence[int] = (1, 1), + ceil_mode: bool = False, +) -> TFloatOrUInt8: + """max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor""" + + # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. + # But ONNX needs to specify a pair of number [x,y] on each side explicitly. + expand_size = 2 + + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3) -@torch_op("aten::max_pool2d", private=True) -def _aten_max_pool2d_onnx( + +@torch_op("internal::max_pool", private=True) +def _aten_max_pool_onnx( self: TFloatOrUInt8, kernel_shape: Sequence[int], strides: Sequence[int], pads: Sequence[int], dilations: Sequence[int], ceil_mode: bool, + unbatched_rank: int, ) -> TFloatOrUInt8: self_rank = op.Size(op.Shape(self)) - if self_rank == 3: # C,H,W -> N,C,H,W and N=1 + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = op.Unsqueeze(self, op.Constant(value_ints=[0])) pool_result, _ = op.MaxPool( @@ -739,122 +749,65 @@ def _aten_max_pool2d_onnx( strides=strides, ) - if self_rank == 3: + if self_rank == unbatched_rank: pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) return pool_result -@torch_op("aten::max_pool2d_with_indices", trace_only=True) -def aten_max_pool2d_with_indices( +@torch_op("aten::max_pool3d", trace_only=True) +def aten_max_pool3d( self: TFloatOrUInt8, kernel_size: Sequence[int], stride: Sequence[int] = (), padding: Sequence[int] = (0, 0), dilation: Sequence[int] = (1, 1), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: - """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" - - # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly - # But ONNX needs pair number [x,y] to specify on each side explicitly - # For pool3d, this number should be 3 - expand_size = 2 - - # The dilations should be [x, y] - if isinstance(dilation, int): # x -> [x, x] - dilations = [dilation] * expand_size - else: # already [x, y] - dilations = dilation - - # The kernel_shape should be [x, y] - if isinstance(kernel_size, int): # x -> [x, x] - kernel_shape = [kernel_size] * expand_size - else: # assert(len(kernel_size)==2), already [x, y] - kernel_shape = kernel_size +) -> TFloatOrUInt8: + """max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor""" - # The pads should be [w, x, y, z] - if isinstance(padding, int): # w -> [w, w, w, w] - pads = [padding] * expand_size * 2 - elif len(padding) == 1: # [w] -> [w, w, w, w] - pads = padding * 4 - elif len(padding) == 2: # [w, x] -> [w, x, w, x] - pads = padding * 2 - else: # assert len(padding) == 4, already [w, x, y, z] - pads = padding + # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. + # But ONNX needs to specify a tuple of three ints for all sides explicitly. + expand_size = 3 - # The strides should be [x, y] - if isinstance(stride, int): # x -> [x, x] - strides = [stride] * expand_size - elif stride is None: - strides = kernel_shape - else: - strides = stride - - return _aten_max_pool2d_with_indices_onnx( - self, expand_size, kernel_shape, strides, pads, dilations, ceil_mode + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation ) + return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 4) + -@torch_op("aten::max_pool2d_with_indices", private=True) -def _aten_max_pool2d_with_indices_onnx( +@torch_op("aten::max_pool2d_with_indices", trace_only=True) +def aten_max_pool2d_with_indices( self: TFloatOrUInt8, - expand_size: INT64, - kernel_shape: Sequence[int], - strides: Sequence[int], - pads: Sequence[int], - dilations: Sequence[int], - ceil_mode: bool, + kernel_size: Sequence[int], + stride: Sequence[int] = (), + padding: Sequence[int] = (0, 0), + dilation: Sequence[int] = (1, 1), + ceil_mode: bool = False, ) -> Tuple[TFloatOrUInt8, INT64]: - self_rank = op.Size(op.Shape(self)) - if self_rank == 3: # C,H,W -> N,C,H,W and N=1 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - - pool_result, indices = op.MaxPool( - self, - ceil_mode=ceil_mode, - dilations=dilations, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) + """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" - if self_rank == 3: - pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) + # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. + # But ONNX needs to specify a pair of number [x,y] on each side explicitly. + expand_size = 2 - # Torch use relative position number for the second Channel data - # If align, need reduce size(Channel) - # e.g. [[8,3,10],[30,32,23]]-[0,18] -> [[8,3,10],[12,14,5]] - # 18 = H x W = 3 x 6 - batches = op.Shape(self, start=0, end=1) - channels = op.Shape(self, start=1, end=2) - end = batches * channels - offset = op.Range(0, end, 1) - data_shape = op.Shape(self, start=2) - data_size = op.ReduceProd(data_shape) - offset = offset * data_size - new_shape = op.Expand( - op.Constant(value_ints=[1]), op.Reshape(expand_size, op.Constant(value_ints=[-1])) + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation ) - new_shape = op.Concat(batches, channels, new_shape, axis=0) - offset = op.Reshape(offset, new_shape) - indices = indices - offset - if self_rank == 3: - indices = op.Squeeze(indices, op.Constant(value_ints=[0])) - return pool_result, indices - -def aten_max_pool3d( - self: TensorType, - kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, - padding: Sequence[int] = (0, 0, 0), - dilation: Sequence[int] = (1, 1, 1), - ceil_mode: bool = False, -) -> TensorType: - """max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor""" - - raise NotImplementedError() + return _aten_max_pool_with_indices_onnx( + self, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + 3, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) def aten_max_pool2d_with_indices_backward( @@ -872,17 +825,113 @@ def aten_max_pool2d_with_indices_backward( raise NotImplementedError() +@torch_op("aten::max_pool3d_with_indices", trace_only=True) def aten_max_pool3d_with_indices( - self: TensorType, + self: TFloatOrUInt8, kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, - padding: Sequence[int] = (0, 0, 0), - dilation: Sequence[int] = (1, 1, 1), + stride: Sequence[int] = (), + padding: Sequence[int] = (0, 0), + dilation: Sequence[int] = (1, 1), ceil_mode: bool = False, -) -> tuple[TensorType, TensorType]: +) -> Tuple[TFloatOrUInt8, INT64]: """max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" - raise NotImplementedError() + # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. + # But ONNX needs to specify a tuple of three ints for all sides explicitly. + expand_size = 3 + + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + + return _aten_max_pool_with_indices_onnx( + self, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + 4, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) + + +@torch_op("internal::max_pool_with_indices", private=True) +def _aten_max_pool_with_indices_onnx( + self: TFloatOrUInt8, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, + n_dims_one: Sequence[int], + n_dims_zero: Sequence[int], + n_dims_axes: Sequence[int], +) -> Tuple[TFloatOrUInt8, INT64]: + self_rank = op.Size(op.Shape(self)) + if self_rank == unbatched_rank: + self = op.Unsqueeze(self, axes=0) + + pool_result, indices = op.MaxPool( + self, + ceil_mode=ceil_mode, + dilations=dilation, + kernel_shape=kernel_size, + pads=padding, + strides=stride, + ) + + # Simple but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by PyTorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and subtract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor where each dimension has values of indices within + # the dimension it is in. + # For Maxpool1d(kernel=1,stride=1,return_indices=True), with the input torch.ones(1,2,2). + # The computed indices are the following: + # output indices pytorch : + # [[0,1], + # [0,1]] + # output indices onnx: + # [[0,1], + # [2,3]] + # The purpose was to convert the indices from one format to the other to be able to match the results. + # So flattened_indices will have the value of each index and will be equal to : + # [[0,1], + # [2,3]] + # Then call Slice to get the first value of each line (so 0 and 2). + # And the subtraction executes : + # [[0-0,1-0], + # [2-2,3-2]] + # So indices results to the expected output which is : + # [[0,1], + # [0,1]] + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + _, flatten_indices = op.MaxPool( + self, dilations=dilation, kernel_shape=n_dims_one, strides=n_dims_one + ) + + ends = op.Constant(value_ints=n_dims_one) + starts = op.Constant(value_ints=n_dims_zero) + axes = op.Constant(value_ints=n_dims_axes) + + delta = op.Slice(flatten_indices, axes=axes, starts=starts, ends=ends) + indices = op.Sub(indices, delta) + + if self_rank == unbatched_rank: + pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) + indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + + return (pool_result, indices) def aten_max_pool3d_with_indices_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py index a3bff1246..5aec4b48d 100644 --- a/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_aten/extra_opinfo.py @@ -163,9 +163,9 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): ) -def sample_inputs_layer_norm( - op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument -): +def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): + del op_info # unused + del kwargs make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) @@ -201,9 +201,8 @@ def sample_inputs_layer_norm( ) -def sample_inputs_max_pool2d_with_indices( - op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument -): +def sample_inputs_max_pool2d_with_indices(op_info, device, dtype, requires_grad, **kwargs): + del op_info make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False ) @@ -215,9 +214,21 @@ def sample_inputs_max_pool2d_with_indices( yield opinfo_core.SampleInput(arg, kwargs=kwargs) -def sample_inputs_col2im( - op_info, device, dtype, requires_grad, **kwargs # pylint: disable=unused-argument -): +def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + ) + params_generator = ( + common_methods_invocations._TestParamsMaxPool3d() # pylint: disable=protected-access + ) + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield opinfo_core.SampleInput(arg, kwargs=kwargs) + + +def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): + del op_info # input_shape, output_size, kernal, dilation, padding, stride cases = ( ( @@ -319,11 +330,20 @@ def sample_inputs_col2im( ), opinfo_core.OpInfo( "nn.functional.max_pool2d_with_indices", - aten_name="max_pool2d", + aten_name="max_pool2d_with_indices", supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=common_dtype.floating_types_and(torch.bfloat16), skips=(), sample_inputs_func=sample_inputs_max_pool2d_with_indices, ), + opinfo_core.OpInfo( + "nn.functional.max_pool3d_with_indices", + aten_name="max_pool3d_with_indices", + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=common_dtype.floating_types_and(torch.bfloat16), + skips=(), + sample_inputs_func=sample_inputs_max_pool3d_with_indices, + ), ] diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 1ba55fc53..89bf3eef9 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -346,7 +346,7 @@ def _gather_input_wrangler( return args, kwargs -def _max_pool2d_input_wrangler( +def _max_pool_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: # Remove return_indices argument because this op doesn't accept it @@ -693,10 +693,15 @@ def _where_input_wrangler( "nn.functional.conv3d": core_ops.aten_conv3d, "nn.functional.gelu": nn_ops.aten_gelu, "nn.functional.linear": nn_ops.aten_linear, - "nn.functional.max_pool2d": (nn_ops.aten_max_pool2d, _max_pool2d_input_wrangler), + "nn.functional.max_pool2d": (nn_ops.aten_max_pool2d, _max_pool_input_wrangler), "nn.functional.max_pool2d_with_indices": ( nn_ops.aten_max_pool2d_with_indices, - _max_pool2d_input_wrangler, + _max_pool_input_wrangler, + ), + "nn.functional.max_pool3d": (nn_ops.aten_max_pool3d, _max_pool_input_wrangler), + "nn.functional.max_pool3d_with_indices": ( + nn_ops.aten_max_pool3d_with_indices, + _max_pool_input_wrangler, ), "nn.functional.scaled_dot_product_attention": nn_ops.aten_scaled_dot_product_attention, "nn.functional.scaled_dot_product_attention_bool_mask": nn_ops.aten_scaled_dot_product_attention_bool_mask, @@ -1039,6 +1044,28 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("return_indices") is True, reason="this aten overload assume return_indices=False", ), + skip( + "nn.functional.max_pool3d", + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True + and sample.kwargs.get("padding") == 1, + reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", + ), + skip( + "nn.functional.max_pool3d", + matcher=lambda sample: sample.kwargs.get("return_indices") is True, + reason="this aten overload assume return_indices=False", + ), + skip( + "nn.functional.max_pool3d_with_indices", + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True + and sample.kwargs.get("padding") == 1, + reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", + ), + skip( + "nn.functional.max_pool3d_with_indices", + matcher=lambda sample: sample.kwargs.get("return_indices") is False, + reason="this aten overload assume return_indices=True", + ), skip( "nn.functional.nll_loss", matcher=lambda sample: "weight" in sample.kwargs,