Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: combine graph by prefixing with unique name #4334

Merged
merged 3 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ py_library(
srcs = ["graph_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
"//tensorboard/compat/proto:protos_all_py_pb2",
],
)

py_test(
Expand All @@ -136,7 +139,6 @@ py_test(
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/compat/proto:protos_all_py_pb2",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_six",
],
)

Expand Down
212 changes: 85 additions & 127 deletions tensorboard/plugins/graph/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,152 +14,110 @@
# ==============================================================================
"""Utilities for graph plugin."""

from tensorboard.compat.proto import graph_pb2

class _ProtoListDuplicateKeyError(Exception):
pass

def _prefixed_op_name(prefix, op_name):
return "%s/%s" % (prefix, op_name)

class _SameKeyDiffContentError(Exception):
pass

def _prefixed_func_name(prefix, func_name):
# TODO(stephanwlee): add business logic to strip "__inference_".
return "%s_%s" % (prefix, func_name)

def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key):
"""Safely merge values from `src_proto_list` into `dst_proto_list`.

Each element in `dst_proto_list` must be mapped by `get_key` to a key
value that is unique within that list; likewise for `src_proto_list`.
If an element of `src_proto_list` has the same key as an existing
element in `dst_proto_list`, then the elements must also be equal.
def _prepend_names(prefix, orig_graph_def):
mut_graph_def = graph_pb2.GraphDef()
for node in orig_graph_def.node:
new_node = mut_graph_def.node.add()
new_node.CopyFrom(node)
new_node.name = _prefixed_op_name(prefix, node.name)
new_node.input[:] = [
_prefixed_op_name(prefix, input_name) for input_name in node.input
]

Args:
dst_proto_list: A `RepeatedCompositeContainer` or
`RepeatedScalarContainer` into which values should be copied.
src_proto_list: A container holding the same kind of values as in
`dst_proto_list` from which values should be copied.
get_key: A function that takes an element of `dst_proto_list` or
`src_proto_list` and returns a key, such that if two elements have
the same key then it is required that they be deep-equal. For
instance, if `dst_proto_list` is a list of nodes, then `get_key`
might be `lambda node: node.name` to indicate that if two nodes
have the same name then they must be the same node. All keys must
be hashable.
# Remap tf.function method name in the PartitionedCall. 'f' is short for
# function.
if new_node.op == "PartitionedCall" and new_node.attr["f"]:

new_node.attr["f"].func.name = _prefixed_func_name(
prefix, new_node.attr["f"].func.name,
)

for func in orig_graph_def.library.function:
new_func = mut_graph_def.library.function.add()
new_func.CopyFrom(func)
# Not creating a structure out of factored out function. They already
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this comment refer to the fact that we don't use a / in the function name prefix, just an underscore? If so, the comment might make more sense as a comment on _prefixed_func_name() itself.

# create an awkward hierarchy and one for each graph.
new_func.signature.name = _prefixed_func_name(
prefix, new_func.signature.name
)

for gradient in orig_graph_def.library.gradient:
new_gradient = mut_graph_def.library.gradient.add()
new_gradient.CopyFrom(gradient)
new_gradient.function_name = _prefixed_func_name(
prefix, new_gradient.function_name,
)
new_gradient.gradient_func = _prefixed_func_name(
prefix, new_gradient.gradient_func,
)

return mut_graph_def

Raises:
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
keys.
_SameKeyDiffContentError: An item with the same key has different contents.
"""

def _assert_proto_container_unique_keys(proto_list, get_key):
"""Asserts proto_list to only contains unique keys.

Args:
proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`.
get_key: A function that takes an element of `proto_list` and returns a
hashable key.

Raises:
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
keys.
"""
keys = set()
for item in proto_list:
key = get_key(item)
if key in keys:
raise _ProtoListDuplicateKeyError(key)
keys.add(key)

_assert_proto_container_unique_keys(dst_proto_list, get_key)
_assert_proto_container_unique_keys(src_proto_list, get_key)

key_to_proto = {}
for proto in dst_proto_list:
key = get_key(proto)
key_to_proto[key] = proto

for proto in src_proto_list:
key = get_key(proto)
if key in key_to_proto:
if proto != key_to_proto.get(key):
raise _SameKeyDiffContentError(key)
else:
dst_proto_list.add().CopyFrom(proto)


def combine_graph_defs(to_proto, from_proto):
"""Combines two GraphDefs by adding nodes from from_proto into to_proto.
def merge_graph_defs(graph_defs):
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.

All GraphDefs are expected to be of TensorBoard's.
It assumes node names are unique across GraphDefs if contents differ. The
names can be the same if the NodeDef content are exactly the same.

When collecting graphs using the `tf.summary.trace` API, node names are not
guranteed to be unique. When non-unique names are not considered, it can
lead to graph visualization showing them as one which creates inaccurate
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
for uniquenss while merging but it resulted in
https://github.com/tensorflow/tensorboard/issues/1929.

To remedy these issues, we simply "apply name scope" on each graph by
prefixing it with unique name (with a chance of collision) to create
unconnected group of graphs.

In case there is only one graph def passed, it returns the original
graph_def. In case no graph defs are passed, it returns an empty GraphDef.

Args:
to_proto: A destination TensorBoard GraphDef.
from_proto: A TensorBoard GraphDef to copy contents from.
graph_defs: TensorBoard GraphDefs to merge.

Returns:
to_proto
TensorBoard GraphDef that merges all graph_defs with unique prefixes.

Raises:
ValueError in case any assumption about GraphDef is violated: A
GraphDef should have unique node, function, and gradient function
names. Also, when merging GraphDefs, they should have not have nodes,
functions, or gradient function mappings that share the name but details
do not match.
ValueError in case GraphDef versions mismatch.
"""
if from_proto.version != to_proto.version:
raise ValueError("Cannot combine GraphDefs of different versions.")
if len(graph_defs) == 1:
return graph_defs[0]
elif len(graph_defs) == 0:
return graph_pb2.GraphDef()

try:
_safe_copy_proto_list_values(
to_proto.node, from_proto.node, lambda n: n.name
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError("A GraphDef contains non-unique node names: %s" % exc)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because nodes share a name "
"but contents are different: %s"
)
% exc
)
try:
_safe_copy_proto_list_values(
to_proto.library.function,
from_proto.library.function,
lambda n: n.signature.name,
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError(
"A GraphDef contains non-unique function names: %s" % exc
)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because functions share a name "
"but are different: %s"
)
% exc
)
dst_graph_def = graph_pb2.GraphDef()

try:
_safe_copy_proto_list_values(
to_proto.library.gradient,
from_proto.library.gradient,
lambda g: g.gradient_func,
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError(
"A GraphDef contains non-unique gradient function names: %s" % exc
)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because gradients share a gradient_func name "
"but map to different functions: %s"
if graph_defs[0].versions.producer:
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)

for index, graph_def in enumerate(graph_defs):
if dst_graph_def.versions.producer != graph_def.versions.producer:
raise ValueError("Cannot combine GraphDefs of different versions.")

mapped_graph_def = _prepend_names("graph_%d" % (index + 1), graph_def)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might consider a function signature more like

def _add_with_prepended_names(prefix, graph_to_add, destination_graph)

That way we can avoid the overhead of constructing the temporary graphdef that's returned here, only to then copy its nodes and functions into the destination graph def. Since proto python objects aren't immutable, the python code has to do a deep copy each time we move the nodes around, so it would be nice not to do that twice when we could do it just once.

dst_graph_def.node.extend(mapped_graph_def.node)
if mapped_graph_def.library.function:
dst_graph_def.library.function.extend(
mapped_graph_def.library.function
)
if mapped_graph_def.library.gradient:
dst_graph_def.library.gradient.extend(
mapped_graph_def.library.gradient
)
% exc
)

return to_proto
return dst_graph_def
Loading