Skip to content

Commit

Permalink
sql: consolidate logic for RESET ALL
Browse files Browse the repository at this point in the history
Release note: None

Release justification: low risk refactor
  • Loading branch information
rafiss committed Aug 21, 2022
1 parent 04bf2f3 commit 3887560
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 37 deletions.
12 changes: 11 additions & 1 deletion pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,17 @@ func (s *Server) SetupConn(
sdMutIterator := s.makeSessionDataMutatorIterator(sds, args.SessionDefaults)
sdMutIterator.onDefaultIntSizeChange = onDefaultIntSizeChange
if err := sdMutIterator.applyOnEachMutatorError(func(m sessionDataMutator) error {
return resetSessionVars(ctx, m)
for varName, v := range varGen {
if v.Set != nil {
hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase)
if hasDefault {
if err := v.Set(ctx, m, defVal); err != nil {
return err
}
}
}
}
return nil
}); err != nil {
log.Errorf(ctx, "error setting up client session: %s", err)
return ConnectionHandler{}, err
Expand Down
4 changes: 1 addition & 3 deletions pkg/sql/copyshim.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ func RunCopyFrom(
)
// TODO(cucaroach): I believe newInternalPlanner should do this but doing it
// there causes lots of session diffs and test failures and is risky.
if err := p.sessionDataMutatorIterator.applyOnEachMutatorError(func(m sessionDataMutator) error {
return resetSessionVars(ctx, m)
}); err != nil {
if err := p.resetAllSessionVars(ctx); err != nil {
return -1, err
}
defer cleanup()
Expand Down
33 changes: 6 additions & 27 deletions pkg/sql/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ func (n *discardNode) startExec(params runParams) error {
"DISCARD ALL cannot run inside a transaction block")
}

// SET SESSION AUTHORIZATION DEFAULT
if err := params.p.setRole(params.ctx, false /* local */, params.p.SessionData().SessionUser()); err != nil {
return err
}

// RESET ALL
if err := params.p.sessionDataMutatorIterator.applyOnEachMutatorError(
func(m sessionDataMutator) error {
return resetSessionVars(params.ctx, m)
},
); err != nil {
if err := params.p.resetAllSessionVars(params.ctx); err != nil {
return err
}

Expand All @@ -56,25 +57,3 @@ func (n *discardNode) startExec(params runParams) error {
}
return nil
}

func resetSessionVars(ctx context.Context, m sessionDataMutator) error {
for _, varName := range varNames {
if err := resetSessionVar(ctx, m, varName); err != nil {
return err
}
}
return nil
}

func resetSessionVar(ctx context.Context, m sessionDataMutator, varName string) error {
v := varGen[varName]
if v.Set != nil {
hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase)
if hasDefault {
if err := v.Set(ctx, m, defVal); err != nil {
return err
}
}
}
return nil
}
16 changes: 10 additions & 6 deletions pkg/sql/set_var.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (n *setVarNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *setVarNode) Values() tree.Datums { return nil }
func (n *setVarNode) Close(_ context.Context) {}

func (n *resetAllNode) startExec(params runParams) error {
func (p *planner) resetAllSessionVars(ctx context.Context) error {
for varName, v := range varGen {
if v.Set == nil && v.RuntimeSet == nil && v.SetWithPlanner == nil {
continue
Expand All @@ -212,32 +212,36 @@ func (n *resetAllNode) startExec(params runParams) error {
hasDefault, defVal := getSessionVarDefaultString(
varName,
v,
params.p.sessionDataMutatorIterator.sessionDataMutatorBase,
p.sessionDataMutatorIterator.sessionDataMutatorBase,
)
if !hasDefault {
continue
}
if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil {
if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil {
return err
}
}
for varName := range params.SessionData().CustomOptions {
for varName := range p.SessionData().CustomOptions {
_, v, err := getSessionVar(varName, false /* missingOK */)
if err != nil {
return err
}
_, defVal := getSessionVarDefaultString(
varName,
v,
params.p.sessionDataMutatorIterator.sessionDataMutatorBase,
p.sessionDataMutatorIterator.sessionDataMutatorBase,
)
if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil {
if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil {
return err
}
}
return nil
}

func (n *resetAllNode) startExec(params runParams) error {
return params.p.resetAllSessionVars(params.ctx)
}

func (n *resetAllNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *resetAllNode) Values() tree.Datums { return nil }
func (n *resetAllNode) Close(_ context.Context) {}
Expand Down

0 comments on commit 3887560

Please sign in to comment.