Skip to content

Commit

Permalink
[TOPI,CUDA] Don't enable cudnn conv2d kernel if is not supported (apa…
Browse files Browse the repository at this point in the history
…che#10021)

* [TOPI,CUDA] Don't enable cudnn conv2d kernel if is not supported

Specifically, check that layout is not NCHW if datatype is int8.

* remove all conv2d_cudnn int8 support
  • Loading branch information
Tristan Konolige authored and ylc committed Feb 16, 2022
1 parent d26530e commit c1637c8
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
and layout in ["NCHW", "NHWC"]
and padding[0] == padding[2]
and padding[1] == padding[3]
and not (data.dtype in ["uint8", "int8"] or kernel.dtype in ["uint8", "int8"])
):
# add cudnn implementation
if layout == "NHWC":
Expand Down Expand Up @@ -347,7 +348,12 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
# add cudnn implementation, if any
cudnn_impl = False
if target.kind.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and padding[1] == padding[3]:
if (
layout in ["NCHW", "NHWC"]
and padding[0] == padding[2]
and padding[1] == padding[3]
and not (data.dtype in ["uint8", "int8"] or kernel.dtype in ["uint8", "int8"])
):
strategy.add_implementation(
wrap_compute_conv2d(
topi.cuda.conv2d_cudnn, need_data_layout=True, has_groups=True
Expand Down

0 comments on commit c1637c8

Please sign in to comment.