-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support converting models generated in v1.3 to 2.0 compatibility (#725)
* add v1.3 compatibility * remove TestModelMajorCompatability as compatibility was added By the way: Compatability should be compatibility * Also remove TestModelMinorCompatability * Update test_deeppot_a.py * Revert "Update test_deeppot_a.py" This reverts commit a03b5ee. * Revert "Also remove TestModelMinorCompatability" This reverts commit 11fdd5c. * Revert "remove TestModelMajorCompatability as compatibility was added" This reverts commit 40dd807. * revert allowing 0.0 model * convert from model 1.3 to 2.0 * fix .gitignore
- Loading branch information
Showing
8 changed files
with
169 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from deepmd.utils.convert import convert_13_to_20 | ||
|
||
def convert( | ||
*, | ||
FROM: str, | ||
input_model: str, | ||
output_model: str, | ||
**kwargs, | ||
): | ||
if FROM == '1.3': | ||
convert_13_to_20(input_model, output_model) | ||
else: | ||
raise RuntimeError('unsupported model version ' + FROM) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
from deepmd.env import tf | ||
from google.protobuf import text_format | ||
from tensorflow.python.platform import gfile | ||
|
||
def convert_13_to_20(input_model: str, output_model: str): | ||
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') | ||
convert_dp13_to_dp20('frozen_model.pbtxt') | ||
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) | ||
if os.path.isfile('frozen_model.pbtxt'): | ||
os.remove('frozen_model.pbtxt') | ||
print("the converted output model (2.0 support) is saved in %s" % output_model) | ||
|
||
def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): | ||
with gfile.FastGFile(pbfile, 'rb') as f: | ||
graph_def = tf.GraphDef() | ||
graph_def.ParseFromString(f.read()) | ||
tf.import_graph_def(graph_def, name='') | ||
tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True) | ||
|
||
def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): | ||
with tf.gfile.FastGFile(pbtxtfile, 'r') as f: | ||
graph_def = tf.GraphDef() | ||
file_content = f.read() | ||
# Merges the human-readable string in `file_content` into `graph_def`. | ||
text_format.Merge(file_content, graph_def) | ||
tf.train.write_graph(graph_def, './', pbfile, as_text=False) | ||
|
||
def convert_dp13_to_dp20(fname: str): | ||
with open(fname) as fp: | ||
file_content = fp.read() | ||
file_content += """ | ||
node { | ||
name: "model_attr/model_version" | ||
op: "Const" | ||
attr { | ||
key: "dtype" | ||
value { | ||
type: DT_STRING | ||
} | ||
} | ||
attr { | ||
key: "value" | ||
value { | ||
tensor { | ||
dtype: DT_STRING | ||
tensor_shape { | ||
} | ||
string_val: "1.0" | ||
} | ||
} | ||
} | ||
} | ||
""" | ||
file_content = file_content\ | ||
.replace('DescrptSeA', 'ProdEnvMatA')\ | ||
.replace('DescrptSeR', 'ProdEnvMatR') | ||
with open(fname, 'w') as fp: | ||
fp.write(file_content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters