diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 4ef15f8e735c..156f00021e71 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -25,6 +25,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout from ..nn.util import get_pad_tuple +from ..nn.depthwise_conv2d import depthwise_conv2d_nchw from ..nn import pad from .. import tag from .. import generic @@ -162,21 +163,77 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] - # Remove attached compilation target because conv2d_NCHWc needs to create # a conv2d_nchwc op and target is not one of conv2d's parameters. if "target" in new_attrs: del new_attrs["target"] - if F.__name__ == 'nnvm.symbol': - out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - else: - out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + data, kernel = tinfo[0], tinfo[1] + batch_size, in_channel, height, width = get_const_tuple(data.shape) + + groups = attrs.get_int("groups") + out_channel = attrs.get_int("channels") + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + out_dtype = attrs["out_dtype"] + + layout_name = 'layout' if F == sym else 'data_layout' + layout = attrs[layout_name] + kh, kw = attrs.get_int_tuple("kernel_size") + + dtype = data.dtype + out_dtype = dtype if out_dtype in ("same", "") else out_dtype + is_depthwise = groups == in_channel and groups == out_channel + + # only optimize for NCHW + if layout != 'NCHW': + return None + if groups != 1 and not is_depthwise: + return None + + dispatch_ctx = autotvm.task.DispatchContext.current + target = tvm.target.current_target() + + # query schedule and fallback if necessary + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ + if is_depthwise else \ + autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) + if is_depthwise: + return None + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) + + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + new_attrs[layout_name] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data.dtype) + + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - return out + dispatch_ctx.update(target, new_workload, cfg) + if F == sym: + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) -@conv2d_NCHWc.register(["intel_graphics"]) -def _decl_conv2d(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'): +@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct') +def _decl_conv2d(cfg, data, kernel, strides, padding, dilation, + layout, out_layout, out_dtype='float32'): """Conv2D operator for Intel Graphics backend. Parameters