diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index e58d459bf902..9f5267f07cd5 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -709,7 +709,17 @@ func (s *Server) SetupConn( sdMutIterator := s.makeSessionDataMutatorIterator(sds, args.SessionDefaults) sdMutIterator.onDefaultIntSizeChange = onDefaultIntSizeChange if err := sdMutIterator.applyOnEachMutatorError(func(m sessionDataMutator) error { - return resetSessionVars(ctx, m) + for varName, v := range varGen { + if v.Set != nil { + hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase) + if hasDefault { + if err := v.Set(ctx, m, defVal); err != nil { + return err + } + } + } + } + return nil }); err != nil { log.Errorf(ctx, "error setting up client session: %s", err) return ConnectionHandler{}, err diff --git a/pkg/sql/copyshim.go b/pkg/sql/copyshim.go index 1e090d4a8c0b..7c9d0a1e79eb 100644 --- a/pkg/sql/copyshim.go +++ b/pkg/sql/copyshim.go @@ -110,9 +110,7 @@ func RunCopyFrom( ) // TODO(cucaroach): I believe newInternalPlanner should do this but doing it // there causes lots of session diffs and test failures and is risky. - if err := p.sessionDataMutatorIterator.applyOnEachMutatorError(func(m sessionDataMutator) error { - return resetSessionVars(ctx, m) - }); err != nil { + if err := p.resetAllSessionVars(ctx); err != nil { return -1, err } defer cleanup() diff --git a/pkg/sql/discard.go b/pkg/sql/discard.go index 3ea8e35382f0..b1ddf6504330 100644 --- a/pkg/sql/discard.go +++ b/pkg/sql/discard.go @@ -22,48 +22,38 @@ import ( // Discard implements the DISCARD statement. // See https://www.postgresql.org/docs/9.6/static/sql-discard.html for details. func (p *planner) Discard(ctx context.Context, s *tree.Discard) (planNode, error) { - switch s.Mode { + return &discardNode{mode: s.Mode}, nil +} + +type discardNode struct { + mode tree.DiscardMode +} + +func (n *discardNode) Next(_ runParams) (bool, error) { return false, nil } +func (n *discardNode) Values() tree.Datums { return nil } +func (n *discardNode) Close(_ context.Context) {} +func (n *discardNode) startExec(params runParams) error { + switch n.mode { case tree.DiscardModeAll: - if !p.autoCommit { - return nil, pgerror.New(pgcode.ActiveSQLTransaction, + if !params.p.autoCommit { + return pgerror.New(pgcode.ActiveSQLTransaction, "DISCARD ALL cannot run inside a transaction block") } - // RESET ALL - if err := p.sessionDataMutatorIterator.applyOnEachMutatorError( - func(m sessionDataMutator) error { - return resetSessionVars(ctx, m) - }, - ); err != nil { - return nil, err + // SET SESSION AUTHORIZATION DEFAULT + if err := params.p.setRole(params.ctx, false /* local */, params.p.SessionData().SessionUser()); err != nil { + return err } - // DEALLOCATE ALL - p.preparedStatements.DeleteAll(ctx) - default: - return nil, errors.AssertionFailedf("unknown mode for DISCARD: %d", s.Mode) - } - return newZeroNode(nil /* columns */), nil -} - -func resetSessionVars(ctx context.Context, m sessionDataMutator) error { - for _, varName := range varNames { - if err := resetSessionVar(ctx, m, varName); err != nil { + // RESET ALL + if err := params.p.resetAllSessionVars(params.ctx); err != nil { return err } - } - return nil -} -func resetSessionVar(ctx context.Context, m sessionDataMutator, varName string) error { - v := varGen[varName] - if v.Set != nil { - hasDefault, defVal := getSessionVarDefaultString(varName, v, m.sessionDataMutatorBase) - if hasDefault { - if err := v.Set(ctx, m, defVal); err != nil { - return err - } - } + // DEALLOCATE ALL + params.p.preparedStatements.DeleteAll(params.ctx) + default: + return errors.AssertionFailedf("unknown mode for DISCARD: %d", n.mode) } return nil } diff --git a/pkg/sql/logictest/testdata/logic_test/set_role b/pkg/sql/logictest/testdata/logic_test/set_role index 9efa32fb2848..28c60c2d859e 100644 --- a/pkg/sql/logictest/testdata/logic_test/set_role +++ b/pkg/sql/logictest/testdata/logic_test/set_role @@ -374,5 +374,11 @@ WHERE active_queries LIKE 'SELECT user_name%' ---- root +# Verify that SET SESSION AUTHORIZATION *does* reset the role. statement ok -RESET ROLE +SET SESSION AUTHORIZATION DEFAULT + +query TTTT +SELECT current_user(), current_user, session_user(), session_user +---- +root root root root diff --git a/pkg/sql/pgwire/testdata/pgtest/pgjdbc b/pkg/sql/pgwire/testdata/pgtest/pgjdbc index d7d2c4329197..2494eb3530b3 100644 --- a/pkg/sql/pgwire/testdata/pgtest/pgjdbc +++ b/pkg/sql/pgwire/testdata/pgtest/pgjdbc @@ -42,13 +42,7 @@ send Query {"String": "DISCARD ALL"} ---- -until crdb_only ignore=ParameterStatus -ReadyForQuery ----- -{"Type":"CommandComplete","CommandTag":"DISCARD"} -{"Type":"ReadyForQuery","TxStatus":"I"} - -until noncrdb_only ignore=ParameterStatus +until ignore=ParameterStatus ignore=NoticeResponse ReadyForQuery ---- {"Type":"CommandComplete","CommandTag":"DISCARD ALL"} diff --git a/pkg/sql/sem/tree/stmt.go b/pkg/sql/sem/tree/stmt.go index bb0fa47c2d76..d09f7e72f7fb 100644 --- a/pkg/sql/sem/tree/stmt.go +++ b/pkg/sql/sem/tree/stmt.go @@ -853,7 +853,13 @@ func (*Discard) StatementReturnType() StatementReturnType { return Ack } func (*Discard) StatementType() StatementType { return TypeTCL } // StatementTag returns a short string identifying the type of statement. -func (*Discard) StatementTag() string { return "DISCARD" } +func (d *Discard) StatementTag() string { + switch d.Mode { + case DiscardModeAll: + return "DISCARD ALL" + } + return "DISCARD" +} // StatementReturnType implements the Statement interface. func (n *DeclareCursor) StatementReturnType() StatementReturnType { return Ack } diff --git a/pkg/sql/sessiondatapb/local_only_session_data.proto b/pkg/sql/sessiondatapb/local_only_session_data.proto index 89c8234ed670..f4c9a3677a35 100644 --- a/pkg/sql/sessiondatapb/local_only_session_data.proto +++ b/pkg/sql/sessiondatapb/local_only_session_data.proto @@ -169,6 +169,8 @@ message LocalOnlySessionData { // established the connection before SET ROLE was first performed. // This is only populated when SET ROLE is used, otherwise the session_user // is the same as the UserProto in SessionData. + // Postgres allows the SessionUser to be changed with SET SESSION AUTHORIZATION + // but CockroachDB doesn't allow that at the time of this writing. string session_user_proto = 46 [(gogoproto.casttype) = "github.com/cockroachdb/cockroach/pkg/security/username.SQLUsernameProto"]; // TxnRowsWrittenLog is the threshold for the number of rows written by a SQL // transaction which - once exceeded - will trigger a logging event to SQL_PERF diff --git a/pkg/sql/set_session_authorization.go b/pkg/sql/set_session_authorization.go index ed61801eabae..b4d704e97d46 100644 --- a/pkg/sql/set_session_authorization.go +++ b/pkg/sql/set_session_authorization.go @@ -10,8 +10,26 @@ package sql +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" +) + func (p *planner) SetSessionAuthorizationDefault() (planNode, error) { - // This is currently a no-op - we don't support changing the session - // authorization, and the parser only accepts DEFAULT. - return newZeroNode(nil /* columns */), nil + return &setSessionAuthorizationDefaultNode{}, nil +} + +type setSessionAuthorizationDefaultNode struct{} + +func (n *setSessionAuthorizationDefaultNode) Next(_ runParams) (bool, error) { return false, nil } +func (n *setSessionAuthorizationDefaultNode) Values() tree.Datums { return nil } +func (n *setSessionAuthorizationDefaultNode) Close(_ context.Context) {} +func (n *setSessionAuthorizationDefaultNode) startExec(params runParams) error { + // This is currently the same as `SET ROLE = DEFAULT`, which means that it + // only changes the "current user." In Postgres, `SET SESSION AUTHORIZATION` + // also changes the "session user," but since the session user cannot be + // modified in CockroachDB (at the time of writing), we just need to change + // the current user here. + return params.p.setRole(params.ctx, false /* local */, params.p.SessionData().SessionUser()) } diff --git a/pkg/sql/set_var.go b/pkg/sql/set_var.go index 7053eef4ae8b..8a1de41ca01b 100644 --- a/pkg/sql/set_var.go +++ b/pkg/sql/set_var.go @@ -200,7 +200,7 @@ func (n *setVarNode) Next(_ runParams) (bool, error) { return false, nil } func (n *setVarNode) Values() tree.Datums { return nil } func (n *setVarNode) Close(_ context.Context) {} -func (n *resetAllNode) startExec(params runParams) error { +func (p *planner) resetAllSessionVars(ctx context.Context) error { for varName, v := range varGen { if v.Set == nil && v.RuntimeSet == nil && v.SetWithPlanner == nil { continue @@ -212,16 +212,16 @@ func (n *resetAllNode) startExec(params runParams) error { hasDefault, defVal := getSessionVarDefaultString( varName, v, - params.p.sessionDataMutatorIterator.sessionDataMutatorBase, + p.sessionDataMutatorIterator.sessionDataMutatorBase, ) if !hasDefault { continue } - if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil { + if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil { return err } } - for varName := range params.SessionData().CustomOptions { + for varName := range p.SessionData().CustomOptions { _, v, err := getSessionVar(varName, false /* missingOK */) if err != nil { return err @@ -229,15 +229,19 @@ func (n *resetAllNode) startExec(params runParams) error { _, defVal := getSessionVarDefaultString( varName, v, - params.p.sessionDataMutatorIterator.sessionDataMutatorBase, + p.sessionDataMutatorIterator.sessionDataMutatorBase, ) - if err := params.p.SetSessionVar(params.ctx, varName, defVal, false /* isLocal */); err != nil { + if err := p.SetSessionVar(ctx, varName, defVal, false /* isLocal */); err != nil { return err } } return nil } +func (n *resetAllNode) startExec(params runParams) error { + return params.p.resetAllSessionVars(params.ctx) +} + func (n *resetAllNode) Next(_ runParams) (bool, error) { return false, nil } func (n *resetAllNode) Values() tree.Datums { return nil } func (n *resetAllNode) Close(_ context.Context) {} diff --git a/pkg/sql/walk.go b/pkg/sql/walk.go index 4135902f0a58..2a57fe543325 100644 --- a/pkg/sql/walk.go +++ b/pkg/sql/walk.go @@ -393,6 +393,7 @@ var planNodeNames = map[reflect.Type]string{ reflect.TypeOf(&delayedNode{}): "virtual table", reflect.TypeOf(&deleteNode{}): "delete", reflect.TypeOf(&deleteRangeNode{}): "delete range", + reflect.TypeOf(&discardNode{}): "discard", reflect.TypeOf(&distinctNode{}): "distinct", reflect.TypeOf(&dropDatabaseNode{}): "drop database", reflect.TypeOf(&dropExternalConnectionNode{}): "drop external connection", @@ -449,6 +450,7 @@ var planNodeNames = map[reflect.Type]string{ reflect.TypeOf(&sequenceSelectNode{}): "sequence select", reflect.TypeOf(&serializeNode{}): "run", reflect.TypeOf(&setClusterSettingNode{}): "set cluster setting", + reflect.TypeOf(&setSessionAuthorizationDefaultNode{}): "set session authorization", reflect.TypeOf(&setVarNode{}): "set", reflect.TypeOf(&setZoneConfigNode{}): "configure zone", reflect.TypeOf(&showFingerprintsNode{}): "show fingerprints",