diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 56cd652d739b..295b37a7e59b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -84,10 +84,12 @@ def _dim_check(attrs): return _dim_check, "Only 2d kernel supported." def _get_param(params, input_node): + if isinstance(input_node, _expr.Constant): + return np.atleast_1d(input_node.data.asnumpy()) return params.pop(input_node.name_hint).asnumpy() def _get_num_param(params, input_node): - return _get_param(params, input_node)[0] + return _get_param(params, input_node).item() def _get_list_param(params, input_node): return _get_param(params, input_node).tolist() @@ -335,9 +337,9 @@ def _impl(inputs, attr, params): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] try: - boxes = params.pop(inputs[1].name_hint).asnumpy().tolist() - box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist() - crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist() + boxes = _get_list_param(params, inputs[1]) + box_ind = _get_list_param(params, inputs[2]) + crop_size = _get_list_param(params, inputs[3]) except (IndexError, KeyError): boxes = _infer_value(inputs[1], params).asnumpy().tolist() box_ind = _infer_value(inputs[2], params).asnumpy().tolist() @@ -505,7 +507,7 @@ def _impl(inputs, attr, params): def _tile(): def _impl(inputs, attr, params): - reps = params[inputs.pop().name_hint].asnumpy() + reps = _get_list_param(params, inputs.pop()) new_input = [] new_input.append(inputs.pop(0)) @@ -752,7 +754,7 @@ def _impl(inputs, attr, params): def _reduce(op): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy() + axis = _get_list_param(params, inputs[1]) axis = tuple(axis) return AttrCvt( op_name=op, @@ -937,8 +939,8 @@ def _impl(inputs, attr, params): def _clip_by_value(): def _impl(inputs, attr, params): - a_min = params.pop(inputs[1].name_hint).asnumpy()[0] - a_max = params.pop(inputs[2].name_hint).asnumpy()[0] + a_min = _get_num_param(params, inputs[1]) + a_max = _get_num_param(params, inputs[2]) return _op.clip(inputs[0], a_min=a_min, a_max=a_max) return _impl @@ -965,10 +967,11 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): - start = params.pop(inputs[0].name_hint).asnumpy()[0] - limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ - if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] - delta = params.pop(inputs[2].name_hint).asnumpy()[0] + start = _get_param(params, inputs[0])[0] + limit = _get_param(params, inputs[1])[0] \ + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ + else params.pop('Rank').asnumpy()[0] + delta = _get_param(params, inputs[2])[0] dtype = attr['dtype'].name if 'dtype' in attr else "int32" return AttrCvt( op_name="arange", @@ -1084,7 +1087,7 @@ def _impl(inputs, attr, params): def _topk(): def _impl(inputs, attr, params): - k = int(params.pop(inputs.pop(1).name_hint).asnumpy()) + k = int(_get_num_param(params, inputs.pop(1))) if k < 1: raise tvm.error.OpAttributeInvalid( 'Attribute k must be positive in operator TopKV2') @@ -1196,7 +1199,7 @@ def _impl(inputs, attr, params): def _prod(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy()[0] + axis = _get_num_param(params, inputs[1]) keepdims = attr['keep_dims'] return _op.prod(inputs[0], int(axis), keepdims=keepdims) return _impl @@ -2104,13 +2107,12 @@ def _parse_param(self, key, value, name, shape): if array_ndim == 0: new_array = np.empty([1], dtype=np_array.dtype) new_array[0] = np_array - self._params[name] = tvm.nd.array(new_array) + self._nodes[name] = [tvm.relay.const(new_array)] else: self._params[name] = tvm.nd.array(np_array) - - self._nodes[name] = [_expr.var(name, - shape=self._params[name].shape, - dtype=self._params[name].dtype)] + self._nodes[name] = [_expr.var(name, + shape=self._params[name].shape, + dtype=self._params[name].dtype)] else: if key not in ('dtype', '_output_shapes', '_class'): raise NotImplementedError \