Skip to content

Commit

Permalink
opttester: support session settings in opt tests
Browse files Browse the repository at this point in the history
This commit adds the `set` opttest flag which can be used to set
session flags via "set=flagname=value".

Release note: none
  • Loading branch information
Mark Sirek committed Aug 10, 2022
1 parent 5290900 commit 854f484
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
17 changes: 17 additions & 0 deletions pkg/sql/opt/testutils/opttester/opt_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr"
"github.com/cockroachdb/cockroach/pkg/sql/opt"
Expand Down Expand Up @@ -265,6 +266,9 @@ type Flags struct {
// SkipRace indicates that a test should be skipped if the race detector is
// enabled.
SkipRace bool

// ot is a reference to the OptTester owning Flags.
ot *OptTester
}

// New constructs a new instance of the OptTester for the given SQL statement.
Expand All @@ -279,6 +283,7 @@ func New(catalog cat.Catalog, sql string) *OptTester {
semaCtx: tree.MakeSemaContext(),
evalCtx: eval.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()),
}
ot.Flags.ot = ot
ot.semaCtx.SearchPath = tree.EmptySearchPath
ot.semaCtx.FunctionResolver = ot.catalog
// To allow opttester tests to use now(), we hardcode a preset transaction
Expand Down Expand Up @@ -897,6 +902,18 @@ func ruleNamesToRuleSet(args []string) (RuleSet, error) {
// See OptTester.RunCommand for supported flags.
func (f *Flags) Set(arg datadriven.CmdArg) error {
switch arg.Key {
case "set":
for _, val := range arg.Vals {
s := strings.Split(val, "=")
if len(s) != 2 {
return errors.Errorf("Expected both session variable name and value for set command")
}
err := sql.SetSessionVariable(f.ot.ctx, f.ot.evalCtx, s[0], s[1])
if err != nil {
return err
}
}

case "format":
if len(arg.Vals) == 0 {
return fmt.Errorf("format flag requires value(s)")
Expand Down
26 changes: 25 additions & 1 deletion pkg/sql/opt/xform/testdata/rules/join
Original file line number Diff line number Diff line change
Expand Up @@ -10230,8 +10230,32 @@ inner-join (merge)
│ └── columns: m:1 n:2
└── filters (true)

# The rule should be allowed to fire when the projection is from a join if the
# session flag disable_hoist_projection_in_join_limitation is true.
opt expect=HoistProjectFromInnerJoin set=disable_hoist_projection_in_join_limitation=true
SELECT * FROM (SELECT a, a+b FROM (SELECT tab1.* from abcd tab1, abcd tab2)) JOIN small ON a=m;
----
project
├── columns: a:1!null "?column?":13 m:14!null n:15
├── immutable
├── fd: (1)==(14), (14)==(1)
├── inner-join (cross)
│ ├── columns: tab1.a:1!null tab1.b:2 m:14!null n:15
│ ├── fd: (1)==(14), (14)==(1)
│ ├── scan abcd@abcd_a_b_idx [as=tab2]
│ ├── inner-join (lookup abcd@abcd_a_b_idx [as=tab1])
│ │ ├── columns: tab1.a:1!null tab1.b:2 m:14!null n:15
│ │ ├── key columns: [14] = [1]
│ │ ├── fd: (1)==(14), (14)==(1)
│ │ ├── scan small
│ │ │ └── columns: m:14 n:15
│ │ └── filters (true)
│ └── filters (true)
└── projections
└── tab1.a:1 + tab1.b:2 [as="?column?":13, outer=(1,2), immutable]

# The rule should not fire when the projection is from a join.
opt expect-not=HoistProjectFromInnerJoin
opt expect-not=HoistProjectFromInnerJoin set=disable_hoist_projection_in_join_limitation=false
SELECT * FROM (SELECT a, a+b FROM (SELECT tab1.* from abcd tab1, abcd tab2)) JOIN small ON a=m
----
inner-join (hash)
Expand Down
29 changes: 28 additions & 1 deletion pkg/sql/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -2119,6 +2119,7 @@ var varGen = map[string]sessionVar{
return formatFloatAsPostgresSetting(0)
},
},

// CockroachDB extension.
`disable_hoist_projection_in_join_limitation`: {
GetStringVal: makePostgresBoolGetStringValFn(`disable_hoist_projection_in_join_limitation`),
Expand All @@ -2130,7 +2131,7 @@ var varGen = map[string]sessionVar{
m.SetDisableHoistProjectionInJoinLimitation(b)
return nil
},
Get: func(evalCtx *extendedEvalContext) (string, error) {
Get: func(evalCtx *extendedEvalContext, _ *kv.Txn) (string, error) {
return formatBoolAsPostgresSetting(evalCtx.SessionData().DisableHoistProjectionInJoinLimitation), nil
},
GlobalDefault: globalFalse,
Expand Down Expand Up @@ -2184,6 +2185,32 @@ func init() {
}()
}

// SetSessionVariable sets a new value for session setting `varName` is the
// session settings owned by `evalCtx`, returning an error if not successful.
func SetSessionVariable(
ctx context.Context, evalCtx eval.Context, varName, varValue string,
) (err error) {
err = CheckSessionVariableValueValid(ctx, evalCtx.Settings, varName, varValue)
if err != nil {
return err
}
sessionDataMutatorBase := sessionDataMutatorBase{
defaults: make(map[string]string),
settings: evalCtx.Settings,
}
sessionDataMutator := sessionDataMutator{
data: evalCtx.SessionData(),
sessionDataMutatorBase: sessionDataMutatorBase,
sessionDataMutatorCallbacks: sessionDataMutatorCallbacks{},
}
_, sVar, err := getSessionVar(varName, false)
if err != nil {
return err
}

return sVar.Set(ctx, sessionDataMutator, varValue)
}

// makePostgresBoolGetStringValFn returns a function that evaluates and returns
// a string representation of the first argument value.
func makePostgresBoolGetStringValFn(varName string) getStringValFn {
Expand Down

0 comments on commit 854f484

Please sign in to comment.