diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 4a7cff5f3f33..d090e03c6af3 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -311,9 +311,15 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.cuda", ) + elif is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True), + naive_schedule, + name="conv2d_generic_layout", + ) elif target.kind.name == "cuda" and "cudnn" not in target.libs: # No TVM native kernel applicable - raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) + raise RuntimeError("Unsupported conv2d layout {} {} for CUDA".format(layout, kernel_layout)) if ( target.kind.name == "cuda" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 2bb009dbc8f7..f66dce8d0f92 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -216,6 +216,7 @@ def schedule_bitpack(attrs, outs, target): def wrap_compute_conv2d( topi_compute, need_data_layout=False, + need_kernel_layout=False, need_out_layout=False, has_groups=False, need_auto_scheduler_layout=False, @@ -227,6 +228,7 @@ def _compute_conv2d(attrs, inputs, out_type): strides = get_const_tuple(attrs.strides) dilation = get_const_tuple(attrs.dilation) data_layout = attrs.get_str("data_layout") + kernel_layout = attrs.get_str("kernel_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype @@ -235,6 +237,8 @@ def _compute_conv2d(attrs, inputs, out_type): args.append(attrs.groups) if need_data_layout: args.append(data_layout) + if need_kernel_layout: + args.append(kernel_layout) if need_out_layout: args.append(out_layout) args.append(out_dtype)