forked from deepchem/deepchem
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Weighted Directed Graph Validator Setup (deepchem#4020)
* added a file * created test file * wdgdata string validator * fixed linting * fixing typo errors * adding documentation to validator for better clarity * added relevant docmentation for validator * added citation and detial of research origin * made the commented changes and replied where required * added citation and references and finalized the changes
- Loading branch information
Showing
4 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
from typing import List, Tuple | ||
import re | ||
|
||
|
||
class PolyWDGStringValidator(): | ||
""" | ||
Class for validating the string format of weighted directed graph | ||
data.This class provides methods to validate the format of a | ||
datapoint string. This is a specific string format that is used | ||
for storing weighted directed polymer data in a parsable format. | ||
The format is as follows: | ||
[monomer1].[monomer2]|[fraction_of_monomer1]|[fraction_of_monomer2]|<[polymer_rule1]<[polymer_rule2] | ||
The polymer rule has an own format in it. Which is as follows: | ||
[[atom_index1]-[atom_index2]]:[fraction_of_bond_between_atom1_to_atom2]:[fraction_of_bond_between_atom2_to_atom1] | ||
This format is explicitly used for formatting the input for | ||
Weighted Directed Message Passing Neural Networks (wD-MPNN). | ||
The input format holds a SMART notation and regular expression | ||
formatting to keep molecular data with corresponding bonds and | ||
weights. Irrespective of this explicit usecase, the formatting | ||
can allow featurization of same data for other graph based neural | ||
networks. | ||
The validate method validates the proper formatting for monomer | ||
molecules, proper value of the fractions and valid atom indicies | ||
and corresponding weights in the polymer rules. | ||
Example | ||
------- | ||
>>> from deepchem.utils import PolyWDGStringValidator | ||
>>> validator = PolyWDGStringValidator() | ||
>>> validator.validate("[1*]C.C[2*]|0.5|0.5|<1-2:0.5:0.5") | ||
True | ||
References | ||
---------- | ||
.. [1] Aldeghi, Matteo, and Connor W. Coley. "A graph representation of molecular | ||
ensembles for polymer property prediction." Chemical Science 13.35 (2022): 10486-10498. | ||
""" | ||
|
||
@staticmethod | ||
def get_parsed_vals(datapoint: str) -> Tuple[str, list, str]: | ||
""" | ||
This static method parses the datapoint string into 3 parts: | ||
1. Monomer molecules | ||
2. Fragments | ||
3. Polymer rules | ||
Parameters | ||
---------- | ||
datapoint : str | ||
The datapoint string to parse | ||
Returns | ||
------- | ||
Tuple[str, list, str] | ||
A tuple containing the 3 parts of the datapoint string | ||
""" | ||
base_parsed = datapoint.split("|") | ||
if len(base_parsed) < 3: | ||
raise ValueError( | ||
f"Invalid datapoint format: At least 3 splits should be there but found {len(base_parsed)} no. of splits" | ||
) | ||
monomer_mols = base_parsed[0] | ||
polymer_rules = base_parsed[-1] | ||
fragments = base_parsed[1:-1] | ||
return monomer_mols, fragments, polymer_rules | ||
|
||
@staticmethod | ||
def get_polymer_rules(rules_str: str) -> List[str]: | ||
""" | ||
This static method parses the polymer rules string into a list of rules. | ||
Parameters | ||
---------- | ||
rules_str : str | ||
The polymer rules string to parse | ||
Returns | ||
------- | ||
List[str] | ||
A list containing the parsed rule strings | ||
""" | ||
if len(rules_str.split("<")) == 1: | ||
raise ValueError( | ||
"Invalid rules string: The rule string must contain '<' as a separator for rules !" | ||
) | ||
return rules_str.split("<")[1:] | ||
|
||
def _validate_fragments(self, datapoint: str): | ||
""" | ||
This method validate the number of fragments match | ||
the number of monomers. | ||
Parameters | ||
---------- | ||
datapoint : str | ||
The datapoint string to validate | ||
Raises | ||
------ | ||
ValueError | ||
If the number of fragments does not match the number of monomers | ||
""" | ||
monomer_mols, fragments, _ = self.get_parsed_vals(datapoint) | ||
if len(fragments) != len(monomer_mols.split(".")): | ||
raise ValueError( | ||
f"Number of fragments and number of monomers should match. Mismatch -> No. of Fragments {len(fragments)} , No. of Monomers{len(monomer_mols.split('.'))}" | ||
) | ||
|
||
def _get_all_wildcards(self, text: str) -> List[str]: | ||
""" | ||
This method returns all the wildcards present in the given string | ||
representation by using regular expression to detect digits after | ||
'*'. | ||
Parameters | ||
---------- | ||
text : str | ||
Returns | ||
------- | ||
List[str] | ||
A list of all wildcards present in the text | ||
""" | ||
matches = re.findall(r"\d+(?=\*)", text) | ||
return matches | ||
|
||
def _validate_wildcards(self, datapoint: str): | ||
""" | ||
This method validates the presence of wildcards in the polymer | ||
molecules string and ensures that the sequence of the wildcard | ||
notation is proper. | ||
Parameters | ||
---------- | ||
datapoint : str | ||
The datapoint string to validate | ||
Raises | ||
------ | ||
ValueError | ||
If the wildcards are not present in the sequce the maximum | ||
wildcard value, ValueError is raised. | ||
""" | ||
monomer_mols, _, _ = self.get_parsed_vals(datapoint) | ||
max_wildcard = max( | ||
[int(x) for x in self._get_all_wildcards(monomer_mols)]) | ||
for wildcard in range(1, max_wildcard + 1): | ||
if str(wildcard) + "*" not in monomer_mols: | ||
raise ValueError( | ||
f"Invalid wildcard format: The wildcard {wildcard} is not present in the monomer molecules string as per the sequence of the maximum {max_wildcard}!" | ||
) | ||
|
||
def _validate_polymer_rules(self, datapoint: str): | ||
""" | ||
This method validates the format of the polymer rules string | ||
by checking for the presence of the '-' separator between the | ||
atom indexes, the correct number of splits in the rule string, | ||
and the validity of the atom indexes present in the monomer | ||
SMILES. It also checks if the atom indexes are in the correct | ||
correct count for a valid bond formation. | ||
Parameters | ||
---------- | ||
datapoint : str | ||
The datapoint string to validate | ||
Raises | ||
------ | ||
ValueError | ||
If the polymer rules string is invalid, ValueError is raised | ||
with appropriate error messages | ||
""" | ||
monomer_mols, _, polymer_rules = self.get_parsed_vals(datapoint) | ||
polymer_rule_list = self.get_polymer_rules(polymer_rules) | ||
for rules in polymer_rule_list: | ||
splits = rules.split(":") | ||
if len(splits) != 3: | ||
raise ValueError( | ||
f"Invalid polymer rule format: The rule must contain exactly 3 splits ! but found {len(splits)} splits" | ||
) | ||
if "-" not in splits[0]: | ||
raise ValueError( | ||
f"Invalid polymer rule format: The bond string between two wildcard index must be seprated by '-', got invalid data {splits[0]}" | ||
) | ||
elif len(splits[0].split("-")) != 2 and any( | ||
elem == "" for elem in splits[0].split("-")): | ||
raise ValueError( | ||
f"Invalid polymer rule format: The first split must contain exactly 2 splits to depict connection between atom indexes! but found {len(splits[0].split('-'))} splits" | ||
) | ||
else: | ||
for wild_card_index in splits[0].split("-"): | ||
if not wild_card_index.isdigit(): | ||
raise ValueError( | ||
f"Invalid polymer rule format: The first split must contain only digits! but found {wild_card_index}" | ||
) | ||
if wild_card_index not in monomer_mols: | ||
raise ValueError( | ||
f"Invalid polymer rule format: The first split must contain only valid wild card indexes! but found {wild_card_index} which is not in {monomer_mols}" | ||
) | ||
|
||
def validate(self, datapoint: str): | ||
""" | ||
This method validates the string format of weighted | ||
directed graph data. To validate the string format | ||
it checks for following conditions: | ||
1. The number of fragments and the number of monomer | ||
molecules should match. | ||
2. The wild card indexes should be present in the monomer | ||
molecules string and should be in the correct sequence. | ||
3. The polymer rules should be in the correct format. | ||
4. The atom indexes in the polymer rules should be valid | ||
and present in the monomer molecules string. | ||
It raises ValueError if the string format is invalid. | ||
Parameters | ||
---------- | ||
datapoint : str | ||
The datapoint string to validate | ||
Returns | ||
------- | ||
bool | ||
True if the string format is valid, None otherwise | ||
(Error will be raised otherwise) | ||
""" | ||
self._validate_fragments(datapoint) | ||
self._validate_wildcards(datapoint) | ||
self._validate_polymer_rules(datapoint) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import unittest | ||
from deepchem.utils import PolyWDGStringValidator | ||
|
||
|
||
class PolyWDGStringValidateTest(unittest.TestCase): | ||
""" | ||
Tests for validator class to validate polymer weighted string representation | ||
""" | ||
|
||
def test_get_parsed_vals(self): | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|<1-2:0.5:0.5" | ||
_ = PolyWDGStringValidator.get_parsed_vals(datapoint) | ||
|
||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-2:0.5:0.5" | ||
monomer_mols, fragments, polymer_rules = PolyWDGStringValidator.get_parsed_vals( | ||
datapoint) | ||
assert monomer_mols == "[1*]C.C[2*]" | ||
assert fragments == ["0.5", "0.5"] | ||
assert polymer_rules == "<1-2:0.5:0.5" | ||
|
||
def test_get_polymer_rules(self): | ||
with self.assertRaises(ValueError): | ||
polymer_rules = "1-2:0.5:0.5" | ||
_ = PolyWDGStringValidator.get_polymer_rules(polymer_rules) | ||
|
||
polymer_rules = "<1-2:0.5:0.5" | ||
polymer_rules = PolyWDGStringValidator.get_polymer_rules(polymer_rules) | ||
assert polymer_rules == ["1-2:0.5:0.5"] | ||
|
||
def test_valid_validate_function(self): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-2:0.5:0.5" | ||
assert PolyWDGStringValidator().validate(datapoint) | ||
|
||
def test_invalid_validate_function(self): | ||
# test for _validate_fragments | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|<1-2:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_wildcards | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[3*]|0.5|0.5|<1-2:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_wildcards | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-3:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
def test_invalid_polymer_rules(self): | ||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-3:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<13:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<13-:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-3-5:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-q:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) | ||
|
||
# test for _validate_polymer_rules | ||
with self.assertRaises(ValueError): | ||
datapoint = "[1*]C.C[2*]|0.5|0.5|<1-69:0.5:0.5" | ||
_ = PolyWDGStringValidator().validate(datapoint) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters