Skip to content

Commit

Permalink
support converting models generated in v1.3 to 2.0 compatibility (#725)
Browse files Browse the repository at this point in the history
* add v1.3 compatibility

* remove TestModelMajorCompatability as compatibility was added

By the way: Compatability should be compatibility

* Also remove TestModelMinorCompatability

* Update test_deeppot_a.py

* Revert "Update test_deeppot_a.py"

This reverts commit a03b5ee.

* Revert "Also remove TestModelMinorCompatability"

This reverts commit 11fdd5c.

* Revert "remove TestModelMajorCompatability as compatibility was added"

This reverts commit 40dd807.

* revert allowing 0.0 model

* convert from model 1.3 to 2.0

* fix .gitignore
  • Loading branch information
njzjz authored Jun 22, 2021
1 parent b15944d commit 4555034
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 4 deletions.
4 changes: 3 additions & 1 deletion deepmd/entrypoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .train import train
from .transfer import transfer
from ..infer.model_devi import make_model_devi
from .convert import convert

__all__ = [
"config",
Expand All @@ -18,5 +19,6 @@
"transfer",
"compress",
"doc_train_input",
"make_model_devi"
"make_model_devi",
"convert",
]
13 changes: 13 additions & 0 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deepmd.utils.convert import convert_13_to_20

def convert(
*,
FROM: str,
input_model: str,
output_model: str,
**kwargs,
):
if FROM == '1.3':
convert_13_to_20(input_model, output_model)
else:
raise RuntimeError('unsupported model version ' + FROM)
33 changes: 32 additions & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import logging
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional

from deepmd.entrypoints import (
compress,
Expand All @@ -14,6 +14,7 @@
train,
transfer,
make_model_devi,
convert,
)
from deepmd.loggers import set_log_handles

Expand Down Expand Up @@ -359,6 +360,34 @@ def parse_args(args: Optional[List[str]] = None):
help="The trajectory frequency of the system"
)

# * convert models
# supported: 1.3->2.0
parser_transform = subparsers.add_parser(
'convert-from',
parents=[parser_log],
help='convert lower model version to supported version',
)
parser_transform.add_argument(
'FROM',
type = str,
choices = ['1.3'],
help="The original model compatibility",
)
parser_transform.add_argument(
'-i',
"--input-model",
default = "frozen_model.pb",
type=str,
help = "the input model",
)
parser_transform.add_argument(
"-o",
"--output-model",
default = "convert_out.pb",
type=str,
help='the output model',
)

parsed_args = parser.parse_args(args=args)
if parsed_args.command is None:
parser.print_help()
Expand Down Expand Up @@ -402,6 +431,8 @@ def main():
doc_train_input()
elif args.command == "model-devi":
make_model_devi(**dict_args)
elif args.command == "convert-from":
convert(**dict_args)
elif args.command is None:
pass
else:
Expand Down
59 changes: 59 additions & 0 deletions deepmd/utils/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from deepmd.env import tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile

def convert_13_to_20(input_model: str, output_model: str):
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
if os.path.isfile('frozen_model.pbtxt'):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.0 support) is saved in %s" % output_model)

def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):
with gfile.FastGFile(pbfile, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True)

def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str):
with tf.gfile.FastGFile(pbtxtfile, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, './', pbfile, as_text=False)

def convert_dp13_to_dp20(fname: str):
with open(fname) as fp:
file_content = fp.read()
file_content += """
node {
name: "model_attr/model_version"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "1.0"
}
}
}
}
"""
file_content = file_content\
.replace('DescrptSeA', 'ProdEnvMatA')\
.replace('DescrptSeR', 'ProdEnvMatR')
with open(fname, 'w') as fp:
fp.write(file_content)
4 changes: 4 additions & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ void
get_env_nthreads(int & num_intra_nthreads,
int & num_inter_nthreads);

struct
tf_exception: public std::exception {
};

/**
* @brief Check TensorFlow status. Exit if not OK.
* @param[in] status TensorFlow status.
Expand Down
5 changes: 5 additions & 0 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,12 @@ init (const std::string & model, const int & gpu_rank, const std::string & file_
if (dfparam < 0) dfparam = 0;
if (daparam < 0) daparam = 0;
model_type = get_scalar<STRINGTYPE>("model_attr/model_type");
try{
model_version = get_scalar<STRINGTYPE>("model_attr/model_version");
} catch (deepmd::tf_exception& e){
// no model version defined in old models
model_version = "0.0";
}
if(! model_compatable(model_version)){
throw std::runtime_error(
"incompatable model: version " + model_version
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ deepmd::
check_status(const tensorflow::Status& status) {
if (!status.ok()) {
std::cout << status.ToString() << std::endl;
exit(1);
throw deepmd::tf_exception();
}
}

Expand Down
53 changes: 52 additions & 1 deletion source/op/prod_env_mat_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ REGISTER_OP("ProdEnvMatA")
.Output("nlist: int32");
// only sel_a and rcut_r uesd.

// an alias of ProdEnvMatA -- Compatible with v1.3
REGISTER_OP("DescrptSeA")
.Attr("T: {float, double}")
.Input("coord: T")
.Input("type: int32")
.Input("natoms: int32")
.Input("box : T")
.Input("mesh : int32")
.Input("davg: T")
.Input("dstd: T")
.Attr("rcut_a: float")
.Attr("rcut_r: float")
.Attr("rcut_r_smth: float")
.Attr("sel_a: list(int)")
.Attr("sel_r: list(int)")
.Output("descrpt: T")
.Output("descrpt_deriv: T")
.Output("rij: T")
.Output("nlist: int32");

REGISTER_OP("ProdEnvMatR")
.Attr("T: {float, double}")
.Input("coord: T")
Expand All @@ -42,6 +62,23 @@ REGISTER_OP("ProdEnvMatR")
.Output("rij: T")
.Output("nlist: int32");

// an alias of ProdEnvMatR -- Compatible with v1.3
REGISTER_OP("DescrptSeR")
.Attr("T: {float, double}")
.Input("coord: T")
.Input("type: int32")
.Input("natoms: int32")
.Input("box: T")
.Input("mesh: int32")
.Input("davg: T")
.Input("dstd: T")
.Attr("rcut: float")
.Attr("rcut_smth: float")
.Attr("sel: list(int)")
.Output("descrpt: T")
.Output("descrpt_deriv: T")
.Output("rij: T")
.Output("nlist: int32");

template<typename FPTYPE>
static int
Expand Down Expand Up @@ -1364,24 +1401,38 @@ _prepare_coord_nlist_gpu_rocm(


// Register the CPU kernels.
// Compatible with v1.3
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatROp<CPUDevice, T>);
ProdEnvMatROp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatROp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(double);

// Register the GPU kernels.
// Compatible with v1.3
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatROp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatROp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(double);
Expand Down

0 comments on commit 4555034

Please sign in to comment.