From 5fefedcbb62eb2d359e146789f1db40f59c36b68 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 22 Jun 2018 11:46:47 -0700 Subject: [PATCH] [TOPI] Conv2d Added and Optimized for Intel HD Graphics (#1290) --- include/tvm/build_module.h | 6 +- python/tvm/target.py | 8 +- src/codegen/build_module.cc | 6 +- src/codegen/codegen_opencl.cc | 2 +- src/runtime/opencl/opencl_device_api.cc | 2 +- tests/python/integration/test_ewise.py | 2 +- tests/python/unittest/test_lang_target.py | 2 +- topi/python/topi/__init__.py | 1 + topi/python/topi/cuda/pooling.py | 5 +- topi/python/topi/intel_graphics/__init__.py | 5 + topi/python/topi/intel_graphics/conv2d.py | 559 ++++++++++++++++++++ 11 files changed, 581 insertions(+), 17 deletions(-) create mode 100644 topi/python/topi/intel_graphics/__init__.py create mode 100644 topi/python/topi/intel_graphics/conv2d.py diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index d81462982965..96b876fe92f0 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -157,9 +157,9 @@ EXPORT Target rasp(const std::vector& options = EXPORT Target mali(const std::vector& options = std::vector()); -/*! \return A target for Intel GPU */ -EXPORT Target intel_gpu(const std::vector& options = - std::vector()); +/*! \return A target for Intel Graphics */ +EXPORT Target intel_graphics(const std::vector& options = + std::vector()); /*! \return A target for stackvm */ EXPORT Target stackvm(const std::vector& options = diff --git a/python/tvm/target.py b/python/tvm/target.py index d62a7be7aac8..7706fb312f9d 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -76,7 +76,7 @@ class Target(NodeBase): - :any:`tvm.target.cuda` create CUDA target - :any:`tvm.target.rocm` create ROCM target - :any:`tvm.target.mali` create Mali target - - :any:`tvm.target.intel_gpu` create Intel GPU target + - :any:`tvm.target.intel_graphics` create Intel Graphics target """ def __init__(self, handle): super(Target, self).__init__(handle) @@ -402,15 +402,15 @@ def mali(options=None): return _api_internal._TargetCreate("opencl", *opts) -def intel_gpu(options=None): - """Returns an Intel GPU target. +def intel_graphics(options=None): + """Returns an Intel Graphics target. Parameters ---------- options : str or list of str Additional options """ - opts = ["-device=intel_gpu"] + opts = ["-device=intel_graphics"] opts = _merge_opts(opts, options) return _api_internal._TargetCreate("opencl", *opts) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 318e231fc67c..ea94f05d4779 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -76,7 +76,7 @@ Target CreateTarget(const std::string& target_name, t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make("gpu")); t->max_num_threads = 256; - if (t->device_name == "intel_gpu") { + if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; } } else if (target_name == "metal" || target_name == "vulkan") { @@ -274,9 +274,9 @@ Target mali(const std::vector& options) { })); } -Target intel_gpu(const std::vector& options) { +Target intel_graphics(const std::vector& options) { return CreateTarget("opencl", MergeOptions(options, { - "-device=intel_gpu" + "-device=intel_graphics" })); } diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 64e10ffbe290..20053c829be6 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -159,7 +159,7 @@ void CodeGenOpenCL::PrintStorageSync(const Call* op) { const std::string& sync = op->args[0].as()->value; if (sync == "warp") { this->PrintIndent(); - this->stream << "sub_group_barrier(CLK_LOCAL_MEM_FENCE);\n"; + this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; } else if (sync == "shared") { this->PrintIndent(); this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 759f074b2fc1..b30b058d3acc 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -40,7 +40,7 @@ void OpenCLWorkspace::GetAttr( } case kWarpSize: { /* TODO: the warp size of OpenCL device is not always 1 - e.g. Intel GPU has a sub group concept which contains 8 - 32 work items, + e.g. Intel Graphics has a sub group concept which contains 8 - 32 work items, corresponding to the number of SIMD entries the heardware configures. We need to figure out a way to query this information from the hardware. */ diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index e9c69bb57d9f..64d164c5e644 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -34,7 +34,7 @@ def check_device(device, host="stackvm"): np.testing.assert_allclose( b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5) - check_device("opencl -device=intel_gpu") + check_device("opencl -device=intel_graphics") check_device("cuda", "llvm") check_device("vulkan") diff --git a/tests/python/unittest/test_lang_target.py b/tests/python/unittest/test_lang_target.py index 71b4f79a58a2..5f77010fa751 100644 --- a/tests/python/unittest/test_lang_target.py +++ b/tests/python/unittest/test_lang_target.py @@ -47,7 +47,7 @@ def test_target_string_parse(): assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn")) - assert tvm.target.intel_gpu().device_name == "intel_gpu" + assert tvm.target.intel_graphics().device_name == "intel_graphics" if __name__ == "__main__": test_target_dispatch() diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 59549f89905b..d58b37b04518 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -26,6 +26,7 @@ from . import cuda from . import rasp from . import mali +from . import intel_graphics from . import opengl from . import util from . import rocm diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index 017408f14b89..ed4419491428 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -33,9 +33,8 @@ def _schedule(Pool): else: Out = outs[0].op.output(0) s[Pool].set_scope("local") - i, c, h, w = s[Out].op.axis - by, ty = s[Out].split(i, factor=num_thread) - bx, tx = s[Out].split(c, factor=num_thread) + by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) + bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) s[Out].reorder(by, bx, ty, tx) s[Out].bind(ty, thread_y) s[Out].bind(tx, thread_x) diff --git a/topi/python/topi/intel_graphics/__init__.py b/topi/python/topi/intel_graphics/__init__.py new file mode 100644 index 000000000000..336b1508f977 --- /dev/null +++ b/topi/python/topi/intel_graphics/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=redefined-builtin, wildcard-import +"""Intel Gen9 GPU specific declaration and schedules.""" +from __future__ import absolute_import as _abs + +from .conv2d import * diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py new file mode 100644 index 000000000000..fa16e0585d51 --- /dev/null +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -0,0 +1,559 @@ +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return +"""conv2d schedule on Intel Graphics""" + +from __future__ import absolute_import as _abs + +import tvm + +from .. import generic +from .. import util +from .. import tag +from ..nn import pad +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload +from ..nn.util import get_pad_tuple +from ..util import simplify + + + +##### SCHEDULE UTILITIES ##### +def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): + """ tile and bind 3d """ + y_factor = y_factor or z_factor + x_factor = x_factor or y_factor + zo, zi = s[tensor].split(z, z_factor) + yo, yi = s[tensor].split(y, y_factor) + xo, xi = s[tensor].split(x, x_factor) + s[tensor].reorder(zo, yo, xo, zi, yi, xi) + + thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") + thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") + thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") + s[tensor].bind(zo, tvm.thread_axis("blockIdx.z")) + s[tensor].bind(zi, thread_z) + s[tensor].bind(yo, tvm.thread_axis("blockIdx.y")) + s[tensor].bind(yi, thread_y) + s[tensor].bind(xo, tvm.thread_axis("blockIdx.x")) + s[tensor].bind(xi, thread_x) + return xi, thread_z, thread_y, thread_x + +@conv2d_alter_layout.register(["intel_graphics"]) +def _alter_conv2d_layout(attrs, inputs, tinfos): + import nnvm.symbol as sym + copy_inputs = [s for s in inputs] + + data = tinfos[0] + kernel = tinfos[1] + + import ast + padding = ast.literal_eval(attrs['padding']) + stride = ast.literal_eval(attrs['strides']) + + wkl = _get_workload(data, kernel, stride, padding, data.dtype) + oc_bn = 16 + + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['kernel_layout'] = 'OIHW%do' % (oc_bn) + + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + +@conv2d_NCHWc.register(["intel_graphics"]) +def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'): + """Conv2D operator for Intel Graphics backend. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel : tvm.Tensor + 5-D with shape [num_filter, in_channel, filter_height, filter_width, nnum_filter_vec] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + layout : str + layout of data + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu" + assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." + + out_dtype = data.dtype + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + kernel_shape = util.get_const_tuple(kernel.shape) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype) + +@generic.schedule_conv2d_NCHWc.register(["intel_graphics"]) +def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs): + """Schedule for conv2d_nchw for Intel Graphics + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d_nchw. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + """inline all one-to-one-mapping operators except the last stage (output)""" + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \ + or "1_16" in op.tag: + _schedule_cl_spatialpack_NCHWc(s, op) + + traverse(outs[0].op) + + return s + +def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16'): + batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape] + num_filter, channel, kernel_h, kernel_w, nv = [util.get_const_int(x) for x in kernel.shape] + num_filter = num_filter * nv + pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) + + if isinstance(stride, (tuple, list)): + stride_h, stride_w = stride + else: + stride_h, stride_w = stride, stride + + out_channel = num_filter + out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) + oshape = (batch, out_channel, out_height, out_width) + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + temp = pad(data, pad_before, pad_after, name="pad_temp") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + + block_w = 0 + block_h = 0 + if stride_h == 2: + if num_filter + kernel_h == 515: + conv_tag = "4_4" + block_h = 4 + block_w = 4 + else: + conv_tag = "4_5" + block_h = 4 + block_w = 5 + elif kernel_h == 3: + if num_filter == 512: + conv_tag = "2_7" + block_h = 2 + block_w = 7 + else: + conv_tag = "2_14" + block_h = 2 + block_w = 14 + else: + conv_tag = "1_16" + block_h = 1 + block_w = 16 + + c_h = out_height + c_w = out_width + + if not out_height % block_h == 0: + c_h = (out_height // block_h + 1) * block_h + + if not out_width % block_w == 0: + c_w = (out_width // block_w + 1) * block_w + + cshape = (batch, out_channel // nv, c_h, c_w, nv) + + conv = tvm.compute( + cshape, + lambda nn, ff, yy, xx, vc:\ + tvm.sum( + temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * + kernel[ff, rc, ry, rx, vc].astype(out_dtype), + axis=[rc, ry, rx]), tag=conv_tag, name='conv') + + output = tvm.compute( + oshape, + lambda nn, ff, yy, xx: + conv[nn][ff//nv][yy][xx][ff%nv], + name='output_unpack', tag=conv_tag) + + return output + +def _schedule_cl_spatialpack_NCHWc(s, op): + output = op.output(0) + _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape] + + conv = op.input_tensors[0] + temp = s[conv].op.input_tensors[0] + kernel = s[conv].op.input_tensors[1] + temp_W = s.cache_read(temp, "warp", [conv]) + conv_L = s.cache_write(conv, "local") + + kernel_L = s.cache_read(kernel, "local", [conv_L]) + _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape] + if "1_16" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 1 + OUTPUT_BLOCK_WIDTH = 16 + elif "2_14" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 2 + OUTPUT_BLOCK_WIDTH = 14 + elif "2_7" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 2 + OUTPUT_BLOCK_WIDTH = 7 + elif "4_5" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 4 + OUTPUT_BLOCK_WIDTH = 5 + elif "4_4" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 4 + OUTPUT_BLOCK_WIDTH = 4 + + # schedule conv + z_factor = 1 + y_factor = 1 + x_factor = 16 + thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") + thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") + thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") + _, co, oh, ow, vc = s[conv].op.axis + ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT) + oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH) + s[conv].reorder(_, co, ooh, oow, vc, ioh, iow) + coo, coi = s[conv].split(co, nparts=1) + ooho, oohi = s[conv].split(ooh, factor=z_factor) + oowo, oowi = s[conv].split(oow, factor=y_factor) + vco, vci = s[conv].split(vc, factor=x_factor) + s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow) + s[conv].bind(oohi, thread_z) + s[conv].bind(oowi, thread_y) + s[conv].bind(vci, thread_x) + s[conv].bind(ooho, tvm.thread_axis("blockIdx.z")) + s[conv].bind(oowo, tvm.thread_axis("blockIdx.y")) + s[conv].bind(coi, tvm.thread_axis("blockIdx.x")) + + # schedule conv_L + s[conv_L].compute_at(s[conv], vci) + i, oc, h, w, vc = s[conv_L].op.axis + rc, ry, rx = s[conv_L].op.reduce_axis + if in_channel == 2048: + rco, rci = s[conv_L].split(rc, nparts=128) + s[conv_L].unroll(rci) + s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w) + s[temp_W].compute_at(s[conv_L], rco) + else: + s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) + s[temp_W].compute_at(s[conv_L], rc) + if kernel.shape[3].value != 7: + s[conv_L].unroll(ry) + s[conv_L].unroll(rx) + if kernel.shape[3].value != 7: + s[conv_L].unroll(ry) + s[conv_L].unroll(rx) + + # schedule temp + _, ci, h, w = s[temp].op.axis + tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) + + # schedule temp_W + _, ci, h, w = s[temp_W].op.axis + zo, zi = s[temp_W].split(ci, 1) + yo, yi = s[temp_W].split(h, 1) + xo, xi = s[temp_W].split(w, 16) + s[temp_W].reorder(zo, yo, xo, zi, yi, xi) + s[temp_W].bind(zi, thread_z) + s[temp_W].bind(yi, thread_y) + s[temp_W].bind(xi, thread_x) + s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0) + + #schedule kernel + + # schedule kernel_L + if "2_14" in s[conv].op.tag: + s[kernel_L].compute_at(s[conv_L], ry) + else: + s[kernel_L].compute_at(s[conv_L], rx) + + # schedule output + if output.op in s.outputs: + out = output + else: + s[output].compute_inline() + out = s.outputs[0] + + _, co, h, w = s[out].op.axis + tile_and_bind3d(s, out, w, h, co, 4, 8, 8) + + +@conv2d.register(["intel_graphics"]) +def decl_conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): + """Conv2D operator for Intel Graphics backend. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + kernel : tvm.Tensor + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + layout : str + layout of data + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + assert layout == 'NCHW', "only support NCHW convolution on intel gpu" + assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu" + assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." + + out_dtype = data.dtype + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + kernel_shape = util.get_const_tuple(kernel.shape) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + return _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype) + +@generic.schedule_conv2d_nchw.register(["intel_graphics"]) +def schedule_conv2d_nchw(outs): + """Schedule for conv2d_nchw for Intel Graphics + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nchw + in the format of an array of tensors. + Returns + ------- + s: Schedule + The computation schedule for conv2d_nchw. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + """inline all one-to-one-mapping operators except the last stage (output)""" + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \ + or "1_16" in op.tag: + _schedule_cl_spatialpack(s, op) + + traverse(outs[0].op) + return s + +def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float16'): + batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape] + num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape] + pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) + + if isinstance(stride, (tuple, list)): + stride_h, stride_w = stride + else: + stride_h, stride_w = stride, stride + + out_channel = num_filter + out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) + oshape = (batch, out_channel, out_height, out_width) + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + temp = pad(data, pad_before, pad_after, name="pad_temp") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + + block_w = 0 + block_h = 0 + if stride_h == 2: + if num_filter + kernel_h == 515: + conv_tag = "4_4" + block_h = 4 + block_w = 4 + else: + conv_tag = "4_5" + block_h = 4 + block_w = 5 + elif kernel_h == 3: + if num_filter == 512: + conv_tag = "2_7" + block_h = 2 + block_w = 7 + else: + conv_tag = "2_14" + block_h = 2 + block_w = 14 + else: + conv_tag = "1_16" + block_h = 1 + block_w = 16 + + c_h = out_height + c_w = out_width + + if not out_height % block_h == 0: + c_h = (out_height // block_h + 1) * block_h + + if not out_width % block_w == 0: + c_w = (out_width // block_w + 1) * block_w + + nv = 16 + cshape = (batch, out_channel // nv, c_h, c_w, nv) + kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv) + + kernel_vec = tvm.compute( + kvshape, + lambda co, ci, kh, kw, vc: + kernel[co*nv + vc][ci][kh][kw], name='kernel_vec') + + conv = tvm.compute( + cshape, + lambda nn, ff, yy, xx, vc:\ + tvm.sum( + temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * + kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype), + axis=[rc, ry, rx]), tag=conv_tag, name='conv') + + output = tvm.compute( + oshape, + lambda nn, ff, yy, xx: + conv[nn][ff//nv][yy][xx][ff%nv], + name='output_unpack', tag=conv_tag) + + return output + +def _schedule_cl_spatialpack(s, op): + output = op.output(0) + _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape] + + conv = op.input_tensors[0] + temp = s[conv].op.input_tensors[0] + kernel_vec = s[conv].op.input_tensors[1] + kernel = s[kernel_vec].op.input_tensors[0] + temp_W = s.cache_read(temp, "warp", [conv]) + conv_L = s.cache_write(conv, "local") + + kernel_L = s.cache_read(kernel_vec, "local", [conv_L]) + _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape] + + if "1_16" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 1 + OUTPUT_BLOCK_WIDTH = 16 + elif "2_14" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 2 + OUTPUT_BLOCK_WIDTH = 14 + elif "2_7" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 2 + OUTPUT_BLOCK_WIDTH = 7 + elif "4_5" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 4 + OUTPUT_BLOCK_WIDTH = 5 + elif "4_4" in s[conv].op.tag: + OUTPUT_BLOCK_HEIGHT = 4 + OUTPUT_BLOCK_WIDTH = 4 + + # schedule conv + z_factor = 1 + y_factor = 1 + x_factor = 16 + thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z") + thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y") + thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x") + _, co, oh, ow, vc = s[conv].op.axis + ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT) + oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH) + s[conv].reorder(_, co, ooh, oow, vc, ioh, iow) + coo, coi = s[conv].split(co, nparts=1) + ooho, oohi = s[conv].split(ooh, factor=z_factor) + oowo, oowi = s[conv].split(oow, factor=y_factor) + vco, vci = s[conv].split(vc, factor=x_factor) + s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow) + s[conv].bind(oohi, thread_z) + s[conv].bind(oowi, thread_y) + s[conv].bind(vci, thread_x) + s[conv].bind(ooho, tvm.thread_axis("blockIdx.z")) + s[conv].bind(oowo, tvm.thread_axis("blockIdx.y")) + s[conv].bind(coi, tvm.thread_axis("blockIdx.x")) + + # schedule conv_L + s[conv_L].compute_at(s[conv], vci) + i, oc, h, w, vc = s[conv_L].op.axis + rc, ry, rx = s[conv_L].op.reduce_axis + if in_channel == 2048: + rco, rci = s[conv_L].split(rc, nparts=128) + s[conv_L].unroll(rci) + s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w) + s[temp_W].compute_at(s[conv_L], rco) + else: + s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) + s[temp_W].compute_at(s[conv_L], rc) + if kernel.shape[3].value != 7: + s[conv_L].unroll(ry) + s[conv_L].unroll(rx) + + # schedule temp + _, ci, h, w = s[temp].op.axis + tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) + + # schedule temp_W + _, ci, h, w = s[temp_W].op.axis + zo, zi = s[temp_W].split(ci, 1) + yo, yi = s[temp_W].split(h, 1) + xo, xi = s[temp_W].split(w, 16) + s[temp_W].reorder(zo, yo, xo, zi, yi, xi) + s[temp_W].bind(zi, thread_z) + s[temp_W].bind(yi, thread_y) + s[temp_W].bind(xi, thread_x) + s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0) + + s[kernel_vec].compute_inline() + + # schedule kernel_L + if "2_14" in s[conv].op.tag: + s[kernel_L].compute_at(s[conv_L], ry) + else: + s[kernel_L].compute_at(s[conv_L], rx) + + # schedule output + if output.op in s.outputs: + out = output + else: + s[output].compute_inline() + out = s.outputs[0] + + _, co, h, w = s[out].op.axis + tile_and_bind3d(s, out, w, h, co, 4, 8, 8)