diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 6ffea0aacd..bb2fba5a7c 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): # | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".saved_model", ".jax"] + suffixes: ClassVar[list[str]] = [".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: @@ -93,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file @property def deserialize_hook(self) -> Callable[[str, dict], None]: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index ca915d61e3..aa41e35f69 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import tensorflow as tf +from pathlib import ( + Path, +) + +import orbax.checkpoint as ocp from deepmd.jax.env import ( - jax2tf, + jax, nnx, ) from deepmd.jax.model.model import ( BaseModel, + get_model, ) from deepmd.jax.utils.network import ( ArrayAPIParam, @@ -23,41 +28,59 @@ def deserialize_to_file(model_file: str, data: dict) -> None: data : dict The dictionary to be deserialized. """ - if model_file.endswith(".saved_model"): - model = BaseModel.deserialize(data["model"]) - model_def_script = data.get("model_def_script", "{}") - my_model = tf.Module() - my_model.f = tf.function( - jax2tf.convert( - model, - polymorphic_shapes=[ - "(b, n, 3)", - "(b, n)", - "(b, 3, 3)", - "(b, f)", - "(b, a)", - "()", - ], - ), - autograph=False, - input_signature=[ - tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int64), - tf.TensorSpec([None, 3, 3], tf.float64), - tf.TensorSpec([None, None], tf.float64), - tf.TensorSpec([None, None], tf.float64), - tf.TensorSpec([], tf.bool), - ], - ) - my_model.model_def_script = model_def_script - tf.saved_model.save( - my_model, - model_file, - options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), - ) - elif model_file.endswith(".jax"): + if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] state = nnx.state(model, ArrayAPIParam) - nnx.display(state) + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + checkpointer.save( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardSave(state), + model_def_script=ocp.args.JsonSave(model_def_script), + ), + ) + else: + raise ValueError("JAX backend only supports converting .jax directory") + + +def serialize_from_file(model_file: str) -> dict: + """Serialize the model file to a dictionary. + + Parameters + ---------- + model_file : str + The model file to be serialized. + + Returns + ------- + dict + The serialized model data. + """ + if model_file.endswith(".jax"): + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + data = checkpointer.restore( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardRestore(), + model_def_script=ocp.args.JsonRestore(), + ), + ) + state = data.state + model_def_script = data.model_def_script + model = get_model(model_def_script) + model_dict = model.serialize() + data = { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model_dict, + "model_def_script": model_def_script, + "@variables": {}, + } + return data else: - raise ValueError("JAX backend only supports converting .pth file") + raise ValueError("JAX backend only supports converting .jax directory") diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 71e4002128..0aaa0788ea 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import shutil import unittest from pathlib import ( Path, @@ -60,12 +61,14 @@ def save_data_to_model(self, model_file: str, data: dict) -> None: def tearDown(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for ii in Path(".").glob(prefix + ".*"): - if Path(ii).exists(): + if Path(ii).is_file(): Path(ii).unlink() + elif Path(ii).is_dir(): + shutil.rmtree(ii) def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() - for backend_name in ("tensorflow", "pytorch", "dpmodel"): + for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() if not backend.is_available: @@ -80,6 +83,7 @@ def test_data_equal(self): "backend", "tf_version", "pt_version", + "jax_version", "@variables", # dpmodel only "software",