-
Notifications
You must be signed in to change notification settings - Fork 519
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(jax): export call_lower to SavedModel via jax2tf (#4254)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the TensorFlow SavedModel format, allowing users to handle additional model file types. - Introduced a new TensorFlow model wrapper class for enhanced integration with JAX functionalities. - **Bug Fixes** - Improved error handling for unsupported file formats during model deserialization. - **Documentation** - Updated backend documentation to reflect new file extensions and clarify backend capabilities. - **Tests** - Enhanced test structure for better clarity and maintainability regarding backend handling. - Added a new job for testing TensorFlow 2 in eager mode within the testing workflow. - Introduced a conditional skip for tests based on TensorFlow 2 compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
11 changed files
with
568 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import tensorflow as tf | ||
|
||
if not tf.executing_eagerly(): | ||
# TF disallow temporary eager execution | ||
raise RuntimeError( | ||
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the " | ||
"TensorFlow backend (disables eager execution). " | ||
"If you are converting a model between different backends, " | ||
"considering converting to the `.dp` format first." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import json | ||
|
||
import tensorflow as tf | ||
from jax.experimental import ( | ||
jax2tf, | ||
) | ||
|
||
from deepmd.jax.model.base_model import ( | ||
BaseModel, | ||
) | ||
|
||
|
||
def deserialize_to_file(model_file: str, data: dict) -> None: | ||
"""Deserialize the dictionary to a model file. | ||
Parameters | ||
---------- | ||
model_file : str | ||
The model file to be saved. | ||
data : dict | ||
The dictionary to be deserialized. | ||
""" | ||
if model_file.endswith(".savedmodel"): | ||
model = BaseModel.deserialize(data["model"]) | ||
model_def_script = data["model_def_script"] | ||
call_lower = model.call_lower | ||
|
||
tf_model = tf.Module() | ||
|
||
def exported_whether_do_atomic_virial(do_atomic_virial): | ||
def call_lower_with_fixed_do_atomic_virial( | ||
coord, atype, nlist, mapping, fparam, aparam | ||
): | ||
return call_lower( | ||
coord, | ||
atype, | ||
nlist, | ||
mapping, | ||
fparam, | ||
aparam, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
|
||
return jax2tf.convert( | ||
call_lower_with_fixed_do_atomic_virial, | ||
polymorphic_shapes=[ | ||
"(nf, nloc + nghost, 3)", | ||
"(nf, nloc + nghost)", | ||
f"(nf, nloc, {model.get_nnei()})", | ||
"(nf, nloc + nghost)", | ||
f"(nf, {model.get_dim_fparam()})", | ||
f"(nf, nloc, {model.get_dim_aparam()})", | ||
], | ||
with_gradient=True, | ||
) | ||
|
||
# Save a function that can take scalar inputs. | ||
# We need to explicit set the function name, so C++ can find it. | ||
@tf.function( | ||
autograph=False, | ||
input_signature=[ | ||
tf.TensorSpec([None, None, 3], tf.float64), | ||
tf.TensorSpec([None, None], tf.int32), | ||
tf.TensorSpec([None, None, model.get_nnei()], tf.int64), | ||
tf.TensorSpec([None, None], tf.int64), | ||
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), | ||
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), | ||
], | ||
) | ||
def call_lower_without_atomic_virial( | ||
coord, atype, nlist, mapping, fparam, aparam | ||
): | ||
return exported_whether_do_atomic_virial(do_atomic_virial=False)( | ||
coord, atype, nlist, mapping, fparam, aparam | ||
) | ||
|
||
tf_model.call_lower = call_lower_without_atomic_virial | ||
|
||
@tf.function( | ||
autograph=False, | ||
input_signature=[ | ||
tf.TensorSpec([None, None, 3], tf.float64), | ||
tf.TensorSpec([None, None], tf.int32), | ||
tf.TensorSpec([None, None, model.get_nnei()], tf.int64), | ||
tf.TensorSpec([None, None], tf.int64), | ||
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), | ||
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), | ||
], | ||
) | ||
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): | ||
return exported_whether_do_atomic_virial(do_atomic_virial=True)( | ||
coord, atype, nlist, mapping, fparam, aparam | ||
) | ||
|
||
tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial | ||
|
||
# set functions to export other attributes | ||
@tf.function | ||
def get_type_map(): | ||
return tf.constant(model.get_type_map(), dtype=tf.string) | ||
|
||
tf_model.get_type_map = get_type_map | ||
|
||
@tf.function | ||
def get_rcut(): | ||
return tf.constant(model.get_rcut(), dtype=tf.double) | ||
|
||
tf_model.get_rcut = get_rcut | ||
|
||
@tf.function | ||
def get_dim_fparam(): | ||
return tf.constant(model.get_dim_fparam(), dtype=tf.int64) | ||
|
||
tf_model.get_dim_fparam = get_dim_fparam | ||
|
||
@tf.function | ||
def get_dim_aparam(): | ||
return tf.constant(model.get_dim_aparam(), dtype=tf.int64) | ||
|
||
tf_model.get_dim_aparam = get_dim_aparam | ||
|
||
@tf.function | ||
def get_sel_type(): | ||
return tf.constant(model.get_sel_type(), dtype=tf.int64) | ||
|
||
tf_model.get_sel_type = get_sel_type | ||
|
||
@tf.function | ||
def is_aparam_nall(): | ||
return tf.constant(model.is_aparam_nall(), dtype=tf.bool) | ||
|
||
tf_model.is_aparam_nall = is_aparam_nall | ||
|
||
@tf.function | ||
def model_output_type(): | ||
return tf.constant(model.model_output_type(), dtype=tf.string) | ||
|
||
tf_model.model_output_type = model_output_type | ||
|
||
@tf.function | ||
def mixed_types(): | ||
return tf.constant(model.mixed_types(), dtype=tf.bool) | ||
|
||
tf_model.mixed_types = mixed_types | ||
|
||
if model.get_min_nbor_dist() is not None: | ||
|
||
@tf.function | ||
def get_min_nbor_dist(): | ||
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) | ||
|
||
tf_model.get_min_nbor_dist = get_min_nbor_dist | ||
|
||
@tf.function | ||
def get_sel(): | ||
return tf.constant(model.get_sel(), dtype=tf.int64) | ||
|
||
tf_model.get_sel = get_sel | ||
|
||
@tf.function | ||
def get_model_def_script(): | ||
return tf.constant( | ||
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string | ||
) | ||
|
||
tf_model.get_model_def_script = get_model_def_script | ||
tf.saved_model.save( | ||
tf_model, | ||
model_file, | ||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), | ||
) |
Oops, something went wrong.