From b6b43918c8329618b93558e875c2c774bc76ef4a Mon Sep 17 00:00:00 2001 From: andrewtarzia Date: Wed, 24 Jul 2024 15:13:03 +0200 Subject: [PATCH] Add kabsch rmsd and dependancy. --- docs/source/calculators.rst | 1 + pyproject.toml | 2 + src/stko/__init__.py | 2 + .../_internal/calculators/rmsd_calculators.py | 59 +++++++++++++++++++ tests/calculators/rmsd/conftest.py | 26 ++++++++ .../calculators/rmsd/test_rmsd_calculators.py | 7 +++ 6 files changed, 97 insertions(+) diff --git a/docs/source/calculators.rst b/docs/source/calculators.rst index 9aea15d..e693bef 100644 --- a/docs/source/calculators.rst +++ b/docs/source/calculators.rst @@ -11,6 +11,7 @@ Calculators OrcaEnergy <_autosummary/stko.OrcaEnergy> RmsdCalculator <_autosummary/stko.RmsdCalculator> RmsdMappedCalculator <_autosummary/stko.RmsdMappedCalculator> + KabschRmsdCalculator <_autosummary/stko.KabschRmsdCalculator> ShapeCalculator <_autosummary/stko.ShapeCalculator> TorsionCalculator <_autosummary/stko.TorsionCalculator> ConstructedMoleculeTorsionCalculator <_autosummary/stko.ConstructedMoleculeTorsionCalculator> diff --git a/pyproject.toml b/pyproject.toml index c266df2..10d4268 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "rdkit==2023.9.5", # remove pin when type issues are resolved "stk", "networkx", + "rmsd", ] requires-python = ">=3.11" dynamic = ["version"] @@ -130,5 +131,6 @@ module = [ "networkx.*", "openbabel.*", "MDAnalysis.*", + "rmsd.*", ] ignore_missing_imports = true diff --git a/src/stko/__init__.py b/src/stko/__init__.py index c78d5e5..4d13b05 100644 --- a/src/stko/__init__.py +++ b/src/stko/__init__.py @@ -22,6 +22,7 @@ ) from stko._internal.calculators.results.xtb_results import XTBResults from stko._internal.calculators.rmsd_calculators import ( + KabschRmsdCalculator, RmsdCalculator, RmsdMappedCalculator, ) @@ -131,6 +132,7 @@ "XTBResults", "RmsdCalculator", "RmsdMappedCalculator", + "KabschRmsdCalculator", "ShapeCalculator", "OrcaResults", "PlanarityResults", diff --git a/src/stko/_internal/calculators/rmsd_calculators.py b/src/stko/_internal/calculators/rmsd_calculators.py index c96845e..ccd8105 100644 --- a/src/stko/_internal/calculators/rmsd_calculators.py +++ b/src/stko/_internal/calculators/rmsd_calculators.py @@ -2,6 +2,7 @@ import numpy as np import stk +from rmsd import kabsch_rmsd from scipy.spatial.distance import cdist from stko._internal.calculators.results.rmsd_results import RmsdResults @@ -221,3 +222,61 @@ def calculate(self, mol: stk.Molecule) -> float: ) mol = mol.with_centroid(np.array((0, 0, 0))) return self._calculate_rmsd(mol) + + +class KabschRmsdCalculator: + """Calculates the root mean square distance between molecules. + + This calculator uses the rmsd package with the default settings and no + reordering. + + See Also: + * rmsd https://github.com/charnley/rmsd + + This calculator will only work if the two molecules are the same + and have the same atom ordering. + + Parameters: + initial_molecule: + The :class:`stk.Molecule` to calculate RMSD from. + + Examples: + .. code-block:: python + + import stk + import stko + + bb1 = stk.BuildingBlock('C1CCCCC1') + calculator = stko.KabschRmsdCalculator(bb1) + results = calculator.get_results(stk.UFF().optimize(bb1)) + rmsd = results.get_rmsd() + + """ + + def __init__(self, initial_molecule: stk.Molecule) -> None: + self._initial_molecule = initial_molecule + + def _calculate_rmsd(self, mol: stk.Molecule) -> float: + p_coord = self._initial_molecule.get_position_matrix() + q_coord = mol.get_position_matrix() + return kabsch_rmsd(p_coord, q_coord) + + def calculate(self, mol: stk.Molecule) -> float: + self._initial_molecule = self._initial_molecule.with_centroid( + position=np.array((0, 0, 0)), + ) + mol = mol.with_centroid(np.array((0, 0, 0))) + return self._calculate_rmsd(mol) + + def get_results(self, mol: stk.Molecule) -> RmsdResults: + """Calculate the RMSD between `mol` and the initial molecule. + + Parameters: + mol: + The :class:`stk.Molecule` to calculate RMSD to. + + Returns: + The RMSD between the molecules. + + """ + return RmsdResults(self.calculate(mol)) diff --git a/tests/calculators/rmsd/conftest.py b/tests/calculators/rmsd/conftest.py index 309141a..033aa51 100644 --- a/tests/calculators/rmsd/conftest.py +++ b/tests/calculators/rmsd/conftest.py @@ -11,6 +11,7 @@ class CaseData: mol1: stk.Molecule mol2: stk.Molecule rmsd: float + kabsch_rmsd: float _optimizer = stko.UFF() @@ -43,6 +44,7 @@ class CaseData: mol1=_cc_molecule, mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=_cc_molecule, @@ -52,6 +54,7 @@ class CaseData: ) ), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=_cc_molecule, @@ -61,31 +64,37 @@ class CaseData: ) ), rmsd=1.0, + kabsch_rmsd=1.0, ), CaseData( mol1=stk.BuildingBlock("NCCN"), mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")), rmsd=0.24492870054279647, + kabsch_rmsd=0.188295954166067, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), mol2=stk.BuildingBlock("CCCCCC"), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")), rmsd=0.35636491354918015, + kabsch_rmsd=0.35044001253075, ), CaseData( mol1=stk.BuildingBlock("c1ccccc1"), mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")), rmsd=0.02936762392637932, + kabsch_rmsd=0.02936762392637932, ), CaseData( mol1=_polymer, mol2=_optimizer.optimize(_polymer), rmsd=2.1485735050384, + kabsch_rmsd=1.786251608496134, ), ], ) @@ -101,26 +110,31 @@ def case_data(request: pytest.FixtureRequest) -> CaseData: mol1=stk.BuildingBlock("NCCN"), mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")), rmsd=0.20811702035676308, + kabsch_rmsd=0.20811702035676308, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")), rmsd=0.22563756374632568, + kabsch_rmsd=0.22563756374632568, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), mol2=stk.BuildingBlock("CCCCCC"), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=stk.BuildingBlock("c1ccccc1"), mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")), rmsd=0.029156836455717483, + kabsch_rmsd=0.029156836455717483, ), CaseData( mol1=_polymer, mol2=_optimizer.optimize(_polymer), rmsd=1.792856412415046, + kabsch_rmsd=1.792856412415046, ), ], ) @@ -136,11 +150,13 @@ def ignore_h_case_data(request: pytest.FixtureRequest) -> CaseData: mol1=stk.BuildingBlock("NCCN"), mol2=stk.BuildingBlock("CCCCCC"), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), mol2=stk.BuildingBlock("c1ccccc1"), rmsd=0.0, + kabsch_rmsd=0.0, ), ], ) @@ -156,6 +172,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData: mol1=_polymer, mol2=_polymer.with_canonical_atom_ordering(), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=stk.BuildingBlock( @@ -171,6 +188,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData: ), ).with_canonical_atom_ordering(), rmsd=0.0, + kabsch_rmsd=0.0, ), ], ) @@ -186,6 +204,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData: mol1=_cc_molecule, mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=_cc_molecule, @@ -195,6 +214,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData: ) ), rmsd=0.0, + kabsch_rmsd=0.0, ), CaseData( mol1=_cc_molecule, @@ -204,6 +224,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData: ) ), rmsd=1.0, + kabsch_rmsd=1.0, ), CaseData( mol1=stk.BuildingBlock("NCCN"), @@ -215,6 +236,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData: ) .with_displacement(np.array((2, 0, 1))), rmsd=1.1309858484314543, + kabsch_rmsd=1.1309858484314543, ), CaseData( mol1=stk.BuildingBlock("CCCCCC"), @@ -226,21 +248,25 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData: ) .with_displacement(np.array((0, 0, 1))), rmsd=0.5943193981905652, + kabsch_rmsd=0.5943193981905652, ), CaseData( mol1=stk.BuildingBlock("NCCN"), mol2=stk.BuildingBlock("NCCCN"), rmsd=0.8832914099448816, + kabsch_rmsd=0.8832914099448816, ), CaseData( mol1=stk.BuildingBlock("NCOCN"), mol2=stk.BuildingBlock("NCCN"), rmsd=1.2678595995702466, + kabsch_rmsd=1.2678595995702466, ), CaseData( mol1=stk.BuildingBlock("NCCN"), mol2=stk.BuildingBlock("NCOCN"), rmsd=1.3921770318522637, + kabsch_rmsd=1.3921770318522637, ), ], ) diff --git a/tests/calculators/rmsd/test_rmsd_calculators.py b/tests/calculators/rmsd/test_rmsd_calculators.py index 0a1ade6..04dec54 100644 --- a/tests/calculators/rmsd/test_rmsd_calculators.py +++ b/tests/calculators/rmsd/test_rmsd_calculators.py @@ -12,6 +12,13 @@ def test_rmsd(case_data: CaseData) -> None: assert np.isclose(test_rmsd, case_data.rmsd, atol=1e-4) +def test_kabsch_rmsd(case_data: CaseData) -> None: + calculator = stko.KabschRmsdCalculator(case_data.mol1) + results = calculator.get_results(case_data.mol2) + test_rmsd = results.get_rmsd() + assert np.isclose(test_rmsd, case_data.kabsch_rmsd, atol=1e-4) + + def test_rmsd_ignore_hydrogens(ignore_h_case_data: CaseData) -> None: calculator = stko.RmsdCalculator( ignore_h_case_data.mol1, ignore_hydrogens=True