Skip to content

Commit

Permalink
add compatibility to dp freeze (#1055)
Browse files Browse the repository at this point in the history
Fix #1053.
  • Loading branch information
njzjz authored Aug 29, 2021
1 parent 5f8989b commit 1dd6e97
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
"""

import logging
from deepmd.env import tf
from deepmd.env import op_module
from deepmd.utils.sess import run_sess
Expand All @@ -18,6 +19,8 @@

__all__ = ["freeze"]

log = logging.getLogger(__name__)


def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> List[str]:
"""Get node names based on model type.
Expand Down Expand Up @@ -175,9 +178,18 @@ def freeze(
modifier_type = None
if node_names is None:
output_node_list = _make_node_names(model_type, modifier_type)
different_set = set(output_node_list) - set(nodes)
if different_set:
log.warning(
"The following nodes are not in the graph: %s. "
"Skip freezeing these nodes. You may be freezing "
"a checkpoint generated by an old version." % different_set
)
# use intersection as output list
output_node_list = list(set(output_node_list) & set(nodes))
else:
output_node_list = node_names.split(",")
print(f"The following nodes will be frozen: {output_node_list}")
log.info(f"The following nodes will be frozen: {output_node_list}")

# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
Expand All @@ -189,4 +201,4 @@ def freeze(
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print(f"{len(output_graph_def.node):d} ops in the final graph.")
log.info(f"{len(output_graph_def.node):d} ops in the final graph.")

0 comments on commit 1dd6e97

Please sign in to comment.