Skip to content

Commit

Permalink
Format black
Browse files Browse the repository at this point in the history
  • Loading branch information
pyiron-runner committed Mar 19, 2024
1 parent b8b6a4e commit 2788c81
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 44 deletions.
8 changes: 7 additions & 1 deletion structuretoolkit/analyse/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import numpy as np


def get_distances_array(structure: Atoms, p1: Optional[np.ndarray] = None, p2: Optional[np.ndarray] = None, mic: bool = True, vectors: bool = False):
def get_distances_array(
structure: Atoms,
p1: Optional[np.ndarray] = None,
p2: Optional[np.ndarray] = None,
mic: bool = True,
vectors: bool = False,
):
"""
Return distance matrix of every position in p1 with every position in
p2. If p2 is not set, it is assumed that distances between all
Expand Down
2 changes: 1 addition & 1 deletion structuretoolkit/analyse/dscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def soap_descriptor_per_atom(
sigma: Optional[float] = 1.0,
rbf: str = "gto",
weighting: Optional[np.ndarray] = None,
average: str ="off",
average: str = "off",
compression: dict = {"mode": "off", "species_weighting": None},
species: Optional[list] = None,
periodic: bool = True,
Expand Down
51 changes: 41 additions & 10 deletions structuretoolkit/analyse/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def copy(self):
new_neigh._positions = self._positions.copy()
return new_neigh

def _reshape(self, value: np.ndarray, key: Optional[str] = None, ref_vector: Optional[np.ndarray] = None):
def _reshape(
self,
value: np.ndarray,
key: Optional[str] = None,
ref_vector: Optional[np.ndarray] = None,
):
if value is None:
raise ValueError("Neighbors not initialized yet")
if key is None:
Expand Down Expand Up @@ -227,7 +232,9 @@ def _get_wrapped_indices(self) -> np.ndarray:
return np.arange(len(self._ref_structure.positions))
return self._wrapped_indices

def _get_wrapped_positions(self, positions: np.ndarray, distance_buffer: float = 1.0e-12):
def _get_wrapped_positions(
self, positions: np.ndarray, distance_buffer: float = 1.0e-12
):
if not self.wrap_positions:
return np.asarray(positions)
x = np.array(positions).copy()
Expand Down Expand Up @@ -319,7 +326,10 @@ def _get_vectors(
return vectors

def _estimate_num_neighbors(
self, num_neighbors: Optional[int] = None, cutoff_radius: float = np.inf, width_buffer: float = 1.2
self,
num_neighbors: Optional[int] = None,
cutoff_radius: float = np.inf,
width_buffer: float = 1.2,
):
"""
Expand Down Expand Up @@ -356,7 +366,10 @@ def _estimate_num_neighbors(
return num_neighbors

def _estimate_width(
self, num_neighbors: Optional[int] = None, cutoff_radius: float = np.inf, width_buffer: float = 1.2
self,
num_neighbors: Optional[int] = None,
cutoff_radius: float = np.inf,
width_buffer: float = 1.2,
):
"""
Expand Down Expand Up @@ -456,7 +469,13 @@ def _check_width(self, width: float, pbc: list[bool, bool, bool]):
return True
return False

def get_spherical_harmonics(self, l: np.ndarray, m: np.ndarray, cutoff_radius: float = np.inf, rotation: Optional[np.ndarray] = None) -> np.ndarray:
def get_spherical_harmonics(
self,
l: np.ndarray,
m: np.ndarray,
cutoff_radius: float = np.inf,
rotation: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Args:
l (int/numpy.array): Degree of the harmonic (int); must have ``l >= 0``.
Expand Down Expand Up @@ -497,7 +516,9 @@ def get_spherical_harmonics(self, l: np.ndarray, m: np.ndarray, cutoff_radius: f
within_cutoff, axis=-1
)

def get_steinhardt_parameter(self, l: np.ndarray, cutoff_radius: float = np.inf) -> np.ndarray:
def get_steinhardt_parameter(
self, l: np.ndarray, cutoff_radius: float = np.inf
) -> np.ndarray:
"""
Args:
l (int/numpy.array): Order of Steinhardt parameter
Expand Down Expand Up @@ -790,7 +811,10 @@ def get_global_shells(
return self._reshape(shells, key=mode)

def get_shell_matrix(
self, chemical_pair: Optional[list] = None, cluster_by_distances: bool = False, cluster_by_vecs: bool = False
self,
chemical_pair: Optional[list] = None,
cluster_by_distances: bool = False,
cluster_by_vecs: bool = False,
):
"""
Shell matrices for pairwise interaction. Note: The matrices are always symmetric, meaning if you
Expand Down Expand Up @@ -851,7 +875,9 @@ def get_shell_matrix(
)
return shell_matrix

def find_neighbors_by_vector(self, vector: np.ndarray, return_deviation: bool = False) -> np.ndarray:
def find_neighbors_by_vector(
self, vector: np.ndarray, return_deviation: bool = False
) -> np.ndarray:
"""
Args:
vector (list/np.ndarray): vector by which positions are translated (and neighbors are searched)
Expand Down Expand Up @@ -1052,7 +1078,12 @@ def __probe_cluster(self, c_count: int, neighbors: list, id_list: list):
self.__probe_cluster(c_count, nbrs, id_list)

# TODO: combine with corresponding routine in plot3d
def get_bonds(self, radius: float = np.inf, max_shells: Optional[int] = None, prec: float = 0.1):
def get_bonds(
self,
radius: float = np.inf,
max_shells: Optional[int] = None,
prec: float = 0.1,
):
"""
Args:
Expand Down Expand Up @@ -1110,7 +1141,7 @@ def get_volume_of_n_sphere_in_p_norm(n: int = 3, p: int = 2) -> float:

def get_neighbors(
structure: Atoms,
num_neighbors: int =12,
num_neighbors: int = 12,
tolerance: int = 2,
id_list: Optional[list] = None,
cutoff_radius: float = np.inf,
Expand Down
4 changes: 3 additions & 1 deletion structuretoolkit/analyse/phonopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
__date__ = "Sep 1, 2018"


def get_equivalent_atoms(structure: Atoms, symprec: float = 1e-5, angle_tolerance: float = -1.0):
def get_equivalent_atoms(
structure: Atoms, symprec: float = 1e-5, angle_tolerance: float = -1.0
):
"""
Args: (read phonopy.structure.spglib for more details)
symprec:
Expand Down
10 changes: 7 additions & 3 deletions structuretoolkit/analyse/pyscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def get_steinhardt_parameters(
return sysq


def get_centro_symmetry_descriptors(structure: Atoms, num_neighbors: int = 12) -> np.ndarray:
def get_centro_symmetry_descriptors(
structure: Atoms, num_neighbors: int = 12
) -> np.ndarray:
"""
Analyse centrosymmetry parameter
Expand Down Expand Up @@ -183,7 +185,9 @@ def get_diamond_structure_descriptors(
)


def get_adaptive_cna_descriptors(structure: Atoms, mode: str = "total", ovito_compatibility: bool = False) -> np.ndarray:
def get_adaptive_cna_descriptors(
structure: Atoms, mode: str = "total", ovito_compatibility: bool = False
) -> np.ndarray:
"""
Use common neighbor analysis
Expand Down Expand Up @@ -253,7 +257,7 @@ def get_voronoi_volumes(structure: Atoms) -> np.ndarray:

def find_solids(
structure: Atoms,
neighbor_method: str ="cutoff",
neighbor_method: str = "cutoff",
cutoff: float = 0.0,
bonds: float = 0.5,
threshold: float = 0.5,
Expand Down
40 changes: 32 additions & 8 deletions structuretoolkit/analyse/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
__date__ = "Sep 1, 2017"


def get_mean_positions(positions: np.ndarray, cell: np.ndarray, pbc: np.ndarray, labels: np.ndarray) -> np.ndarray:
def get_mean_positions(
positions: np.ndarray, cell: np.ndarray, pbc: np.ndarray, labels: np.ndarray
) -> np.ndarray:
"""
This function calculates the average position(-s) across periodic boundary conditions according
to the labels
Expand Down Expand Up @@ -57,7 +59,9 @@ def get_mean_positions(positions: np.ndarray, cell: np.ndarray, pbc: np.ndarray,
return mean_positions


def create_gridpoints(structure: Atoms, n_gridpoints_per_angstrom: int = 5) -> np.ndarray:
def create_gridpoints(
structure: Atoms, n_gridpoints_per_angstrom: int = 5
) -> np.ndarray:
cell = get_vertical_length(structure=structure)
n_points = (n_gridpoints_per_angstrom * cell).astype(int)
positions = np.meshgrid(
Expand All @@ -67,12 +71,16 @@ def create_gridpoints(structure: Atoms, n_gridpoints_per_angstrom: int = 5) -> n
return np.einsum("ji,nj->ni", structure.cell, positions)


def remove_too_close(positions: np.ndarray, structure: Atoms, min_distance: float = 1) -> np.ndarray:
def remove_too_close(
positions: np.ndarray, structure: Atoms, min_distance: float = 1
) -> np.ndarray:
neigh = get_neighborhood(structure=structure, positions=positions, num_neighbors=1)
return positions[neigh.distances.flatten() > min_distance]


def set_to_high_symmetry_points(positions: np.ndarray, structure: Atoms, neigh, decimals: int = 4) -> np.ndarray:
def set_to_high_symmetry_points(
positions: np.ndarray, structure: Atoms, neigh, decimals: int = 4
) -> np.ndarray:
for _ in range(10):
neigh = neigh.get_neighborhood(positions)
dx = np.mean(neigh.vecs, axis=-2)
Expand All @@ -87,7 +95,14 @@ def set_to_high_symmetry_points(positions: np.ndarray, structure: Atoms, neigh,
raise ValueError("High symmetry points could not be detected")


def cluster_by_steinhardt(positions: np.ndarray, neigh, l_values: List[int], q_eps: float, var_ratio: float, min_samples: int) -> np.ndarray:
def cluster_by_steinhardt(
positions: np.ndarray,
neigh,
l_values: List[int],
q_eps: float,
var_ratio: float,
min_samples: int,
) -> np.ndarray:
"""
Clusters candidate positions via Steinhardt parameters and the variance in distances to host atoms.
Expand Down Expand Up @@ -240,7 +255,9 @@ def __init__(
self._positions = None
self.structure = structure

def run_workflow(self, positions: Optional[np.ndarray] = None, steps: int = -1) -> np.ndarray:
def run_workflow(
self, positions: Optional[np.ndarray] = None, steps: int = -1
) -> np.ndarray:
if positions is None:
positions = self.initial_positions.copy()
for ii, ww in enumerate(self.workflow):
Expand Down Expand Up @@ -470,7 +487,10 @@ def get_layers(


def get_voronoi_vertices(
structure: Atoms, epsilon: float = 2.5e-4, distance_threshold: float = 0, width_buffer: float = 10.0
structure: Atoms,
epsilon: float = 2.5e-4,
distance_threshold: float = 0,
width_buffer: float = 10.0,
) -> np.ndarray:
"""
Get voronoi vertices of the box.
Expand Down Expand Up @@ -578,7 +598,11 @@ def get_delaunay_neighbors(structure: Atoms, width_buffer: float = 10.0) -> np.n


def get_cluster_positions(
structure: Atoms, positions: Optional[np.ndarray] = None, eps: float = 1.0, buffer_width: Optional[float] =None, return_labels: bool = False
structure: Atoms,
positions: Optional[np.ndarray] = None,
eps: float = 1.0,
buffer_width: Optional[float] = None,
return_labels: bool = False,
) -> np.ndarray:
"""
Cluster positions according to the distances. Clustering algorithm uses DBSCAN:
Expand Down
27 changes: 21 additions & 6 deletions structuretoolkit/analyse/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ class Strain:
"""

def __init__(
self, structure: Atoms, ref_structure: Atoms, num_neighbors: Optional[int] = None, only_bulk_type: bool = False
self,
structure: Atoms,
ref_structure: Atoms,
num_neighbors: Optional[int] = None,
only_bulk_type: bool = False,
):
"""
Expand Down Expand Up @@ -71,7 +75,9 @@ def _nullify_non_bulk(self) -> np.ndarray:
self.structure.analyse.pyscal_cna_adaptive(mode="str") != self.crystal_phase
)

def _get_perpendicular_unit_vectors(self, vec: np.ndarray, vec_axis: Optional[np.ndarray] = None) -> np.ndarray:
def _get_perpendicular_unit_vectors(
self, vec: np.ndarray, vec_axis: Optional[np.ndarray] = None
) -> np.ndarray:
if vec_axis is not None:
vec_axis = self._get_safe_unit_vectors(vec_axis)
vec = np.array(
Expand All @@ -80,7 +86,9 @@ def _get_perpendicular_unit_vectors(self, vec: np.ndarray, vec_axis: Optional[np
return self._get_safe_unit_vectors(vec)

@staticmethod
def _get_safe_unit_vectors(vectors: np.ndarray, minimum_value: float = 1.0e-8) -> np.ndarray:
def _get_safe_unit_vectors(
vectors: np.ndarray, minimum_value: float = 1.0e-8
) -> np.ndarray:
v = np.linalg.norm(vectors, axis=-1)
v += (v < minimum_value) * minimum_value
return np.einsum("...i,...->...i", vectors, 1 / v)
Expand All @@ -94,7 +102,12 @@ def _get_angle(self, v: np.ndarray, w: np.ndarray) -> np.ndarray:
prod[np.absolute(prod) > 1] = np.sign(prod)[np.absolute(prod) > 1]
return np.arccos(prod)

def _get_rotation_from_vectors(self, vec_before: np.ndarray, vec_after: np.ndarray, vec_axis: Optional[np.ndarray] = None) -> np.ndarray:
def _get_rotation_from_vectors(
self,
vec_before: np.ndarray,
vec_after: np.ndarray,
vec_axis: Optional[np.ndarray] = None,
) -> np.ndarray:
v = self._get_perpendicular_unit_vectors(vec_before, vec_axis)
w = self._get_perpendicular_unit_vectors(vec_after, vec_axis)
if vec_axis is None:
Expand Down Expand Up @@ -132,7 +145,9 @@ def rotations(self) -> np.ndarray:
return self._rotations

@staticmethod
def _get_best_match_indices(coords: np.ndarray, ref_coord: np.ndarray) -> np.ndarray:
def _get_best_match_indices(
coords: np.ndarray, ref_coord: np.ndarray
) -> np.ndarray:
distances = np.linalg.norm(
coords[:, :, None, :] - ref_coord[None, None, :, :], axis=-1
)
Expand Down Expand Up @@ -195,7 +210,7 @@ def strain(self) -> np.ndarray:
def get_strain(
structure: Atoms,
ref_structure: Atoms,
num_neighbors: Optional[int] =None,
num_neighbors: Optional[int] = None,
only_bulk_type: bool = False,
return_object: bool = False,
):
Expand Down
9 changes: 7 additions & 2 deletions structuretoolkit/analyse/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def symmetrize_vectors(
np.einsum("ijk->jki", v_reshaped)[self.permutations],
).reshape(np.shape(vectors)) / len(self["rotations"])

def _get_spglib_cell(self, use_elements: Optional[bool] = None, use_magmoms: Optional[bool] = None) -> tuple:
def _get_spglib_cell(
self, use_elements: Optional[bool] = None, use_magmoms: Optional[bool] = None
) -> tuple:
lattice = np.array(self._structure.get_cell(), dtype="double", order="C")
positions = np.array(
self._structure.get_scaled_positions(wrap=False), dtype="double", order="C"
Expand Down Expand Up @@ -326,7 +328,10 @@ def spacegroup(self) -> dict:
}

def get_primitive_cell(
self, standardize: bool = False, use_elements: Optional[bool] = None, use_magmoms: Optional[bool] = None
self,
standardize: bool = False,
use_elements: Optional[bool] = None,
use_magmoms: Optional[bool] = None,
) -> Atoms:
"""
Get primitive cell of a given structure.
Expand Down
Loading

0 comments on commit 2788c81

Please sign in to comment.