Skip to content

Commit

Permalink
[Hexagon] Add schedule and test for conv2d_transpose_nchw (#11175)
Browse files Browse the repository at this point in the history
* Add test for registered scheduales - depthwise_conv2d

* added more test to depthwise_conv2

* adding new line at the end of the file

* reformatted the file

* resolve comments

* add schedule and tests for conv2d_transpose_nchw

* registering conv2d_transpose strategy and clean up test
  • Loading branch information
farshidsp authored May 3, 2022
1 parent 5c204c6 commit eb3ce91
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/tvm/relay/op/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ def softmax_strategy_hexagon(attrs, inputs, out_type, target):
return strategy


@conv2d_transpose_strategy.register("hexagon")
def conv2d_transpose_strategy_hexagon(attrs, inputs, out_type, target):
"""conv2d_transpose hexagon strategy"""
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
strategy = _op.OpStrategy()
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
else:
raise RuntimeError("Unsupported conv2d_transpose layout {}".format(layout))
return strategy


# --- Op schedule registration


Expand Down
26 changes: 26 additions & 0 deletions python/tvm/topi/hexagon/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Schedule for conv2d"""

import tvm
from ..utils import traverse_inline


def schedule_conv2d_nhwc(outs):
Expand Down Expand Up @@ -60,3 +61,28 @@ def schedule_depthwise_conv2d_nchw(outs):

def schedule_depthwise_conv2d_nhwc(out):
return schedule_conv2d_nhwc(out)


def schedule_conv2d_transpose_nchw(outs):
"""Create schedule for tensors"""
outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs
s = schedule_conv2d_nchw(outs)

def _callback(op):
if "unpack_nchwc" in op.tag:
conv_out = op.input_tensors[0]
# retrieve data
data_vec = conv_out.op.input_tensors[0]
if isinstance(data_vec, tvm.te.ComputeOp):
data_pad = data_vec.op.input_tensors[0]
data_dilate = data_pad.op.input_tensors[0]
s[data_dilate].compute_inline()
s[data_pad].compute_inline()
# retrieve kernel
kernel_vec = conv_out.op.input_tensors[1]
if isinstance(kernel_vec, tvm.te.ComputeOp):
kernel_transform = kernel_vec.op.input_tensors[0]
s[kernel_transform].compute_inline()

traverse_inline(s, outs[0].op, _callback)
return s
157 changes: 157 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_conv2d_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for transposed convolution."""
import numpy as np
import tvm
import tvm.testing
from tvm import te
from tvm import topi
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
from tvm.topi.utils import get_const_tuple
from ..conftest import requires_hexagon_toolchain


# TODO Should add kernal to tvm.testing.fixture

random_seed = tvm.testing.parameter(0)


@tvm.testing.fixture
def shift_shape(batch):
return batch


@tvm.testing.fixture
def shift_shape(in_channel):
return in_channel


@tvm.testing.fixture
def shift_shape(in_size):
return in_size


@tvm.testing.fixture
def shift_shape(num_filter):
return num_filter


@tvm.testing.fixture
def shift_shape(stride):
return stride


@tvm.testing.fixture
def shift_shape(padding):
return padding


@tvm.testing.fixture
def shift_shape(output_padding):
return output_padding


class BaseConv2DTransposeTests:
@requires_hexagon_toolchain
def test_conv2d(
self,
hexagon_session,
batch,
in_channel,
in_size,
num_filter,
stride,
padding,
output_padding,
random_seed,
):

target_hexagon = tvm.target.hexagon("v68")

in_height, in_width = in_size
kernel_height, kernel_width = (1, 1)
stride_height, stride_width = stride
pad_top, pad_left, pad_bottom, pad_right = padding

A = te.placeholder((batch, in_channel, in_height, in_width), name="A")
W = te.placeholder((in_channel, num_filter, kernel_height, kernel_width), name="W")

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

def get_ref_data():

np.random.seed(random_seed)
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = tvm.topi.testing.conv2d_transpose_nchw_python(
a_np, w_np, stride, padding, output_padding
)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np

a_np, w_np, b_np, c_np = get_ref_data()

fcompute_args = (
A,
W,
[stride_height, stride_width],
[pad_top, pad_left, pad_bottom, pad_right],
A.dtype,
output_padding,
)

with tvm.target.Target(target_hexagon):
fcompute = topi.nn.conv2d_transpose_nchw
fschedule = topi.hexagon.schedule_conv2d_transpose_nchw
B = fcompute(*fcompute_args)
C = topi.nn.relu(B)
s1 = fschedule([B])
s2 = fschedule([C])

dev = hexagon_session.device

a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

func1 = tvm.build(s1, [A, W, B], tvm.target.Target(target_hexagon, host=target_hexagon))
func2 = tvm.build(s2, [A, W, C], tvm.target.Target(target_hexagon, host=target_hexagon))

mod1 = hexagon_session.load_module(func1)
mod2 = hexagon_session.load_module(func2)

mod1(a, w, b)
mod2(a, w, c)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)


class TestConv2DTranspose(BaseConv2DTransposeTests):

(batch, in_channel, in_size, num_filter, stride) = tvm.testing.parameters(
(1, 3, (224, 224), 1, (1, 1)),
(1, 8, (224, 224), 1, (1, 1)),
(1, 512, (8, 1), 128, (31, 1)),
(1, 32, (8192, 1), 1, (1, 1)),
)

padding = tvm.testing.parameter((0, 0, 0, 0))
output_padding = tvm.testing.parameter((0, 0))

0 comments on commit eb3ce91

Please sign in to comment.