diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 275d0ce11ef2..bba7d3b1789a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -29,7 +29,6 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape @@ -2281,6 +2280,17 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + + def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/python/tvm/relay/frontend/util.py b/python/tvm/relay/frontend/util.py deleted file mode 100644 index a7f89a30b996..000000000000 --- a/python/tvm/relay/frontend/util.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -""" Utility functions that are used across many directories. """ -from __future__ import absolute_import -import numpy as np -from .. import expr as _expr - -def get_scalar_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ - "Expr is not a constant scalar." - value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index b1c19092b4c7..c96a730ee6ed 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,8 +20,8 @@ import tvm from tvm import relay +import numpy as np from .. import op as reg -from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. @@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types): # Helper functions. ################### +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, relay.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do