Skip to content

Commit

Permalink
Merge pull request #311 from chrisiacovella/fetch_from_wandb
Browse files Browse the repository at this point in the history
Load a checkpoint file from wandb directly into a potential. Enable ability to load "legacy" checkpoint files.
  • Loading branch information
chrisiacovella authored Nov 14, 2024
2 parents dbae94c + ee7e6cf commit 80176df
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ jobs:
python -m pip install . --no-deps
micromamba list
- name: Run tests
# conda setup requires this special shell
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
options: "--check --verbose --line-length 88"
options: "--check --diff --verbose --line-length 88"
src: "./modelforge"
4 changes: 3 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies:
- pydantic>=2
- ray-all
- graphviz
- wandb>=0.18.5
- pytorch

# Testing
- pytest>=2.1
Expand All @@ -39,6 +41,6 @@ dependencies:
- pytorch2jax
- git+https://github.com/ArnNag/sake.git@nanometer
- flax
- torch
# - torch
- pytest-xdist

2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env_mac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- flax
- pydantic>=2.0
- graphviz
-
- wandb>=0.18.5

# Testing
- pytest>=2.1
Expand Down
27 changes: 27 additions & 0 deletions docs/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@ To use the trained model for inference, the checkpoint file generated during tra
:maxdepth: 2
:caption: Contents:

Note, prior to merging PR #299, checkpoint files and state_dicts did not save the `only_unique_pairs` bool parameter, needed to properly generate neighbor information. As such, if you are using a checkpoint file generated prior to this PR, you will need to set this parameter manually. This can be done by passing the `only_unique_pairs` parameter to the `load_inference_model_from_checkpoint` function. For example, for ANI2x models, where this should be True (other currently implemented potentials require False):

.. code-block:: python
from modelforge.potential.models import load_inference_model_from_checkpoint
inference_model = load_inference_model_from_checkpoint(checkpoint_file, only_unique_pairs=True)
To modify state dictionary files, this can be done easily via the `modify_state_dict` function in the file `modify_state_dict.py` in the scripts directory. This will generate a new copy of the state dictionary file with the appropriate `only_unique_pairs` parameter set.

Loading a checkpoint from weights and biases
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Checkpoint files can be loaded directly from wandb using the `load_from_wandb` function as part of the `NeuralNetworkPotentialFactory`. This can be done by passing the wandb run id and appropriate version number. Note this will require authentication with wandb for users part of the project. The following code snippet demonstrates how to load a model from wandb.

.. code-block:: python
from modelforge.potential.potential import NeuralNetworkPotentialFactory
nn_potential = NeuralNetworkPotentialFactory().load_from_wandb(
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
)
Using a model for inference in OpenMM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
76 changes: 67 additions & 9 deletions modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
This module contains the base classes for the neural network potentials.
"""

from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Callable,
Literal,
TYPE_CHECKING,
List,
Mapping,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
)

import lightning as pl
import torch
Expand Down Expand Up @@ -36,8 +49,6 @@
)


from typing import Callable, Literal, Optional, Union, TYPE_CHECKING

if TYPE_CHECKING:
from modelforge.train.training import PotentialTrainer

Expand Down Expand Up @@ -98,7 +109,6 @@ def __repr__(self):


class PostProcessing(torch.nn.Module):

_SUPPORTED_PROPERTIES = [
"per_atom_energy",
"per_atom_charge",
Expand Down Expand Up @@ -159,14 +169,14 @@ def __init__(
]
== "coulomb"
):

self.registered_chained_operations["electrostatic_potential"] = (
CoulombPotential(
postprocessing_parameter["electrostatic_potential"][
"maximum_interaction_radius"
],
)
)

self._registered_properties.append("electrostatic_potential")
assert all(
prop in PostProcessing._SUPPORTED_PROPERTIES
Expand Down Expand Up @@ -194,7 +204,6 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
processed_data: Dict[str, torch.Tensor] = {}
# Iterate over items in ModuleDict
for name, module in self.registered_chained_operations.items():

module_output = module.forward(data)
processed_data.update(module_output)

Expand Down Expand Up @@ -455,7 +464,8 @@ def load_state_dict(
assign : bool, optional
Whether to assign the state dictionary to the model directly
(default is False).
legacy : bool, optional
Earlier version of the potential model did not include only_unique_pairs in the
Notes
-----
This function can remove a specific prefix from the keys in the state
Expand Down Expand Up @@ -584,7 +594,6 @@ def setup_potential(


class NeuralNetworkPotentialFactory:

@staticmethod
def generate_potential(
*,
Expand Down Expand Up @@ -678,6 +687,47 @@ def generate_potential(
else:
return potential

@staticmethod
def load_from_wandb(
*,
run_path: str,
version: str,
local_cache_dir: str = "./",
only_unique_pairs: Optional[bool] = None,
) -> Union[Potential, JAXModel]:
"""
Load a neural network potential from a Weights & Biases run.
Parameters
----------
run_path : str
The path to the Weights & Biases run.
version : str
The version of the run to load.
local_cache_dir : str, optional
The local cache directory for downloading the model (default is "./"),
only_unique_pairs : Optional[bool], optional
For models trained prior to PR #299 in modelforge, this parameter is required to be able to read the model.
This value should be True for the ANI models, False for most other models.
Returns
-------
Union[Potential, JAXModel]
An instantiated neural network potential for training or inference.
"""
import wandb

run = wandb.init()
artifact_path = f"{run_path}:{version}"
artifact = run.use_artifact(artifact_path)
artifact_dir = artifact.download(root=local_cache_dir)
checkpoint_file = f"{artifact_dir}/model.ckpt"
potential = load_inference_model_from_checkpoint(
checkpoint_file, only_unique_pairs
)

return potential

@staticmethod
def generate_trainer(
*,
Expand Down Expand Up @@ -852,6 +902,7 @@ def apply_bwd(res, grads):

def load_inference_model_from_checkpoint(
checkpoint_path: str,
only_unique_pairs: Optional[bool] = None,
) -> Union[Potential, JAXModel]:
"""
Creates an inference model from a checkpoint file.
Expand All @@ -861,8 +912,11 @@ def load_inference_model_from_checkpoint(
----------
checkpoint_path : str
The path to the checkpoint file.
only_unique_pairs : Optional[bool], optional
If defined, this will set the only_unique_pairs key in the neighborlist module. This is only needed
for models trained prior to PR #299 in modelforge. (default is None).
In the case of ANI models, this should be set to True. Typically False for other mdoels
"""
import torch

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
Expand All @@ -879,6 +933,10 @@ def load_inference_model_from_checkpoint(
dataset_statistic=dataset_statistic,
potential_seed=potential_seed,
)
if only_unique_pairs is not None:
checkpoint["state_dict"]["neighborlist.only_unique_pairs"] = torch.Tensor(
[only_unique_pairs]
)

# Load the state dict into the model
potential.load_state_dict(checkpoint["state_dict"])
Expand Down
9 changes: 5 additions & 4 deletions modelforge/tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ def test_state_dict_saving_and_loading(potential_name, prep_temp_dir):
potential.load_state_dict(torch.load(file_path))


@pytest.mark.xfail(
reason="checkpoint file needs to be updated now that non_unique_pairs is registered in nlist"
)
# @pytest.mark.xfail(
# reason="checkpoint file needs to be updated now that non_unique_pairs is registered in nlist"
# )
def test_loading_from_checkpoint_file():
from importlib import resources
from modelforge.tests import data
Expand All @@ -361,7 +361,8 @@ def test_loading_from_checkpoint_file():

from modelforge.potential.potential import load_inference_model_from_checkpoint

potential = load_inference_model_from_checkpoint(ckpt_file)
# note this is a legacy file, and thus we need to manually define only_unique_pairs
potential = load_inference_model_from_checkpoint(ckpt_file, only_unique_pairs=False)
assert potential is not None


Expand Down
21 changes: 21 additions & 0 deletions modelforge/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from modelforge.utils.remote import *

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.fixture(scope="session")
def prep_temp_dir(tmp_path_factory):
Expand Down Expand Up @@ -113,3 +115,22 @@ def test_md5_calculation(prep_temp_dir):
output_filename=name,
force_download=True,
)


@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="Skipping; requires authentication which cannot be done via PR from fork ",
)
def test_load_from_wandb(prep_temp_dir):
from modelforge.potential.potential import NeuralNetworkPotentialFactory

nn_potential = NeuralNetworkPotentialFactory().load_from_wandb(
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
only_unique_pairs=True,
)

assert os.path.isfile(f"{prep_temp_dir}/test_wandb/model.ckpt")

assert nn_potential is not None
42 changes: 42 additions & 0 deletions scripts/modify_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Script to modify a state dict to include only_unique_pairs dictionary key.
This is only necessary for models trained prior to PR #299 in modelforge, that provides
integration with OpenMM and some refactoring of the neighborlisting schemes.
"""


def modify_state_dict(
state_dict_input_file_path: str,
state_dict_output_file_path: str,
only_unique_pairs: bool,
):
"""
Modify a state dict to include the only_unique_pairs dictionary key.
Parameters
----------
state_dict_input_file_path: str
Input file with path to the input state dict file
state_dict_output_file_path: str
Output file with path to the output state dict file
only_unique_pairs: bool
Boolean value to set the only_unique_pairs key for the neighborlist
This value should be True for the ANI models, False for most other models.
Returns
-------
"""
import torch

# Load the state dict
state_dict = torch.load(state_dict_input_file_path)

# Set the only_unique_pairs key
state_dict["neighborlist.only_unique_pairs"] = torch.Tensor([only_unique_pairs])

# Save the modified state dict
torch.save(state_dict, state_dict_output_file_path)

0 comments on commit 80176df

Please sign in to comment.