Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kabsch rmsd and dependancy. #180

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/calculators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Calculators
OrcaEnergy <_autosummary/stko.OrcaEnergy>
RmsdCalculator <_autosummary/stko.RmsdCalculator>
RmsdMappedCalculator <_autosummary/stko.RmsdMappedCalculator>
KabschRmsdCalculator <_autosummary/stko.KabschRmsdCalculator>
ShapeCalculator <_autosummary/stko.ShapeCalculator>
TorsionCalculator <_autosummary/stko.TorsionCalculator>
ConstructedMoleculeTorsionCalculator <_autosummary/stko.ConstructedMoleculeTorsionCalculator>
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"rdkit==2023.9.5", # remove pin when type issues are resolved
"stk",
"networkx",
"rmsd",
]
requires-python = ">=3.11"
dynamic = ["version"]
Expand Down Expand Up @@ -130,5 +131,6 @@ module = [
"networkx.*",
"openbabel.*",
"MDAnalysis.*",
"rmsd.*",
]
ignore_missing_imports = true
2 changes: 2 additions & 0 deletions src/stko/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from stko._internal.calculators.results.xtb_results import XTBResults
from stko._internal.calculators.rmsd_calculators import (
KabschRmsdCalculator,
RmsdCalculator,
RmsdMappedCalculator,
)
Expand Down Expand Up @@ -131,6 +132,7 @@
"XTBResults",
"RmsdCalculator",
"RmsdMappedCalculator",
"KabschRmsdCalculator",
"ShapeCalculator",
"OrcaResults",
"PlanarityResults",
Expand Down
59 changes: 59 additions & 0 deletions src/stko/_internal/calculators/rmsd_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import stk
from rmsd import kabsch_rmsd
from scipy.spatial.distance import cdist

from stko._internal.calculators.results.rmsd_results import RmsdResults
Expand Down Expand Up @@ -221,3 +222,61 @@ def calculate(self, mol: stk.Molecule) -> float:
)
mol = mol.with_centroid(np.array((0, 0, 0)))
return self._calculate_rmsd(mol)


class KabschRmsdCalculator:
"""Calculates the root mean square distance between molecules.

This calculator uses the rmsd package with the default settings and no
reordering.

See Also:
* rmsd https://github.com/charnley/rmsd

This calculator will only work if the two molecules are the same
and have the same atom ordering.

Parameters:
initial_molecule:
The :class:`stk.Molecule` to calculate RMSD from.

Examples:
.. code-block:: python

import stk
import stko

bb1 = stk.BuildingBlock('C1CCCCC1')
calculator = stko.KabschRmsdCalculator(bb1)
results = calculator.get_results(stk.UFF().optimize(bb1))
rmsd = results.get_rmsd()

"""

def __init__(self, initial_molecule: stk.Molecule) -> None:
self._initial_molecule = initial_molecule

def _calculate_rmsd(self, mol: stk.Molecule) -> float:
p_coord = self._initial_molecule.get_position_matrix()
q_coord = mol.get_position_matrix()
return kabsch_rmsd(p_coord, q_coord)

def calculate(self, mol: stk.Molecule) -> float:
self._initial_molecule = self._initial_molecule.with_centroid(
position=np.array((0, 0, 0)),
)
mol = mol.with_centroid(np.array((0, 0, 0)))
return self._calculate_rmsd(mol)

def get_results(self, mol: stk.Molecule) -> RmsdResults:
"""Calculate the RMSD between `mol` and the initial molecule.

Parameters:
mol:
The :class:`stk.Molecule` to calculate RMSD to.

Returns:
The RMSD between the molecules.

"""
return RmsdResults(self.calculate(mol))
26 changes: 26 additions & 0 deletions tests/calculators/rmsd/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class CaseData:
mol1: stk.Molecule
mol2: stk.Molecule
rmsd: float
kabsch_rmsd: float


_optimizer = stko.UFF()
Expand Down Expand Up @@ -43,6 +44,7 @@ class CaseData:
mol1=_cc_molecule,
mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -52,6 +54,7 @@ class CaseData:
)
),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -61,31 +64,37 @@ class CaseData:
)
),
rmsd=1.0,
kabsch_rmsd=1.0,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")),
rmsd=0.24492870054279647,
kabsch_rmsd=0.188295954166067,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")),
rmsd=0.35636491354918015,
kabsch_rmsd=0.35044001253075,
),
CaseData(
mol1=stk.BuildingBlock("c1ccccc1"),
mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")),
rmsd=0.02936762392637932,
kabsch_rmsd=0.02936762392637932,
),
CaseData(
mol1=_polymer,
mol2=_optimizer.optimize(_polymer),
rmsd=2.1485735050384,
kabsch_rmsd=1.786251608496134,
),
],
)
Expand All @@ -101,26 +110,31 @@ def case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=stk.BuildingBlock("NCCN"),
mol2=_optimizer.optimize(stk.BuildingBlock("NCCN")),
rmsd=0.20811702035676308,
kabsch_rmsd=0.20811702035676308,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=_optimizer.optimize(stk.BuildingBlock("CCCCCC")),
rmsd=0.22563756374632568,
kabsch_rmsd=0.22563756374632568,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("c1ccccc1"),
mol2=_optimizer.optimize(stk.BuildingBlock("c1ccccc1")),
rmsd=0.029156836455717483,
kabsch_rmsd=0.029156836455717483,
),
CaseData(
mol1=_polymer,
mol2=_optimizer.optimize(_polymer),
rmsd=1.792856412415046,
kabsch_rmsd=1.792856412415046,
),
],
)
Expand All @@ -136,11 +150,13 @@ def ignore_h_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("CCCCCC"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
mol2=stk.BuildingBlock("c1ccccc1"),
rmsd=0.0,
kabsch_rmsd=0.0,
),
],
)
Expand All @@ -156,6 +172,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=_polymer,
mol2=_polymer.with_canonical_atom_ordering(),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=stk.BuildingBlock(
Expand All @@ -171,6 +188,7 @@ def different_case_data(request: pytest.FixtureRequest) -> CaseData:
),
).with_canonical_atom_ordering(),
rmsd=0.0,
kabsch_rmsd=0.0,
),
],
)
Expand All @@ -186,6 +204,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
mol1=_cc_molecule,
mol2=_cc_molecule.with_centroid(np.array((4, 0, 0))),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -195,6 +214,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
),
rmsd=0.0,
kabsch_rmsd=0.0,
),
CaseData(
mol1=_cc_molecule,
Expand All @@ -204,6 +224,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
),
rmsd=1.0,
kabsch_rmsd=1.0,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
Expand All @@ -215,6 +236,7 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
.with_displacement(np.array((2, 0, 1))),
rmsd=1.1309858484314543,
kabsch_rmsd=1.1309858484314543,
),
CaseData(
mol1=stk.BuildingBlock("CCCCCC"),
Expand All @@ -226,21 +248,25 @@ def ordering_case_data(request: pytest.FixtureRequest) -> CaseData:
)
.with_displacement(np.array((0, 0, 1))),
rmsd=0.5943193981905652,
kabsch_rmsd=0.5943193981905652,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("NCCCN"),
rmsd=0.8832914099448816,
kabsch_rmsd=0.8832914099448816,
),
CaseData(
mol1=stk.BuildingBlock("NCOCN"),
mol2=stk.BuildingBlock("NCCN"),
rmsd=1.2678595995702466,
kabsch_rmsd=1.2678595995702466,
),
CaseData(
mol1=stk.BuildingBlock("NCCN"),
mol2=stk.BuildingBlock("NCOCN"),
rmsd=1.3921770318522637,
kabsch_rmsd=1.3921770318522637,
),
],
)
Expand Down
7 changes: 7 additions & 0 deletions tests/calculators/rmsd/test_rmsd_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ def test_rmsd(case_data: CaseData) -> None:
assert np.isclose(test_rmsd, case_data.rmsd, atol=1e-4)


def test_kabsch_rmsd(case_data: CaseData) -> None:
calculator = stko.KabschRmsdCalculator(case_data.mol1)
results = calculator.get_results(case_data.mol2)
test_rmsd = results.get_rmsd()
assert np.isclose(test_rmsd, case_data.kabsch_rmsd, atol=1e-4)


def test_rmsd_ignore_hydrogens(ignore_h_case_data: CaseData) -> None:
calculator = stko.RmsdCalculator(
ignore_h_case_data.mol1, ignore_hydrogens=True
Expand Down
Loading