diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ba47ae7bc4f1d..f2f607aeb95e2 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -179,9 +179,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 64373dcdd7bf6..d1bb4999c6cbc 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -76,9 +76,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) KH, KW, CI, CO = get_const_tuple(kernel.shape) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index a199534ccb513..6fc5265d7ae80 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -25,7 +25,6 @@ from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda -from .conv2d_nhwc import schedule_conv2d_nhwc_direct @autotvm.register_topi_compute("conv2d_nchw.cuda") @@ -48,26 +47,6 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_nhwc.cuda") -def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): - """Compute conv2d with NHWC layout""" - return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) - - -@autotvm.register_topi_schedule("conv2d_nhwc.cuda") -def schedule_conv2d_nhwc(cfg, outs): - """Create the schedule for conv2d_nhwc""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "conv2d_nhwc": - schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - @autotvm.register_topi_compute("conv2d_cudnn.cuda") def conv2d_cudnn( cfg, data, kernel, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32" diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py index 6d9fd39e16b8e..8ed9362a3cf2a 100644 --- a/python/tvm/topi/gpu/__init__.py +++ b/python/tvm/topi/gpu/__init__.py @@ -18,3 +18,4 @@ # pylint: disable=redefined-builtin, wildcard-import """GPU specific declaration and schedules.""" from .dense import * +from .conv2d import * diff --git a/python/tvm/topi/gpu/conv2d.py b/python/tvm/topi/gpu/conv2d.py new file mode 100644 index 0000000000000..87c900e1d4d76 --- /dev/null +++ b/python/tvm/topi/gpu/conv2d.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Schedule for conv2d operator""" +from tvm import te, autotvm + +from .. import nn +from ..utils import traverse_inline +from .conv2d_nhwc import schedule_conv2d_nhwc_direct + + +@autotvm.register_topi_compute("conv2d_nhwc.gpu") +def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): + """Compute conv2d with NHWC layout""" + return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc.gpu") +def schedule_conv2d_nhwc(cfg, outs): + """Create the schedule for conv2d_nhwc""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv2d_nhwc": + schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/gpu/conv2d_nhwc.py similarity index 98% rename from python/tvm/topi/cuda/conv2d_nhwc.py rename to python/tvm/topi/gpu/conv2d_nhwc.py index c3e62362a7ff9..a3cf124747167 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/gpu/conv2d_nhwc.py @@ -60,7 +60,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.kind.name, target.model, "conv2d_nhwc.cuda" + target.kind.name, target.model, "conv2d_nhwc.gpu" ) cfg.fallback_with_reference_log(ref_log) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index eb4c5a343b583..ba3d75ac07857 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -28,7 +28,7 @@ _conv2d_nhwc_implement = { "llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), - "cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), + "cuda": (topi.gpu.conv2d_nhwc, topi.gpu.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "arm_cpu": ( topi.arm_cpu.conv2d_nhwc_spatial_pack,