diff --git a/docs/inference.rst b/docs/inference.rst index 3bd308e1..67675718 100644 --- a/docs/inference.rst +++ b/docs/inference.rst @@ -44,6 +44,33 @@ To use the trained model for inference, the checkpoint file generated during tra :maxdepth: 2 :caption: Contents: +Note, prior to merging PR #299, checkpoint files and state_dicts did not save the `only_unique_pairs` bool parameter, needed to properly generate neighbor information. As such, if you are using a checkpoint file generated prior to this PR, you will need to set this parameter manually. This can be done by passing the `only_unique_pairs` parameter to the `load_inference_model_from_checkpoint` function. For example, for ANI2x models, where this should be True (other currently implemented potentials require False): + +.. code-block:: python + + from modelforge.potential.models import load_inference_model_from_checkpoint + + inference_model = load_inference_model_from_checkpoint(checkpoint_file, only_unique_pairs=True) + + +To modify state dictionary files, this can be done easily via the `modify_state_dict` function in the file `modify_state_dict.py` in the scripts directory. This will generate a new copy of the state dictionary file with the appropriate `only_unique_pairs` parameter set. + +Loading a checkpoint from weights and biases +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Checkpoint files can be loaded directly from wandb using the `load_from_wandb` function as part of the `NeuralNetworkPotentialFactory`. This can be done by passing the wandb run id and appropriate version number. Note this will require authentication with wandb for users part of the project. The following code snippet demonstrates how to load a model from wandb. + +.. code-block:: python + + from modelforge.potential.potential import NeuralNetworkPotentialFactory + + nn_potential = NeuralNetworkPotentialFactory().load_from_wandb( + run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk", + version="v0", + local_cache_dir=f"{prep_temp_dir}/test_wandb", + ) + + Using a model for inference in OpenMM ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/modelforge/potential/potential.py b/modelforge/potential/potential.py index eb0e979c..e002dec7 100644 --- a/modelforge/potential/potential.py +++ b/modelforge/potential/potential.py @@ -689,7 +689,11 @@ def generate_potential( @staticmethod def load_from_wandb( - *, run_path: str, version: str, local_cache_dir: str = "./" + *, + run_path: str, + version: str, + local_cache_dir: str = "./", + only_unique_pairs: Optional[bool] = None, ) -> Union[Potential, JAXModel]: """ Load a neural network potential from a Weights & Biases run. @@ -701,7 +705,10 @@ def load_from_wandb( version : str The version of the run to load. local_cache_dir : str, optional - The local cache directory for downloading the model (default is "./"). + The local cache directory for downloading the model (default is "./"), + only_unique_pairs : Optional[bool], optional + For models trained prior to PR #299 in modelforge, this parameter is required to be able to read the model. + This value should be True for the ANI models, False for most other models. Returns ------- @@ -715,7 +722,9 @@ def load_from_wandb( artifact = run.use_artifact(artifact_path) artifact_dir = artifact.download(root=local_cache_dir) checkpoint_file = f"{artifact_dir}/model.ckpt" - potential = load_inference_model_from_checkpoint(checkpoint_file) + potential = load_inference_model_from_checkpoint( + checkpoint_file, only_unique_pairs + ) return potential diff --git a/modelforge/tests/test_remote.py b/modelforge/tests/test_remote.py index 363c2b9c..e60b3bab 100644 --- a/modelforge/tests/test_remote.py +++ b/modelforge/tests/test_remote.py @@ -128,6 +128,7 @@ def test_load_from_wandb(prep_temp_dir): run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk", version="v0", local_cache_dir=f"{prep_temp_dir}/test_wandb", + only_unique_pairs=True, ) assert os.path.isfile(f"{prep_temp_dir}/test_wandb/model.ckpt")