Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Sep 5, 2019
1 parent 9f4df30 commit 2789011
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions topi/python/topi/intel_graphics/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):

layout_name = 'layout' if F == sym else 'data_layout'
layout = attrs[layout_name]
kh, kw = attrs.get_int_tuple("kernel_size")

dtype = data.dtype
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
Expand All @@ -191,6 +192,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
if is_depthwise else \
autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
if is_depthwise:
return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
Expand All @@ -202,19 +205,17 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):

new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
if is_depthwise:
raise RuntimeError("Intel graphics not supported depthwise schedule")
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)

# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)

# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)

dispatch_ctx.update(target, new_workload, cfg)
if F == sym:
Expand Down

0 comments on commit 2789011

Please sign in to comment.