Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TENSORFLOW] Convert scalar Const into tvm.relay.const #3885

Merged
merged 2 commits into from
Sep 4, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
@@ -84,10 +84,12 @@ def _dim_check(attrs):
return _dim_check, "Only 2d kernel supported."

def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
return np.atleast_1d(input_node.data.asnumpy())
return params.pop(input_node.name_hint).asnumpy()

def _get_num_param(params, input_node):
return _get_param(params, input_node)[0]
return _get_param(params, input_node).item()

def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
@@ -335,9 +337,9 @@ def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist()
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
boxes = _get_list_param(params, inputs[1])
box_ind = _get_list_param(params, inputs[2])
crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
@@ -505,7 +507,7 @@ def _impl(inputs, attr, params):

def _tile():
def _impl(inputs, attr, params):
reps = params[inputs.pop().name_hint].asnumpy()
reps = _get_list_param(params, inputs.pop())
new_input = []
new_input.append(inputs.pop(0))

@@ -752,7 +754,7 @@ def _impl(inputs, attr, params):

def _reduce(op):
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
axis = _get_list_param(params, inputs[1])
axis = tuple(axis)
return AttrCvt(
op_name=op,
@@ -937,8 +939,8 @@ def _impl(inputs, attr, params):

def _clip_by_value():
def _impl(inputs, attr, params):
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
a_min = _get_num_param(params, inputs[1])
a_max = _get_num_param(params, inputs[2])
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl

@@ -965,10 +967,11 @@ def _impl(inputs, attr, params):

def _range():
def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
start = _get_param(params, inputs[0])[0]
limit = _get_param(params, inputs[1])[0] \
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
return AttrCvt(
op_name="arange",
@@ -1084,7 +1087,7 @@ def _impl(inputs, attr, params):

def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
k = int(_get_num_param(params, inputs.pop(1)))
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
@@ -1196,7 +1199,7 @@ def _impl(inputs, attr, params):

def _prod():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
axis = _get_num_param(params, inputs[1])
keepdims = attr['keep_dims']
return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl
@@ -2104,13 +2107,12 @@ def _parse_param(self, key, value, name, shape):
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._params[name] = tvm.nd.array(new_array)
self._nodes[name] = [tvm.relay.const(new_array)]
else:
self._params[name] = tvm.nd.array(np_array)

self._nodes[name] = [_expr.var(name,
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
self._nodes[name] = [_expr.var(name,
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
else:
if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \