From e1cc4107f065a9aabab5ec78bfcada6a78d6d4f7 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 20 Mar 2019 18:36:23 +0530 Subject: [PATCH] [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. --- python/tvm/relay/frontend/tensorflow.py | 32 ++++++++++++------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0efebe3cfec9..05b92f41e855 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -530,25 +530,23 @@ def _impl(inputs, attr, params): op_name="reshape", extras={'newshape':tuple(shape_arg.asnumpy())}, ignores=['Tshape'])(inputs, attr) - except KeyError: + except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - if all(in_node in params for in_node in inputs[1].list_input_names()): - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(params_new.asnumpy().flatten())}, - ignores=['Tshape'])(inputs, attr) - raise RuntimeError("Reshape with dynamic shape input not supported yet.") + func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + params_new = m.get_output(0) + inputs.pop(1) + return AttrCvt( + op_name="reshape", + extras={'newshape':tuple(params_new.asnumpy().astype('int32').flatten())}, + ignores=['Tshape'])(inputs, attr) return _impl def _bias_add():