Skip to content

Commit

Permalink
* Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jan 14, 2019
1 parent 2c93f2e commit 065d67e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 26 deletions.
20 changes: 8 additions & 12 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,11 @@ def _impl(inputs, attr, params):
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if 'weight_layout' not in attr:
if 'kernel_layout' not in attr:
if opname == 'conv':
attr['weight_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['weight_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

use_bias = len(inputs) == 3
channel_axis = 1 if attr['data_format'] == "NCHW" else 3
Expand Down Expand Up @@ -602,12 +602,8 @@ def _impl(inputs, attr, params):
def _fill():
def _impl(inputs, attr, params):
fill_arg = params.pop(inputs.pop(1).name_hint)
new_inputs = []
return AttrCvt(
op_name='full',
extras={'shape':inputs[0],
'fill_value':fill_arg.asnumpy()[0], 'dtype':attr['T'].name},
ignores=['index_type', 'T'])(new_inputs, attr)
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
attr['_output_shapes'][0], attr['T'].name)
return _impl

def _lrn():
Expand Down Expand Up @@ -1329,10 +1325,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
#Add the RNN outputs also with 'head' nodes of the relay graph
if self._num_rnn_layer:
if len(self._out_rnn) == 1:
out.append(self._out_rnn[0])
out.append(self._out_rnn[0])
else:
out_rnn = _op.concatenate(self._out_rnn, axis=0)
out.append(out_rnn)
out_rnn = _op.concatenate(self._out_rnn, axis=0)
out.append(out_rnn)

out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(ir_pass.free_vars(out), out)
Expand Down
12 changes: 6 additions & 6 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def test_forward_resnetv2():
# PTB
# ---
dir(tf.contrib)
def test_forward_ptb():
def _test_forward_ptb():
'''test ptb model'''
config = tf_testing.get_config()
num_steps = config.num_steps
Expand All @@ -803,18 +803,18 @@ def _pretty_print(items, is_char_model, id2word):
return ''.join([id2word[x] for x in items]).replace('_', ' ')

def _get_tvm_graph_module(graph_def):
sym, params = nnvm.frontend.from_tensorflow(graph_def)

#Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}

sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)

dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
target = 'llvm'
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
graph, lib, params = relay.build(sym, target, params=params)
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
return params, graph_runtime.create(graph, lib, ctx)
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def test_forward_rel_ops():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_ptb()
#test_forward_ptb()

# RNN
test_forward_lstm()
Expand Down
9 changes: 1 addition & 8 deletions tutorials/relay/from_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
Please refer to https://www.tensorflow.org/install
"""

# tvm, relay and nnvm
import nnvm
# tvm, relay
import tvm
from tvm import relay

Expand All @@ -36,12 +35,6 @@
######################################################################
# Tutorials
# ---------
# .. note::
#
# protobuf should be exported with :any:`add_shapes=True` option.
# Could use https://github.com/dmlc/web-data/tree/master/tensorflow/scripts/tf-to-nnvm.py
# to add shapes for existing models.
#
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.

Expand Down

0 comments on commit 065d67e

Please sign in to comment.