diff --git a/CHANGELOG.md b/CHANGELOG.md index 36a7c1c2a..dd67c0d3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ we hit release version 1.0.0. - A new `AtomicMatrixPlot` to plot sparse matrices, #668 ### Fixed +- dispatch methods now searches the mro for best matches, #721 - all `eps` arguments has changed to `atol` - methods with `axis` arguments now accepts the str equivalent 0==a - documentation links to external resources diff --git a/src/sisl/_core/tests/test_lattice.py b/src/sisl/_core/tests/test_lattice.py index cbb3ea1d2..8d6299d7a 100644 --- a/src/sisl/_core/tests/test_lattice.py +++ b/src/sisl/_core/tests/test_lattice.py @@ -243,7 +243,6 @@ def test_creation1(self, setup): assert np.allclose(tmp1.cell, tmp3.cell) assert np.allclose(tmp1.cell, tmp4.cell) - @pytest.mark.xfail(reason="dispatch does not resolve mro (see #721") def test_creation_latticechild_dispatch(self, setup): # full cell class P(LatticeChild): diff --git a/src/sisl/_dispatch_class.py b/src/sisl/_dispatch_class.py index 17d1e76ae..9ec4889ac 100644 --- a/src/sisl/_dispatch_class.py +++ b/src/sisl/_dispatch_class.py @@ -3,7 +3,15 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. from __future__ import annotations -""" Internal class used for subclassing. +import logging +from collections import namedtuple +from typing import Any, Optional, Sequence, Union + +from sisl.utils._search_mro import find_implementation + +from ._dispatcher import AbstractDispatcher, ClassDispatcher, TypeDispatcher + +"""Internal class used for subclassing. This class implements the @@ -24,17 +32,11 @@ class A(_Dispatchs, A.hello.register .. """ -import logging -from collections import namedtuple -from typing import Any, Optional, Sequence, Union - -from ._dispatcher import AbstractDispatcher, ClassDispatcher, TypeDispatcher - _log = logging.getLogger(__name__) class _Dispatchs: - """Subclassable for creating the new/to arguments""" + """Subclassable for creating the dispatch arguments""" def __init_subclass__( cls, @@ -47,18 +49,23 @@ def __init_subclass__( super().__init_subclass__(**kwargs) # Get the allowed actions for subclassing - prefix = "_tonew" + prefix = "_cls_dispatchs" allowed_subclassing = ("keep", "new", "copy") def find_base(cls, attr): - for base in cls.__bases__: + # The order of execution, since the implementation search + # is based on MRO, we should search in that order. + for base in cls.__mro__: if hasattr(base, attr): return base return None if dispatchs is None: + # Copy dispatch names when subclassing. + # I.e. we will search through all the previous ones + # and copy them. dispatchs = [] - for base in cls.__bases__: + for base in cls.__mro__: if hasattr(base, f"{prefix}_dispatchs"): dispatchs.extend(getattr(base, f"{prefix}_dispatchs")) diff --git a/src/sisl/_dispatcher.py b/src/sisl/_dispatcher.py index 33c7eec67..bd2f5e2b5 100644 --- a/src/sisl/_dispatcher.py +++ b/src/sisl/_dispatcher.py @@ -14,6 +14,9 @@ from abc import ABCMeta, abstractmethod from collections import ChainMap, namedtuple from functools import update_wrapper +from typing import Any + +from sisl.utils._search_mro import find_implementation __all__ = [ "AbstractDispatch", @@ -120,6 +123,20 @@ def __getattr__(self, key): return attr +def _get_dispatch(dispatcher: AbstractDispatcher, key: Any): + """Return the dispatch contained in `obj._dispatchs`""" + dispatchs = dispatcher._dispatchs + if isinstance(key, type) and key not in dispatchs: + dispatch = find_implementation(key, dispatchs) + # we will register for a faster look-up next time. + dispatcher.register(key, dispatch) + else: + dispatch = dispatchs.get(key) + if dispatch is None: + raise KeyError(f"{dispatcher.__class__.__name__} has no dispatch for {key}.") + return dispatch + + class AbstractDispatcher(metaclass=ABCMeta): """A container for dispatchers @@ -287,7 +304,8 @@ def __getitem__(self, key): _log.debug( f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} ) - return self._dispatchs[key](self._obj, **self._attrs).dispatch(self.__wrapped__) + dispatch = _get_dispatch(self, key) + return dispatch(self._obj, **self._attrs).dispatch(self.__wrapped__) __getattr__ = __getitem__ @@ -429,10 +447,13 @@ def __getitem__(self, key): _log.debug( f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} ) - return self._dispatchs[key](self._obj, **self._attrs) + dispatch = _get_dispatch(self, key) + return dispatch(self._obj, **self._attrs) def __getattr__(self, key): """Retrieve dispatched method by name, or if the name does not exist return a MethodDispatcher""" + # Attribute retrieval will never be a class, so this will be directly + # inferable in the dictionary. if key in self._dispatchs: _log.debug( f"__getattr__ {self.__class__.__name__},dispatch={key}", @@ -515,14 +536,8 @@ def __call__(self, obj, *args, **kwargs): # if you want obj to be a type, then the dispatcher should control that _log.debug(f"call {self.__class__.__name__}{args}", extra={"obj": self}) - return self._dispatchs[typ](self._obj)(obj, *args, **kwargs) - - def __getitem__(self, key): - r"""Retrieve dispatched dispatchs by hash (allows functions to be dispatched)""" - _log.debug( - f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} - ) - return self._dispatchs[key](self._obj, **self._attrs) + dispatch = _get_dispatch(self, typ) + return dispatch(self._obj)(obj, *args, **kwargs) class ClassDispatcher(AbstractDispatcher): diff --git a/src/sisl/utils/_search_mro.py b/src/sisl/utils/_search_mro.py new file mode 100644 index 000000000..da5345099 --- /dev/null +++ b/src/sisl/utils/_search_mro.py @@ -0,0 +1,193 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from types import GenericAlias + +# +# This code is directly taken from: +# Python standard lib functools (search for `singledispatch` +# and how it is implemented. +# The aim is to provide the same kind of registry effects +# as the `singledispatch` methods does. +# This is important when one does dispatches for subclassed +# classes. +# The reason for copying that code is that it correctly handles +# abc's and when trying to find the best candidate when having +# multiple classes in the registry. If there was always +# single elements in the registry, it would be simpler and this +# complication wouldn't be necessary. +# From here on the LICENSE of the code is that of the PSF License. + + +__all__ = ["find_implementation"] + + +def _c3_merge(sequences): + """Merges MROs in *sequences* to a single MRO using the C3 algorithm. + + Adapted from https://docs.python.org/3/howto/mro.html. + + """ + result = [] + while True: + sequences = [s for s in sequences if s] # purge empty sequences + if not sequences: + return result + for s1 in sequences: # find merge candidates among seq heads + candidate = s1[0] + for s2 in sequences: + if candidate in s2[1:]: + candidate = None + break # reject the current head, it appears later + else: + break + if candidate is None: + raise RuntimeError("Inconsistent hierarchy") + result.append(candidate) + # remove the chosen candidate + for seq in sequences: + if seq[0] == candidate: + del seq[0] + + +def _c3_mro(cls, abcs=None): + """Computes the method resolution order using extended C3 linearization. + + If no *abcs* are given, the algorithm works exactly like the built-in C3 + linearization used for method resolution. + + If given, *abcs* is a list of abstract base classes that should be inserted + into the resulting MRO. Unrelated ABCs are ignored and don't end up in the + result. The algorithm inserts ABCs where their functionality is introduced, + i.e. issubclass(cls, abc) returns True for the class itself but returns + False for all its direct base classes. Implicit ABCs for a given class + (either registered or inferred from the presence of a special method like + __len__) are inserted directly after the last ABC explicitly listed in the + MRO of said class. If two implicit ABCs end up next to each other in the + resulting MRO, their ordering depends on the order of types in *abcs*. + + """ + for i, base in enumerate(reversed(cls.__bases__)): + if hasattr(base, "__abstractmethods__"): + boundary = len(cls.__bases__) - i + break # Bases up to the last explicit ABC are considered first. + else: + boundary = 0 + abcs = list(abcs) if abcs else [] + explicit_bases = list(cls.__bases__[:boundary]) + abstract_bases = [] + other_bases = list(cls.__bases__[boundary:]) + for base in abcs: + if issubclass(cls, base) and not any( + issubclass(b, base) for b in cls.__bases__ + ): + # If *cls* is the class that introduces behaviour described by + # an ABC *base*, insert said ABC to its MRO. + abstract_bases.append(base) + for base in abstract_bases: + abcs.remove(base) + explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] + abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] + other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] + return _c3_merge( + [[cls]] + + explicit_c3_mros + + abstract_c3_mros + + other_c3_mros + + [explicit_bases] + + [abstract_bases] + + [other_bases] + ) + + +def _compose_mro(cls, types): + """Calculates the method resolution order for a given class *cls*. + + Includes relevant abstract base classes (with their respective bases) from + the *types* iterable. Uses a modified C3 linearization algorithm. + + """ + bases = set(cls.__mro__) + + # Remove entries which are already present in the __mro__ or unrelated. + def is_related(typ): + return ( + typ not in bases + and hasattr(typ, "__mro__") + and not isinstance(typ, GenericAlias) + and issubclass(cls, typ) + ) + + types = [n for n in types if is_related(n)] + + # Remove entries which are strict bases of other entries (they will end up + # in the MRO anyway. + def is_strict_base(typ): + for other in types: + if typ != other and typ in other.__mro__: + return True + return False + + types = [n for n in types if not is_strict_base(n)] + + # Subclasses of the ABCs in *types* which are also implemented by + # *cls* can be used to stabilize ABC ordering. + type_set = set(types) + mro = [] + for typ in types: + found = [] + for sub in typ.__subclasses__(): + if sub not in bases and issubclass(cls, sub): + found.append([s for s in sub.__mro__ if s in type_set]) + if not found: + mro.append(typ) + continue + + # Favor subclasses with the biggest number of useful bases + found.sort(key=len, reverse=True) + for sub in found: + for subcls in sub: + if subcls not in mro: + mro.append(subcls) + return _c3_mro(cls, abcs=mro) + + +def find_implementation(cls: object, registry: dict): + """Returns the best matching implementation from *registry* for type *cls*. + + Where there is no registered implementation for a specific type, its method + resolution order is used to find a more generic implementation. + + Note: if *registry* does not contain an implementation for the base + *object* type, this function may return None. + + Parameters + ---------- + cls : + the class to search for in the registry. I.e. if `cls` is a subclass + of an entry in `registry`, this will return the value contained + in `registry` for the best match. + registry: + the registry dictionary where we will return the best matching + class + """ + keys = list(registry.keys()) + mro = _compose_mro(cls, keys) + match = None + for t in mro: + if match is not None: + # If *match* is an implicit ABC but there is another unrelated, + # equally matching implicit ABC, refuse the temptation to guess. + if ( + t in keys + and t not in cls.__mro__ + and match not in cls.__mro__ + and not issubclass(match, t) + ): + raise RuntimeError("Ambiguous dispatch: {} or {}".format(match, t)) + break + if t in registry: + match = t + return registry.get(match)