Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: don't no-op SET SESSION AUTHORIZATION DEFAULT #86485

Merged
merged 3 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,17 @@ func (s *Server) SetupConn(
sdMutIterator := s.makeSessionDataMutatorIterator(sds, args.SessionDefaults)
sdMutIterator.onDefaultIntSizeChange = onDefaultIntSizeChange
if err := sdMutIterator.applyOnEachMutatorError(func(m sessionDataMutator) error {
return resetSessionVars(ctx, m)
for varName, v := range varGen {
if v.Set != nil {
hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase)
if hasDefault {
if err := v.Set(ctx, m, defVal); err != nil {
return err
}
}
}
}
return nil
}); err != nil {
log.Errorf(ctx, "error setting up client session: %s", err)
return ConnectionHandler{}, err
Expand Down
4 changes: 1 addition & 3 deletions pkg/sql/copyshim.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ func RunCopyFrom(
)
// TODO(cucaroach): I believe newInternalPlanner should do this but doing it
// there causes lots of session diffs and test failures and is risky.
if err := p.sessionDataMutatorIterator.applyOnEachMutatorError(func(m sessionDataMutator) error {
return resetSessionVars(ctx, m)
}); err != nil {
if err := p.resetAllSessionVars(ctx); err != nil {
return -1, err
}
defer cleanup()
Expand Down
56 changes: 23 additions & 33 deletions pkg/sql/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,38 @@ import (
// Discard implements the DISCARD statement.
// See https://www.postgresql.org/docs/9.6/static/sql-discard.html for details.
func (p *planner) Discard(ctx context.Context, s *tree.Discard) (planNode, error) {
switch s.Mode {
return &discardNode{mode: s.Mode}, nil
}

type discardNode struct {
mode tree.DiscardMode
}

func (n *discardNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *discardNode) Values() tree.Datums { return nil }
func (n *discardNode) Close(_ context.Context) {}
func (n *discardNode) startExec(params runParams) error {
switch n.mode {
case tree.DiscardModeAll:
if !p.autoCommit {
return nil, pgerror.New(pgcode.ActiveSQLTransaction,
if !params.p.autoCommit {
return pgerror.New(pgcode.ActiveSQLTransaction,
"DISCARD ALL cannot run inside a transaction block")
}

// RESET ALL
if err := p.sessionDataMutatorIterator.applyOnEachMutatorError(
func(m sessionDataMutator) error {
return resetSessionVars(ctx, m)
},
); err != nil {
return nil, err
// SET SESSION AUTHORIZATION DEFAULT
if err := params.p.setRole(params.ctx, false /* local */, params.p.SessionData().SessionUser()); err != nil {
return err
}

// DEALLOCATE ALL
p.preparedStatements.DeleteAll(ctx)
default:
return nil, errors.AssertionFailedf("unknown mode for DISCARD: %d", s.Mode)
}
return newZeroNode(nil /* columns */), nil
}

func resetSessionVars(ctx context.Context, m sessionDataMutator) error {
for _, varName := range varNames {
if err := resetSessionVar(ctx, m, varName); err != nil {
// RESET ALL
if err := params.p.resetAllSessionVars(params.ctx); err != nil {
return err
}
}
return nil
}

func resetSessionVar(ctx context.Context, m sessionDataMutator, varName string) error {
v := varGen[varName]
if v.Set != nil {
hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase)
if hasDefault {
if err := v.Set(ctx, m, defVal); err != nil {
return err
}
}
// DEALLOCATE ALL
params.p.preparedStatements.DeleteAll(params.ctx)
default:
return errors.AssertionFailedf("unknown mode for DISCARD: %d", n.mode)
}
return nil
}
8 changes: 7 additions & 1 deletion pkg/sql/logictest/testdata/logic_test/set_role
Original file line number Diff line number Diff line change
Expand Up @@ -374,5 +374,11 @@ WHERE active_queries LIKE 'SELECT user_name%'
----
root

# Verify that SET SESSION AUTHORIZATION *does* reset the role.
statement ok
RESET ROLE
SET SESSION AUTHORIZATION DEFAULT

query TTTT
SELECT current_user(), current_user, session_user(), session_user
----
root root root root
8 changes: 1 addition & 7 deletions pkg/sql/pgwire/testdata/pgtest/pgjdbc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@ send
Query {"String": "DISCARD ALL"}
----

until crdb_only ignore=ParameterStatus
ReadyForQuery
----
{"Type":"CommandComplete","CommandTag":"DISCARD"}
{"Type":"ReadyForQuery","TxStatus":"I"}

until noncrdb_only ignore=ParameterStatus
until ignore=ParameterStatus ignore=NoticeResponse
ReadyForQuery
----
{"Type":"CommandComplete","CommandTag":"DISCARD ALL"}
Expand Down
8 changes: 7 additions & 1 deletion pkg/sql/sem/tree/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,13 @@ func (*Discard) StatementReturnType() StatementReturnType { return Ack }
func (*Discard) StatementType() StatementType { return TypeTCL }

// StatementTag returns a short string identifying the type of statement.
func (*Discard) StatementTag() string { return "DISCARD" }
func (d *Discard) StatementTag() string {
switch d.Mode {
case DiscardModeAll:
return "DISCARD ALL"
}
return "DISCARD"
}

// StatementReturnType implements the Statement interface.
func (n *DeclareCursor) StatementReturnType() StatementReturnType { return Ack }
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/sessiondatapb/local_only_session_data.proto
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ message LocalOnlySessionData {
// established the connection before SET ROLE was first performed.
// This is only populated when SET ROLE is used, otherwise the session_user
// is the same as the UserProto in SessionData.
// Postgres allows the SessionUser to be changed with SET SESSION AUTHORIZATION
// but CockroachDB doesn't allow that at the time of this writing.
string session_user_proto = 46 [(gogoproto.casttype) = "github.com/cockroachdb/cockroach/pkg/security/username.SQLUsernameProto"];
// TxnRowsWrittenLog is the threshold for the number of rows written by a SQL
// transaction which - once exceeded - will trigger a logging event to SQL_PERF
Expand Down
24 changes: 21 additions & 3 deletions pkg/sql/set_session_authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,26 @@

package sql

import (
"context"

"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
)

func (p *planner) SetSessionAuthorizationDefault() (planNode, error) {
// This is currently a no-op - we don't support changing the session
// authorization, and the parser only accepts DEFAULT.
return newZeroNode(nil /* columns */), nil
return &setSessionAuthorizationDefaultNode{}, nil
}

type setSessionAuthorizationDefaultNode struct{}

func (n *setSessionAuthorizationDefaultNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *setSessionAuthorizationDefaultNode) Values() tree.Datums { return nil }
func (n *setSessionAuthorizationDefaultNode) Close(_ context.Context) {}
func (n *setSessionAuthorizationDefaultNode) startExec(params runParams) error {
// This is currently the same as `SET ROLE = DEFAULT`, which means that it
// only changes the "current user." In Postgres, `SET SESSION AUTHORIZATION`
// also changes the "session user," but since the session user cannot be
// modified in CockroachDB (at the time of writing), we just need to change
// the current user here.
return params.p.setRole(params.ctx, false /* local */, params.p.SessionData().SessionUser())
}
16 changes: 10 additions & 6 deletions pkg/sql/set_var.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (n *setVarNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *setVarNode) Values() tree.Datums { return nil }
func (n *setVarNode) Close(_ context.Context) {}

func (n *resetAllNode) startExec(params runParams) error {
func (p *planner) resetAllSessionVars(ctx context.Context) error {
for varName, v := range varGen {
if v.Set == nil && v.RuntimeSet == nil && v.SetWithPlanner == nil {
continue
Expand All @@ -212,32 +212,36 @@ func (n *resetAllNode) startExec(params runParams) error {
hasDefault, defVal := getSessionVarDefaultString(
varName,
v,
params.p.sessionDataMutatorIterator.sessionDataMutatorBase,
p.sessionDataMutatorIterator.sessionDataMutatorBase,
)
if !hasDefault {
continue
}
if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil {
if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil {
return err
}
}
for varName := range params.SessionData().CustomOptions {
for varName := range p.SessionData().CustomOptions {
_, v, err := getSessionVar(varName, false /* missingOK */)
if err != nil {
return err
}
_, defVal := getSessionVarDefaultString(
varName,
v,
params.p.sessionDataMutatorIterator.sessionDataMutatorBase,
p.sessionDataMutatorIterator.sessionDataMutatorBase,
)
if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil {
if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil {
return err
}
}
return nil
}

func (n *resetAllNode) startExec(params runParams) error {
return params.p.resetAllSessionVars(params.ctx)
}

func (n *resetAllNode) Next(_ runParams) (bool, error) { return false, nil }
func (n *resetAllNode) Values() tree.Datums { return nil }
func (n *resetAllNode) Close(_ context.Context) {}
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ var planNodeNames = map[reflect.Type]string{
reflect.TypeOf(&delayedNode{}): "virtual table",
reflect.TypeOf(&deleteNode{}): "delete",
reflect.TypeOf(&deleteRangeNode{}): "delete range",
reflect.TypeOf(&discardNode{}): "discard",
reflect.TypeOf(&distinctNode{}): "distinct",
reflect.TypeOf(&dropDatabaseNode{}): "drop database",
reflect.TypeOf(&dropExternalConnectionNode{}): "drop external connection",
Expand Down Expand Up @@ -449,6 +450,7 @@ var planNodeNames = map[reflect.Type]string{
reflect.TypeOf(&sequenceSelectNode{}): "sequence select",
reflect.TypeOf(&serializeNode{}): "run",
reflect.TypeOf(&setClusterSettingNode{}): "set cluster setting",
reflect.TypeOf(&setSessionAuthorizationDefaultNode{}): "set session authorization",
reflect.TypeOf(&setVarNode{}): "set",
reflect.TypeOf(&setZoneConfigNode{}): "configure zone",
reflect.TypeOf(&showFingerprintsNode{}): "show fingerprints",
Expand Down