Skip to content

Commit

Permalink
Switching to tf.io. for file access and loading savedmodels in eager …
Browse files Browse the repository at this point in the history
…model. (#3126)

* Switching to tf.io. for file access and loading savedmodels in eager
mode.
  • Loading branch information
davidzats-eng authored Apr 21, 2020
1 parent e8c2253 commit 0337b4e
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 23 deletions.
9 changes: 6 additions & 3 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ py_library(
name = "tensorflowjs",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":quantization",
":version",
"//tensorflowjs/converters:converter"
"//tensorflowjs/converters:converter",
],
visibility = ["//visibility:public"],
)

py_library(
Expand Down Expand Up @@ -72,6 +72,7 @@ py_library(
":quantization",
":read_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand Down Expand Up @@ -115,6 +116,7 @@ py_test(
deps = [
":write_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand All @@ -126,14 +128,15 @@ py_test(
":read_weights",
":write_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "resource_loader_test",
srcs = ["resource_loader_test.py"],
srcs_version = "PY2AND3",
data = [":op_list_jsons"],
srcs_version = "PY2AND3",
deps = [
":resource_loader",
],
Expand Down
28 changes: 22 additions & 6 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":tf_saved_model_conversion_v2",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand All @@ -154,11 +154,11 @@ py_library(
deps = [
":common",
":fold_batch_norms",
":fuse_prelu",
":fuse_depthwise_conv2d",
":fuse_prelu",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:resource_loader",
"//tensorflowjs:version",
"//tensorflowjs:write_weights",
Expand Down Expand Up @@ -187,10 +187,27 @@ py_binary(
],
)

py_library(
name = "converter_lib",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":common",
":keras_h5_conversion",
":keras_tfjs_loader",
":tf_saved_model_conversion_v2",
"//third_party/py/h5py",
"//third_party/py/tensorflow",
"//third_party/py/tensorflowjs:version",
],
)

py_binary(
name = "converter",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":common",
":keras_h5_conversion",
Expand All @@ -200,17 +217,16 @@ py_binary(
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:version",
],
visibility = ["//visibility:public"],
)

py_binary(
name = "generate_test_model",
srcs = ["generate_test_model.py"],
testonly = True,
srcs = ["generate_test_model.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflowjs:expect_tensorflow_installed",
]
],
)

py_test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -446,4 +445,4 @@ def testLoadFunctionalTfKerasModel(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import tf_optimizer
Expand Down Expand Up @@ -285,7 +286,7 @@ def write_artifacts(topology,
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

with open(output_graph, 'wt') as f:
with tf.io.gfile.GFile(output_graph, 'w') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
Expand Down Expand Up @@ -426,14 +427,17 @@ def convert_tf_saved_model(saved_model_dir,
if signature_def is None:
signature_def = 'serving_default'

if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not tf.io.gfile.exists(output_dir):
tf.io.gfile.makedirs(output_dir)
output_graph = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)

if saved_model_tags:
saved_model_tags = saved_model_tags.split(',')
model = load(saved_model_dir, saved_model_tags)
model = None
# Ensure any graphs created in eager mode are able to run.
with context.eager_mode():
model = load(saved_model_dir, saved_model_tags)

_check_signature_in_model(model, signature_def)

Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/python/tensorflowjs/read_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@

import os
import shutil
import unittest

import tempfile

import numpy as np
import tensorflow as tf

from tensorflowjs import read_weights
from tensorflowjs import write_weights


class ReadWeightsTest(unittest.TestCase):
class ReadWeightsTest(tf.test.TestCase):
def setUp(self):
self._tmp_dir = tempfile.mkdtemp()
super(ReadWeightsTest, self).setUp()
Expand Down Expand Up @@ -342,4 +342,4 @@ def testReadQuantizedWeights(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()
5 changes: 3 additions & 2 deletions tfjs-converter/python/tensorflowjs/write_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os

import numpy as np
import tensorflow as tf

from tensorflowjs import quantization
from tensorflowjs import read_weights
Expand Down Expand Up @@ -133,7 +134,7 @@ def write_weights(

if write_manifest:
manifest_path = os.path.join(write_dir, 'weights_manifest.json')
with open(manifest_path, 'wb') as f:
with tf.io.gfile.GFile(manifest_path, 'wb') as f:
f.write(json.dumps(manifest).encode())

return manifest
Expand Down Expand Up @@ -291,7 +292,7 @@ def _shard_group_bytes_to_disk(
filepath = os.path.join(write_dir, filename)

# Write the shard to disk.
with open(filepath, 'wb') as f:
with tf.io.gfile.GFile(filepath, 'wb') as f:
f.write(shard)

return filenames
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/python/tensorflowjs/write_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

import os
import shutil
import unittest

import numpy as np
import tensorflow as tf

from tensorflowjs import write_weights

TMP_DIR = '/tmp/write_weights_test/'


class TestWriteWeights(unittest.TestCase):
class TestWriteWeights(tf.test.TestCase):
def setUp(self):
if not os.path.isdir(TMP_DIR):
os.makedirs(TMP_DIR)
Expand Down Expand Up @@ -751,4 +751,4 @@ def test_quantize_group(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()

0 comments on commit 0337b4e

Please sign in to comment.