Skip to content

Commit

Permalink
[QNN] Add hardswish int8 impl using table lookup (#11700)
Browse files Browse the repository at this point in the history
* v1

* [QNN] Add hardswish int8 impl using table lookup

* format

* format

* fix

* fix utest

* fix ci error

* jostle ci

* triggle ci

* remote nn

* jostle ci

* fix
  • Loading branch information
zhaoyang-star authored Jun 28, 2022
1 parent 688b082 commit 97b3076
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 21 deletions.
29 changes: 20 additions & 9 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,14 +981,10 @@ def _impl(inputs, _):
return _impl


def _hswish():
# refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# They fallback to fp32
def _impl(inputs, _):
assert len(inputs) == 5, "Input quant params not found in op inputs"
# TODO(masahi): Replace this with integer only compute.
# We do not have to strictly follow how PyTorch does it.

def _hswish(fp32_piggy_back=False):
def _impl_fp32(inputs):
# refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# They fallback to fp32
def relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)

Expand All @@ -1007,6 +1003,21 @@ def hardsigmoid(x):
dequantized_hswish, output_scale, output_zero_point, out_dtype="uint8"
)

def _impl_int8(inputs):
output_scale = _expr.const(inputs[1])
output_zero_point = _expr.const(inputs[2])
input_scale = _expr.const(inputs[3])
input_zero_point = _expr.const(inputs[4])
return relay.qnn.op.hardswish(
inputs[0], input_scale, input_zero_point, output_scale, output_zero_point
)

def _impl(inputs, _):
assert len(inputs) == 5, "Input quant params not found in op inputs"
if fp32_piggy_back:
return _impl_fp32(inputs)
return _impl_int8(inputs)

return _impl


Expand Down Expand Up @@ -1153,6 +1164,6 @@ def _impl(inputs, _):
"quantized::relu6": _relu6(),
"quantized::leaky_relu": _leaky_relu(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
"quantized::hardswish": _hswish(fp32_piggy_back=False),
"quantized::conv_transpose2d": _quantized_conv_transpose2d(),
}
7 changes: 7 additions & 0 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def legalize_qnn_unary_op(attrs, inputs, types):
return reg.register_qnn_legalize(op_name, legalize_qnn_unary_op)


def hardswish_func(x):
x2 = x + 3.0
x2 = np.clip(x2, 0.0, 6.0)
return x * x2 / 6.0


register_qnn_unary_op_legalize("qnn.sqrt", np.sqrt)
register_qnn_unary_op_legalize("qnn.rsqrt", lambda arr: 1 / np.sqrt(arr))
register_qnn_unary_op_legalize("qnn.exp", np.exp)
register_qnn_unary_op_legalize("qnn.erf", special.erf)
register_qnn_unary_op_legalize("qnn.sigmoid", lambda arr: 1 / (1 + np.exp(-arr)))
register_qnn_unary_op_legalize("qnn.hardswish", hardswish_func)
register_qnn_unary_op_legalize("qnn.tanh", np.tanh)
register_qnn_unary_op_legalize("qnn.log", np.log)

Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,41 @@ def sigmoid(x, scale, zero_point, output_scale, output_zero_point):
)


def hardswish(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized hardswish.
Parameters
----------
x : relay.Expr
The quantized input tensor.
scale: relay.Expr
The scale of the quantized expr.
zero_point: relay.Expr
The zero point of quantized expr.
output_scale: relay.Expr
The scale of the output quantized expr.
output_zero_point: relay.Expr
The zero point of output quantized expr.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.hardswish(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def log(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized log.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,5 +541,6 @@ def unary(expr, type_map):
register_unary_qnn("exp", relay.qnn.op.exp)
register_unary_qnn("erf", relay.qnn.op.erf)
register_unary_qnn("sigmoid", relay.qnn.op.sigmoid)
register_unary_qnn("hardswish", relay.qnn.op.hardswish)
register_unary_qnn("tanh", relay.qnn.op.tanh)
register_unary_qnn("log", relay.qnn.op.log)
4 changes: 4 additions & 0 deletions src/relay/qnn/op/unary_elementwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ QNN_CREATE_UNARY_ELEMENTWISE_OP("erf").set_attr<FTVMLegalize>(
QNN_CREATE_UNARY_ELEMENTWISE_OP("sigmoid").set_attr<FTVMLegalize>(
"FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Sigmoid));

QNN_CREATE_UNARY_ELEMENTWISE_OP("hardswish")
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize",
QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Hardswish));

QNN_CREATE_UNARY_ELEMENTWISE_OP("log").set_attr<FTVMLegalize>(
"FTVMQnnCanonicalize", QNN_UNARY_OP_DEFAULT_CANONICALIZATION(Log));

Expand Down
10 changes: 10 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,16 @@ static inline Expr BroadCastTo(Expr data, Array<IndexExpr> shape) {
return MakeBroadCastTo(data, CheckConstantShapeArrayInteger(shape));
}

inline Expr Hardswish(Expr x) {
auto three = MakeConstantScalar(DataType::Float(32), 3.0);
auto six = MakeConstantScalar(DataType::Float(32), 6.0);
auto x2 = Add(x, three);
x2 = Clip(x2, 0.0, 6.0);
x2 = Multiply(x, x2);
x2 = Divide(x2, six);
return x2;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
17 changes: 5 additions & 12 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,10 @@ def fuse_model(self):
class Hswish(nn.Module):
def __init__(self, add_stub=False):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hswish = nn.Hardswish()
self.hswish = QuantWrapper(nn.Hardswish())

def forward(self, x):
if self.add_stub:
x = self.quant(x)
x = self.hswish(x)
if self.add_stub:
x = self.dequant(x)
return x
return self.hswish(x)

def fuse_model(self):
pass
Expand Down Expand Up @@ -310,7 +302,7 @@ def test_quantized_modules():
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False),
Expand Down Expand Up @@ -372,7 +364,8 @@ def test_quantized_modules():
linear, per_channel 0.0 0.0 1.0
linear_relu, per_channel 0.0 0.0 1.0
hsigmoid 0.002614379 0.00020525524 0.9214896896258503
hswish 0.0052286386 0.00063522335 0.7587359162414966
hswish 0.0026143193 1.7367661e-08 0.9999933567176871
hswish, per_channel 0.0 0.0 1.0
semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
"""
Expand Down
9 changes: 9 additions & 0 deletions tests/python/relay/test_op_qnn_unary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm
import tvm.testing
from tvm import relay
from tvm.relay.qnn.op.legalizations import hardswish_func


def dequantize(data, scale, zp):
Expand Down Expand Up @@ -209,5 +210,13 @@ def test_all_numbers_int8(self):
generic_test(relay.qnn.op.sigmoid, lambda x: 1 / (1 + np.exp(-x)), input_dtype="int8")


class TestHardswish:
def test_all_numbers_uint8(self):
generic_test(relay.qnn.op.hardswish, hardswish_func, input_dtype="uint8")

def test_all_numbers_int8(self):
generic_test(relay.qnn.op.hardswish, hardswish_func, input_dtype="int8")


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 97b3076

Please sign in to comment.