Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into tmqm_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisiacovella committed Nov 14, 2024
2 parents 08a3ccb + 80176df commit 2a15624
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
27 changes: 27 additions & 0 deletions docs/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@ To use the trained model for inference, the checkpoint file generated during tra
:maxdepth: 2
:caption: Contents:

Note, prior to merging PR #299, checkpoint files and state_dicts did not save the `only_unique_pairs` bool parameter, needed to properly generate neighbor information. As such, if you are using a checkpoint file generated prior to this PR, you will need to set this parameter manually. This can be done by passing the `only_unique_pairs` parameter to the `load_inference_model_from_checkpoint` function. For example, for ANI2x models, where this should be True (other currently implemented potentials require False):

.. code-block:: python
from modelforge.potential.models import load_inference_model_from_checkpoint
inference_model = load_inference_model_from_checkpoint(checkpoint_file, only_unique_pairs=True)
To modify state dictionary files, this can be done easily via the `modify_state_dict` function in the file `modify_state_dict.py` in the scripts directory. This will generate a new copy of the state dictionary file with the appropriate `only_unique_pairs` parameter set.

Loading a checkpoint from weights and biases
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Checkpoint files can be loaded directly from wandb using the `load_from_wandb` function as part of the `NeuralNetworkPotentialFactory`. This can be done by passing the wandb run id and appropriate version number. Note this will require authentication with wandb for users part of the project. The following code snippet demonstrates how to load a model from wandb.

.. code-block:: python
from modelforge.potential.potential import NeuralNetworkPotentialFactory
nn_potential = NeuralNetworkPotentialFactory().load_from_wandb(
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
)
Using a model for inference in OpenMM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
27 changes: 23 additions & 4 deletions modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ def load_state_dict(
assign : bool, optional
Whether to assign the state dictionary to the model directly
(default is False).
legacy : bool, optional
Earlier version of the potential model did not include only_unique_pairs in the
Notes
-----
This function can remove a specific prefix from the keys in the state
Expand Down Expand Up @@ -688,7 +689,11 @@ def generate_potential(

@staticmethod
def load_from_wandb(
*, run_path: str, version: str, local_cache_dir: str = "./"
*,
run_path: str,
version: str,
local_cache_dir: str = "./",
only_unique_pairs: Optional[bool] = None,
) -> Union[Potential, JAXModel]:
"""
Load a neural network potential from a Weights & Biases run.
Expand All @@ -700,7 +705,10 @@ def load_from_wandb(
version : str
The version of the run to load.
local_cache_dir : str, optional
The local cache directory for downloading the model (default is "./").
The local cache directory for downloading the model (default is "./"),
only_unique_pairs : Optional[bool], optional
For models trained prior to PR #299 in modelforge, this parameter is required to be able to read the model.
This value should be True for the ANI models, False for most other models.
Returns
-------
Expand All @@ -714,7 +722,9 @@ def load_from_wandb(
artifact = run.use_artifact(artifact_path)
artifact_dir = artifact.download(root=local_cache_dir)
checkpoint_file = f"{artifact_dir}/model.ckpt"
potential = load_inference_model_from_checkpoint(checkpoint_file)
potential = load_inference_model_from_checkpoint(
checkpoint_file, only_unique_pairs
)

return potential

Expand Down Expand Up @@ -892,6 +902,7 @@ def apply_bwd(res, grads):

def load_inference_model_from_checkpoint(
checkpoint_path: str,
only_unique_pairs: Optional[bool] = None,
) -> Union[Potential, JAXModel]:
"""
Creates an inference model from a checkpoint file.
Expand All @@ -901,6 +912,10 @@ def load_inference_model_from_checkpoint(
----------
checkpoint_path : str
The path to the checkpoint file.
only_unique_pairs : Optional[bool], optional
If defined, this will set the only_unique_pairs key in the neighborlist module. This is only needed
for models trained prior to PR #299 in modelforge. (default is None).
In the case of ANI models, this should be set to True. Typically False for other mdoels
"""

# Load the checkpoint
Expand All @@ -918,6 +933,10 @@ def load_inference_model_from_checkpoint(
dataset_statistic=dataset_statistic,
potential_seed=potential_seed,
)
if only_unique_pairs is not None:
checkpoint["state_dict"]["neighborlist.only_unique_pairs"] = torch.Tensor(
[only_unique_pairs]
)

# Load the state dict into the model
potential.load_state_dict(checkpoint["state_dict"])
Expand Down
9 changes: 5 additions & 4 deletions modelforge/tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ def test_state_dict_saving_and_loading(potential_name, prep_temp_dir):
potential.load_state_dict(torch.load(file_path))


@pytest.mark.xfail(
reason="checkpoint file needs to be updated now that non_unique_pairs is registered in nlist"
)
# @pytest.mark.xfail(
# reason="checkpoint file needs to be updated now that non_unique_pairs is registered in nlist"
# )
def test_loading_from_checkpoint_file():
from importlib import resources
from modelforge.tests import data
Expand All @@ -361,7 +361,8 @@ def test_loading_from_checkpoint_file():

from modelforge.potential.potential import load_inference_model_from_checkpoint

potential = load_inference_model_from_checkpoint(ckpt_file)
# note this is a legacy file, and thus we need to manually define only_unique_pairs
potential = load_inference_model_from_checkpoint(ckpt_file, only_unique_pairs=False)
assert potential is not None


Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_load_from_wandb(prep_temp_dir):
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
only_unique_pairs=True,
)

assert os.path.isfile(f"{prep_temp_dir}/test_wandb/model.ckpt")
Expand Down
42 changes: 42 additions & 0 deletions scripts/modify_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Script to modify a state dict to include only_unique_pairs dictionary key.
This is only necessary for models trained prior to PR #299 in modelforge, that provides
integration with OpenMM and some refactoring of the neighborlisting schemes.
"""


def modify_state_dict(
state_dict_input_file_path: str,
state_dict_output_file_path: str,
only_unique_pairs: bool,
):
"""
Modify a state dict to include the only_unique_pairs dictionary key.
Parameters
----------
state_dict_input_file_path: str
Input file with path to the input state dict file
state_dict_output_file_path: str
Output file with path to the output state dict file
only_unique_pairs: bool
Boolean value to set the only_unique_pairs key for the neighborlist
This value should be True for the ANI models, False for most other models.
Returns
-------
"""
import torch

# Load the state dict
state_dict = torch.load(state_dict_input_file_path)

# Set the only_unique_pairs key
state_dict["neighborlist.only_unique_pairs"] = torch.Tensor([only_unique_pairs])

# Save the modified state dict
torch.save(state_dict, state_dict_output_file_path)

0 comments on commit 2a15624

Please sign in to comment.