diff --git a/pkg/sql/opt/norm/rules/scalar.opt b/pkg/sql/opt/norm/rules/scalar.opt index e7f23afa0821..00ff20a6cd01 100644 --- a/pkg/sql/opt/norm/rules/scalar.opt +++ b/pkg/sql/opt/norm/rules/scalar.opt @@ -121,6 +121,60 @@ $input => (Exists $input $subqueryPrivate) +# InlineExistsSelectTuple splits a tuple equality filter into multiple +# (per-column) equalities, in the case where the tuple on one side is being +# projected. +# +# We are specifically handling the case when this is under Exists because we +# don't have to keep the same output columns for the Select. This case is +# important because it is produced for an IN subquery: +# +# SELECT * FROM ab WHERE (a, b) IN (SELECT c, d FROM cd) +# +# Without this rule, we would not be able to produce a lookup join plan for such +# a query. +# +[InlineExistsSelectTuple, Normalize] +(Exists + (Select + (Project + $input:* + [ + ... + (ProjectionsItem $tuple:(Tuple) $tupleCol:*) + ... + ] + ) + $filters:[ + ... + $item:(FiltersItem + (Eq + # CommuteVar ensures that the variable is on the left. + (Variable + $varCol:* & + (EqualsColumn $varCol $tupleCol) + ) + $rhs:(Tuple) & + (TuplesHaveSameLength $tuple $rhs) + ) + ) + ... + ] + ) + $subqueryPrivate:* +) +=> +(Exists + (Select + $input + (ConcatFilters + (RemoveFiltersItem $filters $item) + (SplitTupleEq $tuple $rhs) + ) + ) + $subqueryPrivate +) + # IntroduceExistsLimit inserts a LIMIT 1 "under" Exists so as to save resources # to make the EXISTS determination. # diff --git a/pkg/sql/opt/norm/scalar_funcs.go b/pkg/sql/opt/norm/scalar_funcs.go index 17e3cb07850c..2a85b6c7bbbd 100644 --- a/pkg/sql/opt/norm/scalar_funcs.go +++ b/pkg/sql/opt/norm/scalar_funcs.go @@ -308,3 +308,29 @@ func (c *CustomFuncs) VarsAreSame(left, right opt.ScalarExpr) bool { rv := right.(*memo.VariableExpr) return lv.Col == rv.Col } + +// EqualsColumn returns true if the two column IDs are the same. +func (c *CustomFuncs) EqualsColumn(left, right opt.ColumnID) bool { + return left == right +} + +// TuplesHaveSameLength returns true if two tuples have the same number of +// elements. +func (c *CustomFuncs) TuplesHaveSameLength(a, b opt.ScalarExpr) bool { + return len(a.(*memo.TupleExpr).Elems) == len(b.(*memo.TupleExpr).Elems) +} + +// SplitTupleEq splits an equality condition between two tuples into multiple +// equalities, one for each tuple column. +func (c *CustomFuncs) SplitTupleEq(lhsExpr, rhsExpr opt.ScalarExpr) memo.FiltersExpr { + lhs := lhsExpr.(*memo.TupleExpr) + rhs := rhsExpr.(*memo.TupleExpr) + if len(lhs.Elems) != len(rhs.Elems) { + panic(errors.AssertionFailedf("unequal tuple lengths")) + } + res := make(memo.FiltersExpr, len(lhs.Elems)) + for i := range res { + res[i] = c.f.ConstructFiltersItem(c.f.ConstructEq(lhs.Elems[i], rhs.Elems[i])) + } + return res +} diff --git a/pkg/sql/opt/norm/testdata/rules/assign_placeholders b/pkg/sql/opt/norm/testdata/rules/assign_placeholders new file mode 100644 index 000000000000..8d27ff305c64 --- /dev/null +++ b/pkg/sql/opt/norm/testdata/rules/assign_placeholders @@ -0,0 +1,171 @@ +exec-ddl +CREATE TABLE kv (k INT PRIMARY KEY, v INT) +---- + +exec-ddl +CREATE TABLE abcd (a INT, b INT, c INT, d INT, PRIMARY KEY (a,b,c)) +---- + +assign-placeholders-norm query-args=(1) +SELECT v FROM kv WHERE k = $1 +---- +project + ├── columns: v:2 + ├── cardinality: [0 - 1] + ├── key: () + ├── fd: ()-->(2) + └── select + ├── columns: k:1!null v:2 + ├── cardinality: [0 - 1] + ├── key: () + ├── fd: ()-->(1,2) + ├── scan kv + │ ├── columns: k:1!null v:2 + │ ├── key: (1) + │ └── fd: (1)-->(2) + └── filters + └── k:1 = 1 [outer=(1), constraints=(/1: [/1 - /1]; tight), fd=()-->(1)] + +assign-placeholders-opt query-args=(1) +SELECT v FROM kv WHERE k = $1 +---- +project + ├── columns: v:2 + ├── cardinality: [0 - 1] + ├── key: () + ├── fd: ()-->(2) + └── scan kv + ├── columns: k:1!null v:2 + ├── constraint: /1: [/1 - /1] + ├── cardinality: [0 - 1] + ├── key: () + └── fd: ()-->(1,2) + +# This is what we ideally want to obtain after assigning placeholders in the +# test below. +norm +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest('{1}'::INT[]), + unnest('{2}'::INT[]) +) +---- +select + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── key: (3) + ├── fd: ()-->(1,2), (3)-->(4) + ├── scan abcd + │ ├── columns: a:1!null b:2!null c:3!null d:4 + │ ├── key: (1-3) + │ └── fd: (1-3)-->(4) + └── filters + └── (a:1, b:2) IN ((1, 2),) [outer=(1,2), constraints=(/1/2: [/1/2 - /1/2]; /2: [/2 - /2]; tight), fd=()-->(1,2)] + +# The normalized expression above can be explored into a constrained scan. +opt +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest('{1}'::INT[]), + unnest('{2}'::INT[]) +) +---- +scan abcd + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── constraint: /1/2/3: [/1/2 - /1/2] + ├── key: (3) + └── fd: ()-->(1,2), (3)-->(4) + +assign-placeholders-norm query-args=('{1}','{2}') +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest($1:::STRING::INT[]), + unnest($2:::STRING::INT[]) +) +---- +select + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── key: (3) + ├── fd: ()-->(1,2), (3)-->(4) + ├── scan abcd + │ ├── columns: a:1!null b:2!null c:3!null d:4 + │ ├── key: (1-3) + │ └── fd: (1-3)-->(4) + └── filters + ├── a:1 = 1 [outer=(1), constraints=(/1: [/1 - /1]; tight), fd=()-->(1)] + └── b:2 = 2 [outer=(2), constraints=(/2: [/2 - /2]; tight), fd=()-->(2)] + +# We want this query to be optimized into a constrained scan, just like the +# no-placeholders variant above. +assign-placeholders-opt query-args=('{1}','{2}') +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest($1:::STRING::INT[]), + unnest($2:::STRING::INT[]) +) +---- +scan abcd + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── constraint: /1/2/3: [/1/2 - /1/2] + ├── key: (3) + └── fd: ()-->(1,2), (3)-->(4) + +# Note: \x2c is a comma; we can't use a comma directly because of the +# datadriven parser. +assign-placeholders-norm query-args=('{1\x2c 2}','{3\x2c 4}') +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest($1:::STRING::INT[]), + unnest($2:::STRING::INT[]) +) +---- +semi-join (hash) + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── stable + ├── key: (1-3) + ├── fd: (1-3)-->(4) + ├── scan abcd + │ ├── columns: a:1!null b:2!null c:3!null d:4 + │ ├── key: (1-3) + │ └── fd: (1-3)-->(4) + ├── project-set + │ ├── columns: unnest:6 unnest:7 + │ ├── stable + │ ├── values + │ │ ├── cardinality: [1 - 1] + │ │ ├── key: () + │ │ └── () + │ └── zip + │ ├── unnest(e'{1\\x2c 2}'::INT8[]) [stable] + │ └── unnest(e'{3\\x2c 4}'::INT8[]) [stable] + └── filters + ├── unnest:6 = a:1 [outer=(1,6), constraints=(/1: (/NULL - ]; /6: (/NULL - ]), fd=(1)==(6), (6)==(1)] + └── unnest:7 = b:2 [outer=(2,7), constraints=(/2: (/NULL - ]; /7: (/NULL - ]), fd=(2)==(7), (7)==(2)] + +assign-placeholders-opt query-args=('{1\x2c 2}','{3\x2c 4}') +SELECT * FROM abcd WHERE (a, b) IN ( + SELECT unnest($1:::STRING::INT[]), + unnest($2:::STRING::INT[]) +) +---- +project + ├── columns: a:1!null b:2!null c:3!null d:4 + ├── stable + ├── key: (1-3) + ├── fd: (1-3)-->(4) + └── inner-join (lookup abcd) + ├── columns: a:1!null b:2!null c:3!null d:4 unnest:6!null unnest:7!null + ├── key columns: [6 7] = [1 2] + ├── stable + ├── key: (3,6,7) + ├── fd: (1-3)-->(4), (1)==(6), (6)==(1), (2)==(7), (7)==(2) + ├── distinct-on + │ ├── columns: unnest:6 unnest:7 + │ ├── grouping columns: unnest:6 unnest:7 + │ ├── stable + │ ├── key: (6,7) + │ └── project-set + │ ├── columns: unnest:6 unnest:7 + │ ├── stable + │ ├── values + │ │ ├── cardinality: [1 - 1] + │ │ ├── key: () + │ │ └── () + │ └── zip + │ ├── unnest(e'{1\\x2c 2}'::INT8[]) [stable] + │ └── unnest(e'{3\\x2c 4}'::INT8[]) [stable] + └── filters (true) diff --git a/pkg/sql/opt/norm/testdata/rules/scalar b/pkg/sql/opt/norm/testdata/rules/scalar index 16af6aa80d90..3ab2a74e955d 100644 --- a/pkg/sql/opt/norm/testdata/rules/scalar +++ b/pkg/sql/opt/norm/testdata/rules/scalar @@ -6,6 +6,10 @@ exec-ddl CREATE TABLE xy (x INT PRIMARY KEY, y INT) ---- +exec-ddl +CREATE TABLE abcd (a INT, b INT, c INT, d INT) +---- + # -------------------------------------------------- # CommuteVar # -------------------------------------------------- @@ -458,6 +462,126 @@ select │ └── limit hint: 1.00 └── 1 +# -------------------------------------------------- +# InlineExistsSelectTuple +# -------------------------------------------------- +norm expect=InlineExistsSelectTuple +SELECT * FROM a WHERE (k, i) IN (SELECT x, y FROM xy) +---- +semi-join (hash) + ├── columns: k:1!null i:2 f:3 s:4 arr:5 + ├── key: (1) + ├── fd: (1)-->(2-5) + ├── scan a + │ ├── columns: k:1!null i:2 f:3 s:4 arr:5 + │ ├── key: (1) + │ └── fd: (1)-->(2-5) + ├── scan xy + │ ├── columns: x:7!null y:8 + │ ├── key: (7) + │ └── fd: (7)-->(8) + └── filters + ├── x:7 = k:1 [outer=(1,7), constraints=(/1: (/NULL - ]; /7: (/NULL - ]), fd=(1)==(7), (7)==(1)] + └── y:8 = i:2 [outer=(2,8), constraints=(/2: (/NULL - ]; /8: (/NULL - ]), fd=(2)==(8), (8)==(2)] + +norm expect=InlineExistsSelectTuple +SELECT * FROM a WHERE (k, i) IN (SELECT x, 2 FROM xy) +---- +semi-join (hash) + ├── columns: k:1!null i:2!null f:3 s:4 arr:5 + ├── key: (1) + ├── fd: ()-->(2), (1)-->(3-5) + ├── select + │ ├── columns: k:1!null i:2!null f:3 s:4 arr:5 + │ ├── key: (1) + │ ├── fd: ()-->(2), (1)-->(3-5) + │ ├── scan a + │ │ ├── columns: k:1!null i:2 f:3 s:4 arr:5 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-5) + │ └── filters + │ └── i:2 = 2 [outer=(2), constraints=(/2: [/2 - /2]; tight), fd=()-->(2)] + ├── scan xy + │ ├── columns: x:7!null + │ └── key: (7) + └── filters + └── x:7 = k:1 [outer=(1,7), constraints=(/1: (/NULL - ]; /7: (/NULL - ]), fd=(1)==(7), (7)==(1)] + +norm expect=InlineExistsSelectTuple +SELECT * FROM a WHERE f>1 AND (k, i) IN (SELECT x, 2 FROM xy) AND s = 'foo' +---- +semi-join (hash) + ├── columns: k:1!null i:2!null f:3!null s:4!null arr:5 + ├── key: (1) + ├── fd: ()-->(2,4), (1)-->(3,5) + ├── select + │ ├── columns: k:1!null i:2!null f:3!null s:4!null arr:5 + │ ├── key: (1) + │ ├── fd: ()-->(2,4), (1)-->(3,5) + │ ├── scan a + │ │ ├── columns: k:1!null i:2 f:3 s:4 arr:5 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-5) + │ └── filters + │ ├── i:2 = 2 [outer=(2), constraints=(/2: [/2 - /2]; tight), fd=()-->(2)] + │ ├── f:3 > 1.0 [outer=(3), constraints=(/3: [/1.0000000000000002 - ]; tight)] + │ └── s:4 = 'foo' [outer=(4), constraints=(/4: [/'foo' - /'foo']; tight), fd=()-->(4)] + ├── scan xy + │ ├── columns: x:7!null + │ └── key: (7) + └── filters + └── x:7 = k:1 [outer=(1,7), constraints=(/1: (/NULL - ]; /7: (/NULL - ]), fd=(1)==(7), (7)==(1)] + +# Verify that we handle multiple tuples. +norm expect=InlineExistsSelectTuple +SELECT * FROM abcd WHERE (a, b) IN (SELECT x, y FROM xy) AND (c, d) IN (SELECT k, i FROM a) +---- +semi-join (hash) + ├── columns: a:1 b:2 c:3 d:4 + ├── semi-join (hash) + │ ├── columns: a:1 b:2 c:3 d:4 + │ ├── scan abcd + │ │ └── columns: a:1 b:2 c:3 d:4 + │ ├── scan a + │ │ ├── columns: k:10!null i:11 + │ │ ├── key: (10) + │ │ └── fd: (10)-->(11) + │ └── filters + │ ├── k:10 = c:3 [outer=(3,10), constraints=(/3: (/NULL - ]; /10: (/NULL - ]), fd=(3)==(10), (10)==(3)] + │ └── i:11 = d:4 [outer=(4,11), constraints=(/4: (/NULL - ]; /11: (/NULL - ]), fd=(4)==(11), (11)==(4)] + ├── scan xy + │ ├── columns: x:7!null y:8 + │ ├── key: (7) + │ └── fd: (7)-->(8) + └── filters + ├── x:7 = a:1 [outer=(1,7), constraints=(/1: (/NULL - ]; /7: (/NULL - ]), fd=(1)==(7), (7)==(1)] + └── y:8 = b:2 [outer=(2,8), constraints=(/2: (/NULL - ]; /8: (/NULL - ]), fd=(2)==(8), (8)==(2)] + +# Make sure we check that the left-hand side is the correct tuple; the result +# would be bad if we didn't check that the variable is for the tuple in the +# projection. +norm expect=InlineExistsSelectTuple +SELECT * FROM abcd WHERE EXISTS(SELECT * FROM (SELECT (x, y), (x+1,y+1) FROM xy) AS v(tup1,tup2) WHERE tup2 = (a, b)) +---- +semi-join (hash) + ├── columns: a:1 b:2 c:3 d:4 + ├── immutable + ├── scan abcd + │ └── columns: a:1 b:2 c:3 d:4 + ├── project + │ ├── columns: column13:13 column12:12!null + │ ├── immutable + │ ├── scan xy + │ │ ├── columns: x:7!null y:8 + │ │ ├── key: (7) + │ │ └── fd: (7)-->(8) + │ └── projections + │ ├── y:8 + 1 [as=column13:13, outer=(8), immutable] + │ └── x:7 + 1 [as=column12:12, outer=(7), immutable] + └── filters + ├── a:1 = column12:12 [outer=(1,12), constraints=(/1: (/NULL - ]; /12: (/NULL - ]), fd=(1)==(12), (12)==(1)] + └── b:2 = column13:13 [outer=(2,13), constraints=(/2: (/NULL - ]; /13: (/NULL - ]), fd=(2)==(13), (13)==(2)] + # -------------------------------------------------- # IntroduceExistsLimit # -------------------------------------------------- @@ -1148,36 +1272,24 @@ SELECT k FROM a WHERE (k, i) IN (SELECT b, a FROM (VALUES (1, 1), (2, 2), (3, 3) ---- project ├── columns: k:1!null - ├── immutable ├── key: (1) └── semi-join (hash) - ├── columns: k:1!null column10:10 - ├── immutable + ├── columns: k:1!null i:2 ├── key: (1) - ├── fd: (1)-->(10) - ├── project - │ ├── columns: column10:10 k:1!null + ├── fd: (1)-->(2) + ├── scan a + │ ├── columns: k:1!null i:2 │ ├── key: (1) - │ ├── fd: (1)-->(10) - │ ├── scan a - │ │ ├── columns: k:1!null i:2 - │ │ ├── key: (1) - │ │ └── fd: (1)-->(2) - │ └── projections - │ └── (k:1, i:2) [as=column10:10, outer=(1,2)] - ├── project - │ ├── columns: column9:9!null + │ └── fd: (1)-->(2) + ├── values + │ ├── columns: column1:7!null column2:8!null │ ├── cardinality: [3 - 3] - │ ├── values - │ │ ├── columns: column1:7!null column2:8!null - │ │ ├── cardinality: [3 - 3] - │ │ ├── (1, 1) - │ │ ├── (2, 2) - │ │ └── (3, 3) - │ └── projections - │ └── (column2:8, column1:7) [as=column9:9, outer=(7,8)] + │ ├── (1, 1) + │ ├── (2, 2) + │ └── (3, 3) └── filters - └── column10:10 = column9:9 [outer=(9,10), immutable, constraints=(/9: (/NULL - ]; /10: (/NULL - ]), fd=(9)==(10), (10)==(9)] + ├── column2:8 = k:1 [outer=(1,8), constraints=(/1: (/NULL - ]; /8: (/NULL - ]), fd=(1)==(8), (8)==(1)] + └── column1:7 = i:2 [outer=(2,7), constraints=(/2: (/NULL - ]; /7: (/NULL - ]), fd=(2)==(7), (7)==(2)] # -------------------------------------------------- # SimplifyEqualsAnyTuple diff --git a/pkg/sql/opt/testutils/opttester/memo_groups.go b/pkg/sql/opt/testutils/opttester/memo_groups.go index 47ed5f578a48..8efe553fec0f 100644 --- a/pkg/sql/opt/testutils/opttester/memo_groups.go +++ b/pkg/sql/opt/testutils/opttester/memo_groups.go @@ -109,9 +109,9 @@ func (g *memoGroups) depthFirstSearch( return nil } - // There are various scalar leaf singletons that won't be registered as - // groups; ignore them. - if scalar, ok := start.(opt.ScalarExpr); ok && scalar.ChildCount() == 0 { + // There are various scalars that won't be registered as groups (e.g. + // singletons). Ignore them (rather than panicking in firstInGroup). + if scalar, ok := start.(opt.ScalarExpr); ok { if _, found := g.exprMap[scalar]; !found { return nil } diff --git a/pkg/sql/opt/testutils/opttester/opt_tester.go b/pkg/sql/opt/testutils/opttester/opt_tester.go index 88018c075cff..f6ac511a9599 100644 --- a/pkg/sql/opt/testutils/opttester/opt_tester.go +++ b/pkg/sql/opt/testutils/opttester/opt_tester.go @@ -42,6 +42,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/stats" "github.com/cockroachdb/cockroach/pkg/util" @@ -199,6 +200,9 @@ type Flags struct { // NoStableFolds controls whether constant folding for normalization includes // stable operators. NoStableFolds bool + + // QueryArgs are values for placeholders, used for assign-placeholders-*. + QueryArgs []string } // New constructs a new instance of the OptTester for the given SQL statement. @@ -255,6 +259,17 @@ func New(catalog cat.Catalog, sql string) *OptTester { // Builds an expression tree from a SQL query, fully optimizes it using the // memo, and then outputs the lowest cost tree. // +// - assign-placeholders-norm query-args=(...) +// +// Builds a query that has placeholders (with normalization enabled), then +// assigns placeholders to the given query arguments. Normalization rules are +// enabled when assigning placeholders. +// +// - assign-placeholders-opt query-args=(...) +// +// Builds a query that has placeholders (with normalization enabled), then +// assigns placeholders to the given query arguments and fully optimizes it. +// // - build-cascades [flags] // // Builds a query and then recursively builds cascading queries. Outputs all @@ -395,14 +410,17 @@ func (ot *OptTester) RunCommand(tb testing.TB, d *datadriven.TestData) string { d.Fatalf(tb, "%+v", err) } } + ot.Flags.Verbose = datadriven.Verbose() + + ot.semaCtx.Placeholders = tree.PlaceholderInfo{} ot.evalCtx.SessionData.ReorderJoinsLimit = ot.Flags.JoinLimit ot.evalCtx.SessionData.PreferLookupJoinsForFKs = ot.Flags.PreferLookupJoinsForFKs - ot.Flags.Verbose = datadriven.Verbose() ot.evalCtx.TestingKnobs.OptimizerCostPerturbation = ot.Flags.PerturbCost ot.evalCtx.Locality = ot.Flags.Locality ot.evalCtx.SessionData.SaveTablesPrefix = ot.Flags.SaveTablesPrefix + ot.evalCtx.Placeholders = nil switch d.Cmd { case "exec-ddl": @@ -458,6 +476,15 @@ func (ot *OptTester) RunCommand(tb testing.TB, d *datadriven.TestData) string { ot.postProcess(tb, d, e) return ot.FormatExpr(e) + case "assign-placeholders-norm", "assign-placeholders-opt": + explore := d.Cmd == "assign-placeholders-opt" + e, err := ot.AssignPlaceholders(ot.Flags.QueryArgs, explore) + if err != nil { + d.Fatalf(tb, "%+v", err) + } + ot.postProcess(tb, d, e) + return ot.FormatExpr(e) + case "build-cascades": o := ot.makeOptimizer() o.DisableOptimizations() @@ -834,6 +861,9 @@ func (f *Flags) Set(arg datadriven.CmdArg) error { } f.CascadeLevels = int(levels) + case "query-args": + f.QueryArgs = arg.Vals + default: return fmt.Errorf("unknown argument: %s", arg.Key) } @@ -886,6 +916,70 @@ func (ot *OptTester) Optimize() (opt.Expr, error) { return ot.optimizeExpr(o) } +// AssignPlaceholders builds the given query with placeholders, then assigns the +// placeholders to the given argument values, and optionally runs exploration. +// +// The arguments are parsed as SQL expressions. +func (ot *OptTester) AssignPlaceholders(queryArgs []string, explore bool) (opt.Expr, error) { + o := ot.makeOptimizer() + + // Build the prepared memo. Note that placeholders don't have values yet, so + // they won't be replaced. + err := ot.buildExpr(o.Factory()) + if err != nil { + return nil, err + } + prepMemo := o.DetachMemo() + + // Construct placeholder values. + if exp := len(ot.semaCtx.Placeholders.Types); len(queryArgs) != exp { + return nil, errors.Errorf("expected %d arguments, got %d", exp, len(queryArgs)) + } + ot.semaCtx.Placeholders.Values = make(tree.QueryArguments, len(queryArgs)) + for i, arg := range queryArgs { + var parg tree.Expr + parg, err := parser.ParseExpr(fmt.Sprintf("%v", arg)) + if err != nil { + return nil, err + } + + id := tree.PlaceholderIdx(i) + typ, _ := ot.semaCtx.Placeholders.ValueType(id) + texpr, err := schemaexpr.SanitizeVarFreeExpr( + context.Background(), + parg, + typ, + "", /* context */ + &ot.semaCtx, + tree.VolatilityVolatile, + ) + if err != nil { + return nil, err + } + + ot.semaCtx.Placeholders.Values[i] = texpr + } + ot.evalCtx.Placeholders = &ot.semaCtx.Placeholders + + // Now assign placeholders. + o = ot.makeOptimizer() + o.NotifyOnMatchedRule(func(ruleName opt.RuleName) bool { + if !explore && !ruleName.IsNormalize() { + return false + } + if ot.Flags.DisableRules.Contains(int(ruleName)) { + return false + } + return true + }) + + o.Factory().FoldingControl().AllowStableFolds() + if err := o.Factory().AssignPlaceholders(prepMemo); err != nil { + return nil, err + } + return o.Optimize() +} + // Memo returns a string that shows the memo data structure that is constructed // by the optimizer. func (ot *OptTester) Memo() (string, error) { diff --git a/pkg/sql/opt/testutils/opttester/reorder_joins.go b/pkg/sql/opt/testutils/opttester/reorder_joins.go index ca67b6dc8abc..a091510ab309 100644 --- a/pkg/sql/opt/testutils/opttester/reorder_joins.go +++ b/pkg/sql/opt/testutils/opttester/reorder_joins.go @@ -197,7 +197,7 @@ func outputRels(baseRels []memo.RelExpr, names map[opt.ColumnID]string) string { // of names generated so far. func getRelationName(nameCount int) string { const lenAlphabet = 26 - name := string(int('A') + (nameCount % lenAlphabet)) + name := string(rune(int('A') + (nameCount % lenAlphabet))) number := nameCount / lenAlphabet if number > 0 { // Names will follow the pattern: A, B, ..., Z, A1, B1, etc. diff --git a/pkg/sql/opt/xform/testdata/rules/join b/pkg/sql/opt/xform/testdata/rules/join index 3c6ce0f570b8..93399e12180e 100644 --- a/pkg/sql/opt/xform/testdata/rules/join +++ b/pkg/sql/opt/xform/testdata/rules/join @@ -6191,3 +6191,23 @@ inner-join-apply │ │ └── a:11 * stu.s:1 [as="?column?":16, outer=(1,11), immutable] │ └── filters (true) └── filters (true) + +# A multi-column IN query must be able to become a lookup join. +opt +SELECT * FROM stu WHERE (s, t) IN (SELECT m, n FROM small) +---- +project + ├── columns: s:1!null t:2!null u:3!null + ├── key: (1-3) + └── inner-join (lookup stu) + ├── columns: s:1!null t:2!null u:3!null m:5!null n:6!null + ├── key columns: [5 6] = [1 2] + ├── key: (3,5,6) + ├── fd: (1)==(5), (5)==(1), (2)==(6), (6)==(2) + ├── distinct-on + │ ├── columns: m:5 n:6 + │ ├── grouping columns: m:5 n:6 + │ ├── key: (5,6) + │ └── scan small + │ └── columns: m:5 n:6 + └── filters (true)