-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into tmqm_dataset
- Loading branch information
Showing
5 changed files
with
98 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
Script to modify a state dict to include only_unique_pairs dictionary key. | ||
This is only necessary for models trained prior to PR #299 in modelforge, that provides | ||
integration with OpenMM and some refactoring of the neighborlisting schemes. | ||
""" | ||
|
||
|
||
def modify_state_dict( | ||
state_dict_input_file_path: str, | ||
state_dict_output_file_path: str, | ||
only_unique_pairs: bool, | ||
): | ||
""" | ||
Modify a state dict to include the only_unique_pairs dictionary key. | ||
Parameters | ||
---------- | ||
state_dict_input_file_path: str | ||
Input file with path to the input state dict file | ||
state_dict_output_file_path: str | ||
Output file with path to the output state dict file | ||
only_unique_pairs: bool | ||
Boolean value to set the only_unique_pairs key for the neighborlist | ||
This value should be True for the ANI models, False for most other models. | ||
Returns | ||
------- | ||
""" | ||
import torch | ||
|
||
# Load the state dict | ||
state_dict = torch.load(state_dict_input_file_path) | ||
|
||
# Set the only_unique_pairs key | ||
state_dict["neighborlist.only_unique_pairs"] = torch.Tensor([only_unique_pairs]) | ||
|
||
# Save the modified state dict | ||
torch.save(state_dict, state_dict_output_file_path) |