Skip to content

Commit

Permalink
feat(patterns): support building sequences in replacement patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 26, 2023
1 parent 47822c6 commit f320c2e
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 115 deletions.
243 changes: 131 additions & 112 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __eq__(self, other):
...

@abstractmethod
def make(self, context: dict):
def build(self, context: dict):
"""Construct a new object from the context.
Parameters
Expand All @@ -337,36 +337,6 @@ def make(self, context: dict):
"""


def builder(obj):
"""Convert an object to a builder.
It encapsulates:
- callable objects into a `Factory` builder
- non-callable objects into a `Just` builder
Parameters
----------
obj
The object to convert to a builder.
Returns
-------
The builder instance.
"""
# TODO(kszucs): the replacer object must be handled differently from patterns
# basically a replacer is just a lazy way to construct objects from the context
# we should have a separate base class for replacers like Variable, Function,
# Just, Apply and Call. Something like Replacer with a specific method e.g.
# apply() could work
if isinstance(obj, Builder):
return obj
elif callable(obj):
# not function but something else
return Factory(obj)
else:
return Just(obj)


class Variable(Slotted, Builder):
"""Retrieve a value from the context.
Expand All @@ -382,7 +352,7 @@ class Variable(Slotted, Builder):
def __init__(self, name):
super().__init__(name=name)

def make(self, context):
def build(self, context):
return context[self]

def __getattr__(self, name):
Expand All @@ -404,11 +374,17 @@ class Just(Slotted, Builder):
__slots__ = ("value",)
value: AnyType

@classmethod
def __create__(cls, value):
if isinstance(value, Just):
return value
return super().__create__(value)

def __init__(self, value):
assert not isinstance(value, (Pattern, Builder))
super().__init__(value=value)

def make(self, context):
def build(self, context):
return self.value


Expand All @@ -434,7 +410,7 @@ def __init__(self, func):
assert callable(func)
super().__init__(func=func)

def make(self, context):
def build(self, context):
value = context[_]
return self.func(value, context)

Expand Down Expand Up @@ -465,9 +441,9 @@ def __init__(self, func, *args, **kwargs):
kwargs = frozendict({k: builder(v) for k, v in kwargs.items()})
super().__init__(func=func, args=args, kwargs=kwargs)

def make(self, context):
args = tuple(arg.make(context) for arg in self.args)
kwargs = {k: v.make(context) for k, v in self.kwargs.items()}
def build(self, context):
args = tuple(arg.build(context) for arg in self.args)
kwargs = {k: v.build(context) for k, v in self.kwargs.items()}
return self.func(*args, **kwargs)

def __call__(self, *args, **kwargs):
Expand All @@ -494,7 +470,7 @@ def namespace(cls, module) -> Namespace:
>>> pattern = c.Negate(x)
>>> pattern
Call(func=<class 'ibis.expr.operations.numeric.Negate'>, args=(Variable(name='x'),), kwargs=FrozenDict({}))
>>> pattern.make({x: 5})
>>> pattern.build({x: 5})
<ibis.expr.operations.numeric.Negate object at 0x...>
"""
return Namespace(cls, module)
Expand Down Expand Up @@ -591,7 +567,16 @@ def match(self, value, context):
# use the `_` reserved variable to record the value being replaced
# in the context, so that it can be used in the replacer pattern
context[_] = value
return self.builder.make(context)
return self.builder.build(context)


def replace(matcher):
"""More convenient syntax for replacing a value with the output of a function."""

def decorator(replacer):
return Replace(matcher, replacer)

return decorator


class Check(Slotted, Pattern):
Expand Down Expand Up @@ -1175,6 +1160,31 @@ def match(self, value, context):
return value


class Between(Slotted, Pattern):
"""Match a value between two bounds.
Parameters
----------
lower
The lower bound.
upper
The upper bound.
"""

__slots__ = ("lower", "upper")
lower: float
upper: float

def __init__(self, lower: float = -math.inf, upper: float = math.inf):
super().__init__(lower=lower, upper=upper)

def match(self, value, context):
if self.lower <= value <= self.upper:
return value
else:
return NoMatch


class Contains(Slotted, Pattern):
"""Pattern that matches if a value contains a given value.
Expand Down Expand Up @@ -1247,7 +1257,8 @@ class SequenceOf(Slotted, Pattern):
item: Pattern
type: type

def __new__(
@classmethod
def __create__(
cls,
item,
type: type = tuple,
Expand All @@ -1264,8 +1275,7 @@ def __new__(
return GenericSequenceOf(
item, type=type, exactly=exactly, at_least=at_least, at_most=at_most
)
else:
return super().__new__(cls)
return super().__create__(item, type=type)

def __init__(self, item, type=tuple):
super().__init__(item=pattern(item), type=type)
Expand Down Expand Up @@ -1311,7 +1321,8 @@ class GenericSequenceOf(Slotted, Pattern):
type: Pattern
length: Length

def __new__(
@classmethod
def __create__(
cls,
item: Pattern,
type: type = tuple,
Expand All @@ -1327,7 +1338,7 @@ def __new__(
):
return SequenceOf(item, type=type)
else:
return super().__new__(cls)
return super().__create__(item, type, exactly, at_least, at_most)

def __init__(
self,
Expand Down Expand Up @@ -1372,11 +1383,11 @@ class TupleOf(Slotted, Pattern):
__slots__ = ("fields",)
fields: tuple[Pattern, ...]

def __new__(cls, fields):
if isinstance(fields, tuple):
return super().__new__(cls)
else:
@classmethod
def __create__(cls, fields):
if not isinstance(fields, tuple):
return SequenceOf(fields, tuple)
return super().__create__(fields)

def __init__(self, fields):
fields = tuple(map(pattern, fields))
Expand Down Expand Up @@ -1490,11 +1501,11 @@ class Object(Slotted, Pattern):
args: tuple[Pattern, ...]
kwargs: FrozenDict[str, Pattern]

def __new__(cls, type, *args, **kwargs):
@classmethod
def __create__(cls, type, *args, **kwargs):
if not args and not kwargs:
return InstanceOf(type)
else:
return super().__new__(cls)
return super().__create__(type, *args, **kwargs)

def __init__(self, type, *args, **kwargs):
type = pattern(type)
Expand Down Expand Up @@ -1704,29 +1715,47 @@ def match(self, value, context):
return dict(zip(keys, values))


class Between(Slotted, Pattern):
"""Match a value between two bounds.
class Topmost(Slotted, Pattern):
"""Traverse the value tree topmost first and match the first value that matches."""

Parameters
----------
lower
The lower bound.
upper
The upper bound.
"""
__slots__ = ("pattern", "filter")
pattern: Pattern
filter: AnyType

__slots__ = ("lower", "upper")
lower: float
upper: float
def __init__(self, searcher, filter=None):
super().__init__(pattern=pattern(searcher), filter=filter)

def __init__(self, lower: float = -math.inf, upper: float = math.inf):
super().__init__(lower=lower, upper=upper)
def match(self, value, context):
result = self.pattern.match(value, context)
if result is not NoMatch:
return result

for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return NoMatch


class Innermost(Slotted, Pattern):
# matches items in the innermost layer first, but all matches belong to the same layer
"""Traverse the value tree innermost first and match the first value that matches."""

__slots__ = ("pattern", "filter")
pattern: Pattern
filter: AnyType

def __init__(self, searcher, filter=None):
super().__init__(pattern=pattern(searcher), filter=filter)

def match(self, value, context):
if self.lower <= value <= self.upper:
return value
else:
return NoMatch
for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return self.pattern.match(value, context)


def NoneOf(*args) -> Pattern:
Expand All @@ -1749,6 +1778,39 @@ def FrozenDictOf(key_pattern, value_pattern):
return MappingOf(key_pattern, value_pattern, type=frozendict)


def builder(obj):
"""Convert an object to a builder.
It encapsulates:
- callable objects into a `Factory` builder
- non-callable objects into a `Just` builder
Parameters
----------
obj
The object to convert to a builder.
Returns
-------
The builder instance.
"""
if isinstance(obj, Builder):
# already a builder, no need to convert
return obj
elif callable(obj):
# the callable builds the substitution
return Factory(obj)
elif isinstance(obj, Sequence):
# allow nesting builder patterns in tuples/lists
return Call(lambda *args: type(obj)(args), *obj)
elif isinstance(obj, Mapping):
# allow nesting builder patterns in dicts
return Call(type(obj), **obj)
else:
# the object is used as a constant value
return Just(obj)


def pattern(obj: AnyType) -> Pattern:
"""Create a pattern from various types.
Expand Down Expand Up @@ -1837,49 +1899,6 @@ def match(
return NoMatch if result is NoMatch else result


class Topmost(Slotted, Pattern):
"""Traverse the value tree topmost first and match the first value that matches."""

__slots__ = ("pattern", "filter")
pattern: Pattern
filter: AnyType

def __init__(self, searcher, filter=None):
super().__init__(pattern=pattern(searcher), filter=filter)

def match(self, value, context):
result = self.pattern.match(value, context)
if result is not NoMatch:
return result

for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return NoMatch


class Innermost(Slotted, Pattern):
# matches items in the innermost layer first, but all matches belong to the same layer
"""Traverse the value tree innermost first and match the first value that matches."""

__slots__ = ("pattern", "filter")
pattern: Pattern
filter: AnyType

def __init__(self, searcher, filter=None):
super().__init__(pattern=pattern(searcher), filter=filter)

def match(self, value, context):
for child in value.__children__(self.filter):
result = self.match(child, context)
if result is not NoMatch:
return result

return self.pattern.match(value, context)


IsTruish = Check(lambda x: bool(x))
IsNumber = InstanceOf(numbers.Number) & ~InstanceOf(bool)
IsString = InstanceOf(str)
Loading

0 comments on commit f320c2e

Please sign in to comment.