diff --git a/pkg/sql/logictest/testdata/logic_test/udf b/pkg/sql/logictest/testdata/logic_test/udf index 7b3ef111928f..8c3f701d2fb6 100644 --- a/pkg/sql/logictest/testdata/logic_test/udf +++ b/pkg/sql/logictest/testdata/logic_test/udf @@ -2812,6 +2812,130 @@ SELECT a, sub_two() FROM sub_all 5 2 6 2 +subtest any_subquery + +statement ok +CREATE TABLE any_tab ( + a INT, + b INT +) + +statement ok +CREATE FUNCTION any_fn(i INT) RETURNS BOOL LANGUAGE SQL AS $$ + SELECT i = ANY(SELECT a FROM any_tab) +$$ + +statement ok +CREATE FUNCTION any_fn_lt(i INT) RETURNS BOOL LANGUAGE SQL AS $$ + SELECT i < ANY(SELECT a FROM any_tab) +$$ + +statement ok +CREATE FUNCTION any_fn_tuple(i INT, j INT) RETURNS BOOL LANGUAGE SQL AS $$ + SELECT (i, j) = ANY(SELECT a, b FROM any_tab) +$$ + +# If the subquery returns no rows, the result should always be false. +query BBB +SELECT any_fn(1), any_fn(4), any_fn(NULL::INT) +---- +false false false + +query BBB +SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT) +---- +false false false + +query BBB +SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT) +---- +false false false + +statement ok +INSERT INTO any_tab VALUES (1, 10), (3, 30) + +query BBB +SELECT any_fn(1), any_fn(4), any_fn(NULL::INT) +---- +true false NULL + +query BBB +SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT) +---- +true false NULL + +query BBB +SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT) +---- +true false NULL + +statement ok +INSERT INTO any_tab VALUES (NULL, NULL) + +query BBB +SELECT any_fn(1), any_fn(4), any_fn(NULL::INT) +---- +true NULL NULL + +query BBB +SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT) +---- +true NULL NULL + +query BBB +SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT) +---- +true NULL NULL + +statement ok +CREATE FUNCTION any_fn2(i INT) RETURNS SETOF INT LANGUAGE SQL AS $$ + SELECT b FROM (VALUES (1), (2), (3), (NULL)) v(b) + WHERE b = ANY (SELECT a FROM any_tab WHERE a <= i) +$$ + +query I +SELECT any_fn2(2) +---- +1 + +query I +SELECT any_fn2(3) +---- +1 +3 + +subtest all_subquery + +statement ok +CREATE TABLE all_tab (a INT) + +statement ok +CREATE FUNCTION all_fn(i INT) RETURNS BOOL LANGUAGE SQL AS $$ + SELECT i = ALL(SELECT a FROM all_tab) +$$ + +# If the subquery returns no rows, the result should always be true. +query BBB +SELECT all_fn(1), all_fn(2), all_fn(NULL::INT) +---- +true true true + +statement ok +INSERT INTO all_tab VALUES (1), (1); + +query BBB +SELECT all_fn(1), all_fn(2), all_fn(NULL::INT) +---- +true false NULL + +statement ok +INSERT INTO all_tab VALUES (NULL); + +query BBB +SELECT all_fn(1), all_fn(2), all_fn(NULL::INT) +---- +NULL false NULL + subtest variadic @@ -2922,10 +3046,10 @@ SELECT oid, proname, pronamespace, proowner, prolang, proleakproof, proisstrict, FROM pg_catalog.pg_proc WHERE proname IN ('f_93314', 'f_93314_alias', 'f_93314_comp', 'f_93314_comp_t') ORDER BY oid; ---- -100264 f_93314 105 1546506610 14 false false false v 0 100263 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1; -100266 f_93314_alias 105 1546506610 14 false false false v 0 100265 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1; -100270 f_93314_comp 105 1546506610 14 false false false v 0 100267 · {} NULL SELECT (1, 2); -100271 f_93314_comp_t 105 1546506610 14 false false false v 0 100269 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1; +100271 f_93314 105 1546506610 14 false false false v 0 100270 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1; +100273 f_93314_alias 105 1546506610 14 false false false v 0 100272 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1; +100277 f_93314_comp 105 1546506610 14 false false false v 0 100274 · {} NULL SELECT (1, 2); +100278 f_93314_comp_t 105 1546506610 14 false false false v 0 100276 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1; # Regression test for #95240. Strict UDFs that are inlined should result in NULL # when presented with NULL arguments. diff --git a/pkg/sql/opt/optbuilder/builder.go b/pkg/sql/opt/optbuilder/builder.go index 794ab1cf96f2..246b1165e4a3 100644 --- a/pkg/sql/opt/optbuilder/builder.go +++ b/pkg/sql/opt/optbuilder/builder.go @@ -125,6 +125,13 @@ type Builder struct { // are disabled and only statements whitelisted are allowed. insideFuncDef bool + // udfDepth tracks the depth of UDFs within which the current expressions + // are being built. It is incremented before building statements in a UDF + // and decremented after all the statements in a UDF have been built. If + // udfDepth is greater than zero, then the builder is currently building + // expressions within one or more UDFs. + udfDepth int + // If set, we are collecting view dependencies in schemaDeps. This can only // happen inside view/function definitions. // diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index c11d5fd41caa..02ca725af773 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -687,6 +687,7 @@ func (b *Builder) buildUDF( // Build an expression for each statement in the function body. rels := make(memo.RelListExpr, len(stmts)) isSetReturning := o.Class == tree.GeneratorClass + b.udfDepth++ for i := range stmts { stmtScope := b.buildStmt(stmts[i].AST, nil /* desiredTypes */, bodyScope) expr := stmtScope.expr @@ -768,6 +769,7 @@ func (b *Builder) buildUDF( PhysProps: physProps, } } + b.udfDepth-- out = b.factory.ConstructUDF( args, diff --git a/pkg/sql/opt/optbuilder/subquery.go b/pkg/sql/opt/optbuilder/subquery.go index 01e484415498..34f8608e4218 100644 --- a/pkg/sql/opt/optbuilder/subquery.go +++ b/pkg/sql/opt/optbuilder/subquery.go @@ -372,11 +372,18 @@ func (b *Builder) buildMultiRowSubquery( )) } - // Construct the outer Any(...) operator. - out = b.factory.ConstructAny(input, scalar, &memo.SubqueryPrivate{ - Cmp: cmp, - OriginalExpr: s.Subquery, - }) + if b.udfDepth > 0 { + // Any expressions are cannot be built by the optimizer within a UDF, so + // building them as regular subqueries instead. + out = b.buildAnyAsSubquery(scalar, cmp, input, s.Subquery) + } else { + // Construct the outer Any(...) operator. + out = b.factory.ConstructAny(input, scalar, &memo.SubqueryPrivate{ + Cmp: cmp, + OriginalExpr: s.Subquery, + }) + } + switch c.Operator.Symbol { case treecmp.NotIn, treecmp.All: // NOT Any(...) @@ -386,6 +393,115 @@ func (b *Builder) buildMultiRowSubquery( return out, outScope } +// buildAnyAsSubquery builds an Any expression as a SubqueryExpr. An Any +// expression such as the one below has peculiar behavior. +// +// i ANY () +// +// The logic for evaluating this comparison is: +// +// 1. If the subquery results in zero rows, then the expression evaluates to +// false, even if i is NULL. +// 2. Otherwise, if the comparison between i and any value returned by the +// subquery is true, then the expression evaluates to true. +// 3. Otherwise, if any values returned by the subquery are NULL, then the +// expression evaluates to NULL. +// 4. Otherwise, if i is NULL, then the expression evaluates to NULL. +// 5. Otherwise, the expression evaluates to false. +// +// We use the following transformation to express this logic: +// +// i = ANY (SELECT a FROM t) +// => +// SELECT count > 0 AND (bool_or OR (null_count > 0 AND NULL)) +// FROM ( +// SELECT +// count(*) AS count, +// bool_or(cmp) AS bool_or, +// count(*) FILTER (is_null) AS null_count +// FROM ( +// SELECT a = i AS cmp, a IS NULL AS is_null +// FROM ( +// SELECT a FROM t +// ) +// ) +// ) +func (b *Builder) buildAnyAsSubquery( + left opt.ScalarExpr, op opt.Operator, sub memo.RelExpr, origExpr *tree.Subquery, +) opt.ScalarExpr { + f := b.factory + md := f.Metadata() + subCol := sub.Relational().OutputCols.SingleColumn() + + // Create projections of: + // left subCol + // left IS NULL + cmpCol := md.AddColumn("cmp", types.Bool) + isNullCol := md.AddColumn("is_null", types.Bool) + var isNull opt.ScalarExpr + if md.ColumnMeta(subCol).Type.Family() == types.TupleFamily { + // If the subquery results in a tuple, we must use an IsTupleNullExpr. + isNull = f.ConstructIsTupleNull(f.ConstructVariable(subCol)) + } else { + isNull = f.ConstructIs(f.ConstructVariable(subCol), memo.NullSingleton) + } + projections := memo.ProjectionsExpr{ + f.ConstructProjectionsItem( + b.constructComparisonWithOp(op, left, f.ConstructVariable(subCol)), + cmpCol, + ), + f.ConstructProjectionsItem(isNull, isNullCol), + } + out := f.ConstructProject(sub, projections, opt.ColSet{} /* passthrough */) + + // Create aggregations for: + // count(*) + // bool_or(cmpCol) + // count(*) FILTER (isNullCol) + countCol := md.AddColumn("count", types.Int) + boolOrCol := md.AddColumn("bool_or", types.Bool) + nullCountCol := md.AddColumn("null_count", types.Int) + aggs := memo.AggregationsExpr{ + f.ConstructAggregationsItem(f.ConstructCountRows(), countCol), + f.ConstructAggregationsItem(f.ConstructBoolOr(f.ConstructVariable(cmpCol)), boolOrCol), + f.ConstructAggregationsItem( + f.ConstructAggFilter( + f.ConstructCountRows(), + f.ConstructVariable(isNullCol), + ), + nullCountCol, + ), + } + out = f.ConstructScalarGroupBy(out, aggs, &memo.GroupingPrivate{}) + + // Create a projection of: + // countCol > 0 AND (boolOrCol OR (nullCountCol > 0 AND NULL)) + resCol := md.AddColumn("any", types.Bool) + resultProj := memo.ProjectionsExpr{ + f.ConstructProjectionsItem( + f.ConstructAnd( + f.ConstructGt( + f.ConstructVariable(countCol), + f.ConstructConstVal(tree.DZero, types.Int), + ), + f.ConstructOr( + f.ConstructVariable(boolOrCol), + f.ConstructAnd( + f.ConstructGt( + f.ConstructVariable(nullCountCol), + f.ConstructConstVal(tree.DZero, types.Int), + ), + memo.NullSingleton, + ), + ), + ), + resCol, + ), + } + out = b.factory.ConstructProject(out, resultProj, opt.ColSet{}) + return b.factory.ConstructSubquery(out, &memo.SubqueryPrivate{OriginalExpr: origExpr}) +} + var _ tree.Expr = &subquery{} var _ tree.TypedExpr = &subquery{} diff --git a/pkg/sql/opt/optbuilder/testdata/udf b/pkg/sql/opt/optbuilder/testdata/udf index f7428e7afa17..f6820c3912e7 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf +++ b/pkg/sql/opt/optbuilder/testdata/udf @@ -1329,3 +1329,212 @@ project └── gt ├── variable: a:7 └── variable: i:6 + + +# -------------------------------------------------- +# UDFs with ANY/ALL expressions. +# -------------------------------------------------- + +exec-ddl +CREATE FUNCTION any_fn(i INT) RETURNS BOOL LANGUAGE SQL AS 'SELECT i = ANY(SELECT a FROM abc)' +---- + +build format=show-scalars +SELECT any_fn(10) +---- +project + ├── columns: any_fn:14 + ├── values + │ └── tuple + └── projections + └── udf: any_fn [as=any_fn:14] + ├── params: i:1 + ├── args + │ └── const: 10 + └── body + └── limit + ├── columns: "?column?":13 + ├── project + │ ├── columns: "?column?":13 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── subquery [as="?column?":13] + │ └── project + │ ├── columns: any:12 + │ ├── scalar-group-by + │ │ ├── columns: count:9!null bool_or:10 null_count:11!null + │ │ ├── project + │ │ │ ├── columns: cmp:7 is_null:8!null + │ │ │ ├── project + │ │ │ │ ├── columns: a:2!null + │ │ │ │ └── scan abc + │ │ │ │ └── columns: a:2!null b:3 c:4 crdb_internal_mvcc_timestamp:5 tableoid:6 + │ │ │ └── projections + │ │ │ ├── eq [as=cmp:7] + │ │ │ │ ├── variable: i:1 + │ │ │ │ └── variable: a:2 + │ │ │ └── is [as=is_null:8] + │ │ │ ├── variable: a:2 + │ │ │ └── null + │ │ └── aggregations + │ │ ├── count-rows [as=count:9] + │ │ ├── bool-or [as=bool_or:10] + │ │ │ └── variable: cmp:7 + │ │ └── agg-filter [as=null_count:11] + │ │ ├── count-rows + │ │ └── variable: is_null:8 + │ └── projections + │ └── and [as=any:12] + │ ├── gt + │ │ ├── variable: count:9 + │ │ └── const: 0 + │ └── or + │ ├── variable: bool_or:10 + │ └── and + │ ├── gt + │ │ ├── variable: null_count:11 + │ │ └── const: 0 + │ └── null + └── const: 1 + +exec-ddl +CREATE FUNCTION all_fn(i INT) RETURNS BOOL LANGUAGE SQL AS 'SELECT i < ALL(SELECT a FROM abc WHERE b > i)' +---- + +build format=show-scalars +SELECT all_fn(10) +---- +project + ├── columns: all_fn:14 + ├── values + │ └── tuple + └── projections + └── udf: all_fn [as=all_fn:14] + ├── params: i:1 + ├── args + │ └── const: 10 + └── body + └── limit + ├── columns: "?column?":13 + ├── project + │ ├── columns: "?column?":13 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── not [as="?column?":13] + │ └── subquery + │ └── project + │ ├── columns: any:12 + │ ├── scalar-group-by + │ │ ├── columns: count:9!null bool_or:10 null_count:11!null + │ │ ├── project + │ │ │ ├── columns: cmp:7 is_null:8!null + │ │ │ ├── project + │ │ │ │ ├── columns: a:2!null + │ │ │ │ └── select + │ │ │ │ ├── columns: a:2!null b:3!null c:4 crdb_internal_mvcc_timestamp:5 tableoid:6 + │ │ │ │ ├── scan abc + │ │ │ │ │ └── columns: a:2!null b:3 c:4 crdb_internal_mvcc_timestamp:5 tableoid:6 + │ │ │ │ └── filters + │ │ │ │ └── gt + │ │ │ │ ├── variable: b:3 + │ │ │ │ └── variable: i:1 + │ │ │ └── projections + │ │ │ ├── ge [as=cmp:7] + │ │ │ │ ├── variable: i:1 + │ │ │ │ └── variable: a:2 + │ │ │ └── is [as=is_null:8] + │ │ │ ├── variable: a:2 + │ │ │ └── null + │ │ └── aggregations + │ │ ├── count-rows [as=count:9] + │ │ ├── bool-or [as=bool_or:10] + │ │ │ └── variable: cmp:7 + │ │ └── agg-filter [as=null_count:11] + │ │ ├── count-rows + │ │ └── variable: is_null:8 + │ └── projections + │ └── and [as=any:12] + │ ├── gt + │ │ ├── variable: count:9 + │ │ └── const: 0 + │ └── or + │ ├── variable: bool_or:10 + │ └── and + │ ├── gt + │ │ ├── variable: null_count:11 + │ │ └── const: 0 + │ └── null + └── const: 1 + +exec-ddl +CREATE FUNCTION any_fn_tuple(i INT, j INT) RETURNS BOOL LANGUAGE SQL AS 'SELECT (i, j) = ANY(SELECT a, b FROM abc)' +---- + +build format=show-scalars +SELECT any_fn_tuple(10, 20) +---- +project + ├── columns: any_fn_tuple:16 + ├── values + │ └── tuple + └── projections + └── udf: any_fn_tuple [as=any_fn_tuple:16] + ├── params: i:1 j:2 + ├── args + │ ├── const: 10 + │ └── const: 20 + └── body + └── limit + ├── columns: "?column?":15 + ├── project + │ ├── columns: "?column?":15 + │ ├── values + │ │ └── tuple + │ └── projections + │ └── subquery [as="?column?":15] + │ └── project + │ ├── columns: any:14 + │ ├── scalar-group-by + │ │ ├── columns: count:11!null bool_or:12 null_count:13!null + │ │ ├── project + │ │ │ ├── columns: cmp:9 is_null:10!null + │ │ │ ├── project + │ │ │ │ ├── columns: column8:8 + │ │ │ │ ├── project + │ │ │ │ │ ├── columns: a:3!null b:4 + │ │ │ │ │ └── scan abc + │ │ │ │ │ └── columns: a:3!null b:4 c:5 crdb_internal_mvcc_timestamp:6 tableoid:7 + │ │ │ │ └── projections + │ │ │ │ └── tuple [as=column8:8] + │ │ │ │ ├── variable: a:3 + │ │ │ │ └── variable: b:4 + │ │ │ └── projections + │ │ │ ├── eq [as=cmp:9] + │ │ │ │ ├── tuple + │ │ │ │ │ ├── variable: i:1 + │ │ │ │ │ └── variable: j:2 + │ │ │ │ └── variable: column8:8 + │ │ │ └── is-tuple-null [as=is_null:10] + │ │ │ └── variable: column8:8 + │ │ └── aggregations + │ │ ├── count-rows [as=count:11] + │ │ ├── bool-or [as=bool_or:12] + │ │ │ └── variable: cmp:9 + │ │ └── agg-filter [as=null_count:13] + │ │ ├── count-rows + │ │ └── variable: is_null:10 + │ └── projections + │ └── and [as=any:14] + │ ├── gt + │ │ ├── variable: count:11 + │ │ └── const: 0 + │ └── or + │ ├── variable: bool_or:12 + │ └── and + │ ├── gt + │ │ ├── variable: null_count:13 + │ │ └── const: 0 + │ └── null + └── const: 1