-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TOPI][Hexagon] Implement global_avg_pool2d for hexagon (#13614)
* [TOPI][Hexagon] Implement global_avg_pool2d for hexagon * Fix name * Fix lint issues * Use get_hexagon_target()
- Loading branch information
1 parent
ce97138
commit cdb4eea
Showing
7 changed files
with
341 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
""" | ||
Assumptions: | ||
1) The input is in NCHW layout. Squeezenet is the only model that calls | ||
nn.global_avg_pool2d and the only layout it uses is 'NCHW'. | ||
2) Both input and output dtype is uint8 and | ||
quantization parameter is provided to the op. | ||
3) Input is assumed to always be multiple of fixed chunk 32c8h8w. | ||
""" | ||
|
||
from tvm import te | ||
from tvm import tir | ||
from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate | ||
|
||
|
||
def global_avg_pool2d_u8( | ||
data: te.Tensor, | ||
odtype: str, | ||
input_zero_point: int, | ||
input_scale: float, | ||
output_zero_point: int, | ||
output_scale: float, | ||
): | ||
"""global_avg_pool2d""" | ||
input_b, input_c, input_h, input_w = data.shape | ||
oshape = (input_b, input_c) + (1, 1) | ||
|
||
if input_h * input_w < 256: | ||
bits = "16" | ||
else: | ||
bits = "32" | ||
|
||
if odtype == "uint8": | ||
temp_dtype = "uint" + bits | ||
elif odtype == "int8": | ||
temp_dtype = "int" + bits | ||
else: | ||
raise RuntimeError(f"Unsupported output dtype, {odtype}'") | ||
|
||
pool_area = input_h * input_w | ||
rh_r = te.reduce_axis((0, input_h), name="rh_r") | ||
rw_r = te.reduce_axis((0, input_w), name="rw_r") | ||
|
||
scale_with_area = input_scale / (output_scale * int(pool_area)) | ||
scale_fixed_point, rsh = get_fixed_point_value(scale_with_area, "int16") | ||
corr = (output_zero_point << rsh) - input_zero_point * pool_area * scale_fixed_point | ||
|
||
sum_compute = te.compute( | ||
oshape, | ||
lambda n, c, h, w: te.sum( | ||
data[n, c, h + rh_r, w + rw_r].astype(temp_dtype), axis=[rh_r, rw_r] | ||
), | ||
name="sum", | ||
) | ||
|
||
avg_compute = te.compute( | ||
oshape, | ||
lambda n, c, h, w: saturate( | ||
((sum_compute[n, c, h, w] * scale_fixed_point) + corr) >> rsh, odtype | ||
).astype(odtype), | ||
name="global_avg_pool2d", | ||
) | ||
|
||
return avg_compute | ||
|
||
|
||
def stir_global_avg_pool2d_u8_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str): | ||
"""Schedule""" | ||
func = te.create_prim_func([ins, outs]) | ||
s = tir.Schedule(func) | ||
|
||
sum_block = s.get_block("sum") | ||
|
||
# Input is multiple of fixed chunk but output is NxCx1x1 | ||
# Hence transform_layout is only applied on input | ||
input_transformed_layout = get_layout_transform_fn(input_layout) | ||
s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout) | ||
|
||
return s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
""" | ||
Assumptions: | ||
1) The input is in NCHW layout. Squeezenet is the only model that calls | ||
nn.global_avg_pool2d and the only layout it uses is 'NCHW'. | ||
2) The op takes input data as an argument. | ||
3) Both input and output dtype is float32 and | ||
4) Input is assumed to always be multiple of fixed chunk 32c8h4w. | ||
""" | ||
|
||
from tvm import te | ||
from tvm import tir | ||
from tvm import topi | ||
from ..utils import get_layout_transform_fn | ||
|
||
|
||
def global_avg_pool2d( | ||
data: te.Tensor, | ||
): | ||
"""global_avg_pool2d""" | ||
return topi.nn.global_pool(data, "avg", "NCHW") | ||
|
||
|
||
def stir_global_avg_pool2d_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str): | ||
"""Schedule""" | ||
func = te.create_prim_func([ins, outs]) | ||
s = tir.Schedule(func) | ||
|
||
sum_block = s.get_block("adaptive_pool_sum") | ||
|
||
# Input is multiple of fixed chunk but output is NxCx1x1 | ||
# Hence transform_layout is only applied on input | ||
input_transformed_layout = get_layout_transform_fn(input_layout) | ||
s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout) | ||
|
||
return s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
tests/python/contrib/test_hexagon/topi/slice_op/test_global_avg_pool2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Test code for float16 and uint8 global_avg_pool2d.""" | ||
|
||
import numpy as np | ||
|
||
import tvm | ||
from tvm import te | ||
from tvm.topi.testing import adaptive_pool | ||
import tvm.topi.hexagon.qnn as qn | ||
import tvm.topi.hexagon.slice_ops as sl | ||
from tvm.contrib.hexagon import allocate_hexagon_array | ||
from ...infrastructure import transform_numpy, quantize_np, get_hexagon_target | ||
|
||
|
||
SCALE_M_VAL = None | ||
ZERO_POINT_M_VAL = None | ||
SCALE_VAL = None | ||
ZERO_POINT_VAL = None | ||
|
||
|
||
class TestGlobalPool2D: | ||
(input_shape,) = tvm.testing.parameters( | ||
([1, 32, 8, 8],), | ||
([1, 1056, 16, 16],), | ||
) | ||
|
||
# Fixed chunk layout is set as nchw-32c8h8w-2d for uint8 and nchw-32c8h4w-2d for float16. | ||
# For optimization, it might get changed later. | ||
# Since output shape will be NxCx1x1 which is not a | ||
# multiple of fixed-chunk, output_layout is NCHW. | ||
input_layout, output_layout, pool_type, layout, dtype = tvm.testing.parameters( | ||
("nchw-32c8h8w-2d", "nchw", "avg", "NCHW", "uint8"), | ||
("nchw-32c8h4w-2d", "nchw", "avg", "NCHW", "float16"), | ||
) | ||
|
||
@tvm.testing.fixture | ||
def expected_output_np( | ||
self, | ||
input_np, | ||
pool_type, | ||
layout, | ||
): | ||
"""Generate expected output.""" | ||
ref_np = tvm.topi.testing.adaptive_pool( | ||
input_np, | ||
(1, 1), | ||
pool_type, | ||
layout, | ||
) | ||
return ref_np | ||
|
||
@tvm.testing.fixture | ||
def input_np(self, input_shape, dtype): | ||
if dtype in ("uint8", "int8"): | ||
dtype = "float32" | ||
return np.random.random(input_shape).astype(dtype) | ||
|
||
@tvm.testing.fixture | ||
def quantize_input_np(self, input_np, dtype): | ||
if dtype in ("uint8", "int8"): | ||
global ZERO_POINT_VAL, SCALE_VAL | ||
input_np_quantized, SCALE_VAL, ZERO_POINT_VAL = quantize_np(input_np, dtype) | ||
return input_np_quantized | ||
|
||
@tvm.testing.fixture | ||
def transformed_input_np(self, input_np, quantize_input_np, input_layout, layout, dtype): | ||
if dtype == "float16": | ||
return transform_numpy(input_np, layout.lower(), input_layout) | ||
if dtype in ("uint8", "int8"): | ||
return transform_numpy(quantize_input_np, layout.lower(), input_layout) | ||
|
||
raise RuntimeError(f"Unsupported data type '{dtype}'") | ||
|
||
@tvm.testing.fixture | ||
def quantize_expected_output_np(self, expected_output_np, dtype): | ||
if dtype in ("uint8", "int8"): | ||
global ZERO_POINT_M_VAL, SCALE_M_VAL | ||
out_ref_quantized, SCALE_M_VAL, ZERO_POINT_M_VAL = quantize_np( | ||
expected_output_np, dtype | ||
) | ||
|
||
# Since output_layout is nchw, no transformation is needed. | ||
return out_ref_quantized | ||
|
||
@tvm.testing.requires_hexagon | ||
def test_global_pool2d( | ||
self, | ||
dtype, | ||
input_shape, | ||
input_layout, | ||
transformed_input_np, | ||
expected_output_np, | ||
quantize_expected_output_np, | ||
hexagon_session, | ||
): | ||
a_tensor = te.placeholder(input_shape, name="a_tensor", dtype=dtype) | ||
|
||
if dtype == "float16": | ||
m_tensor = sl.global_avg_pool2d(a_tensor) | ||
tir_schedule = sl.stir_global_avg_pool2d_schedule(m_tensor, a_tensor, input_layout) | ||
elif dtype in ["uint8", "int8"]: | ||
m_tensor = qn.global_avg_pool2d_u8( | ||
a_tensor, | ||
dtype, | ||
ZERO_POINT_VAL, | ||
SCALE_VAL, | ||
ZERO_POINT_M_VAL, | ||
SCALE_M_VAL, | ||
) | ||
tir_schedule = qn.stir_global_avg_pool2d_u8_schedule(m_tensor, a_tensor, input_layout) | ||
|
||
sch = tir_schedule.mod | ||
|
||
with tvm.transform.PassContext(opt_level=3): | ||
func = tvm.build( | ||
sch, | ||
[a_tensor, m_tensor], | ||
get_hexagon_target("v69"), | ||
name="global_pool2d", | ||
) | ||
|
||
input_axis_separator = [4] | ||
|
||
a_data_nd = allocate_hexagon_array( | ||
hexagon_session.device, | ||
data=transformed_input_np, | ||
dtype=dtype, | ||
axis_separators=input_axis_separator, | ||
mem_scope="global.vtcm", | ||
) | ||
|
||
m_data_nd = allocate_hexagon_array( | ||
hexagon_session.device, | ||
expected_output_np.shape, | ||
dtype=dtype, | ||
) | ||
|
||
mod = hexagon_session.load_module(func) | ||
mod(a_data_nd, m_data_nd) | ||
|
||
# Convert nd to np | ||
m_data_np = m_data_nd.numpy() | ||
|
||
if dtype == "float16": | ||
np.testing.assert_allclose(expected_output_np, m_data_np, rtol=1e-3, atol=1e-3) | ||
elif dtype in ["int8", "uint8"]: | ||
np.testing.assert_allclose(quantize_expected_output_np, m_data_np, atol=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |