Skip to content

Commit

Permalink
fix(common): don't match an Object pattern with more positional arg…
Browse files Browse the repository at this point in the history
…uments defined than `__match_args__` has
  • Loading branch information
kszucs committed Feb 6, 2024
1 parent e4b6b70 commit bbded29
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/base/sqlglot/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

import ibis
from ibis import _


def test_window_with_row_number_compiles():
# GH #8058: the add_order_by_to_empty_ranking_window_functions rule was
# matching on `RankBase` subclasses with a pattern expecting an `arg`
# attribute, which is not present on `RowNumber`
expr = (
ibis.memtable({"a": range(30)})
.mutate(id=ibis.row_number())
.sample(fraction=0.25, seed=0)
.mutate(is_test=_.id.isin(_.id))
.filter(~_.is_test)
)
assert ibis.to_sql(expr)
17 changes: 13 additions & 4 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,17 +1264,26 @@ def __create__(cls, type, *args, **kwargs):
return InstanceOf(type)
return super().__create__(type, *args, **kwargs)

def __init__(self, type, *args, **kwargs):
type = pattern(type)
def __init__(self, typ, *args, **kwargs):
if isinstance(typ, type) and len(typ.__match_args__) < len(args):
raise ValueError(
"The type to match has fewer `__match_args__` than the number "
"of positional arguments in the pattern"
)
typ = pattern(typ)
args = tuple(map(pattern, args))
kwargs = frozendict(toolz.valmap(pattern, kwargs))
super().__init__(type=type, args=args, kwargs=kwargs)
super().__init__(type=typ, args=args, kwargs=kwargs)

def match(self, value, context):
if self.type.match(value, context) is NoMatch:
return NoMatch

patterns = {**dict(zip(value.__match_args__, self.args)), **self.kwargs}
# the pattern requirest more positional arguments than the object has
if len(value.__match_args__) < len(self.args):
return NoMatch
patterns = dict(zip(value.__match_args__, self.args))
patterns.update(self.kwargs)

fields = {}
changed = False
Expand Down
23 changes: 23 additions & 0 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def test_object_pattern_complex_type():

def test_object_pattern_from_instance_of():
class MyType:
__match_args__ = ("a", "b")

def __init__(self, a, b):
self.a = a
self.b = b
Expand All @@ -593,6 +595,8 @@ def __init__(self, a, b):

def test_object_pattern_from_coerced_to():
class MyCoercibleType(Coercible):
__match_args__ = ("a", "b")

def __init__(self, a, b):
self.a = a
self.b = b
Expand Down Expand Up @@ -651,6 +655,25 @@ def test_object_pattern_matching_dictionary_field():
assert match(pattern, d) is d


def test_object_pattern_requires_its_arguments_to_match():
class Empty:
__match_args__ = ()

msg = "The type to match has fewer `__match_args__`"
with pytest.raises(ValueError, match=msg):
Object(Empty, 1)

# if the type matcher (first argument of Object) receives a generic pattern
# instead of an explicit type, the validation above cannot occur, so test
# the the pattern still doesn't match when it requires more positional
# arguments than the object `__match_args__` has
pattern = Object(InstanceOf(Empty), var("a"))
assert match(pattern, Empty()) is NoMatch

pattern = Object(InstanceOf(Empty), a=var("a"))
assert match(pattern, Empty()) is NoMatch


def test_callable_with():
def func(a, b):
return str(a) + b
Expand Down

0 comments on commit bbded29

Please sign in to comment.