Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Option to select which convolution layers are quantized. #3173

Merged
merged 7 commits into from
May 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None

if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt in leave_alone_indices:
_set_conv_counter(cnt + 1)
return None

_set_conv_counter(cnt + 1)

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
Expand All @@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


Expand All @@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx):
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

Expand All @@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if _conv_counter() <= current_qconfig().skip_k_conv:
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if _conv_counter() <= current_qconfig().skip_k_conv:
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx):

def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if _conv_counter() <= current_qconfig().skip_k_conv:
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
Expand All @@ -262,8 +290,14 @@ def identity_rewrite(ref_call, new_args, ctx):

def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if _conv_counter() <= current_qconfig().skip_k_conv:
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

expr, x_kind = _get_expr_kind(new_args[0])

if x_kind is None:
Expand All @@ -280,8 +314,13 @@ def pool2d_rewrite(ref_call, new_args, ctx):
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if _conv_counter() <= current_qconfig().skip_k_conv:
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class QConfig(NodeBase):
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_k_conv": 1,
"skip_conv_layers": None,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
Expand Down Expand Up @@ -139,6 +140,10 @@ def qconfig(**kwargs):
skip_k_conv: int
The number of skipped conv2d.

skip_conv_layers: list
Different way of specifying which layers to avoid. Provide a list of indices
that indicate which conv2d layers to leave untouched.

round_for_shift: boolean
Whether to add bias for rounding during shift.

Expand Down
1 change: 1 addition & 0 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class QConfigNode : public Node {
DataType dtype_activation = Int(32);
double global_scale = 8.0;
int skip_k_conv = 1;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
Expand All @@ -140,6 +141,7 @@ class QConfigNode : public Node {
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
pre_computed=False)
if cfg.template_key == 'int8':
return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
if (data.dtype == 'int8' or data.dtype == 'uint8'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work with uint8? conv2d compute and dp4a need some changes since they are hardcoded as int8 (although it should work with uint8)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that it only supports int8 now but could be extended to uint8. I added this check because I was doing some quantization without autotuning, which caused the incorrect convolution algorithm to be chosen in some cases. Since autotuning is the intended workflow I think dropping this would be fine.

return conv2d_NCHWc_int8(
cfg, data, kernel, strides, padding, dilation, layout, out_dtype)

if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
Expand Down