From e3fad7336566b79dc727f6e80d3f3300931e4860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 2 Jan 2024 17:18:09 +0100 Subject: [PATCH] refactor(common): support union types as well as forward references in the dispatch utilities --- ibis/common/dispatch.py | 143 +++++++++++++++++++---------- ibis/common/patterns.py | 34 ++++--- ibis/common/tests/test_dispatch.py | 40 +++++++- ibis/common/typing.py | 19 +++- ibis/expr/api.py | 59 +++++------- ibis/util.py | 14 +++ 6 files changed, 208 insertions(+), 101 deletions(-) diff --git a/ibis/common/dispatch.py b/ibis/common/dispatch.py index 28999b452e37..2b5c4bd5ee0e 100644 --- a/ibis/common/dispatch.py +++ b/ibis/common/dispatch.py @@ -6,7 +6,14 @@ import re from collections import defaultdict -from ibis.util import import_object +from ibis.common.typing import ( + Union, + UnionType, + evaluate_annotations, + get_args, + get_origin, +) +from ibis.util import import_object, unalias_package def normalize(r: str | re.Pattern): @@ -26,71 +33,113 @@ def normalize(r: str | re.Pattern): return re.compile("^" + r.lstrip("^").rstrip("$") + "$") -def lazy_singledispatch(func): - """A `singledispatch` implementation that supports lazily registering implementations.""" - - lookup = {object: func} - abc_lookup = {} - lazy_lookup = defaultdict(dict) +class SingleDispatch: + def __init__(self, func, typ=None): + self.lookup = {} + self.abc_lookup = {} + self.lazy_lookup = defaultdict(dict) + self.func = func + self.add(func, typ) + + def add(self, func, typ=None): + if typ is None: + annots = getattr(func, "__annotations__", {}) + typehints = evaluate_annotations(annots, func.__module__, best_effort=True) + if typehints: + typ, *_ = typehints.values() + if get_origin(typ) in (Union, UnionType): + for t in get_args(typ): + self.add(func, t) + else: + self.add(func, typ) + else: + self.add(func, object) + elif isinstance(typ, tuple): + for t in typ: + self.add(func, t) + elif isinstance(typ, abc.ABCMeta): + if typ in self.abc_lookup: + raise TypeError(f"{typ} is already registered") + self.abc_lookup[typ] = func + elif isinstance(typ, str): + package, rest = typ.split(".", 1) + package = unalias_package(package) + typ = f"{package}.{rest}" + if typ in self.lazy_lookup[package]: + raise TypeError(f"{typ} is already registered") + self.lazy_lookup[package][typ] = func + else: + if typ in self.lookup: + raise TypeError(f"{typ} is already registered") + self.lookup[typ] = func + return func - def register(cls, func=None): - """Registers a new implementation for arguments of type `cls`.""" + def register(self, typ, func=None): + """Register a new implementation for arguments of type `cls`.""" def inner(func): - if isinstance(cls, tuple): - for t in cls: - register(t, func) - elif isinstance(cls, abc.ABCMeta): - abc_lookup[cls] = func - elif isinstance(cls, str): - module = cls.split(".", 1)[0] - lazy_lookup[module][cls] = func - else: - lookup[cls] = func + self.add(func, typ) return func return inner if func is None else inner(func) - def dispatch(cls): + def dispatch(self, typ): """Return the implementation for the given `cls`.""" - for cls2 in cls.__mro__: + for klass in typ.__mro__: # 1. Check for a concrete implementation try: - impl = lookup[cls2] + impl = self.lookup[klass] except KeyError: pass else: - if cls is not cls2: + if typ is not klass: # Cache implementation - lookup[cls] = impl + self.lookup[typ] = impl return impl # 2. Check lazy implementations - module = cls2.__module__.split(".", 1)[0] - if lazy := lazy_lookup.get(module): + package = klass.__module__.split(".", 1)[0] + if lazy := self.lazy_lookup.get(package): # Import all lazy implementations first before registering # (which should never fail), to ensure an error anywhere # doesn't result in a half-registered state. new = {import_object(name): func for name, func in lazy.items()} - lookup.update(new) + self.lookup.update(new) # drop lazy implementations, idempotent for thread safety - lazy_lookup.pop(module, None) - return dispatch(cls) + self.lazy_lookup.pop(package, None) + return self.dispatch(typ) # 3. Check for abcs - for abc_cls, impl in abc_lookup.items(): - if issubclass(cls, abc_cls): - lookup[cls] = impl + for abc_class, impl in self.abc_lookup.items(): + if issubclass(typ, abc_class): + self.lookup[typ] = impl return impl - # Can never get here, since a base `object` implementation is - # always registered - raise AssertionError("should never get here") # pragma: no cover + raise TypeError(f"Could not find implementation for {typ}") + + def __call__(self, arg, *args, **kwargs): + impl = self.dispatch(type(arg)) + return impl(arg, *args, **kwargs) + + def __get__(self, obj, cls=None): + def _method(*args, **kwargs): + method = self.dispatch(type(args[0])) + method = method.__get__(obj, cls) + return method(*args, **kwargs) + + functools.update_wrapper(_method, self.func) + return _method + + +def lazy_singledispatch(func): + """A `singledispatch` implementation that supports lazily registering implementations.""" + + dispatcher = SingleDispatch(func, object) @functools.wraps(func) def call(arg, *args, **kwargs): - return dispatch(type(arg))(arg, *args, **kwargs) - - call.dispatch = dispatch - call.register = register + impl = dispatcher.dispatch(type(arg)) + return impl(arg, *args, **kwargs) + call.dispatch = dispatcher.dispatch + call.register = dispatcher.register return call @@ -117,21 +166,21 @@ def __new__(cls, name, bases, dct): # multiple functions are defined with the same name, so create # a dispatcher function first, *rest = value - func = functools.singledispatchmethod(first) + func = SingleDispatch(first) for impl in rest: - func.register(impl) + func.add(impl) namespace[key] = func elif all(isinstance(v, classmethod) for v in value): first, *rest = value - func = functools.singledispatchmethod(first.__func__) - for v in rest: - func.register(v.__func__) + func = SingleDispatch(first.__func__) + for impl in rest: + func.add(impl.__func__) namespace[key] = classmethod(func) elif all(isinstance(v, staticmethod) for v in value): first, *rest = value - func = functools.singledispatch(first.__func__) - for v in rest: - func.register(v.__func__) + func = SingleDispatch(first.__func__) + for impl in rest: + func.add(impl.__func__) namespace[key] = staticmethod(func) else: raise TypeError(f"Multiple attributes are defined with name {key}") diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 66d041163547..5ec2a01c3e19 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -30,7 +30,6 @@ _, # noqa: F401 resolver, ) -from ibis.common.dispatch import lazy_singledispatch from ibis.common.typing import ( Coercible, CoercionError, @@ -41,7 +40,7 @@ get_bound_typevars, get_type_params, ) -from ibis.util import is_iterable, promote_tuple +from ibis.util import import_object, is_iterable, unalias_package T_co = TypeVar("T_co", covariant=True) @@ -711,21 +710,28 @@ class LazyInstanceOf(Slotted, Pattern): The types to check against. """ - __slots__ = ("types", "check") - types: tuple[type, ...] - check: Callable + __fields__ = ("qualname", "package") + __slots__ = ("qualname", "package", "loaded") + qualname: str + package: str + loaded: type - def __init__(self, types): - types = promote_tuple(types) - check = lazy_singledispatch(lambda x: False) - check.register(types, lambda x: True) - super().__init__(types=types, check=check) + def __init__(self, qualname): + package = unalias_package(qualname.split(".", 1)[0]) + super().__init__(qualname=qualname, package=package) def match(self, value, context): - if self.check(value): - return value - else: - return NoMatch + if hasattr(self, "loaded"): + return value if isinstance(value, self.loaded) else NoMatch + + for klass in type(value).__mro__: + package = klass.__module__.split(".", 1)[0] + if package == self.package: + typ = import_object(self.qualname) + object.__setattr__(self, "loaded", typ) + return value if isinstance(value, typ) else NoMatch + + return NoMatch class CoercedTo(Slotted, Pattern, Generic[T_co]): diff --git a/ibis/common/tests/test_dispatch.py b/ibis/common/tests/test_dispatch.py index 5f34c533851d..5cce61447fc1 100644 --- a/ibis/common/tests/test_dispatch.py +++ b/ibis/common/tests/test_dispatch.py @@ -2,10 +2,14 @@ import collections import decimal +from typing import TYPE_CHECKING, Union from ibis.common.dispatch import Dispatched, lazy_singledispatch # ruff: noqa: F811 +if TYPE_CHECKING: + import pandas as pd + import pyarrow as pa def test_lazy_singledispatch(): @@ -122,6 +126,14 @@ def _(a): assert foo(sum) == "callable" +class A: + pass + + +class B: + pass + + class Visitor(Dispatched): def a(self): return "a" @@ -132,6 +144,9 @@ def b(self, x: int): def b(self, x: str): return "b_str" + def b(self, x: Union[A, B]): + return "b_union" + @classmethod def c(cls, x: int, **kwargs): return "c_int" @@ -154,6 +169,15 @@ def e(x: int): def e(x: str): return "e_str" + def f(self, df: dict): + return "f_dict" + + def f(self, df: pd.DataFrame): + return "f_pandas" + + def f(self, df: pa.Table): + return "f_pyarrow" + class Subvisitor(Visitor): def b(self, x): @@ -173,9 +197,11 @@ def c(cls, s: float): def test_dispatched(): v = Visitor() - assert v.a == v.a + assert v.a() == "a" assert v.b(1) == "b_int" assert v.b("1") == "b_str" + assert v.b(A()) == "b_union" + assert v.b(B()) == "b_union" assert v.d(1) == "d_int" assert v.d("1") == "d_str" @@ -193,3 +219,15 @@ def test_dispatched(): assert Subvisitor.c(1.1) == "c_float" assert Subvisitor.e(1) == "e_int" + + +def test_dispatched_lazy(): + import pyarrow as pa + + empty_pyarrow_table = pa.Table.from_arrays([]) + empty_pandas_table = empty_pyarrow_table.to_pandas() + + v = Visitor() + assert v.f({}) == "f_dict" + assert v.f(empty_pyarrow_table) == "f_pyarrow" + assert v.f(empty_pandas_table) == "f_pandas" diff --git a/ibis/common/typing.py b/ibis/common/typing.py index 170ca2bd2b40..0ae48a2fc7c9 100644 --- a/ibis/common/typing.py +++ b/ibis/common/typing.py @@ -172,6 +172,7 @@ def evaluate_annotations( annots: dict[str, str], module_name: str, class_name: Optional[str] = None, + best_effort: bool = False, ) -> dict[str, Any]: """Evaluate type annotations that are strings. @@ -185,6 +186,8 @@ def evaluate_annotations( class_name The name of the class that the annotations are defined in, hence providing Self type. + best_effort + Whether to ignore errors when evaluating type annotations. Returns ------- @@ -202,10 +205,18 @@ def evaluate_annotations( localns = None else: localns = dict(Self=f"{module_name}.{class_name}") - return { - k: eval(v, globalns, localns) if isinstance(v, str) else v # noqa: PGH001 - for k, v in annots.items() - } + + result = {} + for k, v in annots.items(): + if isinstance(v, str): + try: + v = eval(v, globalns, localns) # noqa: PGH001 + except NameError: + if not best_effort: + raise + result[k] = v + + return result def format_typehint(typ: Any) -> str: diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 975db12c69e9..381e0977089d 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -346,17 +346,6 @@ def table( return ops.UnboundTable(name=name, schema=schema).to_expr() -@lazy_singledispatch -def _memtable( - data, - *, - columns: Iterable[str] | None = None, - schema: SupportsSchema | None = None, - name: str | None = None, -): - raise NotImplementedError(type(data)) - - def memtable( data, *, @@ -444,33 +433,13 @@ def memtable( return _memtable(data, name=name, schema=schema, columns=columns) -@_memtable.register("pyarrow.Table") -def _memtable_from_pyarrow_table( - data: pa.Table, - *, - name: str | None = None, - schema: SupportsSchema | None = None, - columns: Iterable[str] | None = None, -): - from ibis.formats.pyarrow import PyArrowTableProxy - - if columns is not None: - assert schema is None, "if `columns` is not `None` then `schema` must be `None`" - schema = sch.Schema(dict(zip(columns, sch.infer(data).values()))) - return ops.InMemoryTable( - name=name if name is not None else util.gen_name("pyarrow_memtable"), - schema=sch.infer(data) if schema is None else schema, - data=PyArrowTableProxy(data), - ).to_expr() - - -@_memtable.register(object) -def _memtable_from_dataframe( +@lazy_singledispatch +def _memtable( data: pd.DataFrame | Any, *, - name: str | None = None, - schema: SupportsSchema | None = None, columns: Iterable[str] | None = None, + schema: SupportsSchema | None = None, + name: str | None = None, ) -> Table: import pandas as pd @@ -516,6 +485,26 @@ def _memtable_from_dataframe( return op.to_expr() +@_memtable.register("pyarrow.Table") +def _memtable_from_pyarrow_table( + data: pa.Table, + *, + name: str | None = None, + schema: SupportsSchema | None = None, + columns: Iterable[str] | None = None, +): + from ibis.formats.pyarrow import PyArrowTableProxy + + if columns is not None: + assert schema is None, "if `columns` is not `None` then `schema` must be `None`" + schema = sch.Schema(dict(zip(columns, sch.infer(data).values()))) + return ops.InMemoryTable( + name=name if name is not None else util.gen_name("pyarrow_memtable"), + schema=sch.infer(data) if schema is None else schema, + data=PyArrowTableProxy(data), + ).to_expr() + + def _deferred_method_call(expr, method_name): method = operator.methodcaller(method_name) if isinstance(expr, str): diff --git a/ibis/util.py b/ibis/util.py index d9e0bc1544e6..4bec77bb19c8 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -492,6 +492,20 @@ def backend_entry_points() -> list[importlib.metadata.EntryPoint]: return sorted(eps) +_common_package_aliases = { + "pa": "pyarrow", + "pd": "pandas", + "np": "numpy", + "sk": "sklearn", + "sp": "scipy", + "tf": "tensorflow", +} + + +def unalias_package(name: str) -> str: + return _common_package_aliases.get(name, name) + + def import_object(qualname: str) -> Any: """Attempt to import an object given its full qualname.