From 094a0aa85db083971b4dd9717b2c0736e0715214 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 7 May 2019 15:01:56 -0700 Subject: [PATCH 1/7] Stashing for later maybe. --- python/tvm/relay/quantize/_annotate.py | 66 ++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index e52ce142e5c3..caee00048fbf 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -20,6 +20,7 @@ import warnings import topi +from tvm import relay from . import _quantize from .quantize import QAnnotateKind, current_qconfig from .quantize import _conv_counter, _set_conv_counter @@ -153,9 +154,22 @@ def conv2d_rewrite(ref_call, new_args, ctx): input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" cnt = _conv_counter() + print(cnt) if cnt < current_qconfig().skip_k_conv: _set_conv_counter(cnt + 1) return None + print(cnt) + + boundary_node = False + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt not in leave_alone_indices and (cnt + 1) in leave_alone_indices: + # If this node is quantized and the next node isn't, we need to handle the boundary. + boundary_node = True + if cnt in leave_alone_indices: + _set_conv_counter(cnt + 1) + return None + _set_conv_counter(cnt + 1) lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -168,6 +182,17 @@ def conv2d_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + + # If this is a boundary node, also need to attach dequantization. + #if boundary_node: + # first transpose extra channel back into main channel. + #expr = relay.transpose(expr, [0, 1, 4, 2, 3]) + # Next we reshape to fuse channel and subchannel. Then convert to float. + #expr = relay.reshape(expr, [0, -3, 0, 0]).astype('float32') + + # just see if this even works. + #expr = relay.add(expr, expr) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -178,6 +203,11 @@ def dense_rewrite(ref_call, new_args, ctx): cnt = _conv_counter() if cnt < current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + return None + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -194,8 +224,13 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -216,8 +251,13 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -244,8 +284,13 @@ def add_rewrite(ref_call, new_args, ctx): def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + return None x_expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -262,8 +307,14 @@ def identity_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + return None + expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -280,8 +331,13 @@ def pool2d_rewrite(ref_call, new_args, ctx): @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - if _conv_counter() <= current_qconfig().skip_k_conv: + cnt = _conv_counter() + if cnt <= current_qconfig().skip_k_conv: return None + if current_qconfig().skip_conv_layers is not None: + leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if cnt in leave_alone_indices: + return None input_tuple = new_args[0] expr_list = [_get_expr_kind(x)[0] for x in input_tuple] From e963eba20e1d776fcc44dff1361a3dbb7bf93b37 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 9 May 2019 11:00:41 -0700 Subject: [PATCH 2/7] Added new option to leave specific layers unquantized. --- python/tvm/relay/quantize/_annotate.py | 28 ++++++-------------------- python/tvm/relay/quantize/quantize.py | 5 +++++ src/relay/pass/quantize.cc | 3 ++- src/relay/pass/quantize.h | 2 ++ topi/python/topi/cuda/conv2d.py | 6 ++++-- 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index caee00048fbf..c001288fbb0b 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -154,18 +154,12 @@ def conv2d_rewrite(ref_call, new_args, ctx): input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" cnt = _conv_counter() - print(cnt) if cnt < current_qconfig().skip_k_conv: _set_conv_counter(cnt + 1) return None - print(cnt) - boundary_node = False if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt not in leave_alone_indices and (cnt + 1) in leave_alone_indices: - # If this node is quantized and the next node isn't, we need to handle the boundary. - boundary_node = True if cnt in leave_alone_indices: _set_conv_counter(cnt + 1) return None @@ -183,16 +177,6 @@ def conv2d_rewrite(ref_call, new_args, ctx): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - # If this is a boundary node, also need to attach dequantization. - #if boundary_node: - # first transpose extra channel back into main channel. - #expr = relay.transpose(expr, [0, 1, 4, 2, 3]) - # Next we reshape to fuse channel and subchannel. Then convert to float. - #expr = relay.reshape(expr, [0, -3, 0, 0]).astype('float32') - - # just see if this even works. - #expr = relay.add(expr, expr) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -205,7 +189,7 @@ def dense_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -229,7 +213,7 @@ def multiply_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -256,7 +240,7 @@ def add_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -289,7 +273,7 @@ def identity_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None x_expr, x_kind = _get_expr_kind(new_args[0]) @@ -312,7 +296,7 @@ def pool2d_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt + 1) in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -336,7 +320,7 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices: + if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: return None input_tuple = new_args[0] diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 607ee1821c86..237f1d23b2c5 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -71,6 +71,7 @@ class QConfig(NodeBase): "dtype_activation": "int32", "global_scale": 8.0, "skip_k_conv": 1, + "skip_conv_layers": None, "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, @@ -139,6 +140,10 @@ def qconfig(**kwargs): skip_k_conv: int The number of skipped conv2d. + skip_conv_layers: list + Different way of specifying which layers to avoid. Provide a list of indices + that indicate which conv2d layers to leave untouched. + round_for_shift: boolean Whether to add bias for rounding during shift. diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 7fd27b46ad6a..30bf1ea7ecc9 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -439,7 +439,7 @@ Expr AddRealize(const Call& ref_call, Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } - CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + //CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); } @@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; + p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 4d26dd6be4a5..5864f51629ed 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -126,6 +126,7 @@ class QConfigNode : public Node { DataType dtype_activation = Int(32); double global_scale = 8.0; int skip_k_conv = 1; + Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); @@ -140,6 +141,7 @@ class QConfigNode : public Node { v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); v->Visit("skip_k_conv", &skip_k_conv); + v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 006e2fc5aaf8..297773cb1068 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name """Compute definition for conv2d with cuda backend""" import tvm +import topi from tvm import autotvm from tvm.contrib import cudnn @@ -104,8 +105,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou if cfg.template_key == 'winograd': return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed=False) - if cfg.template_key == 'int8': - return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + if cfg.template_key == 'int8' : + if (data.dtype=='int8' or data.dtype=='uint8'): + return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) if layout == 'NCHW': return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) From bb3464ae528cf9993971be7ea08556c4d4f37594 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 10 May 2019 14:15:37 -0700 Subject: [PATCH 3/7] Better error checking. --- python/tvm/relay/quantize/_annotate.py | 12 ++++++------ src/relay/pass/quantize.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index c001288fbb0b..f708f6df9248 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -189,7 +189,7 @@ def dense_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -213,7 +213,7 @@ def multiply_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -240,7 +240,7 @@ def add_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -273,7 +273,7 @@ def identity_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None x_expr, x_kind = _get_expr_kind(new_args[0]) @@ -296,7 +296,7 @@ def pool2d_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -320,7 +320,7 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices or (cnt - 1) in leave_alone_indices: + if (cnt - 1) in leave_alone_indices: return None input_tuple = new_args[0] diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 30bf1ea7ecc9..3a2e54c8ad39 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -439,7 +439,7 @@ Expr AddRealize(const Call& ref_call, Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } - //CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); } From d0623915c91ff90aa24bd7a3460b472e839aa7ba Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 10 May 2019 14:16:51 -0700 Subject: [PATCH 4/7] remove unneeded import --- topi/python/topi/cuda/conv2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 297773cb1068..d763a3366545 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name """Compute definition for conv2d with cuda backend""" import tvm -import topi from tvm import autotvm from tvm.contrib import cudnn From cd43e2fa3a031f1b2a6fd9b6d7bba93374a19403 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 10 May 2019 14:17:57 -0700 Subject: [PATCH 5/7] tab to spaces --- src/relay/pass/quantize.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 5864f51629ed..2c70da177199 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -141,7 +141,7 @@ class QConfigNode : public Node { v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); v->Visit("skip_k_conv", &skip_k_conv); - v->Visit("skip_conv_layers", &skip_conv_layers); + v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); From 1ba47de56f7ab4d76897c0769872b5d8ea703c21 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 10 May 2019 14:30:41 -0700 Subject: [PATCH 6/7] pylint fixes --- python/tvm/relay/quantize/_annotate.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index f708f6df9248..9bf546fcdadf 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -20,7 +20,6 @@ import warnings import topi -from tvm import relay from . import _quantize from .quantize import QAnnotateKind, current_qconfig from .quantize import _conv_counter, _set_conv_counter @@ -189,7 +188,7 @@ def dense_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -213,7 +212,7 @@ def multiply_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -240,7 +239,7 @@ def add_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -273,7 +272,7 @@ def identity_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None x_expr, x_kind = _get_expr_kind(new_args[0]) @@ -296,7 +295,7 @@ def pool2d_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -320,7 +319,7 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None if current_qconfig().skip_conv_layers is not None: leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if (cnt - 1) in leave_alone_indices: + if cnt - 1 in leave_alone_indices: return None input_tuple = new_args[0] From bbf128a82363b812b73bab80eaf3393fee43a2bf Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 10 May 2019 14:34:10 -0700 Subject: [PATCH 7/7] more pylint fixes --- topi/python/topi/cuda/conv2d.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index d763a3366545..4d764b02b99d 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -104,9 +104,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou if cfg.template_key == 'winograd': return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed=False) - if cfg.template_key == 'int8' : - if (data.dtype=='int8' or data.dtype=='uint8'): - return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + if cfg.template_key == 'int8': + if (data.dtype == 'int8' or data.dtype == 'uint8'): + return conv2d_NCHWc_int8( + cfg, data, kernel, strides, padding, dilation, layout, out_dtype) if layout == 'NCHW': return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)