Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support converting models generated in v1.3 to 2.0 compatibility #725

Merged
merged 11 commits into from
Jun 22, 2021
4 changes: 3 additions & 1 deletion deepmd/entrypoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .train import train
from .transfer import transfer
from ..infer.model_devi import make_model_devi
from .convert import convert

__all__ = [
"config",
Expand All @@ -18,5 +19,6 @@
"transfer",
"compress",
"doc_train_input",
"make_model_devi"
"make_model_devi",
"convert",
]
13 changes: 13 additions & 0 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deepmd.utils.convert import convert_13_to_20

def convert(
*,
FROM: str,
input_model: str,
output_model: str,
**kwargs,
):
if FROM == '1.3':
convert_13_to_20(input_model, output_model)
else:
raise RuntimeError('unsupported model version ' + FROM)
33 changes: 32 additions & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import logging
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional

from deepmd.entrypoints import (
compress,
Expand All @@ -14,6 +14,7 @@
train,
transfer,
make_model_devi,
convert,
)
from deepmd.loggers import set_log_handles

Expand Down Expand Up @@ -359,6 +360,34 @@ def parse_args(args: Optional[List[str]] = None):
help="The trajectory frequency of the system"
)

# * convert models
# supported: 1.3->2.0
parser_transform = subparsers.add_parser(
'convert-from',
parents=[parser_log],
help='convert lower model version to supported version',
)
parser_transform.add_argument(
'FROM',
type = str,
choices = ['1.3'],
help="The original model compatibility",
)
parser_transform.add_argument(
'-i',
"--input-model",
default = "frozen_model.pb",
type=str,
help = "the input model",
)
parser_transform.add_argument(
"-o",
"--output-model",
default = "convert_out.pb",
type=str,
help='the output model',
)

parsed_args = parser.parse_args(args=args)
if parsed_args.command is None:
parser.print_help()
Expand Down Expand Up @@ -402,6 +431,8 @@ def main():
doc_train_input()
elif args.command == "model-devi":
make_model_devi(**dict_args)
elif args.command == "convert-from":
convert(**dict_args)
elif args.command is None:
pass
else:
Expand Down
59 changes: 59 additions & 0 deletions deepmd/utils/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from deepmd.env import tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile

def convert_13_to_20(input_model: str, output_model: str):
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
if os.path.isfile('frozen_model.pbtxt'):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.0 support) is saved in %s" % output_model)

def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):
with gfile.FastGFile(pbfile, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True)

def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str):
with tf.gfile.FastGFile(pbtxtfile, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, './', pbfile, as_text=False)

def convert_dp13_to_dp20(fname: str):
with open(fname) as fp:
file_content = fp.read()
file_content += """
node {
name: "model_attr/model_version"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "1.0"
}
}
}
}
"""
file_content = file_content\
.replace('DescrptSeA', 'ProdEnvMatA')\
.replace('DescrptSeR', 'ProdEnvMatR')
with open(fname, 'w') as fp:
fp.write(file_content)
4 changes: 4 additions & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ void
get_env_nthreads(int & num_intra_nthreads,
int & num_inter_nthreads);

struct
tf_exception: public std::exception {
};

/**
* @brief Check TensorFlow status. Exit if not OK.
* @param[in] status TensorFlow status.
Expand Down
5 changes: 5 additions & 0 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,12 @@ init (const std::string & model, const int & gpu_rank, const std::string & file_
if (dfparam < 0) dfparam = 0;
if (daparam < 0) daparam = 0;
model_type = get_scalar<STRINGTYPE>("model_attr/model_type");
try{
model_version = get_scalar<STRINGTYPE>("model_attr/model_version");
} catch (deepmd::tf_exception& e){
// no model version defined in old models
model_version = "0.0";
}
if(! model_compatable(model_version)){
throw std::runtime_error(
"incompatable model: version " + model_version
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ deepmd::
check_status(const tensorflow::Status& status) {
if (!status.ok()) {
std::cout << status.ToString() << std::endl;
exit(1);
throw deepmd::tf_exception();
}
}

Expand Down
53 changes: 52 additions & 1 deletion source/op/prod_env_mat_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ REGISTER_OP("ProdEnvMatA")
.Output("nlist: int32");
// only sel_a and rcut_r uesd.

// an alias of ProdEnvMatA -- Compatible with v1.3
REGISTER_OP("DescrptSeA")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to register DescrptSeA? the op name has changed to ProdEnvMatA by convert

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to load the old graph first when executing convert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

.Attr("T: {float, double}")
.Input("coord: T")
.Input("type: int32")
.Input("natoms: int32")
.Input("box : T")
.Input("mesh : int32")
.Input("davg: T")
.Input("dstd: T")
.Attr("rcut_a: float")
.Attr("rcut_r: float")
.Attr("rcut_r_smth: float")
.Attr("sel_a: list(int)")
.Attr("sel_r: list(int)")
.Output("descrpt: T")
.Output("descrpt_deriv: T")
.Output("rij: T")
.Output("nlist: int32");

REGISTER_OP("ProdEnvMatR")
.Attr("T: {float, double}")
.Input("coord: T")
Expand All @@ -42,6 +62,23 @@ REGISTER_OP("ProdEnvMatR")
.Output("rij: T")
.Output("nlist: int32");

// an alias of ProdEnvMatR -- Compatible with v1.3
REGISTER_OP("DescrptSeR")
.Attr("T: {float, double}")
.Input("coord: T")
.Input("type: int32")
.Input("natoms: int32")
.Input("box: T")
.Input("mesh: int32")
.Input("davg: T")
.Input("dstd: T")
.Attr("rcut: float")
.Attr("rcut_smth: float")
.Attr("sel: list(int)")
.Output("descrpt: T")
.Output("descrpt_deriv: T")
.Output("rij: T")
.Output("nlist: int32");

template<typename FPTYPE>
static int
Expand Down Expand Up @@ -1364,24 +1401,38 @@ _prepare_coord_nlist_gpu_rocm(


// Register the CPU kernels.
// Compatible with v1.3
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatROp<CPUDevice, T>);
ProdEnvMatROp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatROp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(double);

// Register the GPU kernels.
// Compatible with v1.3
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdEnvMatR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatROp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatROp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(double);
Expand Down