diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index d763a3366545f..4d764b02b99dc 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -104,9 +104,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou if cfg.template_key == 'winograd': return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed=False) - if cfg.template_key == 'int8' : - if (data.dtype=='int8' or data.dtype=='uint8'): - return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + if cfg.template_key == 'int8': + if (data.dtype == 'int8' or data.dtype == 'uint8'): + return conv2d_NCHWc_int8( + cfg, data, kernel, strides, padding, dilation, layout, out_dtype) if layout == 'NCHW': return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)