diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index f5db3b47e1e5..7cccc3ff9204 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -6,7 +6,6 @@ from collections.abc import Callable, Mapping, Sequence from enum import Enum from inspect import Parameter -from itertools import chain from typing import ( Annotated, ForwardRef, @@ -161,20 +160,18 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: # variadic tuples differently, e.g. tuple[int, ...] is a variadic # tuple of integers, while tuple[int] is a tuple with a single int first, *rest = args - # TODO(kszucs): consider to support the same SequenceOf path if args - # has a single element, e.g. tuple[int] since annotation a single - # element tuple is not common OR use typing.Sequence for annotating - # instead of tuple[T, ...] OR have a VarTupleOf pattern if rest == [Ellipsis]: - inners = cls.from_typehint(first) + return TupleOf(cls.from_typehint(first)) else: - inners = tuple(map(cls.from_typehint, args)) - return TupleOf(inners) + return PatternList(map(cls.from_typehint, args), type=origin) elif issubclass(origin, Sequence): # construct a validator for the sequence elements where all elements # must be of the same type, e.g. Sequence[int] is a sequence of ints (value_inner,) = map(cls.from_typehint, args) - return SequenceOf(value_inner, type=origin) + if allow_coercion and issubclass(origin, Coercible): + return GenericSequenceOf(value_inner, type=origin) + else: + return SequenceOf(value_inner, type=origin) elif issubclass(origin, Mapping): # construct a validator for the mapping keys and values, e.g. # Mapping[str, int] is a mapping with string keys and int values @@ -277,6 +274,9 @@ def __rmatmul__(self, name: str) -> Capture: """ return Capture(name, self) + def __iter__(self) -> SomeOf: + yield SomeOf(self) + class Is(Slotted, Pattern): """Pattern that matches a value against a reference value. @@ -419,9 +419,6 @@ def match(self, value, context): return NoMatch -If = Check - - class DeferredCheck(Slotted, Pattern): __slots__ = ("resolver",) resolver: Resolver @@ -495,9 +492,6 @@ def describe(self, plural=False): return repr(self.value) -Eq = EqualTo - - class DeferredEqualTo(Slotted, Pattern): """Pattern that checks a value equals to the given value. @@ -777,9 +771,6 @@ def __call__(self, *args, **kwargs): return Object(self.type, *args, **kwargs) -As = CoercedTo - - class GenericCoercedTo(Slotted, Pattern): """Force a value to have a particular generic Python type. @@ -1069,9 +1060,6 @@ def match(self, value, context): return NoMatch -In = IsIn - - class SequenceOf(Slotted, Pattern): """Pattern that matches if all of the items in a sequence match a given pattern. @@ -1091,27 +1079,7 @@ class SequenceOf(Slotted, Pattern): item: Pattern type: type - @classmethod - def __create__( - cls, - item, - type: type = tuple, - exactly: Optional[int] = None, - at_least: Optional[int] = None, - at_most: Optional[int] = None, - ): - if ( - exactly is not None - or at_least is not None - or at_most is not None - or issubclass(type, Coercible) - ): - return GenericSequenceOf( - item, type=type, exactly=exactly, at_least=at_least, at_most=at_most - ) - return super().__create__(item, type=type) - - def __init__(self, item, type=tuple): + def __init__(self, item, type=list): super().__init__(item=pattern(item), type=type) def describe(self, plural=False): @@ -1123,12 +1091,16 @@ def match(self, values, context): if not is_iterable(values): return NoMatch - result = [] - for item in values: - item = self.item.match(item, context) - if item is NoMatch: - return NoMatch - result.append(item) + if self.item == _any: + # optimization to avoid unnecessary iteration + result = values + else: + result = [] + for item in values: + item = self.item.match(item, context) + if item is NoMatch: + return NoMatch + result.append(item) return self.type(result) @@ -1141,7 +1113,7 @@ class GenericSequenceOf(Slotted, Pattern): item The pattern to match against each item in the sequence. type - The type to coerce the sequence to. Defaults to tuple. + The type to coerce the sequence to. Defaults to list. exactly The exact length of the sequence. at_least @@ -1155,48 +1127,33 @@ class GenericSequenceOf(Slotted, Pattern): type: Pattern length: Length - @classmethod - def __create__( - cls, - item: Pattern, - type: type = tuple, - exactly: Optional[int] = None, - at_least: Optional[int] = None, - at_most: Optional[int] = None, - ): - if ( - exactly is None - and at_least is None - and at_most is None - and not issubclass(type, Coercible) - ): - return SequenceOf(item, type=type) - else: - return super().__create__(item, type, exactly, at_least, at_most) - def __init__( self, item: Pattern, - type: type = tuple, + type: type = list, exactly: Optional[int] = None, at_least: Optional[int] = None, at_most: Optional[int] = None, ): item = pattern(item) type = CoercedTo(type) - length = Length(at_least=at_least, at_most=at_most) + length = Length(exactly=exactly, at_least=at_least, at_most=at_most) super().__init__(item=item, type=type, length=length) def match(self, values, context): if not is_iterable(values): return NoMatch - result = [] - for value in values: - value = self.item.match(value, context) - if value is NoMatch: - return NoMatch - result.append(value) + if self.item == _any: + # optimization to avoid unnecessary iteration + result = values + else: + result = [] + for value in values: + value = self.item.match(value, context) + if value is NoMatch: + return NoMatch + result.append(value) result = self.type.match(result, context) if result is NoMatch: @@ -1205,52 +1162,6 @@ def match(self, values, context): return self.length.match(result, context) -class TupleOf(Slotted, Pattern): - """Pattern that matches if the respective items in a tuple match the given patterns. - - Parameters - ---------- - fields - The patterns to match the respective items in the tuple. - """ - - __slots__ = ("fields",) - fields: tuple[Pattern, ...] - - @classmethod - def __create__(cls, fields): - if not isinstance(fields, tuple): - return SequenceOf(fields, tuple) - return super().__create__(fields) - - def __init__(self, fields): - fields = tuple(map(pattern, fields)) - super().__init__(fields=fields) - - def describe(self, plural=False): - fields = ", ".join(f.describe(plural=False) for f in self.fields) - if plural: - return f"tuples of ({fields})" - else: - return f"a tuple of ({fields})" - - def match(self, values, context): - if not is_iterable(values): - return NoMatch - - if len(values) != len(self.fields): - return NoMatch - - result = [] - for pattern, value in zip(self.fields, values): - value = pattern.match(value, context) - if value is NoMatch: - return NoMatch - result.append(value) - - return tuple(result) - - class GenericMappingOf(Slotted, Pattern): """Pattern that matches if all of the keys and values match the given patterns. @@ -1351,7 +1262,7 @@ def match(self, value, context): if self.type.match(value, context) is NoMatch: return NoMatch - patterns = {**self.kwargs, **dict(zip(value.__match_args__, self.args))} + patterns = {**dict(zip(value.__match_args__, self.args)), **self.kwargs} fields = {} changed = False @@ -1443,47 +1354,141 @@ def match(self, value, context): return fn -class PatternSequence(Slotted, Pattern): - # TODO(kszucs): add a length optimization to not even try to match if the - # length of the sequence is lower than the length of the pattern sequence +class SomeOf(Slotted, Pattern): + __slots__ = ("pattern", "delimiter") - __slots__ = ("pattern_window",) - pattern_window: tuple[tuple[Pattern, Pattern], ...] + @classmethod + def __create__(cls, *args, **kwargs): + if len(args) == 1: + return super().__create__(*args, **kwargs) + else: + return SomeChunksOf(*args, **kwargs) - def __init__(self, patterns): - current_patterns = [ - SequenceOf(_any) if p is Ellipsis else pattern(p) for p in patterns - ] - following_patterns = chain(current_patterns[1:], [Not(_any)]) - pattern_window = tuple(zip(current_patterns, following_patterns)) - super().__init__(pattern_window=pattern_window) + def __init__(self, item, **kwargs): + pattern = GenericSequenceOf(item, **kwargs) + delimiter = pattern.item + super().__init__(pattern=pattern, delimiter=delimiter) - def match(self, value, context): - it = RewindableIterator(value) - result = [] + def match(self, values, context): + return self.pattern.match(values, context) - if not self.pattern_window: - try: - next(it) - except StopIteration: - return result + +class SomeChunksOf(Slotted, Pattern): + """Pattern that unpacks a value into its elements. + + Designed to be used inside a `PatternList` pattern with the `*` syntax. + """ + + __slots__ = ("pattern", "delimiter") + + def __init__(self, *args, **kwargs): + pattern = GenericSequenceOf(PatternList(args), **kwargs) + delimiter = pattern.item.patterns[0] + super().__init__(pattern=pattern, delimiter=delimiter) + + def chunk(self, values, context): + chunk = [] + for item in values: + if self.delimiter.match(item, context) is NoMatch: + chunk.append(item) else: + if chunk: # only yield if there are items in the chunk + yield chunk + chunk = [item] # start a new chunk with the delimiter + if chunk: + yield chunk + + def match(self, values, context): + chunks = self.chunk(values, context) + result = self.pattern.match(chunks, context) + if result is NoMatch: + return NoMatch + else: + return sum(result, []) + + +def _maybe_unwrap_capture(obj): + return obj.pattern if isinstance(obj, Capture) else obj + + +class PatternList(Slotted, Pattern): + """Pattern that matches if the respective items in a tuple match the given patterns. + + Parameters + ---------- + fields + The patterns to match the respective items in the tuple. + """ + + __slots__ = ("patterns", "type") + patterns: tuple[Pattern, ...] + type: type + + @classmethod + def __create__(cls, patterns, type=list): + if patterns == (): + return EqualTo(patterns) + + patterns = tuple(map(pattern, patterns)) + for pat in patterns: + pat = _maybe_unwrap_capture(pat) + if isinstance(pat, (SomeOf, SomeChunksOf)): + return VariadicPatternList(patterns, type) + + return super().__create__(patterns, type) + + def __init__(self, patterns, type): + super().__init__(patterns=patterns, type=type) + + def describe(self, plural=False): + patterns = ", ".join(f.describe(plural=False) for f in self.patterns) + if plural: + return f"tuples of ({patterns})" + else: + return f"a tuple of ({patterns})" + + def match(self, values, context): + if not is_iterable(values): + return NoMatch + + if len(values) != len(self.patterns): + return NoMatch + + result = [] + for pattern, value in zip(self.patterns, values): + value = pattern.match(value, context) + if value is NoMatch: return NoMatch + result.append(value) - for current, following in self.pattern_window: - original = current + return self.type(result) + + +class VariadicPatternList(Slotted, Pattern): + __slots__ = ("patterns", "type") + patterns: tuple[Pattern, ...] + type: type + + def __init__(self, patterns, type=list): + patterns = tuple(map(pattern, patterns)) + super().__init__(patterns=patterns, type=type) + + def match(self, value, context): + if not self.patterns: + return NoMatch if value else [] + + it = RewindableIterator(value) + result = [] - if isinstance(current, Capture): - current = current.pattern - if isinstance(following, Capture): - following = following.pattern + following_patterns = self.patterns[1:] + (Nothing(),) + for current, following in zip(self.patterns, following_patterns): + original = current + current = _maybe_unwrap_capture(current) + following = _maybe_unwrap_capture(following) - if isinstance(current, (SequenceOf, GenericSequenceOf, PatternSequence)): - if isinstance(following, (SequenceOf, GenericSequenceOf)): - following = following.item - elif isinstance(following, PatternSequence): - # first pattern to match from the pattern window - following = following.pattern_window[0][0] + if isinstance(current, (SomeOf, SomeChunksOf)): + if isinstance(following, (SomeOf, SomeChunksOf)): + following = following.delimiter matches = [] while True: @@ -1517,32 +1522,7 @@ def match(self, value, context): else: result.append(res) - return result - - -class PatternMapping(Slotted, Pattern): - __slots__ = ("keys", "values") - keys: PatternSequence - values: PatternSequence - - def __init__(self, patterns): - keys = PatternSequence(list(map(pattern, patterns.keys()))) - values = PatternSequence(list(map(pattern, patterns.values()))) - super().__init__(keys=keys, values=values) - - def match(self, value, context): - if not isinstance(value, Mapping): - return NoMatch - - keys = value.keys() - if (keys := self.keys.match(keys, context)) is NoMatch: - return NoMatch - - values = value.values() - if (values := self.values.match(values, context)) is NoMatch: - return NoMatch - - return dict(zip(keys, values)) + return self.type(result) def NoneOf(*args) -> Pattern: @@ -1555,6 +1535,11 @@ def ListOf(pattern): return SequenceOf(pattern, type=list) +def TupleOf(pattern): + """Match a variable-length tuple of items matching the given pattern.""" + return SequenceOf(pattern, type=tuple) + + def DictOf(key_pattern, value_pattern): """Match a dictionary with keys and values matching the given patterns.""" return MappingOf(key_pattern, value_pattern, type=dict) @@ -1601,13 +1586,16 @@ def pattern(obj: AnyType) -> Pattern: elif isinstance(obj, (Deferred, Resolver)): return Capture(obj) elif isinstance(obj, Mapping): - return PatternMapping(obj) + raise TypeError("Cannot create a pattern from a mapping") + elif isinstance(obj, Sequence): + if isinstance(obj, (str, bytes)): + return EqualTo(obj) + else: + return PatternList(obj, type=type(obj)) elif isinstance(obj, type): return InstanceOf(obj) elif get_origin(obj): return Pattern.from_typehint(obj, allow_coercion=False) - elif is_iterable(obj): - return PatternSequence(obj) elif callable(obj): return Custom(obj) else: @@ -1657,3 +1645,9 @@ def match( IsTruish = Check(lambda x: bool(x)) IsNumber = InstanceOf(numbers.Number) & ~InstanceOf(bool) IsString = InstanceOf(str) + +As = CoercedTo +Eq = EqualTo +In = IsIn +If = Check +Some = SomeOf