Skip to content

Commit

Permalink
update docs. Add option for "only unique pairs" to loading from wandb…
Browse files Browse the repository at this point in the history
… (now that openmm integration is merged).
  • Loading branch information
chrisiacovella committed Nov 14, 2024
1 parent 9f7b8e3 commit ee7e6cf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
27 changes: 27 additions & 0 deletions docs/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@ To use the trained model for inference, the checkpoint file generated during tra
:maxdepth: 2
:caption: Contents:

Note, prior to merging PR #299, checkpoint files and state_dicts did not save the `only_unique_pairs` bool parameter, needed to properly generate neighbor information. As such, if you are using a checkpoint file generated prior to this PR, you will need to set this parameter manually. This can be done by passing the `only_unique_pairs` parameter to the `load_inference_model_from_checkpoint` function. For example, for ANI2x models, where this should be True (other currently implemented potentials require False):

.. code-block:: python
from modelforge.potential.models import load_inference_model_from_checkpoint
inference_model = load_inference_model_from_checkpoint(checkpoint_file, only_unique_pairs=True)
To modify state dictionary files, this can be done easily via the `modify_state_dict` function in the file `modify_state_dict.py` in the scripts directory. This will generate a new copy of the state dictionary file with the appropriate `only_unique_pairs` parameter set.

Loading a checkpoint from weights and biases
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Checkpoint files can be loaded directly from wandb using the `load_from_wandb` function as part of the `NeuralNetworkPotentialFactory`. This can be done by passing the wandb run id and appropriate version number. Note this will require authentication with wandb for users part of the project. The following code snippet demonstrates how to load a model from wandb.

.. code-block:: python
from modelforge.potential.potential import NeuralNetworkPotentialFactory
nn_potential = NeuralNetworkPotentialFactory().load_from_wandb(
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
)
Using a model for inference in OpenMM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
15 changes: 12 additions & 3 deletions modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,11 @@ def generate_potential(

@staticmethod
def load_from_wandb(
*, run_path: str, version: str, local_cache_dir: str = "./"
*,
run_path: str,
version: str,
local_cache_dir: str = "./",
only_unique_pairs: Optional[bool] = None,
) -> Union[Potential, JAXModel]:
"""
Load a neural network potential from a Weights & Biases run.
Expand All @@ -701,7 +705,10 @@ def load_from_wandb(
version : str
The version of the run to load.
local_cache_dir : str, optional
The local cache directory for downloading the model (default is "./").
The local cache directory for downloading the model (default is "./"),
only_unique_pairs : Optional[bool], optional
For models trained prior to PR #299 in modelforge, this parameter is required to be able to read the model.
This value should be True for the ANI models, False for most other models.
Returns
-------
Expand All @@ -715,7 +722,9 @@ def load_from_wandb(
artifact = run.use_artifact(artifact_path)
artifact_dir = artifact.download(root=local_cache_dir)
checkpoint_file = f"{artifact_dir}/model.ckpt"
potential = load_inference_model_from_checkpoint(checkpoint_file)
potential = load_inference_model_from_checkpoint(
checkpoint_file, only_unique_pairs
)

return potential

Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_load_from_wandb(prep_temp_dir):
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
only_unique_pairs=True,
)

assert os.path.isfile(f"{prep_temp_dir}/test_wandb/model.ckpt")
Expand Down

0 comments on commit ee7e6cf

Please sign in to comment.