Skip to content

Commit

Permalink
fixed MRO resolution for dispatch methods
Browse files Browse the repository at this point in the history
This will fix #721.

Basically I have copied the code from stdlib functools.

Credit is mentioned!

Also cleaned the code a bit.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Apr 16, 2024
1 parent 8c6d9c8 commit 3475392
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ we hit release version 1.0.0.
- A new `AtomicMatrixPlot` to plot sparse matrices, #668

### Fixed
- dispatch methods now searches the mro for best matches, #721
- all `eps` arguments has changed to `atol`
- methods with `axis` arguments now accepts the str equivalent 0==a
- documentation links to external resources
Expand Down
1 change: 0 additions & 1 deletion src/sisl/_core/tests/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def test_creation1(self, setup):
assert np.allclose(tmp1.cell, tmp3.cell)
assert np.allclose(tmp1.cell, tmp4.cell)

@pytest.mark.xfail(reason="dispatch does not resolve mro (see #721")
def test_creation_latticechild_dispatch(self, setup):
# full cell
class P(LatticeChild):
Expand Down
29 changes: 18 additions & 11 deletions src/sisl/_dispatch_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

""" Internal class used for subclassing.
import logging
from collections import namedtuple
from typing import Any, Optional, Sequence, Union

from sisl.utils._search_mro import find_implementation

from ._dispatcher import AbstractDispatcher, ClassDispatcher, TypeDispatcher

"""Internal class used for subclassing.
This class implements the
Expand All @@ -24,17 +32,11 @@ class A(_Dispatchs,
A.hello.register ..
"""

import logging
from collections import namedtuple
from typing import Any, Optional, Sequence, Union

from ._dispatcher import AbstractDispatcher, ClassDispatcher, TypeDispatcher

_log = logging.getLogger(__name__)


class _Dispatchs:
"""Subclassable for creating the new/to arguments"""
"""Subclassable for creating the dispatch arguments"""

def __init_subclass__(
cls,
Expand All @@ -47,18 +49,23 @@ def __init_subclass__(
super().__init_subclass__(**kwargs)

# Get the allowed actions for subclassing
prefix = "_tonew"
prefix = "_cls_dispatchs"
allowed_subclassing = ("keep", "new", "copy")

def find_base(cls, attr):
for base in cls.__bases__:
# The order of execution, since the implementation search
# is based on MRO, we should search in that order.
for base in cls.__mro__:
if hasattr(base, attr):
return base
return None

if dispatchs is None:
# Copy dispatch names when subclassing.
# I.e. we will search through all the previous ones
# and copy them.
dispatchs = []
for base in cls.__bases__:
for base in cls.__mro__:
if hasattr(base, f"{prefix}_dispatchs"):
dispatchs.extend(getattr(base, f"{prefix}_dispatchs"))

Expand Down
35 changes: 25 additions & 10 deletions src/sisl/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from abc import ABCMeta, abstractmethod
from collections import ChainMap, namedtuple
from functools import update_wrapper
from typing import Any

from sisl.utils._search_mro import find_implementation

__all__ = [
"AbstractDispatch",
Expand Down Expand Up @@ -120,6 +123,20 @@ def __getattr__(self, key):
return attr


def _get_dispatch(dispatcher: AbstractDispatcher, key: Any):
"""Return the dispatch contained in `obj._dispatchs`"""
dispatchs = dispatcher._dispatchs
if isinstance(key, type) and key not in dispatchs:
dispatch = find_implementation(key, dispatchs)
# we will register for a faster look-up next time.
dispatcher.register(key, dispatch)
else:
dispatch = dispatchs.get(key)
if dispatch is None:
raise KeyError(f"{dispatcher.__class__.__name__} has no dispatch for {key}.")
return dispatch


class AbstractDispatcher(metaclass=ABCMeta):
"""A container for dispatchers
Expand Down Expand Up @@ -287,7 +304,8 @@ def __getitem__(self, key):
_log.debug(
f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}
)
return self._dispatchs[key](self._obj, **self._attrs).dispatch(self.__wrapped__)
dispatch = _get_dispatch(self, key)
return dispatch(self._obj, **self._attrs).dispatch(self.__wrapped__)

__getattr__ = __getitem__

Expand Down Expand Up @@ -429,10 +447,13 @@ def __getitem__(self, key):
_log.debug(
f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}
)
return self._dispatchs[key](self._obj, **self._attrs)
dispatch = _get_dispatch(self, key)
return dispatch(self._obj, **self._attrs)

def __getattr__(self, key):
"""Retrieve dispatched method by name, or if the name does not exist return a MethodDispatcher"""
# Attribute retrieval will never be a class, so this will be directly
# inferable in the dictionary.
if key in self._dispatchs:
_log.debug(
f"__getattr__ {self.__class__.__name__},dispatch={key}",
Expand Down Expand Up @@ -515,14 +536,8 @@ def __call__(self, obj, *args, **kwargs):

# if you want obj to be a type, then the dispatcher should control that
_log.debug(f"call {self.__class__.__name__}{args}", extra={"obj": self})
return self._dispatchs[typ](self._obj)(obj, *args, **kwargs)

def __getitem__(self, key):
r"""Retrieve dispatched dispatchs by hash (allows functions to be dispatched)"""
_log.debug(
f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}
)
return self._dispatchs[key](self._obj, **self._attrs)
dispatch = _get_dispatch(self, typ)
return dispatch(self._obj)(obj, *args, **kwargs)


class ClassDispatcher(AbstractDispatcher):
Expand Down
193 changes: 193 additions & 0 deletions src/sisl/utils/_search_mro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from types import GenericAlias

#
# This code is directly taken from:
# Python standard lib functools (search for `singledispatch`
# and how it is implemented.
# The aim is to provide the same kind of registry effects
# as the `singledispatch` methods does.
# This is important when one does dispatches for subclassed
# classes.
# The reason for copying that code is that it correctly handles
# abc's and when trying to find the best candidate when having
# multiple classes in the registry. If there was always
# single elements in the registry, it would be simpler and this
# complication wouldn't be necessary.
# From here on the LICENSE of the code is that of the PSF License.


__all__ = ["find_implementation"]


def _c3_merge(sequences):
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
Adapted from https://docs.python.org/3/howto/mro.html.
"""
result = []
while True:
sequences = [s for s in sequences if s] # purge empty sequences
if not sequences:
return result
for s1 in sequences: # find merge candidates among seq heads
candidate = s1[0]
for s2 in sequences:
if candidate in s2[1:]:
candidate = None
break # reject the current head, it appears later
else:
break
if candidate is None:
raise RuntimeError("Inconsistent hierarchy")
result.append(candidate)
# remove the chosen candidate
for seq in sequences:
if seq[0] == candidate:
del seq[0]


def _c3_mro(cls, abcs=None):
"""Computes the method resolution order using extended C3 linearization.
If no *abcs* are given, the algorithm works exactly like the built-in C3
linearization used for method resolution.
If given, *abcs* is a list of abstract base classes that should be inserted
into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
result. The algorithm inserts ABCs where their functionality is introduced,
i.e. issubclass(cls, abc) returns True for the class itself but returns
False for all its direct base classes. Implicit ABCs for a given class
(either registered or inferred from the presence of a special method like
__len__) are inserted directly after the last ABC explicitly listed in the
MRO of said class. If two implicit ABCs end up next to each other in the
resulting MRO, their ordering depends on the order of types in *abcs*.
"""
for i, base in enumerate(reversed(cls.__bases__)):
if hasattr(base, "__abstractmethods__"):
boundary = len(cls.__bases__) - i
break # Bases up to the last explicit ABC are considered first.
else:
boundary = 0
abcs = list(abcs) if abcs else []
explicit_bases = list(cls.__bases__[:boundary])
abstract_bases = []
other_bases = list(cls.__bases__[boundary:])
for base in abcs:
if issubclass(cls, base) and not any(
issubclass(b, base) for b in cls.__bases__
):
# If *cls* is the class that introduces behaviour described by
# an ABC *base*, insert said ABC to its MRO.
abstract_bases.append(base)
for base in abstract_bases:
abcs.remove(base)
explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
return _c3_merge(
[[cls]]
+ explicit_c3_mros
+ abstract_c3_mros
+ other_c3_mros
+ [explicit_bases]
+ [abstract_bases]
+ [other_bases]
)


def _compose_mro(cls, types):
"""Calculates the method resolution order for a given class *cls*.
Includes relevant abstract base classes (with their respective bases) from
the *types* iterable. Uses a modified C3 linearization algorithm.
"""
bases = set(cls.__mro__)

# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (
typ not in bases
and hasattr(typ, "__mro__")
and not isinstance(typ, GenericAlias)
and issubclass(cls, typ)
)

types = [n for n in types if is_related(n)]

# Remove entries which are strict bases of other entries (they will end up
# in the MRO anyway.
def is_strict_base(typ):
for other in types:
if typ != other and typ in other.__mro__:
return True
return False

types = [n for n in types if not is_strict_base(n)]

# Subclasses of the ABCs in *types* which are also implemented by
# *cls* can be used to stabilize ABC ordering.
type_set = set(types)
mro = []
for typ in types:
found = []
for sub in typ.__subclasses__():
if sub not in bases and issubclass(cls, sub):
found.append([s for s in sub.__mro__ if s in type_set])
if not found:
mro.append(typ)
continue

# Favor subclasses with the biggest number of useful bases
found.sort(key=len, reverse=True)
for sub in found:
for subcls in sub:
if subcls not in mro:
mro.append(subcls)
return _c3_mro(cls, abcs=mro)


def find_implementation(cls: object, registry: dict):
"""Returns the best matching implementation from *registry* for type *cls*.
Where there is no registered implementation for a specific type, its method
resolution order is used to find a more generic implementation.
Note: if *registry* does not contain an implementation for the base
*object* type, this function may return None.
Parameters
----------
cls :
the class to search for in the registry. I.e. if `cls` is a subclass
of an entry in `registry`, this will return the value contained
in `registry` for the best match.
registry:
the registry dictionary where we will return the best matching
class
"""
keys = list(registry.keys())
mro = _compose_mro(cls, keys)
match = None
for t in mro:
if match is not None:
# If *match* is an implicit ABC but there is another unrelated,
# equally matching implicit ABC, refuse the temptation to guess.
if (
t in keys
and t not in cls.__mro__
and match not in cls.__mro__
and not issubclass(match, t)
):
raise RuntimeError("Ambiguous dispatch: {} or {}".format(match, t))
break
if t in registry:
match = t
return registry.get(match)

0 comments on commit 3475392

Please sign in to comment.