From 491314f906c05faf59283f9e7904927a113d2140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 28 Jan 2024 15:40:20 +0100 Subject: [PATCH] fix(common): don't match an `Object` pattern with more positional arguments defined than `__match_args__` has --- .../base/sqlglot/tests/test_compiler.py | 18 +++++++++++++++ ibis/common/patterns.py | 17 ++++++++++---- ibis/common/tests/test_patterns.py | 23 +++++++++++++++++++ 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 ibis/backends/base/sqlglot/tests/test_compiler.py diff --git a/ibis/backends/base/sqlglot/tests/test_compiler.py b/ibis/backends/base/sqlglot/tests/test_compiler.py new file mode 100644 index 0000000000000..95db51e76de36 --- /dev/null +++ b/ibis/backends/base/sqlglot/tests/test_compiler.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import ibis +from ibis import _ + + +def test_window_with_row_number_compiles(): + # GH #8058: the add_order_by_to_empty_ranking_window_functions rule was + # matching on `RankBase` subclasses with a pattern expecting an `arg` + # attribute, which is not present on `RowNumber` + expr = ( + ibis.memtable({"a": range(30)}) + .mutate(id=ibis.row_number()) + .sample(fraction=0.25, seed=0) + .mutate(is_test=_.id.isin(_.id)) + .filter(~_.is_test) + ) + assert ibis.to_sql(expr) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index f383ae0043301..6fd0602eade34 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -1264,17 +1264,26 @@ def __create__(cls, type, *args, **kwargs): return InstanceOf(type) return super().__create__(type, *args, **kwargs) - def __init__(self, type, *args, **kwargs): - type = pattern(type) + def __init__(self, typ, *args, **kwargs): + if isinstance(typ, type) and len(typ.__match_args__) < len(args): + raise ValueError( + "The type to match has fewer `__match_args__` than the number " + "of positional arguments in the pattern" + ) + typ = pattern(typ) args = tuple(map(pattern, args)) kwargs = frozendict(toolz.valmap(pattern, kwargs)) - super().__init__(type=type, args=args, kwargs=kwargs) + super().__init__(type=typ, args=args, kwargs=kwargs) def match(self, value, context): if self.type.match(value, context) is NoMatch: return NoMatch - patterns = {**dict(zip(value.__match_args__, self.args)), **self.kwargs} + # the pattern requirest more positional arguments than the object has + if len(value.__match_args__) < len(self.args): + return NoMatch + patterns = dict(zip(value.__match_args__, self.args)) + patterns.update(self.kwargs) fields = {} changed = False diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 29bfe0430efe6..2e66b7d92b255 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -580,6 +580,8 @@ def test_object_pattern_complex_type(): def test_object_pattern_from_instance_of(): class MyType: + __match_args__ = ("a", "b") + def __init__(self, a, b): self.a = a self.b = b @@ -593,6 +595,8 @@ def __init__(self, a, b): def test_object_pattern_from_coerced_to(): class MyCoercibleType(Coercible): + __match_args__ = ("a", "b") + def __init__(self, a, b): self.a = a self.b = b @@ -651,6 +655,25 @@ def test_object_pattern_matching_dictionary_field(): assert match(pattern, d) is d +def test_object_pattern_requires_its_arguments_to_match(): + class Empty: + __match_args__ = () + + msg = "The type to match has fewer `__match_args__`" + with pytest.raises(ValueError, match=msg): + Object(Empty, 1) + + # if the type matcher (first argument of Object) receives a generic pattern + # instead of an explicit type, the validation above cannot occur, so test + # the the pattern still doesn't match when it requires more positional + # arguments than the object `__match_args__` has + pattern = Object(InstanceOf(Empty), var("a")) + assert match(pattern, Empty()) is NoMatch + + pattern = Object(InstanceOf(Empty), a=var("a")) + assert match(pattern, Empty()) is NoMatch + + def test_callable_with(): def func(a, b): return str(a) + b