Skip to content

Commit

Permalink
[TOPI][Hexagon] Implement global_avg_pool2d for hexagon (#13614)
Browse files Browse the repository at this point in the history
* [TOPI][Hexagon] Implement global_avg_pool2d for hexagon

* Fix name

* Fix lint issues

* Use get_hexagon_target()
  • Loading branch information
trahman-quic authored Dec 15, 2022
1 parent ce97138 commit cdb4eea
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .nn import *
from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule
from .adaptive_avg_pool1d import *
from .global_avg_pool2d import *
95 changes: 95 additions & 0 deletions python/tvm/topi/hexagon/qnn/global_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Assumptions:
1) The input is in NCHW layout. Squeezenet is the only model that calls
nn.global_avg_pool2d and the only layout it uses is 'NCHW'.
2) Both input and output dtype is uint8 and
quantization parameter is provided to the op.
3) Input is assumed to always be multiple of fixed chunk 32c8h8w.
"""

from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate


def global_avg_pool2d_u8(
data: te.Tensor,
odtype: str,
input_zero_point: int,
input_scale: float,
output_zero_point: int,
output_scale: float,
):
"""global_avg_pool2d"""
input_b, input_c, input_h, input_w = data.shape
oshape = (input_b, input_c) + (1, 1)

if input_h * input_w < 256:
bits = "16"
else:
bits = "32"

if odtype == "uint8":
temp_dtype = "uint" + bits
elif odtype == "int8":
temp_dtype = "int" + bits
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")

pool_area = input_h * input_w
rh_r = te.reduce_axis((0, input_h), name="rh_r")
rw_r = te.reduce_axis((0, input_w), name="rw_r")

scale_with_area = input_scale / (output_scale * int(pool_area))
scale_fixed_point, rsh = get_fixed_point_value(scale_with_area, "int16")
corr = (output_zero_point << rsh) - input_zero_point * pool_area * scale_fixed_point

sum_compute = te.compute(
oshape,
lambda n, c, h, w: te.sum(
data[n, c, h + rh_r, w + rw_r].astype(temp_dtype), axis=[rh_r, rw_r]
),
name="sum",
)

avg_compute = te.compute(
oshape,
lambda n, c, h, w: saturate(
((sum_compute[n, c, h, w] * scale_fixed_point) + corr) >> rsh, odtype
).astype(odtype),
name="global_avg_pool2d",
)

return avg_compute


def stir_global_avg_pool2d_u8_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str):
"""Schedule"""
func = te.create_prim_func([ins, outs])
s = tir.Schedule(func)

sum_block = s.get_block("sum")

# Input is multiple of fixed chunk but output is NxCx1x1
# Hence transform_layout is only applied on input
input_transformed_layout = get_layout_transform_fn(input_layout)
s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout)

return s
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
from .tanh import tanh_te_compute, tanhf16_schedule
from .dwconv2d import *
from .depth_to_space import d2s_compute, d2s_schedule
from .global_avg_pool2d import *
52 changes: 52 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/global_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Assumptions:
1) The input is in NCHW layout. Squeezenet is the only model that calls
nn.global_avg_pool2d and the only layout it uses is 'NCHW'.
2) The op takes input data as an argument.
3) Both input and output dtype is float32 and
4) Input is assumed to always be multiple of fixed chunk 32c8h4w.
"""

from tvm import te
from tvm import tir
from tvm import topi
from ..utils import get_layout_transform_fn


def global_avg_pool2d(
data: te.Tensor,
):
"""global_avg_pool2d"""
return topi.nn.global_pool(data, "avg", "NCHW")


def stir_global_avg_pool2d_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str):
"""Schedule"""
func = te.create_prim_func([ins, outs])
s = tir.Schedule(func)

sum_block = s.get_block("adaptive_pool_sum")

# Input is multiple of fixed chunk but output is NxCx1x1
# Hence transform_layout is only applied on input
input_transformed_layout = get_layout_transform_fn(input_layout)
s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout)

return s
12 changes: 12 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def ncw_32c64w_2d(n, c, w):
return [n, c // 32, w // 64, te.AXIS_SEPARATOR, c % 32, w % 64]


def nchw_32c8h8w_2d(n, c, h, w):
return [n, c // 32, h // 8, w // 8, te.AXIS_SEPARATOR, c % 32, h % 8, w % 8]


def nchw_32c8h4w_2d(n, c, h, w):
return [n, c // 32, h // 8, w // 4, te.AXIS_SEPARATOR, c % 32, h % 8, w % 4]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
Expand Down Expand Up @@ -180,6 +188,10 @@ def get_layout_transform_fn(layout):
return ohwi32o_1d
if layout == "ncw-32c64w-2d":
return ncw_32c64w_2d
if layout == "nchw-32c8h8w-2d":
return nchw_32c8h8w_2d
if layout == "nchw-32c8h4w-2d":
return nchw_32c8h4w_2d
raise RuntimeError(f"Unexpected layout '{layout}'")


Expand Down
13 changes: 13 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,19 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "nchw":
if new_layout in ["nchw-32c8h8w-2d", "nchw-32c8h8w-1d"]:
n, c, h, w = arr_np.shape
return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 8, 8]).transpose(
0, 1, 3, 5, 2, 4, 6
)
if new_layout in ["nchw-32c8h4w-2d", "nchw-32c8h4w-1d"]:
n, c, h, w = arr_np.shape
return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 4, 4]).transpose(
0, 1, 3, 5, 2, 4, 6
)
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

raise RuntimeError(f"Unexpected current_layout '{current_layout}'")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Test code for float16 and uint8 global_avg_pool2d."""

import numpy as np

import tvm
from tvm import te
from tvm.topi.testing import adaptive_pool
import tvm.topi.hexagon.qnn as qn
import tvm.topi.hexagon.slice_ops as sl
from tvm.contrib.hexagon import allocate_hexagon_array
from ...infrastructure import transform_numpy, quantize_np, get_hexagon_target


SCALE_M_VAL = None
ZERO_POINT_M_VAL = None
SCALE_VAL = None
ZERO_POINT_VAL = None


class TestGlobalPool2D:
(input_shape,) = tvm.testing.parameters(
([1, 32, 8, 8],),
([1, 1056, 16, 16],),
)

# Fixed chunk layout is set as nchw-32c8h8w-2d for uint8 and nchw-32c8h4w-2d for float16.
# For optimization, it might get changed later.
# Since output shape will be NxCx1x1 which is not a
# multiple of fixed-chunk, output_layout is NCHW.
input_layout, output_layout, pool_type, layout, dtype = tvm.testing.parameters(
("nchw-32c8h8w-2d", "nchw", "avg", "NCHW", "uint8"),
("nchw-32c8h4w-2d", "nchw", "avg", "NCHW", "float16"),
)

@tvm.testing.fixture
def expected_output_np(
self,
input_np,
pool_type,
layout,
):
"""Generate expected output."""
ref_np = tvm.topi.testing.adaptive_pool(
input_np,
(1, 1),
pool_type,
layout,
)
return ref_np

@tvm.testing.fixture
def input_np(self, input_shape, dtype):
if dtype in ("uint8", "int8"):
dtype = "float32"
return np.random.random(input_shape).astype(dtype)

@tvm.testing.fixture
def quantize_input_np(self, input_np, dtype):
if dtype in ("uint8", "int8"):
global ZERO_POINT_VAL, SCALE_VAL
input_np_quantized, SCALE_VAL, ZERO_POINT_VAL = quantize_np(input_np, dtype)
return input_np_quantized

@tvm.testing.fixture
def transformed_input_np(self, input_np, quantize_input_np, input_layout, layout, dtype):
if dtype == "float16":
return transform_numpy(input_np, layout.lower(), input_layout)
if dtype in ("uint8", "int8"):
return transform_numpy(quantize_input_np, layout.lower(), input_layout)

raise RuntimeError(f"Unsupported data type '{dtype}'")

@tvm.testing.fixture
def quantize_expected_output_np(self, expected_output_np, dtype):
if dtype in ("uint8", "int8"):
global ZERO_POINT_M_VAL, SCALE_M_VAL
out_ref_quantized, SCALE_M_VAL, ZERO_POINT_M_VAL = quantize_np(
expected_output_np, dtype
)

# Since output_layout is nchw, no transformation is needed.
return out_ref_quantized

@tvm.testing.requires_hexagon
def test_global_pool2d(
self,
dtype,
input_shape,
input_layout,
transformed_input_np,
expected_output_np,
quantize_expected_output_np,
hexagon_session,
):
a_tensor = te.placeholder(input_shape, name="a_tensor", dtype=dtype)

if dtype == "float16":
m_tensor = sl.global_avg_pool2d(a_tensor)
tir_schedule = sl.stir_global_avg_pool2d_schedule(m_tensor, a_tensor, input_layout)
elif dtype in ["uint8", "int8"]:
m_tensor = qn.global_avg_pool2d_u8(
a_tensor,
dtype,
ZERO_POINT_VAL,
SCALE_VAL,
ZERO_POINT_M_VAL,
SCALE_M_VAL,
)
tir_schedule = qn.stir_global_avg_pool2d_u8_schedule(m_tensor, a_tensor, input_layout)

sch = tir_schedule.mod

with tvm.transform.PassContext(opt_level=3):
func = tvm.build(
sch,
[a_tensor, m_tensor],
get_hexagon_target("v69"),
name="global_pool2d",
)

input_axis_separator = [4]

a_data_nd = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
dtype=dtype,
axis_separators=input_axis_separator,
mem_scope="global.vtcm",
)

m_data_nd = allocate_hexagon_array(
hexagon_session.device,
expected_output_np.shape,
dtype=dtype,
)

mod = hexagon_session.load_module(func)
mod(a_data_nd, m_data_nd)

# Convert nd to np
m_data_np = m_data_nd.numpy()

if dtype == "float16":
np.testing.assert_allclose(expected_output_np, m_data_np, rtol=1e-3, atol=1e-3)
elif dtype in ["int8", "uint8"]:
np.testing.assert_allclose(quantize_expected_output_np, m_data_np, atol=1)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit cdb4eea

Please sign in to comment.