Skip to content

Commit

Permalink
added typing to AtomCategory
Browse files Browse the repository at this point in the history
Also added test to reduce search space
when categorizing.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Feb 20, 2024
1 parent e78ef89 commit fdface3
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 55 deletions.
33 changes: 21 additions & 12 deletions src/sisl/_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def name(self):
r"""Name of category"""
return self._name

def set_name(self, name):
@name.setter
def name(self, name):
r"""Override the name of the categorization"""
self._name = name

Expand Down Expand Up @@ -288,6 +289,8 @@ class GenericCategory(Category):
a specific object in which they act.
"""

__slots__ = ()

@classmethod
def is_class(cls, name):
# never allow one to match a generic class
Expand All @@ -299,10 +302,7 @@ def is_class(cls, name):
class NullCategory(GenericCategory):
r"""Special Null class which always represents a classification not being *anything*"""

__slots__ = tuple()

def __init__(self):
pass
__slots__ = ()

def categorize(self, *args, **kwargs):
return self
Expand All @@ -321,6 +321,12 @@ def _(self, other):
def name(self):
return "∅"

@name.setter
def name(self, name):
raise ValueError(
f"One cannot overwrite the name of a {self.__class__.__name__}"
)


@set_module("sisl.category")
class NotCategory(GenericCategory):
Expand All @@ -331,9 +337,9 @@ class NotCategory(GenericCategory):
def __init__(self, cat):
super().__init__()
if isinstance(cat, CompositeCategory):
self.set_name(f"~({cat})")
self.name = f"~({cat})"
else:
self.set_name(f"~{cat}")
self.name = f"~{cat}"
self._cat = cat

def categorize(self, *args, **kwargs):
Expand Down Expand Up @@ -364,7 +370,7 @@ def _(self, other):


def _composite_name(sep):
def name(self):
def getter(self):
if not self._name is None:
return self._name

Expand All @@ -380,7 +386,10 @@ def name(self):

return f"{nameA} {sep} {nameB}"

return property(name)
def setter(self, name):
self._name = name

return property(getter, setter)


@set_module("sisl.category")
Expand Down Expand Up @@ -426,7 +435,7 @@ class OrCategory(CompositeCategory, composite_name="|"):
the right hand side of the set operation
"""

__slots__ = tuple()
__slots__ = ()

def categorize(self, *args, **kwargs):
r"""Base method for queriyng whether an object is a certain category"""
Expand Down Expand Up @@ -465,7 +474,7 @@ class AndCategory(CompositeCategory, composite_name="&"):
the right hand side of the set operation
"""

__slots__ = tuple()
__slots__ = ()

def categorize(self, *args, **kwargs):
r"""Base method for queriyng whether an object is a certain category"""
Expand Down Expand Up @@ -507,7 +516,7 @@ class XOrCategory(CompositeCategory, composite_name="⊕"):
the right hand side of the set operation
"""

__slots__ = tuple()
__slots__ = ()

def categorize(self, *args, **kwargs):
r"""Base method for queriyng whether an object is a certain category"""
Expand Down
14 changes: 9 additions & 5 deletions src/sisl/_core/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# Note how we are overwriting the module
@set_module("sisl.geom")
class AtomCategory(Category):
__slots__ = tuple()
__slots__ = ()

@classmethod
def is_class(cls, name, case=True) -> bool:
Expand Down Expand Up @@ -397,9 +397,13 @@ def _(self, atoms: Atom) -> ndarray:

@_sanitize_atoms.register(AtomCategory)
@_sanitize_atoms.register(GenericCategory)
def _(self, atoms: Union[AtomCategory, GenericCategory]) -> ndarray:
def _(
self,
atoms_: Union[AtomCategory, GenericCategory],
atoms: Optional[AtomsArgument] = None,
) -> ndarray:
# First do categorization
cat = atoms.categorize(self)
cat = atoms_.categorize(self, atoms)

def m(cat):
for ia, c in enumerate(cat):
Expand All @@ -412,9 +416,9 @@ def m(cat):
return _a.fromiterl(m(cat))

@_sanitize_atoms.register
def _(self, atoms: dict) -> ndarray:
def _(self, atoms_: dict, atoms: Optional[AtomsArgument] = None) -> ndarray:
# First do categorization
return self._sanitize_atoms(AtomCategory.kw(**atoms))
return self._sanitize_atoms(AtomCategory.kw(**atoms_), atoms)

@_sanitize_atoms.register
def _(self, atoms: Shape) -> ndarray:
Expand Down
6 changes: 5 additions & 1 deletion src/sisl/geom/_category/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

from .base import *

# isort: split

from ._coord import *
from ._kind import *
from ._neighbours import *
from .base import *
20 changes: 12 additions & 8 deletions src/sisl/geom/_category/_coord.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

import operator
from functools import wraps
from numbers import Integral
from typing import Optional

import numpy as np

import sisl._array as _a
from sisl._category import CategoryMeta
from sisl._core import Geometry, Lattice, LatticeChild
from sisl._core._lattice import cell_invert
from sisl._core.lattice import Lattice, LatticeChild
from sisl._internal import set_module
from sisl.messages import deprecate_argument
from sisl.shape import *
from sisl.shape import Shape
from sisl.typing import AtomsArgument
from sisl.utils.misc import direction

from .base import AtomCategory, NullCategory
Expand Down Expand Up @@ -55,9 +59,7 @@ class AtomFracSite(AtomCategory):
... assert c == B_site
"""

__slots__ = (
f"_{a}" for a in ("cell", "icell", "length", "atol", "offset", "foffset")
)
__slots__ = ("_cell", "_icell", "_length", "_atol", "_offset", "_foffset")

@deprecate_argument(
"sc", "lattice", "use lattice= instead of sc=", from_version="0.15"
Expand Down Expand Up @@ -87,11 +89,12 @@ def __init__(
f"fracsite(atol={self._atol}, offset={self._offset}, foffset={self._foffset})"
)

def categorize(self, geometry, atoms=None):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if atoms is None:
fxyz = np.dot(geometry.xyz + self._offset, self._icell.T) + self._foffset
else:
atoms = geometry._sanitize_atoms(atoms)
fxyz = (
np.dot(geometry.xyz[atoms].reshape(-1, 3) + self._offset, self._icell.T)
+ self._foffset
Expand Down Expand Up @@ -254,11 +257,12 @@ def func(a, b):
self._coord_check = coord_ops
super().__init__("coord")

def categorize(self, geometry, atoms=None):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
if atoms is None:
xyz = geometry.xyz
fxyz = geometry.fxyz
else:
atoms = geometry._sanitize_atoms(atoms)
xyz = geometry.xyz[atoms]
fxyz = geometry.fxyz[atoms]

Expand Down Expand Up @@ -360,7 +364,7 @@ def _apply_key(k, v):
# Create the class for this direction
# note this is lower case since AtomZ should not interfere with Atomz
new_cls = AtomXYZMeta(
f"Atom{name}", (AtomCategory,), {"__new__": _new_factory(key)}
f"Atom{name}", (AtomCategory,), {"__new__": _new_factory(key), "__slots__": ()}
)

new_cls = set_module("sisl.geom")(new_cls)
Expand Down
45 changes: 27 additions & 18 deletions src/sisl/geom/_category/_kind.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

import operator as op
import re
from functools import reduce, wraps
from numbers import Integral
from typing import Optional, Union

import numpy as np

from sisl._core import Geometry
from sisl._help import isiterable
from sisl._internal import set_module
from sisl.typing import AtomsArgument
from sisl.utils import lstranges, strmap

from .base import AtomCategory, NullCategory, _sanitize_loop
Expand All @@ -29,7 +34,7 @@ class AtomZ(AtomCategory):

__slots__ = ("_Z",)

def __init__(self, Z):
def __init__(self, Z: Union[int, Sequence[int]]):
if isiterable(Z):
self._Z = set(Z)
else:
Expand All @@ -38,7 +43,7 @@ def __init__(self, Z):
super().__init__(f"Z={self._Z}")

@_sanitize_loop
def categorize(self, geometry, atoms=None):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if geometry.atoms.Z[atoms] in self._Z:
return self
Expand Down Expand Up @@ -66,13 +71,13 @@ class AtomTag(AtomCategory):

__slots__ = ("_compiled_re", "_re")

def __init__(self, tag):
def __init__(self, tag: str):
self._re = tag
self._compiled_re = re.compile(self._re)
super().__init__(f"tag={self._re}")

@_sanitize_loop
def categorize(self, geometry, atoms=None):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if self._compiled_re.match(geometry.atoms[atoms].tag):
return self
Expand Down Expand Up @@ -125,8 +130,9 @@ def __init__(self, *args, **kwargs):
idx.add(arg)
else:
idx.update(arg)
for key_a in ["eq", "in", "contains"]:
for key in [key_a, f"__{key_a}__"]:

for key_a in ("eq", "in", "contains"):
for key in (key_a, f"__{key_a}__"):
arg = kwargs.pop(key, set())
if isinstance(arg, Integral):
idx.add(arg)
Expand All @@ -140,6 +146,7 @@ def make_partial(a, b):
"""Wrapper to make partial useful"""
if isinstance(b, Integral):
return op.truth(func(a, b))

is_true = True
for ib in b:
is_true = is_true and func(a, ib)
Expand Down Expand Up @@ -177,7 +184,7 @@ def func_wrap(a, b):
)

@_sanitize_loop
def categorize(self, geometry, atoms=None):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if reduce(op.and_, (f(atoms, b) for f, b in self._op_val), True):
return self
Expand Down Expand Up @@ -224,9 +231,11 @@ class AtomSeq(AtomIndex):
the functions used to parse the sequence string into indices.
"""

__slots__ = ("_seq",)

def __init__(self, seq):
self._seq = seq
self.set_name(self._seq)
self._name = seq

@staticmethod
def _sanitize_negs(indices_map, end):
Expand Down Expand Up @@ -255,15 +264,15 @@ def _sanitize(item):

return [_sanitize(item) for item in indices_map]

def categorize(self, geometry, *args, **kwargs):
def categorize(self, geometry: Geometry, *args, **kwargs):
# Now that we have the geometry, we know what is the end index
# and we can finally safely convert the sequence to indices.
indices_map = strmap(int, self._seq, start=0, end=geometry.na - 1)
indices = lstranges(self._sanitize_negs(indices_map, end=geometry.na - 1))

# Initialize the machinery of AtomIndex
super().__init__(indices)
self.set_name(self._seq)
self.name = self._seq
# Finally categorize
return super().categorize(geometry, *args, **kwargs)

Expand All @@ -276,13 +285,13 @@ def __eq__(self, other):
class AtomEven(AtomCategory):
r"""Classify atoms based on indices (even in this case)"""

__slots__ = []
__slots__ = ()

def __init__(self):
super().__init__("even")
def __init__(self, name="even"):
super().__init__(name)

@_sanitize_loop
def categorize(self, geometry, atoms):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if atoms % 2 == 0:
return self
Expand All @@ -296,13 +305,13 @@ def __eq__(self, other):
class AtomOdd(AtomCategory):
r"""Classify atoms based on indices (odd in this case)"""

__slots__ = []
__slots__ = ()

def __init__(self):
super().__init__("odd")
def __init__(self, name="odd"):
super().__init__(name)

@_sanitize_loop
def categorize(self, geometry, atoms):
def categorize(self, geometry: Geometry, atoms: Optional[AtomsArgument] = None):
# _sanitize_loop will ensure that atoms will always be an integer
if atoms % 2 == 1:
return self
Expand Down
Loading

0 comments on commit fdface3

Please sign in to comment.