From 52b9b5d035c4546a3a0690653f90238818377d08 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 7 Jun 2024 15:42:30 -0800 Subject: [PATCH] fix: fix and improve shape inference in many ops Usually when using eg duckdb's REGEXP_REPLACE() function, the pattern and replacement are always scalar constants. But actually it can accept columnar for all three of it's arguments. These sort of assumptions appeared to be all over the place in the code base. I found these fixs by grepping for "shape_like" and manually looking at all the instances. I didn't add any test cases, that felt like a monumental thing to do I wasn't willing to put the time in for. But IDK, I'm not sure I would even want to put in all the tests, it would be SO much boilerplate. To deal with Nodes sometimes having optional args, I modified shape_like() to be more flexible. IDK, that was sort of a lazy approach, we could be much more verbose and have each op do this filtering individually, but I didn't that was worth it. --- ibis/expr/operations/arrays.py | 8 +++--- ibis/expr/operations/generic.py | 3 +-- ibis/expr/operations/strings.py | 42 ++++++++++++++------------------ ibis/expr/operations/temporal.py | 6 ++--- ibis/expr/rules.py | 1 + 5 files changed, 26 insertions(+), 34 deletions(-) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 6d68baab94c3..a297b9322067 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -48,7 +48,7 @@ class ArraySlice(Value): stop: Optional[Value[dt.Integer]] = None dtype = rlz.dtype_like("arg") - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -71,14 +71,12 @@ class ArrayConcat(Value): arg: VarTuple[Value[dt.Array]] + shape = rlz.shape_like("arg") + @attribute def dtype(self): return dt.Array(dt.highest_precedence(arg.dtype.value_type for arg in self.arg)) - @attribute - def shape(self): - return rlz.highest_precedence_shape(self.arg) - @public class ArrayRepeat(Value): diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index cfa3ece1b456..c843d42afeb4 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -321,8 +321,7 @@ def __init__(self, cases, results, default): @attribute def shape(self): - # TODO(kszucs): can be removed after making Sequence iterable - return rlz.highest_precedence_shape(self.cases) + return rlz.highest_precedence_shape((*self.cases, *self.results, self.default)) @attribute def dtype(self): diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index 45037fc218ee..dce4ec9d5599 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -66,7 +66,7 @@ class Substring(Value): length: Optional[Value[dt.Integer]] = None dtype = dt.string - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -78,7 +78,7 @@ class StringSlice(Value): end: Optional[Value[dt.Integer]] = None dtype = dt.string - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -88,7 +88,7 @@ class StrRight(Value): arg: Value[dt.String] nchars: Value[dt.Integer] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -112,7 +112,7 @@ class StringFind(Value): start: Optional[Value[dt.Integer]] = None end: Optional[Value[dt.Integer]] = None - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.int64 @@ -124,7 +124,7 @@ class Translate(Value): from_str: Value[dt.String] to_str: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -136,7 +136,7 @@ class LPad(Value): length: Value[dt.Integer] pad: Optional[Value[dt.String]] = None - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -148,7 +148,7 @@ class RPad(Value): length: Value[dt.Integer] pad: Optional[Value[dt.String]] = None - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -174,7 +174,7 @@ class StringJoin(Value): @attribute def shape(self): - return rlz.highest_precedence_shape(self.arg) + return rlz.highest_precedence_shape((self.sep, *self.arg)) @public @@ -196,7 +196,7 @@ class StartsWith(Value): start: Value[dt.String] dtype = dt.boolean - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -207,7 +207,7 @@ class EndsWith(Value): end: Value[dt.String] dtype = dt.boolean - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -216,7 +216,7 @@ class FuzzySearch(Value): pattern: Value[dt.String] dtype = dt.boolean - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") @public @@ -252,7 +252,7 @@ class RegexExtract(Value): pattern: Value[dt.String] index: Value[dt.Integer] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -263,12 +263,9 @@ class RegexSplit(Value): arg: Value[dt.String] pattern: Value[dt.String] + shape = rlz.shape_like("args") dtype = dt.Array(dt.string) - @attribute - def shape(self): - return rlz.highest_precedence_shape((self.arg, self.pattern)) - @public class RegexReplace(Value): @@ -278,7 +275,7 @@ class RegexReplace(Value): pattern: Value[dt.String] replacement: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -290,7 +287,7 @@ class StringReplace(Value): pattern: Value[dt.String] replacement: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -301,7 +298,7 @@ class StringSplit(Value): arg: Value[dt.String] delimiter: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.Array(dt.string) @@ -316,11 +313,8 @@ class StringConcat(Value): @public -class ExtractURLField(Value): - arg: Value[dt.String] - - shape = rlz.shape_like("arg") - dtype = dt.string +class ExtractURLField(StringUnary): + pass @public diff --git a/ibis/expr/operations/temporal.py b/ibis/expr/operations/temporal.py index 8bdfc34bda76..fa17d6f7c14f 100644 --- a/ibis/expr/operations/temporal.py +++ b/ibis/expr/operations/temporal.py @@ -68,7 +68,7 @@ class Strftime(Value): arg: Value[dt.Temporal] format_str: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -79,7 +79,7 @@ class StringToTimestamp(Value): arg: Value[dt.String] format_str: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.Timestamp(timezone="UTC") @@ -90,7 +90,7 @@ class StringToDate(Value): arg: Value[dt.String] format_str: Value[dt.String] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.date diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 0c865297889f..7abf235f1525 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -77,6 +77,7 @@ def shape_like(name): def shape(self): args = getattr(self, name) args = args if util.is_iterable(args) else [args] + args = [a for a in args if a is not None] return highest_precedence_shape(args) return shape