diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 63cae4c7da8d..f8bb16f08c21 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -23,6 +23,8 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from tvm.autotvm.task.topi_integration import deserialize_args +from tvm.autotvm.task import get_config from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout from ..nn.util import get_pad_tuple from ..nn.depthwise_conv2d import depthwise_conv2d_nchw @@ -153,6 +155,38 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None s[tensor].bind(xi, thread_x) return xi, thread_z, thread_y, thread_x +# Define template function for autotvm task +# We define schedule template in this function instead of +# declaration function since actual input arguments need +# to be altered by the schedule selected. +@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc") +def __topi_nn_conv2d_NCHWc(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args) + raw_data_shape = get_const_tuple(data.shape) + raw_kernel_shape = get_const_tuple(kernel.shape) + + # get config here + cfg = get_config() + _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout) + cfg.add_flop(1) + + # change shape with the value in config + ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1] + oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1] + + new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, + raw_data_shape[2], raw_data_shape[3], ic_bn) + new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, + raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) + new_data = tvm.placeholder(new_data_shape, data.dtype) + new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) + + C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype) + s = _schedule_conv2d_NCHWc(cfg, [C]) + + return s, [new_data, new_kernel, C] + @conv2d_alter_layout.register(["intel_graphics"]) def _alter_conv2d_layout(attrs, inputs, tinfo, F): import nnvm.symbol as sym