Skip to content

Commit

Permalink
[TOPI] Add valid auto tvm for Intel Graphics (#4078)
Browse files Browse the repository at this point in the history
* add valid autotune

* fix pylint
  • Loading branch information
Laurawly authored and kevinthesun committed Oct 9, 2019
1 parent f2abd9f commit 4d875d1
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions topi/python/topi/intel_graphics/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from tvm.autotvm.task.topi_integration import deserialize_args
from tvm.autotvm.task import get_config
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_nchw
Expand Down Expand Up @@ -153,6 +155,38 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
s[tensor].bind(xi, thread_x)
return xi, thread_z, thread_y, thread_x

# Define template function for autotvm task
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc")
def __topi_nn_conv2d_NCHWc(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args)
raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape)

# get config here
cfg = get_config()
_create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout)
cfg.add_flop(1)

# change shape with the value in config
ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]

new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)

C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype)
s = _schedule_conv2d_NCHWc(cfg, [C])

return s, [new_data, new_kernel, C]

@conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym
Expand Down

0 comments on commit 4d875d1

Please sign in to comment.