Skip to content

Commit

Permalink
refactor(common): support union types as well as forward references i…
Browse files Browse the repository at this point in the history
…n the dispatch utilities
  • Loading branch information
kszucs committed Jan 3, 2024
1 parent 7755e6a commit e3fad73
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 101 deletions.
143 changes: 96 additions & 47 deletions ibis/common/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import re
from collections import defaultdict

from ibis.util import import_object
from ibis.common.typing import (
Union,
UnionType,
evaluate_annotations,
get_args,
get_origin,
)
from ibis.util import import_object, unalias_package


def normalize(r: str | re.Pattern):
Expand All @@ -26,71 +33,113 @@ def normalize(r: str | re.Pattern):
return re.compile("^" + r.lstrip("^").rstrip("$") + "$")


def lazy_singledispatch(func):
"""A `singledispatch` implementation that supports lazily registering implementations."""

lookup = {object: func}
abc_lookup = {}
lazy_lookup = defaultdict(dict)
class SingleDispatch:
def __init__(self, func, typ=None):
self.lookup = {}
self.abc_lookup = {}
self.lazy_lookup = defaultdict(dict)
self.func = func
self.add(func, typ)

def add(self, func, typ=None):
if typ is None:
annots = getattr(func, "__annotations__", {})
typehints = evaluate_annotations(annots, func.__module__, best_effort=True)
if typehints:
typ, *_ = typehints.values()
if get_origin(typ) in (Union, UnionType):
for t in get_args(typ):
self.add(func, t)
else:
self.add(func, typ)
else:
self.add(func, object)
elif isinstance(typ, tuple):
for t in typ:
self.add(func, t)
elif isinstance(typ, abc.ABCMeta):
if typ in self.abc_lookup:
raise TypeError(f"{typ} is already registered")
self.abc_lookup[typ] = func
elif isinstance(typ, str):
package, rest = typ.split(".", 1)
package = unalias_package(package)
typ = f"{package}.{rest}"
if typ in self.lazy_lookup[package]:
raise TypeError(f"{typ} is already registered")
self.lazy_lookup[package][typ] = func
else:
if typ in self.lookup:
raise TypeError(f"{typ} is already registered")
self.lookup[typ] = func
return func

def register(cls, func=None):
"""Registers a new implementation for arguments of type `cls`."""
def register(self, typ, func=None):
"""Register a new implementation for arguments of type `cls`."""

def inner(func):
if isinstance(cls, tuple):
for t in cls:
register(t, func)
elif isinstance(cls, abc.ABCMeta):
abc_lookup[cls] = func
elif isinstance(cls, str):
module = cls.split(".", 1)[0]
lazy_lookup[module][cls] = func
else:
lookup[cls] = func
self.add(func, typ)
return func

return inner if func is None else inner(func)

def dispatch(cls):
def dispatch(self, typ):
"""Return the implementation for the given `cls`."""
for cls2 in cls.__mro__:
for klass in typ.__mro__:
# 1. Check for a concrete implementation
try:
impl = lookup[cls2]
impl = self.lookup[klass]
except KeyError:
pass
else:
if cls is not cls2:
if typ is not klass:
# Cache implementation
lookup[cls] = impl
self.lookup[typ] = impl
return impl
# 2. Check lazy implementations
module = cls2.__module__.split(".", 1)[0]
if lazy := lazy_lookup.get(module):
package = klass.__module__.split(".", 1)[0]
if lazy := self.lazy_lookup.get(package):
# Import all lazy implementations first before registering
# (which should never fail), to ensure an error anywhere
# doesn't result in a half-registered state.
new = {import_object(name): func for name, func in lazy.items()}
lookup.update(new)
self.lookup.update(new)
# drop lazy implementations, idempotent for thread safety
lazy_lookup.pop(module, None)
return dispatch(cls)
self.lazy_lookup.pop(package, None)
return self.dispatch(typ)
# 3. Check for abcs
for abc_cls, impl in abc_lookup.items():
if issubclass(cls, abc_cls):
lookup[cls] = impl
for abc_class, impl in self.abc_lookup.items():
if issubclass(typ, abc_class):
self.lookup[typ] = impl
return impl
# Can never get here, since a base `object` implementation is
# always registered
raise AssertionError("should never get here") # pragma: no cover
raise TypeError(f"Could not find implementation for {typ}")

def __call__(self, arg, *args, **kwargs):
impl = self.dispatch(type(arg))
return impl(arg, *args, **kwargs)

def __get__(self, obj, cls=None):
def _method(*args, **kwargs):
method = self.dispatch(type(args[0]))
method = method.__get__(obj, cls)
return method(*args, **kwargs)

functools.update_wrapper(_method, self.func)
return _method


def lazy_singledispatch(func):
"""A `singledispatch` implementation that supports lazily registering implementations."""

dispatcher = SingleDispatch(func, object)

@functools.wraps(func)
def call(arg, *args, **kwargs):
return dispatch(type(arg))(arg, *args, **kwargs)

call.dispatch = dispatch
call.register = register
impl = dispatcher.dispatch(type(arg))
return impl(arg, *args, **kwargs)

call.dispatch = dispatcher.dispatch
call.register = dispatcher.register
return call


Expand All @@ -117,21 +166,21 @@ def __new__(cls, name, bases, dct):
# multiple functions are defined with the same name, so create
# a dispatcher function
first, *rest = value
func = functools.singledispatchmethod(first)
func = SingleDispatch(first)
for impl in rest:
func.register(impl)
func.add(impl)
namespace[key] = func
elif all(isinstance(v, classmethod) for v in value):
first, *rest = value
func = functools.singledispatchmethod(first.__func__)
for v in rest:
func.register(v.__func__)
func = SingleDispatch(first.__func__)
for impl in rest:
func.add(impl.__func__)
namespace[key] = classmethod(func)
elif all(isinstance(v, staticmethod) for v in value):
first, *rest = value
func = functools.singledispatch(first.__func__)
for v in rest:
func.register(v.__func__)
func = SingleDispatch(first.__func__)
for impl in rest:
func.add(impl.__func__)
namespace[key] = staticmethod(func)
else:
raise TypeError(f"Multiple attributes are defined with name {key}")
Expand Down
34 changes: 20 additions & 14 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
_, # noqa: F401
resolver,
)
from ibis.common.dispatch import lazy_singledispatch
from ibis.common.typing import (
Coercible,
CoercionError,
Expand All @@ -41,7 +40,7 @@
get_bound_typevars,
get_type_params,
)
from ibis.util import is_iterable, promote_tuple
from ibis.util import import_object, is_iterable, unalias_package

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -711,21 +710,28 @@ class LazyInstanceOf(Slotted, Pattern):
The types to check against.
"""

__slots__ = ("types", "check")
types: tuple[type, ...]
check: Callable
__fields__ = ("qualname", "package")
__slots__ = ("qualname", "package", "loaded")
qualname: str
package: str
loaded: type

def __init__(self, types):
types = promote_tuple(types)
check = lazy_singledispatch(lambda x: False)
check.register(types, lambda x: True)
super().__init__(types=types, check=check)
def __init__(self, qualname):
package = unalias_package(qualname.split(".", 1)[0])
super().__init__(qualname=qualname, package=package)

def match(self, value, context):
if self.check(value):
return value
else:
return NoMatch
if hasattr(self, "loaded"):
return value if isinstance(value, self.loaded) else NoMatch

for klass in type(value).__mro__:
package = klass.__module__.split(".", 1)[0]
if package == self.package:
typ = import_object(self.qualname)
object.__setattr__(self, "loaded", typ)
return value if isinstance(value, typ) else NoMatch

return NoMatch


class CoercedTo(Slotted, Pattern, Generic[T_co]):
Expand Down
40 changes: 39 additions & 1 deletion ibis/common/tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import collections
import decimal
from typing import TYPE_CHECKING, Union

from ibis.common.dispatch import Dispatched, lazy_singledispatch

# ruff: noqa: F811
if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


def test_lazy_singledispatch():
Expand Down Expand Up @@ -122,6 +126,14 @@ def _(a):
assert foo(sum) == "callable"


class A:
pass


class B:
pass


class Visitor(Dispatched):
def a(self):
return "a"
Expand All @@ -132,6 +144,9 @@ def b(self, x: int):
def b(self, x: str):
return "b_str"

def b(self, x: Union[A, B]):
return "b_union"

@classmethod
def c(cls, x: int, **kwargs):
return "c_int"
Expand All @@ -154,6 +169,15 @@ def e(x: int):
def e(x: str):
return "e_str"

def f(self, df: dict):
return "f_dict"

def f(self, df: pd.DataFrame):
return "f_pandas"

def f(self, df: pa.Table):
return "f_pyarrow"


class Subvisitor(Visitor):
def b(self, x):
Expand All @@ -173,9 +197,11 @@ def c(cls, s: float):

def test_dispatched():
v = Visitor()
assert v.a == v.a
assert v.a() == "a"
assert v.b(1) == "b_int"
assert v.b("1") == "b_str"
assert v.b(A()) == "b_union"
assert v.b(B()) == "b_union"
assert v.d(1) == "d_int"
assert v.d("1") == "d_str"

Expand All @@ -193,3 +219,15 @@ def test_dispatched():
assert Subvisitor.c(1.1) == "c_float"

assert Subvisitor.e(1) == "e_int"


def test_dispatched_lazy():
import pyarrow as pa

empty_pyarrow_table = pa.Table.from_arrays([])
empty_pandas_table = empty_pyarrow_table.to_pandas()

v = Visitor()
assert v.f({}) == "f_dict"
assert v.f(empty_pyarrow_table) == "f_pyarrow"
assert v.f(empty_pandas_table) == "f_pandas"
19 changes: 15 additions & 4 deletions ibis/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def evaluate_annotations(
annots: dict[str, str],
module_name: str,
class_name: Optional[str] = None,
best_effort: bool = False,
) -> dict[str, Any]:
"""Evaluate type annotations that are strings.
Expand All @@ -185,6 +186,8 @@ def evaluate_annotations(
class_name
The name of the class that the annotations are defined in, hence
providing Self type.
best_effort
Whether to ignore errors when evaluating type annotations.
Returns
-------
Expand All @@ -202,10 +205,18 @@ def evaluate_annotations(
localns = None
else:
localns = dict(Self=f"{module_name}.{class_name}")
return {
k: eval(v, globalns, localns) if isinstance(v, str) else v # noqa: PGH001
for k, v in annots.items()
}

result = {}
for k, v in annots.items():
if isinstance(v, str):
try:
v = eval(v, globalns, localns) # noqa: PGH001
except NameError:
if not best_effort:
raise
result[k] = v

return result


def format_typehint(typ: Any) -> str:
Expand Down
Loading

0 comments on commit e3fad73

Please sign in to comment.