diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py index 21d893fd5656a..83c16dae7ac44 100644 --- a/topi/python/topi/nn/conv3d.py +++ b/topi/python/topi/nn/conv3d.py @@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): pad_before = [0, pad_front, pad_top, pad_left, 0] pad_after = [0, pad_back, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + rd = tvm.reduce_axis((0, kernel_d), name='rd') + rh = tvm.reduce_axis((0, kernel_h), name='rh') + rw = tvm.reduce_axis((0, kernel_w), name='rw') rc = tvm.reduce_axis((0, in_channel), name='rc') - rz = tvm.reduce_axis((0, kernel_d), name='rz') - ry = tvm.reduce_axis((0, kernel_h), name='ry') - rx = tvm.reduce_axis((0, kernel_w), name='rx') Output = tvm.compute( (batch, out_depth, out_height, out_width, out_channel), - lambda nn, zz, yy, xx, ff: tvm.sum( - PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]), + lambda nn, dd, hh, ww, cc: tvm.sum( + PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h, + ww * stride_w + rw * dilation_w, rc].astype(out_dtype) * + Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]), name="Conv3dOutput", tag="conv3d_ndhwc") return Output diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index af7f97415242c..d1c728d7b75ce 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -21,6 +21,7 @@ from .conv1d import schedule_conv1d_nwc from .conv2d import schedule_conv2d, schedule_conv2d_nhwc +from .conv3d import schedule_conv3d_ndhwc from .binarize_pack import schedule_binarize_pack from .binary_dense import schedule_binary_dense from .nn import * diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py new file mode 100644 index 0000000000000..279410a37e71b --- /dev/null +++ b/topi/python/topi/x86/conv3d.py @@ -0,0 +1,93 @@ +import tvm +from tvm import autotvm +from .. import generic, tag +from ..nn.conv3d import conv3d, conv3d_ndhwc, conv3d_ncdhw +from ..generic.nn import schedule_conv3d_ndhwc + +@autotvm.register_topi_compute(conv3d, 'cpu', ['direct']) +def conv3d_x86(cfg, input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None): + if layout == 'NCDHW': + return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype) + elif layout == 'NDHWC': + return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype) + +@autotvm.register_topi_schedule(schedule_conv3d_ndhwc, 'cpu', ['direct']) +def schedule_conv3d_ndhwc_x86(cfg, outs): + """TOPI schedule callback for conv2d + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d. + """ + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 5: # schedule bias + bn + relu + n, d, h, w, c = op.axis + fused = s[op].fuse(n, d, h, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv3d_ndhwc' in op.tag: + conv = op.output(0) + kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + n_pad, d_pad, h_pad, w_pad, c_pad = data_pad.op.axis + pad_fused = s[data_pad].fuse(h_pad, w_pad) + s[data_pad].parallel(pad_fused) + + C = conv + # data axes + n, d, h, w, c = s[C].op.axis + + if True: + # tile data h and w + ho, wo, hi, wi = s[C].tile(h, w, 2, 2) + # kernel axes + kd, ky, kx, kc = s[C].op.reduce_axis + kxi, kxo = s[C].split(kx, factor=2) + kci, kco = s[C].split(kc, factor=2) + # + s[C].reorder(n, d, ho, wo, hi, wi, c, kxo, kco, kxi, kci) + s[C].unroll(kci) + + s[C].vectorize(c) + if op != output_op: + _, _, _, _, c_out = output_op.axis + s[C].compute_at(s[output_op], c_out) + else: + fused = s[C].fuse(n, d) + s[C].parallel(fused) + + scheduled_ops.append(op) + + traverse(output_op) + return s