Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasturcani committed Apr 26, 2024
1 parent a474932 commit de13097
Show file tree
Hide file tree
Showing 42 changed files with 210 additions and 533 deletions.
19 changes: 10 additions & 9 deletions examples/gulp_test_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ruff: noqa: S101
import argparse
import logging
import sys
from pathlib import Path

import numpy as np
Expand All @@ -10,12 +10,7 @@

def main() -> None:
"""Run the example."""
first_line = f"Usage: {__file__}.py"
if len(sys.argv) != 2:
logging.info("%s gulp_path", first_line)
sys.exit()
else:
gulp_path = sys.argv[1]
args = _parse_args()

iron_atom = stk.BuildingBlock(
smiles="[Fe+2]",
Expand Down Expand Up @@ -72,7 +67,7 @@ def main() -> None:
# Use conjugate gradient method for a slower, but more stable
# optimisation.
gulp_opt = stko.GulpUFFOptimizer(
gulp_path=gulp_path,
gulp_path=args.gulp_path,
output_dir="gulp_test_output",
metal_FF={26: "Fe4+2"},
conjugate_gradient=True,
Expand All @@ -86,7 +81,7 @@ def main() -> None:

target_num_confs = 40
gulp_md = stko.GulpUFFMDOptimizer(
gulp_path=gulp_path,
gulp_path=args.gulp_path,
metal_FF={26: "Fe4+2"},
output_dir="gulp_test_output_MD",
temperature=300,
Expand All @@ -104,6 +99,12 @@ def main() -> None:
assert len(confs_gen) == target_num_confs


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("gulp_path", type=str)
return parser.parse_args()


if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ select = ["ALL"]
ignore = [
"ANN101",
"ANN102",
"ANN401",
"COM812",
"ISC001",
"FBT001",
Expand Down
26 changes: 12 additions & 14 deletions src/stko/_internal/calculators/geometry_analysis/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GeometryAnalyser:
def _get_metal_atom_ids(
self,
molecule: stk.Molecule,
metal_atom_nos: tuple[int],
metal_atom_nos: tuple[int, ...],
) -> list[int]:
return [
i.get_id()
Expand All @@ -39,7 +39,7 @@ def _get_metal_atom_ids(
def get_metal_distances(
self,
molecule: stk.Molecule,
metal_atom_nos: tuple[int],
metal_atom_nos: tuple[int, ...],
) -> dict[tuple[int, int], float]:
"""Get all metal atom pair distances.
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_metal_distances(
def get_metal_centroid_metal_angle(
self,
molecule: stk.Molecule,
metal_atom_nos: tuple[int],
metal_atom_nos: tuple[int, ...],
) -> dict[tuple[int, int], float]:
"""Get all metal-centroid-metal angles.
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_max_diameter(self, molecule: stk.Molecule) -> float:
def calculate_bonds(
self,
molecule: stk.Molecule,
) -> dict[tuple[str, ...], list[float]]:
) -> dict[tuple[str, str], list[float]]:
"""Calculate bond lengths for all `stk.Molecule.get_bonds()`.
Parameters:
Expand All @@ -218,19 +218,17 @@ def calculate_bonds(
"""
position_matrix = molecule.get_position_matrix()
lengths = defaultdict(list)
lengths: dict[tuple[str, str], list[float]] = defaultdict(list)
for bond in molecule.get_bonds():
a1id = bond.get_atom1().get_id()
a2id = bond.get_atom2().get_id()
length_type = tuple(
sorted(
(
bond.get_atom1().__class__.__name__,
bond.get_atom2().__class__.__name__,
)
a, b = sorted(
(
bond.get_atom1().__class__.__name__,
bond.get_atom2().__class__.__name__,
)
)
lengths[length_type].append(
lengths[(a, b)].append(
get_atom_distance(position_matrix, a1id, a2id)
)

Expand All @@ -239,7 +237,7 @@ def calculate_bonds(
def calculate_angles(
self,
molecule: stk.Molecule,
) -> dict[tuple[str, ...], list[float]]:
) -> dict[tuple[str, str, str], list[float]]:
"""Calculate angles for all angles defined by molecule bonding.
Parameters:
Expand All @@ -251,7 +249,7 @@ def calculate_angles(
"""
position_matrix = molecule.get_position_matrix()
angles: dict[tuple[str, ...], list[float]] = defaultdict(list)
angles: dict[tuple[str, str, str], list[float]] = defaultdict(list)
for a_ids in self._get_paths(molecule, 3):
atoms = list(molecule.get_atoms(atom_ids=a_ids))
atom1 = atoms[0]
Expand Down
3 changes: 2 additions & 1 deletion src/stko/_internal/molecular/conversion/md_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

import stk

Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self) -> None:
)
raise WrapperNotInstalledError(msg)

def get_universe(self, mol: stk.Molecule): # type: ignore[no-untyped-def]
def get_universe(self, mol: stk.Molecule) -> Any: # type: ignore[no-untyped-def]
"""Get an MDAnalysis object.
Parameters:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from collections import abc
from itertools import combinations

import stk
Expand Down
42 changes: 15 additions & 27 deletions tests/calculators/geometry/case_data.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
from dataclasses import dataclass

import stk


@dataclass(frozen=True, slots=True)
class CaseData:
def __init__(
self,
molecule: stk.Molecule,
metal_atom_distances: dict[tuple[int, int], float],
metal_centroid_angles: dict[tuple[int, int], float],
min_centoid_distance: float,
avg_centoid_distance: tuple[float, float],
radius_gyration: float,
min_atom_atom_distance: float,
max_diameter: float,
bonds: dict[tuple[str, str], list[float]],
angles: dict[tuple[str, str, str], list[float]],
torsions: dict[tuple[str, str, str, str], list[float]],
name: str,
) -> None:
self.molecule = molecule
self.metal_atom_distances = metal_atom_distances
self.metal_centroid_angles = metal_centroid_angles
self.min_centoid_distance = min_centoid_distance
self.avg_centoid_distance = avg_centoid_distance
self.radius_gyration = radius_gyration
self.min_atom_atom_distance = min_atom_atom_distance
self.max_diameter = max_diameter
self.bonds = bonds
self.angles = angles
self.torsions = torsions
self.name = name
molecule: stk.Molecule
metal_atom_distances: dict[tuple[int, int], float]
metal_centroid_angles: dict[tuple[int, int], float]
min_centoid_distance: float
avg_centoid_distance: tuple[float, float]
radius_gyration: float
min_atom_atom_distance: float
max_diameter: float
bonds: dict[tuple[str, str], list[float]]
angles: dict[tuple[str, str, str], list[float]]
torsions: dict[tuple[str, str, str, str], list[float]]
name: str
11 changes: 6 additions & 5 deletions tests/calculators/geometry/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import stk

Expand All @@ -8,9 +9,9 @@
metals=stk.BuildingBlock(
smiles="[Fe+2]",
functional_groups=(
stk.SingleAtom(stk.Fe(0, charge=2)) for i in range(6)
stk.SingleAtom(stk.Fe(0, charge=2)) for _ in range(6)
),
position_matrix=[[0, 0, 0]],
position_matrix=np.array([[0, 0, 0]]),
),
ligands=stk.BuildingBlock(
smiles="C1=NC(C=NBr)=CC=C1",
Expand Down Expand Up @@ -2321,9 +2322,9 @@
smiles="[Pd+2]",
functional_groups=(
stk.SingleAtom(stk.Pd(0, charge=2))
for i in range(4)
for _ in range(4)
),
position_matrix=[[0, 0, 0]],
position_matrix=np.array([[0, 0, 0]]),
),
stk.BuildingBlock(
smiles=(
Expand Down Expand Up @@ -2926,7 +2927,7 @@
),
),
)
def case_data(request) -> CaseData:
def case_data(request: pytest.FixtureRequest) -> CaseData:
return request.param(
f"{request.fixturename}{request.param_index}",
)
11 changes: 2 additions & 9 deletions tests/calculators/geometry/test_calculate_angles.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import numpy as np
import stko

from .case_data import CaseData

def test_calculate_angles(case_data):
"""Test :class:`.GeometryAnalyser.calculate_angles`.

Parameters
----------
case_data:
A test case.
"""
def test_calculate_angles(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.calculate_angles(case_data.molecule)
print(result)
for triple in result:
if triple == ("C", "C", "C"):
continue
Expand Down
11 changes: 2 additions & 9 deletions tests/calculators/geometry/test_calculate_bonds.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import numpy as np
import stko

from .case_data import CaseData

def test_calculate_bonds(case_data):
"""Test :class:`.GeometryAnalyser.calculate_bonds`.

Parameters
----------
case_data:
A test case.
"""
def test_calculate_bonds(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.calculate_bonds(case_data.molecule)
print(result)
for pair in result:
if pair == ("C", "C"):
continue
Expand Down
11 changes: 2 additions & 9 deletions tests/calculators/geometry/test_calculate_torsions.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import numpy as np
import stko

from .case_data import CaseData

def test_calculate_torsions(case_data):
"""Test :class:`.GeometryAnalyser.calculate_torsions`.

Parameters
----------
case_data:
A test case.
"""
def test_calculate_torsions(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.calculate_torsions(case_data.molecule)
print(result)
for four in result:
if four == ("C", "C", "C", "C"):
continue
Expand Down
11 changes: 2 additions & 9 deletions tests/calculators/geometry/test_get_avg_centroid_distance.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import numpy as np
import stko

from .case_data import CaseData

def test_get_avg_centroid_distance(case_data):
"""Test :class:`.GeometryAnalyser.get_avg_centroid_distance`.

Parameters
----------
case_data:
A test case.
"""
def test_get_avg_centroid_distance(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.get_avg_centroid_distance(case_data.molecule)
print(result)
assert np.isclose(
result[0], case_data.avg_centoid_distance[0], atol=1e-3, rtol=0
)
Expand Down
11 changes: 2 additions & 9 deletions tests/calculators/geometry/test_get_max_diamter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import numpy as np
import stko

from .case_data import CaseData

def test_get_max_diameter(case_data):
"""Test :class:`.GeometryAnalyser.get_max_diameter`.

Parameters
----------
case_data:
A test case.
"""
def test_get_max_diameter(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.get_max_diameter(case_data.molecule)
print(result)
assert np.isclose(result, case_data.max_diameter, atol=1e-3, rtol=0)
12 changes: 2 additions & 10 deletions tests/calculators/geometry/test_get_metal_centroid_metal_angle.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import numpy as np
import stko

from .case_data import CaseData

def test_get_metal_centroid_metal_angle(case_data):
"""Test :class:`.GeometryAnalyser.get_metal_centroid_metal_angle`.

Parameters
----------
case_data:
A test case.
"""
def test_get_metal_centroid_metal_angle(case_data: CaseData) -> None:
analyser = stko.molecule_analysis.GeometryAnalyser()

result = analyser.get_metal_centroid_metal_angle(
case_data.molecule,
metal_atom_nos=(26, 46),
)
print(result)
assert len(result) == len(case_data.metal_centroid_angles)
for i in result:
print(i, result[i])
assert np.isclose(
result[i], case_data.metal_centroid_angles[i], atol=1e-3, rtol=0
)
Loading

0 comments on commit de13097

Please sign in to comment.