Skip to content

Commit

Permalink
feat(jax): export call_lower to SavedModel via jax2tf (#4254)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Added support for the TensorFlow SavedModel format, allowing users to
handle additional model file types.
- Introduced a new TensorFlow model wrapper class for enhanced
integration with JAX functionalities.
  
- **Bug Fixes**
- Improved error handling for unsupported file formats during model
deserialization.

- **Documentation**
- Updated backend documentation to reflect new file extensions and
clarify backend capabilities.

- **Tests**
- Enhanced test structure for better clarity and maintainability
regarding backend handling.
- Added a new job for testing TensorFlow 2 in eager mode within the
testing workflow.
- Introduced a conditional skip for tests based on TensorFlow 2
compatibility.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 4, 2024
1 parent 7aaf284 commit 38815b3
Show file tree
Hide file tree
Showing 11 changed files with 568 additions and 22 deletions.
18 changes: 14 additions & 4 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,23 @@ jobs:
python-version: ${{ matrix.python }}
- run: python -m pip install -U uv
- run: |
source/install/uv_with_retry.sh pip install --system mpich
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])')
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py
source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation
env:
# Please note that uv has some issues with finding
# existing TensorFlow package. Currently, it uses
# TensorFlow in the build dependency, but if it
# changes, setting `TENSORFLOW_ROOT`.
TENSORFLOW_VERSION: 2.16.1
DP_ENABLE_PYTORCH: 1
DP_BUILD_TESTING: 1
UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/njzjz/simple https://pypi.anaconda.org/mpi4py/simple"
UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/mpi4py/simple"
HOROVOD_WITH_TENSORFLOW: 1
HOROVOD_WITHOUT_PYTORCH: 1
HOROVOD_WITH_MPI: 1
- run: dp --version
- name: Get durations from cache
uses: actions/cache@v4
Expand All @@ -53,6 +57,12 @@ jobs:
- run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration
env:
NUM_WORKERS: 0
- name: Test TF2 eager mode
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class JAXBackend(Backend):
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down
27 changes: 18 additions & 9 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,24 @@ def __init__(
self.output_def = output_def
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
if model_file.endswith(".hlo"):
model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
elif model_file.endswith(".savedmodel"):
from deepmd.jax.jax2tf.tfmodel import (
TFModelWrapper,
)

self.dp = TFModelWrapper(model_file)
else:
raise ValueError("Unsupported file extension")
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
11 changes: 11 additions & 0 deletions deepmd/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf

if not tf.executing_eagerly():
# TF disallow temporary eager execution
raise RuntimeError(
"Unfortunatly, jax2tf (requires eager execution) cannot be used with the "
"TensorFlow backend (disables eager execution). "
"If you are converting a model between different backends, "
"considering converting to the `.dp` format first."
)
172 changes: 172 additions & 0 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import tensorflow as tf
from jax.experimental import (
jax2tf,
)

from deepmd.jax.model.base_model import (
BaseModel,
)


def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.
Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".savedmodel"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
call_lower = model.call_lower

tf_model = tf.Module()

def exported_whether_do_atomic_virial(do_atomic_virial):
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
return call_lower(
coord,
atype,
nlist,
mapping,
fparam,
aparam,
do_atomic_virial=do_atomic_virial,
)

return jax2tf.convert(
call_lower_with_fixed_do_atomic_virial,
polymorphic_shapes=[
"(nf, nloc + nghost, 3)",
"(nf, nloc + nghost)",
f"(nf, nloc, {model.get_nnei()})",
"(nf, nloc + nghost)",
f"(nf, {model.get_dim_fparam()})",
f"(nf, nloc, {model.get_dim_aparam()})",
],
with_gradient=True,
)

# Save a function that can take scalar inputs.
# We need to explicit set the function name, so C++ can find it.
@tf.function(
autograph=False,
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
return exported_whether_do_atomic_virial(do_atomic_virial=False)(
coord, atype, nlist, mapping, fparam, aparam
)

tf_model.call_lower = call_lower_without_atomic_virial

@tf.function(
autograph=False,
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
return exported_whether_do_atomic_virial(do_atomic_virial=True)(
coord, atype, nlist, mapping, fparam, aparam
)

tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial

# set functions to export other attributes
@tf.function
def get_type_map():
return tf.constant(model.get_type_map(), dtype=tf.string)

tf_model.get_type_map = get_type_map

@tf.function
def get_rcut():
return tf.constant(model.get_rcut(), dtype=tf.double)

tf_model.get_rcut = get_rcut

@tf.function
def get_dim_fparam():
return tf.constant(model.get_dim_fparam(), dtype=tf.int64)

tf_model.get_dim_fparam = get_dim_fparam

@tf.function
def get_dim_aparam():
return tf.constant(model.get_dim_aparam(), dtype=tf.int64)

tf_model.get_dim_aparam = get_dim_aparam

@tf.function
def get_sel_type():
return tf.constant(model.get_sel_type(), dtype=tf.int64)

tf_model.get_sel_type = get_sel_type

@tf.function
def is_aparam_nall():
return tf.constant(model.is_aparam_nall(), dtype=tf.bool)

tf_model.is_aparam_nall = is_aparam_nall

@tf.function
def model_output_type():
return tf.constant(model.model_output_type(), dtype=tf.string)

tf_model.model_output_type = model_output_type

@tf.function
def mixed_types():
return tf.constant(model.mixed_types(), dtype=tf.bool)

tf_model.mixed_types = mixed_types

if model.get_min_nbor_dist() is not None:

@tf.function
def get_min_nbor_dist():
return tf.constant(model.get_min_nbor_dist(), dtype=tf.double)

tf_model.get_min_nbor_dist = get_min_nbor_dist

@tf.function
def get_sel():
return tf.constant(model.get_sel(), dtype=tf.int64)

tf_model.get_sel = get_sel

@tf.function
def get_model_def_script():
return tf.constant(
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
)

tf_model.get_model_def_script = get_model_def_script
tf.saved_model.save(
tf_model,
model_file,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
Loading

0 comments on commit 38815b3

Please sign in to comment.