From e63e80e190fae95d8de6d11bc512ccd7b1aecb5f Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Fri, 11 Oct 2019 15:31:48 +0900 Subject: [PATCH 01/11] WIP Run the TF tutorial on TF2 --- include/tvm/relay/attrs/image.h | 3 +++ python/tvm/relay/frontend/common.py | 7 ++++++- python/tvm/relay/testing/tf.py | 2 +- src/relay/op/image/resize.cc | 2 ++ tutorials/frontend/from_tensorflow.py | 22 ++++++++++++++-------- 5 files changed, 26 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index dd3a0aa0cc65..63bd00ec7c97 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -37,6 +37,7 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string layout; std::string method; bool align_corners; + bool half_pixel_centers; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -54,6 +55,8 @@ struct ResizeAttrs : public tvm::AttrsNode { "bicubic - Bicubic Interpolation"); TVM_ATTR_FIELD(align_corners).set_default(true) .describe("Should be true to preserve the values at the corner pixels"); + TVM_ATTR_FIELD(half_pixel_centers).set_default(false) + .describe("Defaults to false, following the Tensorflow 2.0 settings"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type."); diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 637e1f0860da..2227f7589ff6 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -409,7 +409,12 @@ def __call__(self, inputs, attrs, *args): new_attrs[k] = attrs[k] # add extras new_attrs.update(self._extras) - return get_relay_op(op_name)(*inputs, **new_attrs) + try: + return get_relay_op(op_name)(*inputs, **new_attrs) + except: + import pdb; pdb.set_trace() + import traceback + print(traceback.format_exc()) def _parse_default(self, target): """Helper function to parse default values.""" diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index a56e6fe1782d..c8b1f92c745f 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -80,7 +80,7 @@ def AddShapesToGraphDef(session, out_node): """ - graph_def = tf.graph_util.convert_variables_to_constants( + graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( session, session.graph.as_graph_def(add_shapes=True), [out_node], diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index dbdf89790ac5..97d2a09e425e 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -73,12 +73,14 @@ Expr MakeResize(Expr data, std::string layout, std::string method, bool align_corners, + bool half_pixel_centers, DataType out_dtype) { auto attrs = make_node(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->align_corners = align_corners; + attrs->half_pixel_centers = half_pixel_centers; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index 34865f021230..4cbe9ae8b25f 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -34,6 +34,12 @@ # Tensorflow imports import tensorflow as tf +if tf.__version__.startswith('2'): + gfile_mod = tf.io.gfile + gfile = gfile_mod.GFile +else: + gfile_mod = tf.gfile + gfile = gfile_mod.FastGFile # Tensorflow utility functions import tvm.relay.testing.tf as tf_testing @@ -89,14 +95,14 @@ # ------------ # Creates tensorflow graph definition from protobuf file. -with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() +with gfile(model_path, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. graph_def = tf_testing.ProcessGraphDefParam(graph_def) # Add shapes to the graph. - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax') ###################################################################### @@ -187,8 +193,8 @@ def create_graph(): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. - with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() + with gfile(model_path, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. @@ -206,14 +212,14 @@ def run_inference_on_image(image): ------- Nothing """ - if not tf.gfile.Exists(image): + if not gfile_mod.Exists(image): tf.logging.fatal('File does not exist %s', image) - image_data = tf.gfile.FastGFile(image, 'rb').read() + image_data = gfile(image, 'rb').read() # Creates graph from saved GraphDef. create_graph() - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data}) From 6ec2b28c6c588f0510e69277adc2e9cff64951d7 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Fri, 11 Oct 2019 15:36:06 +0900 Subject: [PATCH 02/11] Remove debugger statement. --- python/tvm/relay/frontend/common.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 2227f7589ff6..637e1f0860da 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -409,12 +409,7 @@ def __call__(self, inputs, attrs, *args): new_attrs[k] = attrs[k] # add extras new_attrs.update(self._extras) - try: - return get_relay_op(op_name)(*inputs, **new_attrs) - except: - import pdb; pdb.set_trace() - import traceback - print(traceback.format_exc()) + return get_relay_op(op_name)(*inputs, **new_attrs) def _parse_default(self, target): """Helper function to parse default values.""" From bbf927a6faf0dd16d4d3d8552fc69d6312ff95e0 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Fri, 11 Oct 2019 22:22:13 +0900 Subject: [PATCH 03/11] Complete the support for TF2.0's `resize`. TF2.0 adds a `half_pixel_centers` attribute to the `resize` function in the image API. This commit completes the hooks in Relay's TF frontend. At the point of this commit, no new test yet. Also, this commit addresses solely the `resize` change. Other commits address other changes in TF2.0. --- python/tvm/relay/op/image/image.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index c54e438dce51..78a1050c6bdd 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -23,6 +23,7 @@ def resize(data, layout="NCHW", method="bilinear", align_corners=True, + half_pixel_centers=False, out_dtype=None): """Image resize operator. @@ -51,6 +52,9 @@ def resize(data, align_corners : int, optional Should be true to preserve the values at the corner pixels + half_pixel_centers : int, optional + If true, `align_corners` must be false. + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -59,4 +63,4 @@ def resize(data, result: relay.Expr The resized result. """ - return _make.resize(data, size, layout, method, align_corners, out_dtype) + return _make.resize(data, size, layout, method, align_corners, half_pixel_centers, out_dtype) From eeb865038ebc782fda51aa55009e40856ef77ee3 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Wed, 16 Oct 2019 17:52:57 +0900 Subject: [PATCH 04/11] Support TF2.0 in the tutorial by using the compat API. This looks cleaner than trying to detect the TF version. --- tutorials/frontend/from_tensorflow.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index 4cbe9ae8b25f..2c109cbaf907 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -34,12 +34,6 @@ # Tensorflow imports import tensorflow as tf -if tf.__version__.startswith('2'): - gfile_mod = tf.io.gfile - gfile = gfile_mod.GFile -else: - gfile_mod = tf.gfile - gfile = gfile_mod.FastGFile # Tensorflow utility functions import tvm.relay.testing.tf as tf_testing @@ -95,7 +89,7 @@ # ------------ # Creates tensorflow graph definition from protobuf file. -with gfile(model_path, 'rb') as f: +with tf.compat.v1.gfile.GFile(model_path, 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') @@ -193,7 +187,7 @@ def create_graph(): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. - with gfile(model_path, 'rb') as f: + with tf.compat.v1.gfile.GFile(model_path, 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') @@ -212,9 +206,9 @@ def run_inference_on_image(image): ------- Nothing """ - if not gfile_mod.Exists(image): + if not tf.compat.v1.io.gfile.exists(image): tf.logging.fatal('File does not exist %s', image) - image_data = gfile(image, 'rb').read() + image_data = tf.compat.v1.gfile.GFile(image, 'rb').read() # Creates graph from saved GraphDef. create_graph() From 63ce00534569096d793a4cb6867dc7c174ed29f5 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Wed, 16 Oct 2019 17:53:49 +0900 Subject: [PATCH 05/11] Use the TF compat API, so as to support TF2.0. This is a direct change, relying on the compat API provided by the TF team. This code will last as long as the compat API exists, so a "proper" support for TF1.x and 2.x will require more work in some future. --- python/tvm/relay/testing/tf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index c8b1f92c745f..79d0d8257953 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -112,13 +112,13 @@ def load(self, label_lookup_path, uid_lookup_path): dict from integer node ID to human-readable string. """ - if not tf.gfile.Exists(uid_lookup_path): + if not tf.compat.v1.io.gfile.exists(uid_lookup_path): tf.logging.fatal('File does not exist %s', uid_lookup_path) - if not tf.gfile.Exists(label_lookup_path): + if not tf.compat.v1.io.gfile.exists(label_lookup_path): tf.logging.fatal('File does not exist %s', label_lookup_path) # Loads mapping from string UID to human-readable string - proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines() + proto_as_ascii_lines = tf.compat.v1.gfile.GFile(uid_lookup_path).readlines() uid_to_human = {} p = re.compile(r'[n\d]*[ \S,]*') for line in proto_as_ascii_lines: @@ -129,7 +129,7 @@ def load(self, label_lookup_path, uid_lookup_path): # Loads mapping from string UID to integer node ID. node_id_to_uid = {} - proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines() + proto_as_ascii = tf.compat.v1.gfile.GFile(label_lookup_path).readlines() for line in proto_as_ascii: if line.startswith(' target_class:'): target_class = int(line.split(': ')[1]) @@ -209,7 +209,7 @@ def get_workload(model_path, model_sub_path=None): path_model = download_testdata(model_url, model_path, module='tf') # Creates graph from saved graph_def.pb. - with tf.gfile.FastGFile(path_model, 'rb') as f: + with tf.compat.v1.gfile.FastGFile(path_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') @@ -299,7 +299,7 @@ def _create_ptb_vocabulary(data_dir): file_name = 'ptb.train.txt' def _read_words(filename): """Read the data for creating vocabulary""" - with tf.gfile.GFile(filename, "r") as f: + with tf.compat.v1.gfile.GFile(filename, "r") as f: return f.read().encode("utf-8").decode("utf-8").replace("\n", "").split() def _build_vocab(filename): From 2ad38bf48234c5ded0a8ee0d88c1faf5de1bd329 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Wed, 16 Oct 2019 17:55:43 +0900 Subject: [PATCH 06/11] Partial support for EXPLICIT padding introduced in TF2.0. Explicit padding is a special case in TF2.0 (see reference linked below). Some models are serialized with that mode, and break TF support in TVM. Support is *partial* as EXPLICIT falls back to set padding on the Relay op, which only supports 2 values. At some point, padding may need to be extended to support 4 values, but that is out of scope of this support commit. Reference on EXPLICIT padding: https://github.com/tensorflow/tensorflow/commit/ec81825aaf7e848d9f8ddffdf1e0d20aebe9172c#diff-1d1c0bb0a880f85b6164f71dbb2f446e --- python/tvm/relay/frontend/tensorflow.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38f9c523e0b1..996001aaeca6 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -294,11 +294,30 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] + elif attr['padding'] == 'EXPLICIT': + warnings.warn( + 'Explicit padding is for now only supporting same padding ' + 'before/after. It uses values 0 and 2, out of its 4 values.' + ) + if 'explicit_paddings' not in attr: + raise tvm.error.OpAttributeInvalid( + 'EXPLICIT padding mode requires the' + ' `explicit_padding` attribute.' + ) + h = attr['explicit_paddings'][0] + v = attr['explicit_paddings'][2] + attr['padding'] = [h, v] + else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ 'valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) + # Not needed in the Relay API for `conv`. + # Consumed only if `padding` was set to `EXPLICIT`. + if opname == 'conv': + del attr['explicit_paddings'] + if 'kernel_layout' not in attr: if opname == 'conv': attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' From 3fa86a6b8775b8a94cf6054edc7f6d9c4328f410 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Thu, 17 Oct 2019 16:42:41 +0900 Subject: [PATCH 07/11] Guard on checking for optional TF2.0 attribute. --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 996001aaeca6..22a541ddea28 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -315,7 +315,7 @@ def _impl(inputs, attr, params): # Not needed in the Relay API for `conv`. # Consumed only if `padding` was set to `EXPLICIT`. - if opname == 'conv': + if opname == 'conv' and 'explicit_paddings' in attr: del attr['explicit_paddings'] if 'kernel_layout' not in attr: From 2ebc516f76367901f8998e3ff18da36a4ec838f6 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Thu, 17 Oct 2019 18:19:53 +0900 Subject: [PATCH 08/11] Do not expect Relay to implement TF-specific attributes. The `half_pixel_centers` attribute is a new feature in TF2.0. Earlier commits of mine mistakenly introduce them in the Relay API. This is probably not what Relay is expected to support, and the semantics of `half_pixel_centers` is unclear (to me, at least) at this point. --- include/tvm/relay/attrs/image.h | 3 --- python/tvm/relay/frontend/tensorflow.py | 4 +++- python/tvm/relay/op/image/image.py | 6 +----- src/relay/op/image/resize.cc | 2 -- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 63bd00ec7c97..dd3a0aa0cc65 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -37,7 +37,6 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string layout; std::string method; bool align_corners; - bool half_pixel_centers; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -55,8 +54,6 @@ struct ResizeAttrs : public tvm::AttrsNode { "bicubic - Bicubic Interpolation"); TVM_ATTR_FIELD(align_corners).set_default(true) .describe("Should be true to preserve the values at the corner pixels"); - TVM_ATTR_FIELD(half_pixel_centers).set_default(false) - .describe("Defaults to false, following the Tensorflow 2.0 settings"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type."); diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 22a541ddea28..058b9e2951ef 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -425,8 +425,10 @@ def _impl(inputs, attr, params): # NHWC attr['layout'] = 'NHWC' + # Ignore the new attribute `half_pixel_centers` from TF2.0, for now. + # The semantics of the attribute is not very clear at this point. return AttrCvt(op_name="resize", - ignores=['Tdim'], + ignores=['Tdim', 'half_pixel_centers'], extras={'method': "bilinear"})(inputs, attr) return _impl diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 78a1050c6bdd..c54e438dce51 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -23,7 +23,6 @@ def resize(data, layout="NCHW", method="bilinear", align_corners=True, - half_pixel_centers=False, out_dtype=None): """Image resize operator. @@ -52,9 +51,6 @@ def resize(data, align_corners : int, optional Should be true to preserve the values at the corner pixels - half_pixel_centers : int, optional - If true, `align_corners` must be false. - out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -63,4 +59,4 @@ def resize(data, result: relay.Expr The resized result. """ - return _make.resize(data, size, layout, method, align_corners, half_pixel_centers, out_dtype) + return _make.resize(data, size, layout, method, align_corners, out_dtype) diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 97d2a09e425e..dbdf89790ac5 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -73,14 +73,12 @@ Expr MakeResize(Expr data, std::string layout, std::string method, bool align_corners, - bool half_pixel_centers, DataType out_dtype) { auto attrs = make_node(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->align_corners = align_corners; - attrs->half_pixel_centers = half_pixel_centers; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return CallNode::make(op, {data}, Attrs(attrs), {}); From 5f635308d7e1f2a16342d71bde9f6e3aa3964fb8 Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Fri, 25 Oct 2019 10:25:17 +0900 Subject: [PATCH 09/11] Remove unclear comment. CR https://github.com/dmlc/tvm/pull/4104#discussion_r338705742 Addresses #4104 --- python/tvm/relay/frontend/tensorflow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 058b9e2951ef..539c753ba3b1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -426,7 +426,6 @@ def _impl(inputs, attr, params): attr['layout'] = 'NHWC' # Ignore the new attribute `half_pixel_centers` from TF2.0, for now. - # The semantics of the attribute is not very clear at this point. return AttrCvt(op_name="resize", ignores=['Tdim', 'half_pixel_centers'], extras={'method': "bilinear"})(inputs, attr) From f2fca5a5ff9e79ce5bf9098c2e53bfbda6d98abe Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Thu, 31 Oct 2019 14:59:33 +0900 Subject: [PATCH 10/11] Changes after review. Complying without understanding the rationale for now. --- python/tvm/relay/frontend/tensorflow.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 020c21053cb1..f9b8e7b31883 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -286,30 +286,11 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] - elif attr['padding'] == 'EXPLICIT': - warnings.warn( - 'Explicit padding is for now only supporting same padding ' - 'before/after. It uses values 0 and 2, out of its 4 values.' - ) - if 'explicit_paddings' not in attr: - raise tvm.error.OpAttributeInvalid( - 'EXPLICIT padding mode requires the' - ' `explicit_padding` attribute.' - ) - h = attr['explicit_paddings'][0] - v = attr['explicit_paddings'][2] - attr['padding'] = [h, v] - else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ 'valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) - # Not needed in the Relay API for `conv`. - # Consumed only if `padding` was set to `EXPLICIT`. - if opname == 'conv' and 'explicit_paddings' in attr: - del attr['explicit_paddings'] - if 'kernel_layout' not in attr: if opname == 'conv': attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' @@ -417,9 +398,9 @@ def _impl(inputs, attr, params): # NHWC attr['layout'] = 'NHWC' - # Ignore the new attribute `half_pixel_centers` from TF2.0, for now. + # Ignore the new attributes from TF2.0, for now. return AttrCvt(op_name="resize", - ignores=['Tdim', 'half_pixel_centers'], + ignores=['Tdim', 'half_pixel_centers', 'explicit_paddings'], extras={'method': "bilinear"})(inputs, attr) return _impl From 84702fe76860c4be3c3e5b228bf86f93e365229d Mon Sep 17 00:00:00 2001 From: Eric Platon Date: Thu, 31 Oct 2019 16:40:56 +0900 Subject: [PATCH 11/11] Fix the arguments set mistakenly. An argument ignored for the wrong operation. --- python/tvm/relay/frontend/tensorflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f9b8e7b31883..44dc9e5c657a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -300,8 +300,10 @@ def _impl(inputs, attr, params): use_bias = len(inputs) == 3 channel_axis = 1 if attr['data_format'] == "NCHW" else 3 + # Ignore the new attributes from TF2.0, for now. out = AttrCvt( op_name=_dimension_picker('conv'), + ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', @@ -400,7 +402,7 @@ def _impl(inputs, attr, params): # Ignore the new attributes from TF2.0, for now. return AttrCvt(op_name="resize", - ignores=['Tdim', 'half_pixel_centers', 'explicit_paddings'], + ignores=['Tdim', 'half_pixel_centers'], extras={'method': "bilinear"})(inputs, attr) return _impl