Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 21, 2024
1 parent 71a4b55 commit 5024f70
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 41 deletions.
8 changes: 6 additions & 2 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class JAXBackend(Backend):
# | Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".saved_model", ".jax"]
suffixes: ClassVar[list[str]] = [".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -93,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
Expand Down
97 changes: 60 additions & 37 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
from pathlib import (
Path,
)

import orbax.checkpoint as ocp

from deepmd.jax.env import (
jax2tf,
jax,
nnx,
)
from deepmd.jax.model.model import (
BaseModel,
get_model,
)
from deepmd.jax.utils.network import (
ArrayAPIParam,
Expand All @@ -23,41 +28,59 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".saved_model"):
model = BaseModel.deserialize(data["model"])
model_def_script = data.get("model_def_script", "{}")
my_model = tf.Module()
my_model.f = tf.function(
jax2tf.convert(
model,
polymorphic_shapes=[
"(b, n, 3)",
"(b, n)",
"(b, 3, 3)",
"(b, f)",
"(b, a)",
"()",
],
),
autograph=False,
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, 3, 3], tf.float64),
tf.TensorSpec([None, None], tf.float64),
tf.TensorSpec([None, None], tf.float64),
tf.TensorSpec([], tf.bool),
],
)
my_model.model_def_script = model_def_script
tf.saved_model.save(
my_model,
model_file,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
elif model_file.endswith(".jax"):
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
state = nnx.state(model, ArrayAPIParam)
nnx.display(state)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
checkpointer.save(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardSave(state),
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
else:
raise ValueError("JAX backend only supports converting .jax directory")


def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.
Parameters
----------
model_file : str
The model file to be serialized.
Returns
-------
dict
The serialized model data.
"""
if model_file.endswith(".jax"):
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
data = checkpointer.restore(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardRestore(),
model_def_script=ocp.args.JsonRestore(),
),
)
state = data.state
model_def_script = data.model_def_script
model = get_model(model_def_script)
model_dict = model.serialize()
data = {
"backend": "JAX",
"jax_version": jax.__version__,
"model": model_dict,
"model_def_script": model_def_script,
"@variables": {},
}
return data
else:
raise ValueError("JAX backend only supports converting .pth file")
raise ValueError("JAX backend only supports converting .jax directory")
8 changes: 6 additions & 2 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import shutil
import unittest
from pathlib import (
Path,
Expand Down Expand Up @@ -60,12 +61,14 @@ def save_data_to_model(self, model_file: str, data: dict) -> None:
def tearDown(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for ii in Path(".").glob(prefix + ".*"):
if Path(ii).exists():
if Path(ii).is_file():
Path(ii).unlink()
elif Path(ii).is_dir():
shutil.rmtree(ii)

def test_data_equal(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
with self.subTest(backend_name=backend_name):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
Expand All @@ -80,6 +83,7 @@ def test_data_equal(self):
"backend",
"tf_version",
"pt_version",
"jax_version",
"@variables",
# dpmodel only
"software",
Expand Down

0 comments on commit 5024f70

Please sign in to comment.