Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt: build ANY expressions as regular subqueries within UDFs #98375

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 128 additions & 4 deletions pkg/sql/logictest/testdata/logic_test/udf
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,130 @@ SELECT a, sub_two() FROM sub_all
5 2
6 2

subtest any_subquery

statement ok
CREATE TABLE any_tab (
a INT,
b INT
)

statement ok
CREATE FUNCTION any_fn(i INT) RETURNS BOOL LANGUAGE SQL AS $$
SELECT i = ANY(SELECT a FROM any_tab)
$$

statement ok
CREATE FUNCTION any_fn_lt(i INT) RETURNS BOOL LANGUAGE SQL AS $$
SELECT i < ANY(SELECT a FROM any_tab)
$$

statement ok
CREATE FUNCTION any_fn_tuple(i INT, j INT) RETURNS BOOL LANGUAGE SQL AS $$
SELECT (i, j) = ANY(SELECT a, b FROM any_tab)
$$

# If the subquery returns no rows, the result should always be false.
query BBB
SELECT any_fn(1), any_fn(4), any_fn(NULL::INT)
----
false false false

query BBB
SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT)
----
false false false

query BBB
SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT)
----
false false false

statement ok
INSERT INTO any_tab VALUES (1, 10), (3, 30)

query BBB
SELECT any_fn(1), any_fn(4), any_fn(NULL::INT)
----
true false NULL

query BBB
SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT)
----
true false NULL

query BBB
SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT)
----
true false NULL

statement ok
INSERT INTO any_tab VALUES (NULL, NULL)

query BBB
SELECT any_fn(1), any_fn(4), any_fn(NULL::INT)
----
true NULL NULL

query BBB
SELECT any_fn_lt(1), any_fn_lt(4), any_fn_lt(NULL::INT)
----
true NULL NULL

query BBB
SELECT any_fn_tuple(1, 10), any_fn_tuple(1, 20), any_fn_tuple(NULL::INT, NULL::INT)
----
true NULL NULL

statement ok
CREATE FUNCTION any_fn2(i INT) RETURNS SETOF INT LANGUAGE SQL AS $$
SELECT b FROM (VALUES (1), (2), (3), (NULL)) v(b)
WHERE b = ANY (SELECT a FROM any_tab WHERE a <= i)
$$

query I
SELECT any_fn2(2)
----
1

query I rowsort
SELECT any_fn2(3)
----
1
3

subtest all_subquery

statement ok
CREATE TABLE all_tab (a INT)

statement ok
CREATE FUNCTION all_fn(i INT) RETURNS BOOL LANGUAGE SQL AS $$
SELECT i = ALL(SELECT a FROM all_tab)
$$

# If the subquery returns no rows, the result should always be true.
query BBB
SELECT all_fn(1), all_fn(2), all_fn(NULL::INT)
----
true true true

statement ok
INSERT INTO all_tab VALUES (1), (1);

query BBB
SELECT all_fn(1), all_fn(2), all_fn(NULL::INT)
----
true false NULL

statement ok
INSERT INTO all_tab VALUES (NULL);

query BBB
SELECT all_fn(1), all_fn(2), all_fn(NULL::INT)
----
NULL false NULL


subtest variadic

Expand Down Expand Up @@ -2922,10 +3046,10 @@ SELECT oid, proname, pronamespace, proowner, prolang, proleakproof, proisstrict,
FROM pg_catalog.pg_proc WHERE proname IN ('f_93314', 'f_93314_alias', 'f_93314_comp', 'f_93314_comp_t')
ORDER BY oid;
----
100264 f_93314 105 1546506610 14 false false false v 0 100263 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1;
100266 f_93314_alias 105 1546506610 14 false false false v 0 100265 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1;
100270 f_93314_comp 105 1546506610 14 false false false v 0 100267 · {} NULL SELECT (1, 2);
100271 f_93314_comp_t 105 1546506610 14 false false false v 0 100269 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1;
100271 f_93314 105 1546506610 14 false false false v 0 100270 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1;
100273 f_93314_alias 105 1546506610 14 false false false v 0 100272 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1;
100277 f_93314_comp 105 1546506610 14 false false false v 0 100274 · {} NULL SELECT (1, 2);
100278 f_93314_comp_t 105 1546506610 14 false false false v 0 100276 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1;

# Regression test for #95240. Strict UDFs that are inlined should result in NULL
# when presented with NULL arguments.
Expand Down
92 changes: 50 additions & 42 deletions pkg/sql/opt/norm/decorrelate_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,25 @@ func (r *subqueryHoister) constructGroupByExists(subquery memo.RelExpr) memo.Rel
)
}

// constructGroupByAny transforms a scalar Any expression like this:
// constructGroupByAny transforms a scalar Any expression into a scalar GroupBy
// expression that returns a one row, one column relation. See
// CustomFuncs.ConstructGroupByAny for more details.
func (r *subqueryHoister) constructGroupByAny(
scalar opt.ScalarExpr, cmp opt.Operator, input memo.RelExpr,
) memo.RelExpr {
// When the scalar value is not a simple variable or constant expression,
// then cache its value using a projection, since it will be referenced
// multiple times.
if scalar.Op() != opt.VariableOp && !opt.IsConstValueOp(scalar) {
typ := scalar.DataType()
scalarColID := r.f.Metadata().AddColumn("scalar", typ)
r.hoisted = r.c.ProjectExtraCol(r.hoisted, scalar, scalarColID)
scalar = r.f.ConstructVariable(scalarColID)
}
return r.c.ConstructGroupByAny(scalar, cmp, input)
}

// ConstructGroupByAny transforms a scalar Any expression like this:
//
// z = ANY(SELECT x FROM xy)
//
Expand Down Expand Up @@ -990,82 +1008,72 @@ func (r *subqueryHoister) constructGroupByExists(subquery memo.RelExpr) memo.Rel
// TryDecorrelateScalarGroupBy rule, which will push a left join into the
// GroupBy. Null values produced by the left join will simply be ignored by
// BOOL_OR, and so cannot be used for any other purpose.
func (r *subqueryHoister) constructGroupByAny(
func (c *CustomFuncs) ConstructGroupByAny(
scalar opt.ScalarExpr, cmp opt.Operator, input memo.RelExpr,
) memo.RelExpr {
// When the scalar value is not a simple variable or constant expression,
// then cache its value using a projection, since it will be referenced
// multiple times.
if scalar.Op() != opt.VariableOp && !opt.IsConstValueOp(scalar) {
typ := scalar.DataType()
scalarColID := r.f.Metadata().AddColumn("scalar", typ)
r.hoisted = r.c.ProjectExtraCol(r.hoisted, scalar, scalarColID)
scalar = r.f.ConstructVariable(scalarColID)
}

inputVar := r.f.funcs.referenceSingleColumn(input)
notNullColID := r.f.Metadata().AddColumn("notnull", types.Bool)
aggColID := r.f.Metadata().AddColumn("bool_or", types.Bool)
aggVar := r.f.ConstructVariable(aggColID)
caseColID := r.f.Metadata().AddColumn("case", types.Bool)
inputVar := c.f.funcs.referenceSingleColumn(input)
notNullColID := c.f.Metadata().AddColumn("notnull", types.Bool)
aggColID := c.f.Metadata().AddColumn("bool_or", types.Bool)
aggVar := c.f.ConstructVariable(aggColID)
caseColID := c.f.Metadata().AddColumn("case", types.Bool)

var scalarNotNull opt.ScalarExpr
if scalar.DataType().Family() == types.TupleFamily {
scalarNotNull = r.f.ConstructIsTupleNotNull(scalar)
scalarNotNull = c.f.ConstructIsTupleNotNull(scalar)
} else {
scalarNotNull = r.f.ConstructIsNot(scalar, memo.NullSingleton)
scalarNotNull = c.f.ConstructIsNot(scalar, memo.NullSingleton)
}

var inputNotNull opt.ScalarExpr
if inputVar.DataType().Family() == types.TupleFamily {
inputNotNull = r.f.ConstructIsTupleNotNull(inputVar)
inputNotNull = c.f.ConstructIsTupleNotNull(inputVar)
} else {
inputNotNull = r.f.ConstructIsNot(inputVar, memo.NullSingleton)
inputNotNull = c.f.ConstructIsNot(inputVar, memo.NullSingleton)
}

return r.f.ConstructProject(
r.f.ConstructScalarGroupBy(
r.f.ConstructProject(
r.f.ConstructSelect(
return c.f.ConstructProject(
c.f.ConstructScalarGroupBy(
c.f.ConstructProject(
c.f.ConstructSelect(
input,
memo.FiltersExpr{r.f.ConstructFiltersItem(
r.f.ConstructIsNot(
r.f.funcs.ConstructBinary(cmp, scalar, inputVar),
memo.FiltersExpr{c.f.ConstructFiltersItem(
c.f.ConstructIsNot(
c.f.funcs.ConstructBinary(cmp, scalar, inputVar),
memo.FalseSingleton,
),
)},
),
memo.ProjectionsExpr{r.f.ConstructProjectionsItem(
memo.ProjectionsExpr{c.f.ConstructProjectionsItem(
inputNotNull,
notNullColID,
)},
opt.ColSet{},
),
memo.AggregationsExpr{r.f.ConstructAggregationsItem(
r.f.ConstructBoolOr(
r.f.ConstructVariable(notNullColID),
memo.AggregationsExpr{c.f.ConstructAggregationsItem(
c.f.ConstructBoolOr(
c.f.ConstructVariable(notNullColID),
),
aggColID,
)},
memo.EmptyGroupingPrivate,
),
memo.ProjectionsExpr{r.f.ConstructProjectionsItem(
r.f.ConstructCase(
r.f.ConstructTrue(),
memo.ProjectionsExpr{c.f.ConstructProjectionsItem(
c.f.ConstructCase(
c.f.ConstructTrue(),
memo.ScalarListExpr{
r.f.ConstructWhen(
r.f.ConstructAnd(
c.f.ConstructWhen(
c.f.ConstructAnd(
aggVar,
scalarNotNull,
),
r.f.ConstructTrue(),
c.f.ConstructTrue(),
),
r.f.ConstructWhen(
r.f.ConstructIs(aggVar, memo.NullSingleton),
r.f.ConstructFalse(),
c.f.ConstructWhen(
c.f.ConstructIs(aggVar, memo.NullSingleton),
c.f.ConstructFalse(),
),
},
r.f.ConstructNull(types.Bool),
c.f.ConstructNull(types.Bool),
),
caseColID,
)},
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql/opt/optbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ type Builder struct {
// are disabled and only statements whitelisted are allowed.
insideFuncDef bool

// insideUDF is true when the current expressions are being built within a
// UDF.
// TODO(mgartner): Once other UDFs can be referenced from within a UDF, a
// boolean will not be sufficient to track whether or not we are in a UDF.
// We'll need to track the depth of the UDFs we are building expressions
// within.
insideUDF bool

// If set, we are collecting view dependencies in schemaDeps. This can only
// happen inside view/function definitions.
//
Expand Down
6 changes: 6 additions & 0 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ func (b *Builder) buildUDF(
// Build an expression for each statement in the function body.
rels := make(memo.RelListExpr, len(stmts))
isSetReturning := o.Class == tree.GeneratorClass
// TODO(mgartner): Once other UDFs can be referenced from within a UDF, a
// boolean will not be sufficient to track whether or not we are in a UDF.
// We'll need to track the depth of the UDFs we are building expressions
// within.
b.insideUDF = true
for i := range stmts {
stmtScope := b.buildStmt(stmts[i].AST, nil /* desiredTypes */, bodyScope)
expr := stmtScope.expr
Expand Down Expand Up @@ -768,6 +773,7 @@ func (b *Builder) buildUDF(
PhysProps: physProps,
}
}
b.insideUDF = false

// For set-returning functions, we handle STRICT behavior in the routine
// execution logic. For scalar UDFs this is handled by a CASE statement - see
Expand Down
18 changes: 13 additions & 5 deletions pkg/sql/opt/optbuilder/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,19 @@ func (b *Builder) buildMultiRowSubquery(
))
}

// Construct the outer Any(...) operator.
out = b.factory.ConstructAny(input, scalar, &memo.SubqueryPrivate{
Cmp: cmp,
OriginalExpr: s.Subquery,
})
if b.insideUDF {
// Any expressions cannot be built by the optimizer within a UDF, so
// build them as subqueries with ScalarGroupBy expressions instead.
sub := b.factory.CustomFuncs().ConstructGroupByAny(scalar, cmp, input)
out = b.factory.ConstructSubquery(sub, &memo.SubqueryPrivate{OriginalExpr: s.Subquery})
} else {
// Construct the outer Any(...) operator.
out = b.factory.ConstructAny(input, scalar, &memo.SubqueryPrivate{
Cmp: cmp,
OriginalExpr: s.Subquery,
})
}

switch c.Operator.Symbol {
case treecmp.NotIn, treecmp.All:
// NOT Any(...)
Expand Down
Loading