Skip to content

Commit

Permalink
[Frontend][TFLite] Add support for relu6, leaky_relu, relu_n1_to_1, l…
Browse files Browse the repository at this point in the history
…og_softmax

* add implementation in parser
* add qnn tests for each operator
  • Loading branch information
inadob committed Jun 4, 2020
1 parent 34c95a8 commit 1c2435d
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 1 deletion.
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
Submodule dmlc-core updated 57 files
+0 −38 .github/workflows/githubci.yml
+1 −0 .gitignore
+82 −0 .travis.yml
+111 −122 CMakeLists.txt
+13 −201 LICENSE
+1 −1 README.md
+6 −19 appveyor.yml
+0 −13 cmake/Modules/FindASan.cmake
+0 −13 cmake/Modules/FindLSan.cmake
+0 −13 cmake/Modules/FindTSan.cmake
+0 −13 cmake/Modules/FindUBSan.cmake
+0 −63 cmake/Sanitizer.cmake
+1 −4 cmake/build_config.h.in
+1 −1 cmake/gtest_cmake.in
+263 −175 doc/Doxyfile
+1 −29 include/dmlc/base.h
+1 −4 include/dmlc/build_config_default.h
+0 −4 include/dmlc/concurrency.h
+18 −18 include/dmlc/concurrentqueue.h
+1 −1 include/dmlc/data.h
+2 −3 include/dmlc/json.h
+22 −53 include/dmlc/logging.h
+1 −1 include/dmlc/omp.h
+0 −10 include/dmlc/optional.h
+23 −106 include/dmlc/parameter.h
+1 −1 include/dmlc/registry.h
+3 −1 include/dmlc/thread_group.h
+2 −4 include/dmlc/thread_local.h
+46 −74 include/dmlc/threadediter.h
+2 −0 make/dmlc.mk
+2 −2 scripts/lint.py
+19 −12 scripts/packages.mk
+32 −0 scripts/setup_nvcc.sh
+0 −66 scripts/test_script.sh
+0 −0 scripts/travis/s390x/Dockerfile
+0 −0 scripts/travis/s390x/build_via_cmake.sh
+1 −1 scripts/travis/s390x/ci_build.sh
+0 −0 scripts/travis/s390x/entrypoint.sh
+3 −0 scripts/travis/travis_before_cache.sh
+9 −0 scripts/travis/travis_osx_install.sh
+57 −0 scripts/travis/travis_script.sh
+40 −0 scripts/travis/travis_setup_env.sh
+16 −0 src/build_config.cc
+3 −7 src/data/csv_parser.h
+1 −1 test/logging_test.cc
+0 −4 test/unittest/CMakeLists.txt
+1 −2 test/unittest/unittest_env.cc
+0 −11 test/unittest/unittest_logging_throw.cc
+0 −30 test/unittest/unittest_param.cc
+56 −80 test/unittest/unittest_parser.cc
+1 −0 test/unittest/unittest_thread_group.cc
+2 −2 test/unittest/unittest_threaditer.cc
+15 −19 test/unittest/unittest_threaditer_exc_handling.cc
+0 −4 tracker/dmlc_tracker/launcher.py
+0 −7 tracker/dmlc_tracker/ssh.py
+0 −13 tracker/dmlc_tracker/util.py
+2 −4 tracker/dmlc_tracker/yarn.py
112 changes: 112 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ def __init__(self, model, subgraph, exp_tab):
'HARD_SWISH': self.convert_hard_swish,
'L2_NORMALIZATION': self.convert_l2_normalization,
'L2_POOL_2D': self.convert_l2_pool2d,
'LEAKY_RELU': self.convert_leaky_relu,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
'LOG': self.convert_log,
'LOG_SOFTMAX': self.convert_log_softmax,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_NOT': self.convert_logical_not,
'LOGICAL_OR': self.convert_logical_or,
Expand All @@ -121,6 +123,8 @@ def __init__(self, model, subgraph, exp_tab):
'REDUCE_MIN': self.convert_reduce_min,
'REDUCE_PROD': self.convert_reduce_prod,
'RELU':self.convert_relu,
'RELU6': self.convert_relu6,
'RELU_N1_TO_1': self.convert_relu_n1_to_1,
'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
Expand Down Expand Up @@ -685,6 +689,114 @@ def _hard_swish(data):

return out

def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.clip(in_expr, a_min=0, a_max=6)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

return out

def convert_leaky_relu(self, op):
"""Convert TFLite LEAKY_RELU"""
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.LeakyReluOptions import LeakyReluOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

assert op.BuiltinOptionsType() == BuiltinOptions.LeakyReluOptions
op_options = op.BuiltinOptions()
leaky_relu_options = LeakyReluOptions()
leaky_relu_options.Init(op_options.Bytes, op_options.Pos)
alpha_tensor = leaky_relu_options.Alpha()

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.nn.leaky_relu(in_expr, alpha_tensor)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

return out

def convert_relu_n1_to_1(self, op):
"""Convert TFLite RELU_N1_TO_1"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.clip(in_expr, a_min=-1, a_max=1)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

return out

def convert_log_softmax(self, op):
"""Convert TFLite LOG_SOFTMAX"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.nn.log_softmax(in_expr)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)

return out

def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
Expand Down
109 changes: 109 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,32 @@ def test_forward_softmax():
""" Softmax """
_test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

######################################################################
# Log_softmax
# -----------

def _test_log_softmax(data, quantized=False):
""" One iteration of log_softmax """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-10, max=10, name="inq_0")
input_range = {'inq_0': (-10, 10)}
# tflite log_softmax supports only the case when axis is not specified
out = nn_ops.log_softmax(inq_data)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-20, max=0, name="out")
compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
else:
out = nn_ops.log_softmax(in_data)
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])


def test_forward_log_softmax():
""" Log_softmax """
_test_log_softmax(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
_test_log_softmax(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)

#######################################################################
# Tanh
# ----
Expand Down Expand Up @@ -1930,6 +1956,85 @@ def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# ReLU6
# -----

def _test_relu6(data, quantized=False):
""" One iteration of ReLU6 """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-10, max=10, name="inq_0")
input_range = {'inq_0': (-10, 10)}
out = nn_ops.relu6(inq_data)
out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=6, name="out")
compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
else:
out = nn_ops.relu6(in_data)
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_relu6():
""" ReLU6 """
_test_relu6(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
_test_relu6(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)

#######################################################################
# Leaky_ReLU
# ----------

def _test_leaky_relu(data, alpha, quantized=False):
""" One iteration of Leaky_ReLU """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-3, max=2, name="inq_0")
input_range = {'inq_0': (-3, 2)}
out = nn_ops.leaky_relu(inq_data, alpha)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-3, max=2, name="out")
compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
else:
out = nn_ops.leaky_relu(in_data, alpha)
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_leaky_relu():
""" Leaky_ReLU """
_test_leaky_relu(np.random.uniform(-5, 5, (1, 6)).astype(np.float32), alpha=0.2)
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_leaky_relu(np.random.uniform(0, 255, (2, 3)).astype(np.uint8), alpha=0.3, quantized=True)

#######################################################################
# ReLU_n1_to_1
# ------------

def _test_relu_n1_to_1(data, quantized=False):
""" One iteration of ReLU_n1_to_1 """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-3, max=3, name="inq_0")
input_range = {'inq_0': (-3, 3)}
# There is no such tf operation. The specific pattern will be replaced into RELU_N1_TO_1 by tflite
out = math_ops.maximum(-1.0, math_ops.minimum(inq_data, 1.0))
out = tf.quantization.fake_quant_with_min_max_args(out, min=-1, max=1, name="out")
compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
else:
out = math_ops.maximum(-1.0, math_ops.minimum(in_data, 1.0))
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_relu_n1_to_1():
""" ReLU_n1_to_1 """
_test_relu_n1_to_1(np.random.uniform(-3, 3, (1, 6)).astype(np.float32))
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_relu_n1_to_1(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)

#######################################################################
# PReLU
# -----

def _test_prelu(data, alpha):
""" One iteration of PReLU """
with tf.Graph().as_default():
Expand Down Expand Up @@ -2511,6 +2616,10 @@ def test_forward_mediapipe_hand_landmark():
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
test_forward_relu6()
test_forward_leaky_relu()
test_forward_relu_n1_to_1()
test_forward_log_softmax()
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
Expand Down

0 comments on commit 1c2435d

Please sign in to comment.