diff --git a/src/sisl/_category.py b/src/sisl/_category.py index b123268ce5..d17af801c3 100644 --- a/src/sisl/_category.py +++ b/src/sisl/_category.py @@ -117,7 +117,8 @@ def name(self): r"""Name of category""" return self._name - def set_name(self, name): + @name.setter + def name(self, name): r"""Override the name of the categorization""" self._name = name @@ -288,6 +289,8 @@ class GenericCategory(Category): a specific object in which they act. """ + __slots__ = () + @classmethod def is_class(cls, name): # never allow one to match a generic class @@ -299,10 +302,7 @@ def is_class(cls, name): class NullCategory(GenericCategory): r"""Special Null class which always represents a classification not being *anything*""" - __slots__ = tuple() - - def __init__(self): - pass + __slots__ = () def categorize(self, *args, **kwargs): return self @@ -321,6 +321,12 @@ def _(self, other): def name(self): return "∅" + @name.setter + def name(self, name): + raise ValueError( + f"One cannot overwrite the name of a {self.__class__.__name__}" + ) + @set_module("sisl.category") class NotCategory(GenericCategory): @@ -331,9 +337,9 @@ class NotCategory(GenericCategory): def __init__(self, cat): super().__init__() if isinstance(cat, CompositeCategory): - self.set_name(f"~({cat})") + self.name = f"~({cat})" else: - self.set_name(f"~{cat}") + self.name = f"~{cat}" self._cat = cat def categorize(self, *args, **kwargs): @@ -364,7 +370,7 @@ def _(self, other): def _composite_name(sep): - def name(self): + def getter(self): if not self._name is None: return self._name @@ -380,7 +386,10 @@ def name(self): return f"{nameA} {sep} {nameB}" - return property(name) + def setter(self, name): + self._name = name + + return property(getter, setter) @set_module("sisl.category") @@ -426,7 +435,7 @@ class OrCategory(CompositeCategory, composite_name="|"): the right hand side of the set operation """ - __slots__ = tuple() + __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" @@ -465,7 +474,7 @@ class AndCategory(CompositeCategory, composite_name="&"): the right hand side of the set operation """ - __slots__ = tuple() + __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" @@ -507,7 +516,7 @@ class XOrCategory(CompositeCategory, composite_name="⊕"): the right hand side of the set operation """ - __slots__ = tuple() + __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" diff --git a/src/sisl/_core/geometry.py b/src/sisl/_core/geometry.py index d26bc62856..7df4f215a2 100644 --- a/src/sisl/_core/geometry.py +++ b/src/sisl/_core/geometry.py @@ -79,7 +79,7 @@ # Note how we are overwriting the module @set_module("sisl.geom") class AtomCategory(Category): - __slots__ = tuple() + __slots__ = () @classmethod def is_class(cls, name, case=True) -> bool: @@ -397,9 +397,13 @@ def _(self, atoms: Atom) -> ndarray: @_sanitize_atoms.register(AtomCategory) @_sanitize_atoms.register(GenericCategory) - def _(self, atoms: Union[AtomCategory, GenericCategory]) -> ndarray: + def _( + self, + atoms_: Union[AtomCategory, GenericCategory], + atoms: Optional[AtomsArgument] = None, + ) -> ndarray: # First do categorization - cat = atoms.categorize(self) + cat = atoms_.categorize(self, atoms) def m(cat): for ia, c in enumerate(cat): @@ -412,9 +416,9 @@ def m(cat): return _a.fromiterl(m(cat)) @_sanitize_atoms.register - def _(self, atoms: dict) -> ndarray: + def _(self, atoms_: dict, atoms: Optional[AtomsArgument] = None) -> ndarray: # First do categorization - return self._sanitize_atoms(AtomCategory.kw(**atoms)) + return self._sanitize_atoms(AtomCategory.kw(**atoms_), atoms) @_sanitize_atoms.register def _(self, atoms: Shape) -> ndarray: diff --git a/src/sisl/geom/_category/__init__.py b/src/sisl/geom/_category/__init__.py index 91051f1383..711246a88b 100644 --- a/src/sisl/geom/_category/__init__.py +++ b/src/sisl/geom/_category/__init__.py @@ -1,7 +1,11 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. + +from .base import * + +# isort: split + from ._coord import * from ._kind import * from ._neighbours import * -from .base import * diff --git a/src/sisl/geom/_category/_coord.py b/src/sisl/geom/_category/_coord.py index 551b42cf42..e1523ff14b 100644 --- a/src/sisl/geom/_category/_coord.py +++ b/src/sisl/geom/_category/_coord.py @@ -1,19 +1,23 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + import operator from functools import wraps from numbers import Integral +from typing import Optional import numpy as np import sisl._array as _a from sisl._category import CategoryMeta +from sisl._core import Geometry, Lattice, LatticeChild from sisl._core._lattice import cell_invert -from sisl._core.lattice import Lattice, LatticeChild from sisl._internal import set_module from sisl.messages import deprecate_argument -from sisl.shape import * +from sisl.shape import Shape +from sisl.typing import AtomsArgument from sisl.utils.misc import direction from .base import AtomCategory, NullCategory @@ -55,9 +59,7 @@ class AtomFracSite(AtomCategory): ... assert c == B_site """ - __slots__ = ( - f"_{a}" for a in ("cell", "icell", "length", "atol", "offset", "foffset") - ) + __slots__ = ("_cell", "_icell", "_length", "_atol", "_offset", "_foffset") @deprecate_argument( "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" @@ -87,11 +89,12 @@ def __init__( f"fracsite(atol={self._atol}, offset={self._offset}, foffset={self._foffset})" ) - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if atoms is None: fxyz = np.dot(geometry.xyz + self._offset, self._icell.T) + self._foffset else: + atoms = geometry._sanitize_atoms(atoms) fxyz = ( np.dot(geometry.xyz[atoms].reshape(-1, 3) + self._offset, self._icell.T) + self._foffset @@ -254,11 +257,12 @@ def func(a, b): self._coord_check = coord_ops super().__init__("coord") - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): if atoms is None: xyz = geometry.xyz fxyz = geometry.fxyz else: + atoms = geometry._sanitize_atoms(atoms) xyz = geometry.xyz[atoms] fxyz = geometry.fxyz[atoms] @@ -360,7 +364,7 @@ def _apply_key(k, v): # Create the class for this direction # note this is lower case since AtomZ should not interfere with Atomz new_cls = AtomXYZMeta( - f"Atom{name}", (AtomCategory,), {"__new__": _new_factory(key)} + f"Atom{name}", (AtomCategory,), {"__new__": _new_factory(key), "__slots__": ()} ) new_cls = set_module("sisl.geom")(new_cls) diff --git a/src/sisl/geom/_category/_kind.py b/src/sisl/geom/_category/_kind.py index b6f1f1054f..9959b1edb3 100644 --- a/src/sisl/geom/_category/_kind.py +++ b/src/sisl/geom/_category/_kind.py @@ -1,15 +1,20 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + import operator as op import re from functools import reduce, wraps from numbers import Integral +from typing import Optional, Union import numpy as np +from sisl._core import Geometry from sisl._help import isiterable from sisl._internal import set_module +from sisl.typing import AtomsArgument from sisl.utils import lstranges, strmap from .base import AtomCategory, NullCategory, _sanitize_loop @@ -29,7 +34,7 @@ class AtomZ(AtomCategory): __slots__ = ("_Z",) - def __init__(self, Z): + def __init__(self, Z: Union[int, Sequence[int]]): if isiterable(Z): self._Z = set(Z) else: @@ -38,7 +43,7 @@ def __init__(self, Z): super().__init__(f"Z={self._Z}") @_sanitize_loop - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if geometry.atoms.Z[atoms] in self._Z: return self @@ -66,13 +71,13 @@ class AtomTag(AtomCategory): __slots__ = ("_compiled_re", "_re") - def __init__(self, tag): + def __init__(self, tag: str): self._re = tag self._compiled_re = re.compile(self._re) super().__init__(f"tag={self._re}") @_sanitize_loop - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if self._compiled_re.match(geometry.atoms[atoms].tag): return self @@ -125,8 +130,9 @@ def __init__(self, *args, **kwargs): idx.add(arg) else: idx.update(arg) - for key_a in ["eq", "in", "contains"]: - for key in [key_a, f"__{key_a}__"]: + + for key_a in ("eq", "in", "contains"): + for key in (key_a, f"__{key_a}__"): arg = kwargs.pop(key, set()) if isinstance(arg, Integral): idx.add(arg) @@ -140,6 +146,7 @@ def make_partial(a, b): """Wrapper to make partial useful""" if isinstance(b, Integral): return op.truth(func(a, b)) + is_true = True for ib in b: is_true = is_true and func(a, ib) @@ -177,7 +184,7 @@ def func_wrap(a, b): ) @_sanitize_loop - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if reduce(op.and_, (f(atoms, b) for f, b in self._op_val), True): return self @@ -224,9 +231,11 @@ class AtomSeq(AtomIndex): the functions used to parse the sequence string into indices. """ + __slots__ = ("_seq",) + def __init__(self, seq): self._seq = seq - self.set_name(self._seq) + self._name = seq @staticmethod def _sanitize_negs(indices_map, end): @@ -255,7 +264,7 @@ def _sanitize(item): return [_sanitize(item) for item in indices_map] - def categorize(self, geometry, *args, **kwargs): + def categorize(self, geometry: Geometry, *args, **kwargs): # Now that we have the geometry, we know what is the end index # and we can finally safely convert the sequence to indices. indices_map = strmap(int, self._seq, start=0, end=geometry.na - 1) @@ -263,7 +272,7 @@ def categorize(self, geometry, *args, **kwargs): # Initialize the machinery of AtomIndex super().__init__(indices) - self.set_name(self._seq) + self.name = self._seq # Finally categorize return super().categorize(geometry, *args, **kwargs) @@ -276,13 +285,13 @@ def __eq__(self, other): class AtomEven(AtomCategory): r"""Classify atoms based on indices (even in this case)""" - __slots__ = [] + __slots__ = () - def __init__(self): - super().__init__("even") + def __init__(self, name="even"): + super().__init__(name) @_sanitize_loop - def categorize(self, geometry, atoms): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if atoms % 2 == 0: return self @@ -296,13 +305,13 @@ def __eq__(self, other): class AtomOdd(AtomCategory): r"""Classify atoms based on indices (odd in this case)""" - __slots__ = [] + __slots__ = () - def __init__(self): - super().__init__("odd") + def __init__(self, name="odd"): + super().__init__(name) @_sanitize_loop - def categorize(self, geometry, atoms): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): # _sanitize_loop will ensure that atoms will always be an integer if atoms % 2 == 1: return self diff --git a/src/sisl/geom/_category/_neighbours.py b/src/sisl/geom/_category/_neighbours.py index 7a801563b7..ddecc14502 100644 --- a/src/sisl/geom/_category/_neighbours.py +++ b/src/sisl/geom/_category/_neighbours.py @@ -1,9 +1,15 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from typing import Optional + from numpy import ndarray +from sisl._core import Geometry from sisl._internal import set_module +from sisl.typing import AtomsArgument from .base import AtomCategory, NullCategory, _sanitize_loop @@ -82,10 +88,12 @@ def __init__(self, *args, **kwargs): # Determine name. If there are requirements for the neighbours # then the name changes + if self._in is None: - self.set_name(f"neighbours{name}") + name = f"neighbours{name}" else: - self.set_name(f"neighbours({self._in}){name}") + name = f"neighbours({self._in}){name}" + super().__init__(name) def R(self, atom): if self._R is None: @@ -97,9 +105,9 @@ def R(self, atom): return (0.01, self._R) @_sanitize_loop - def categorize(self, geometry, atoms=None): + def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): """Check if geometry and atoms matches the neighbour criteria""" - idx = geometry.close(atoms, R=self.R(geometry.atoms[atoms]))[1] + idx = geometry.close(atoms, R=self.R(geometry.atoms[atoms]))[-1] # quick escape the lower bound, in case we have more than max, they could # be limited by the self._in type n = len(idx) @@ -112,6 +120,7 @@ def categorize(self, geometry, atoms=None): cat = self._in.categorize(geometry, geometry.asc2uc(idx)) idx = [i for i, c in zip(idx, cat) if not isinstance(c, NullCategory)] n = len(idx) + if self._min <= n <= self._max: return self return NullCategory() diff --git a/src/sisl/geom/_category/base.py b/src/sisl/geom/_category/base.py index b5806a972f..51cf52b546 100644 --- a/src/sisl/geom/_category/base.py +++ b/src/sisl/geom/_category/base.py @@ -1,19 +1,24 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + from functools import wraps +from typing import Optional from sisl._category import Category, NullCategory -from sisl._core.geometry import AtomCategory +from sisl._core import AtomCategory, Geometry +from sisl.typing import AtomsArgument __all__ = ["NullCategory", "AtomCategory"] def _sanitize_loop(func): @wraps(func) - def loop_func(self, geometry, atoms=None): + def loop_func(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None): if atoms is None: return [func(self, geometry, ia) for ia in geometry] + # extract based on atoms selection atoms = geometry._sanitize_atoms(atoms) if atoms.ndim == 0: @@ -21,8 +26,3 @@ def loop_func(self, geometry, atoms=None): return [func(self, geometry, ia) for ia in atoms] return loop_func - - -# class AtomCategory(Category) -# is defined in sisl._core.geometry.py since it is required in -# that instance. diff --git a/src/sisl/geom/_category/tests/test_geom_category.py b/src/sisl/geom/_category/tests/test_geom_category.py index 4c9214a2c2..401c39167e 100644 --- a/src/sisl/geom/_category/tests/test_geom_category.py +++ b/src/sisl/geom/_category/tests/test_geom_category.py @@ -38,6 +38,10 @@ def test_geom_category(): category = (B & B2) ^ (N & N2) ^ (B & B3) ^ (N & N3) ^ n2 cat = category.categorize(hBN) + cat1 = category.categorize(hBN, atoms=[2, 3]) + assert len(cat1) == 2 + assert cat[2] == cat1[0] + assert cat[3] == cat1[1] def test_geom_category_no_r():