From 9dcd40d61f90a131d4b0c13b9babc5b2dd2b6d2b Mon Sep 17 00:00:00 2001 From: Ilya Gozman <92577591+ilyag-grovety@users.noreply.github.com> Date: Wed, 5 Apr 2023 12:32:24 +0400 Subject: [PATCH] [microNPU] Add support for ResizeNearestNeighbor with half_pixel_centers=True (#14401) This PR involves supporting the legalization case of RESIZE_NEAREST_NEIGHBOR where coordinate transformation mode is set to half_pixel. --- python/tvm/relay/op/contrib/ethosu.py | 15 ++++++++++++--- .../python/contrib/test_ethosu/test_codegen.py | 18 ++++++++++++++---- .../contrib/test_ethosu/test_legalize.py | 18 ++++++++++++------ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a6d959c98b01..d74140da5db2 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1671,7 +1671,18 @@ def check_compatible_size(mode, method, upscale_size, ifm_size): return False if self.method not in ("nearest_neighbor", "linear"): return False - if self.coordinate_transformation_mode not in ("asymmetric", "align_corners"): + if self.coordinate_transformation_mode not in ( + "asymmetric", + "align_corners", + "half_pixel", + ): + return False + if ( + self.coordinate_transformation_mode == "half_pixel" + and self.rounding_method != "round_prefer_ceil" + or self.coordinate_transformation_mode != "half_pixel" + and self.rounding_method != "" + ): return False if not check_compatible_size( self.coordinate_transformation_mode, @@ -1680,8 +1691,6 @@ def check_compatible_size(mode, method, upscale_size, ifm_size): self.ifm.shape[1:3], ): return False - if self.rounding_method != "": - return False if self.out_dtype and self.out_dtype != "int8": return False return True diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 6eb382d8f588..c68dde1288b8 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1085,17 +1085,27 @@ def squeeze_func(x): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize( - "ifm_shape,size", - [[(1, 2, 2, 1), (4, 4)], [(1, 4, 7, 3), (8, 14)], [(1, 3, 5, 3), (3, 5)]], + "ifm_shape,size,half_pixel", + [ + [(1, 2, 2, 1), (4, 4), False], + [(1, 2, 2, 1), (4, 4), True], + [(1, 4, 7, 3), (8, 14), False], + [(1, 3, 5, 3), (3, 5), False], + [(1, 6, 6, 96), (12, 12), False], + [(1, 6, 6, 96), (12, 12), True], + ], ) -def test_tflite_resize2d_nearest_neighbor(accel_type, ifm_shape, size): +def test_tflite_resize2d_nearest_neighbor(accel_type, ifm_shape, size, half_pixel): np.random.seed(0) align_corners = False @tf.function def resize_model(x): return tf.compat.v1.image.resize_nearest_neighbor( - x, size, align_corners=align_corners, half_pixel_centers=False + x, + size, + align_corners=align_corners, + half_pixel_centers=half_pixel, ) infra.compare_tvm_with_tflite( diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 0bd9c1ac3bf4..594f4a0e2aef 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2342,14 +2342,17 @@ def verify(ext_func): @pytest.mark.parametrize( - "ifm_shape,size", + "ifm_shape,size,half_pixel", [ - [(1, 2, 2, 1), (4, 4)], - [(1, 4, 7, 3), (8, 14)], - [(1, 3, 5, 3), (3, 5)], + [(1, 2, 2, 1), (4, 4), False], + [(1, 2, 2, 1), (4, 4), True], + [(1, 4, 7, 3), (8, 14), False], + [(1, 3, 5, 3), (3, 5), False], + [(1, 6, 6, 96), (12, 12), False], + [(1, 6, 6, 96), (12, 12), True], ], ) -def test_tflite_resize2d_nearest_neighbor(ifm_shape, size): +def test_tflite_resize2d_nearest_neighbor(ifm_shape, size, half_pixel): align_corners = False dtype = "int8" @@ -2357,7 +2360,10 @@ def create_tflite_graph(): @tf.function def resize_model(x): return tf.compat.v1.image.resize_nearest_neighbor( - x, size, align_corners=align_corners, half_pixel_centers=False + x, + size, + align_corners=align_corners, + half_pixel_centers=half_pixel, ) concrete_func = resize_model.get_concrete_function(