diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py index 21d893fd5656a..83c16dae7ac44 100644 --- a/topi/python/topi/nn/conv3d.py +++ b/topi/python/topi/nn/conv3d.py @@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): pad_before = [0, pad_front, pad_top, pad_left, 0] pad_after = [0, pad_back, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + rd = tvm.reduce_axis((0, kernel_d), name='rd') + rh = tvm.reduce_axis((0, kernel_h), name='rh') + rw = tvm.reduce_axis((0, kernel_w), name='rw') rc = tvm.reduce_axis((0, in_channel), name='rc') - rz = tvm.reduce_axis((0, kernel_d), name='rz') - ry = tvm.reduce_axis((0, kernel_h), name='ry') - rx = tvm.reduce_axis((0, kernel_w), name='rx') Output = tvm.compute( (batch, out_depth, out_height, out_width, out_channel), - lambda nn, zz, yy, xx, ff: tvm.sum( - PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]), + lambda nn, dd, hh, ww, cc: tvm.sum( + PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h, + ww * stride_w + rw * dilation_w, rc].astype(out_dtype) * + Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]), name="Conv3dOutput", tag="conv3d_ndhwc") return Output diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index af7f97415242c..d1c728d7b75ce 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -21,6 +21,7 @@ from .conv1d import schedule_conv1d_nwc from .conv2d import schedule_conv2d, schedule_conv2d_nhwc +from .conv3d import schedule_conv3d_ndhwc from .binarize_pack import schedule_binarize_pack from .binary_dense import schedule_binary_dense from .nn import * diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py new file mode 100644 index 0000000000000..0fe5d8d8d582b --- /dev/null +++ b/topi/python/topi/x86/conv3d.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=unused-argument, redefined-builtin, no-else-return +"""Conv3D operators""" +import tvm +from .. import generic, tag + +@generic.schedule_conv3d_ndhwc.register("cpu") +def schedule_conv3d_ndhwc(outs): + """TOPI schedule callback for conv3d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv3d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv3d. + """ + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 5: + # schedule bias + bn + activation + n, d, h, w, c = op.axis + fused = s[op].fuse(n, d, h, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv3d_ndhwc' in op.tag: + conv = op.output(0) + kernel = op.input_tensors[1] + # dilation stage + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + # padding stage + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + # fuse pad h and w + data_pad = data + data = data_pad.op.input_tensors[0] + _, _, h_pad, w_pad, _ = data_pad.op.axis + pad_fused = s[data_pad].fuse(h_pad, w_pad) + s[data_pad].parallel(pad_fused) + + # compute conv + C = conv + n, d, h, w, c = s[C].op.axis + s[C].vectorize(c) + if op != output_op: # fuse bias + bn + activation + _, _, _, _, c_out = output_op.axis + s[C].compute_at(s[output_op], c_out) + else: + # fuse batch, depth, height axes + fused = s[C].fuse(n, d, h) + s[C].parallel(fused) + + scheduled_ops.append(op) + + traverse(output_op) + return s