From 70f7f5bc0992b34ab9aed67ad043ab9f5d1f345b Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Sat, 25 Apr 2020 17:18:30 +0100 Subject: [PATCH] [RELAY] Move frontend utils (#5345) * [RELAY] Move frontend utils The util file currently under frontend is used from outside of frontend (in qnn/op/legalizations). This suggests that the file should be pushed up to a higher level. The benefit from this change is that importing qnn no longer also imports all the frontends. * Inline get_scalar_from_constant Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4 * Remove util.py from Relay Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96 * Shorten functions Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2 * Line length Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f --- python/tvm/relay/frontend/tflite.py | 12 ++++++++- python/tvm/relay/frontend/util.py | 33 ------------------------ python/tvm/relay/qnn/op/legalizations.py | 11 +++++++- 3 files changed, 21 insertions(+), 35 deletions(-) delete mode 100644 python/tvm/relay/frontend/util.py diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 275d0ce11ef2..bba7d3b1789a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -29,7 +29,6 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape @@ -2281,6 +2280,17 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + + def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/python/tvm/relay/frontend/util.py b/python/tvm/relay/frontend/util.py deleted file mode 100644 index a7f89a30b996..000000000000 --- a/python/tvm/relay/frontend/util.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -""" Utility functions that are used across many directories. """ -from __future__ import absolute_import -import numpy as np -from .. import expr as _expr - -def get_scalar_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ - "Expr is not a constant scalar." - value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index b1c19092b4c7..c96a730ee6ed 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,8 +20,8 @@ import tvm from tvm import relay +import numpy as np from .. import op as reg -from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. @@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types): # Helper functions. ################### +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, relay.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do