diff --git a/pkg/sql/opt/idxconstraint/index_constraints_test.go b/pkg/sql/opt/idxconstraint/index_constraints_test.go index cf7042e4118e..868a53e17e00 100644 --- a/pkg/sql/opt/idxconstraint/index_constraints_test.go +++ b/pkg/sql/opt/idxconstraint/index_constraints_test.go @@ -120,10 +120,10 @@ func TestIndexConstraints(t *testing.T) { computedCols = make(map[opt.ColumnID]opt.ScalarExpr) for col, expr := range sv.ComputedCols() { b := optbuilder.NewScalar(context.Background(), &semaCtx, &evalCtx, &f) - if err := b.Build(expr); err != nil { + computedColExpr, err := b.Build(expr) + if err != nil { d.Fatalf(t, "error building computed column expression: %v", err) } - computedColExpr := f.Memo().RootExpr().(opt.ScalarExpr) computedCols[col] = computedColExpr var sharedProps props.Shared memo.BuildSharedProps(computedColExpr, &sharedProps, &evalCtx) @@ -314,10 +314,10 @@ func buildFilters( return memo.FiltersExpr{}, err } b := optbuilder.NewScalar(context.Background(), semaCtx, evalCtx, f) - if err := b.Build(expr); err != nil { + root, err := b.Build(expr) + if err != nil { return memo.FiltersExpr{}, err } - root := f.Memo().RootExpr().(opt.ScalarExpr) if _, ok := root.(*memo.TrueExpr); ok { return memo.TrueFilter, nil } diff --git a/pkg/sql/opt/lookupjoin/constraint_builder_test.go b/pkg/sql/opt/lookupjoin/constraint_builder_test.go index 527a38c92c52..365e9b0e8a1b 100644 --- a/pkg/sql/opt/lookupjoin/constraint_builder_test.go +++ b/pkg/sql/opt/lookupjoin/constraint_builder_test.go @@ -103,10 +103,10 @@ func TestLookupConstraints(t *testing.T) { return 0, opt.ColSet{}, err } b := optbuilder.NewScalar(context.Background(), &semaCtx, &evalCtx, &f) - if err := b.Build(expr); err != nil { + compExpr, err := b.Build(expr) + if err != nil { return 0, opt.ColSet{}, err } - compExpr := f.Memo().RootExpr().(opt.ScalarExpr) var sharedProps props.Shared memo.BuildSharedProps(compExpr, &sharedProps, &evalCtx) md.TableMeta(tableID).AddComputedCol(colID, compExpr, sharedProps.OuterCols) @@ -312,12 +312,11 @@ func makeFiltersExpr( } b := optbuilder.NewScalar(context.Background(), semaCtx, evalCtx, f) - if err := b.Build(expr); err != nil { + root, err := b.Build(expr) + if err != nil { return nil, err } - root := f.Memo().RootExpr().(opt.ScalarExpr) - return memo.FiltersExpr{f.ConstructFiltersItem(root)}, nil } diff --git a/pkg/sql/opt/memo/expr_test.go b/pkg/sql/opt/memo/expr_test.go index 8ee969272e18..2f91e2f130f5 100644 --- a/pkg/sql/opt/memo/expr_test.go +++ b/pkg/sql/opt/memo/expr_test.go @@ -17,7 +17,6 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/settings/cluster" - "github.com/cockroachdb/cockroach/pkg/sql/opt" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/opt/optbuilder" "github.com/cockroachdb/cockroach/pkg/sql/opt/testutils" @@ -93,11 +92,11 @@ func TestExprIsNeverNull(t *testing.T) { } b := optbuilder.NewScalar(ctx, &semaCtx, &evalCtx, o.Factory()) - err = b.Build(expr) + scalar, err := b.Build(expr) if err != nil { return fmt.Sprintf("error: %s\n", strings.TrimSpace(err.Error())) } - result := memo.ExprIsNeverNull(o.Memo().RootExpr().(opt.ScalarExpr), sv.NotNullCols()) + result := memo.ExprIsNeverNull(scalar, sv.NotNullCols()) return fmt.Sprintf("%t\n", result) default: diff --git a/pkg/sql/opt/memo/memo.go b/pkg/sql/opt/memo/memo.go index c571c6d18cd1..566a4e804e25 100644 --- a/pkg/sql/opt/memo/memo.go +++ b/pkg/sql/opt/memo/memo.go @@ -315,12 +315,6 @@ func (m *Memo) SetRoot(e RelExpr, phys *physical.Required) { } } -// SetScalarRoot stores the root memo expression when it is a scalar expression. -// Used only for testing. -func (m *Memo) SetScalarRoot(scalar opt.ScalarExpr) { - m.rootExpr = scalar -} - // HasPlaceholders returns true if the memo contains at least one placeholder // operator. func (m *Memo) HasPlaceholders() bool { diff --git a/pkg/sql/opt/memo/memo_test.go b/pkg/sql/opt/memo/memo_test.go index 45300216d53a..82fe097d87ac 100644 --- a/pkg/sql/opt/memo/memo_test.go +++ b/pkg/sql/opt/memo/memo_test.go @@ -97,10 +97,11 @@ func TestCompositeSensitive(t *testing.T) { } b := optbuilder.NewScalar(context.Background(), &semaCtx, &evalCtx, &f) - if err := b.Build(expr); err != nil { + scalar, err := b.Build(expr) + if err != nil { d.Fatalf(t, "error building: %v", err) } - return fmt.Sprintf("%v", memo.CanBeCompositeSensitive(md, f.Memo().RootExpr())) + return fmt.Sprintf("%v", memo.CanBeCompositeSensitive(md, scalar)) }) } diff --git a/pkg/sql/opt/optbuilder/builder_test.go b/pkg/sql/opt/optbuilder/builder_test.go index ba1306f648a0..13f2ef9a7326 100644 --- a/pkg/sql/opt/optbuilder/builder_test.go +++ b/pkg/sql/opt/optbuilder/builder_test.go @@ -112,14 +112,14 @@ func TestBuilder(t *testing.T) { // of the build process. o.DisableOptimizations() b := optbuilder.NewScalar(ctx, &semaCtx, &evalCtx, o.Factory()) - err = b.Build(expr) + scalar, err := b.Build(expr) if err != nil { return fmt.Sprintf("error: %s\n", strings.TrimSpace(err.Error())) } f := memo.MakeExprFmtCtx( ctx, tester.Flags.ExprFormat, false /* redactableValues */, o.Memo(), catalog, ) - f.FormatExpr(o.Memo().RootExpr()) + f.FormatExpr(scalar) return f.Buffer.String() default: diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index 40ed863989d7..8be3852f9d30 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -1177,7 +1177,7 @@ func NewScalar( // Build a memo structure from a TypedExpr: the root group represents a scalar // expression equivalent to expr. -func (sb *ScalarBuilder) Build(expr tree.Expr) (err error) { +func (sb *ScalarBuilder) Build(expr tree.Expr) (_ opt.ScalarExpr, err error) { defer func() { if r := recover(); r != nil { // This code allows us to propagate errors without adding lots of checks @@ -1194,8 +1194,7 @@ func (sb *ScalarBuilder) Build(expr tree.Expr) (err error) { typedExpr := sb.scope.resolveType(expr, types.Any) scalar := sb.buildScalar(typedExpr, &sb.scope, nil, nil, nil) - sb.factory.Memo().SetScalarRoot(scalar) - return nil + return scalar, nil } // reType is similar to tree.ReType, except that it panics with an internal diff --git a/pkg/sql/opt/partialidx/implicator_test.go b/pkg/sql/opt/partialidx/implicator_test.go index be9d9eeb01d4..9545f8dacb3a 100644 --- a/pkg/sql/opt/partialidx/implicator_test.go +++ b/pkg/sql/opt/partialidx/implicator_test.go @@ -370,11 +370,10 @@ func makeFiltersExpr( } b := optbuilder.NewScalar(context.Background(), semaCtx, evalCtx, f) - if err := b.Build(expr); err != nil { + root, err := b.Build(expr) + if err != nil { return nil, err } - root := f.Memo().RootExpr().(opt.ScalarExpr) - return memo.FiltersExpr{f.ConstructFiltersItem(root)}, nil } diff --git a/pkg/sql/opt/testutils/build.go b/pkg/sql/opt/testutils/build.go index c9ac7f15c100..a844c205b77f 100644 --- a/pkg/sql/opt/testutils/build.go +++ b/pkg/sql/opt/testutils/build.go @@ -57,11 +57,12 @@ func BuildScalar( } b := optbuilder.NewScalar(context.Background(), semaCtx, evalCtx, f) - if err := b.Build(expr); err != nil { + root, err := b.Build(expr) + if err != nil { t.Fatal(err) } - return f.Memo().RootExpr().(opt.ScalarExpr) + return root } // BuildFilters builds the given input string as a FiltersExpr and returns it. diff --git a/pkg/sql/sem/eval/eval_test.go b/pkg/sql/sem/eval/eval_test.go index 18d06c16bfa4..7edabdaa8da9 100644 --- a/pkg/sql/sem/eval/eval_test.go +++ b/pkg/sql/sem/eval/eval_test.go @@ -100,12 +100,13 @@ func optBuildScalar(evalCtx *eval.Context, e tree.Expr) (tree.TypedExpr, error) o.Init(ctx, evalCtx, nil /* catalog */) semaCtx := tree.MakeSemaContext() b := optbuilder.NewScalar(ctx, &semaCtx, evalCtx, o.Factory()) - if err := b.Build(e); err != nil { + scalar, err := b.Build(e) + if err != nil { return nil, err } bld := execbuilder.New( - ctx, nil /* factory */, &o, o.Memo(), nil /* catalog */, o.Memo().RootExpr(), + ctx, nil /* factory */, &o, o.Memo(), nil /* catalog */, scalar, evalCtx, false, /* allowAutoCommit */ false, /* isANSIDML */ )