diff --git a/pkg/ccl/auditloggingccl/audit_logging_test.go b/pkg/ccl/auditloggingccl/audit_logging_test.go index aff0d84ed329..575157eb5d98 100644 --- a/pkg/ccl/auditloggingccl/audit_logging_test.go +++ b/pkg/ccl/auditloggingccl/audit_logging_test.go @@ -156,6 +156,22 @@ func TestSingleRoleAuditLogging(t *testing.T) { `GRANT SELECT ON TABLE u TO root`, // DML statement `SELECT * FROM u`, + // The following statements are all executed specially by the conn_executor. + `SET application_name = 'test'`, + `SET CLUSTER SETTING sql.defaults.vectorize = 'on'`, + `BEGIN`, + `SHOW application_name`, + `SAVEPOINT s`, + `RELEASE SAVEPOINT s`, + `SAVEPOINT t`, + `ROLLBACK TO SAVEPOINT t`, + `COMMIT`, + `SHOW COMMIT TIMESTAMP`, + `BEGIN TRANSACTION PRIORITY LOW`, + `ROLLBACK`, + `PREPARE q AS SELECT 1`, + `EXECUTE q`, + `DEALLOCATE q`, } testData := []struct { name string @@ -167,7 +183,7 @@ func TestSingleRoleAuditLogging(t *testing.T) { name: "test-all-stmt-types", role: allStmtTypesRole, queries: testQueries, - expectedNumLogs: 3, + expectedNumLogs: len(testQueries), }, { name: "test-no-stmt-types", @@ -181,7 +197,7 @@ func TestSingleRoleAuditLogging(t *testing.T) { role: "testuser", queries: testQueries, // One for each test query. - expectedNumLogs: 3, + expectedNumLogs: len(testQueries), }, } diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index 97afdc251c4c..8b8b2eeaf11d 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -133,7 +133,7 @@ func (ex *connExecutor) execStmt( // Note: when not using explicit transactions, we go through this transition // for every statement. It is important to minimize the amount of work and // allocations performed up to this point. - ev, payload = ex.execStmtInNoTxnState(ctx, ast, res) + ev, payload = ex.execStmtInNoTxnState(ctx, parserStmt, res) case stateOpen: var preparedStmt *PreparedStatement @@ -753,6 +753,72 @@ func (ex *connExecutor) execStmtInOpenState( } }(ctx) + // If adminAuditLogging is enabled, we want to check for HasAdminRole + // before maybeLogStatement. + // We must check prior to execution in the case the txn is aborted due to + // an error. HasAdminRole can only be checked in a valid txn. + if adminAuditLog := adminAuditLogEnabled.Get( + &ex.planner.execCfg.Settings.SV, + ); adminAuditLog { + if !ex.extraTxnState.hasAdminRoleCache.IsSet { + hasAdminRole, err := ex.planner.HasAdminRole(ctx) + if err != nil { + return makeErrEvent(err) + } + ex.extraTxnState.hasAdminRoleCache.HasAdminRole = hasAdminRole + ex.extraTxnState.hasAdminRoleCache.IsSet = true + } + } + + p.stmt = stmt + p.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations) + p.extendedEvalCtx.Annotations = &p.semaCtx.Annotations + if err := p.semaCtx.Placeholders.Assign(pinfo, stmt.NumPlaceholders); err != nil { + return makeErrEvent(err) + } + p.extendedEvalCtx.Placeholders = &p.semaCtx.Placeholders + + shouldLogToExecAndAudit := true + defer func() { + if !shouldLogToExecAndAudit { + // We don't want to log this statement, since another layer of the + // conn_executor will handle the logging for this statement. + return + } + + p.curPlan.init(&p.stmt, &p.instrumentation) + var execErr error + if p, ok := retPayload.(payloadWithError); ok { + execErr = p.errorCause() + } + f := tree.NewFmtCtx(tree.FmtHideConstants) + f.FormatNode(ast) + stmtFingerprintID := appstatspb.ConstructStatementFingerprintID( + f.CloseAndGetString(), + execErr != nil, + ex.implicitTxn(), + p.CurrentDatabase(), + ) + + p.maybeLogStatement( + ctx, + ex.executorType, + false, /* isCopy */ + int(ex.state.mu.autoRetryCounter), + ex.extraTxnState.txnCounter, + 0, /* rowsAffected */ + ex.state.mu.stmtCount, + 0, /* bulkJobId */ + execErr, + ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), + &ex.extraTxnState.hasAdminRoleCache, + ex.server.TelemetryLoggingMetrics, + stmtFingerprintID, + &topLevelQueryStats{}, + ex.statsCollector, + ) + }() + switch s := ast.(type) { case *tree.BeginTransaction: // BEGIN is only allowed if we are in an implicit txn. @@ -833,15 +899,27 @@ func (ex *connExecutor) execStmtInOpenState( ex.server.cfg.GenerateID(), ) var rawTypeHints []oid.Oid + + // Placeholders should be part of the statement being prepared, not the + // PREPARE statement itself. + oldPlaceholders := p.extendedEvalCtx.Placeholders + p.extendedEvalCtx.Placeholders = nil if _, err := ex.addPreparedStmt( ctx, name, prepStmt, typeHints, rawTypeHints, PreparedStatementOriginSQL, ); err != nil { return makeErrEvent(err) } + // The call to addPreparedStmt changed the planner stmt to the statement + // being prepared. Set it back to the PREPARE statement, so that it's + // logged correctly. + p.stmt = stmt + p.extendedEvalCtx.Placeholders = oldPlaceholders return nil, nil, nil } - p.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations) + // Don't write to the exec/audit logs here; it will be handled in + // dispatchToExecutionEngine. + shouldLogToExecAndAudit = false // For regular statements (the ones that get to this point), we // don't return any event unless an error happens. @@ -894,12 +972,6 @@ func (ex *connExecutor) execStmtInOpenState( return makeErrEvent(err) } - if err := p.semaCtx.Placeholders.Assign(pinfo, stmt.NumPlaceholders); err != nil { - return makeErrEvent(err) - } - p.extendedEvalCtx.Placeholders = &p.semaCtx.Placeholders - p.extendedEvalCtx.Annotations = &p.semaCtx.Annotations - p.stmt = stmt if isPausablePortal() { p.pausablePortal = portal } @@ -1453,23 +1525,6 @@ func (ex *connExecutor) dispatchToExecutionEngine( } } - // If adminAuditLogging is enabled, we want to check for HasAdminRole - // before the deferred maybeLogStatement. - // We must check prior to execution in the case the txn is aborted due to - // an error. HasAdminRole can only be checked in a valid txn. - if adminAuditLog := adminAuditLogEnabled.Get( - &ex.planner.execCfg.Settings.SV, - ); adminAuditLog { - if !ex.extraTxnState.hasAdminRoleCache.IsSet { - hasAdminRole, err := ex.planner.HasAdminRole(ctx) - if err != nil { - return err - } - ex.extraTxnState.hasAdminRoleCache.HasAdminRole = hasAdminRole - ex.extraTxnState.hasAdminRoleCache.IsSet = true - } - } - var err error if ppInfo := getPausablePortalInfo(); ppInfo != nil { if !ppInfo.dispatchToExecutionEngine.cleanup.isComplete { @@ -2190,8 +2245,57 @@ var eventStartExplicitTxn fsm.Event = eventTxnStart{ImplicitTxn: fsm.False} // the cursor is not advanced. This means that the statement will run again in // stateOpen, at each point its results will also be flushed. func (ex *connExecutor) execStmtInNoTxnState( - ctx context.Context, ast tree.Statement, res RestrictedCommandResult, + ctx context.Context, parserStmt statements.Statement[tree.Statement], res RestrictedCommandResult, ) (_ fsm.Event, payload fsm.EventPayload) { + shouldLogToExecAndAudit := true + defer func() { + if !shouldLogToExecAndAudit { + // We don't want to log this statement, since another layer of the + // conn_executor will handle the logging for this statement. + return + } + + p := &ex.planner + stmt := makeStatement(parserStmt, ex.server.cfg.GenerateID()) + p.stmt = stmt + p.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations) + p.extendedEvalCtx.Annotations = &p.semaCtx.Annotations + p.extendedEvalCtx.Placeholders = &tree.PlaceholderInfo{} + p.curPlan.init(&p.stmt, &p.instrumentation) + var execErr error + if p, ok := payload.(payloadWithError); ok { + execErr = p.errorCause() + } + + f := tree.NewFmtCtx(tree.FmtHideConstants) + f.FormatNode(stmt.AST) + stmtFingerprintID := appstatspb.ConstructStatementFingerprintID( + f.CloseAndGetString(), + execErr != nil, + ex.implicitTxn(), + p.CurrentDatabase(), + ) + + p.maybeLogStatement( + ctx, + ex.executorType, + false, /* isCopy */ + int(ex.state.mu.autoRetryCounter), + ex.extraTxnState.txnCounter, + 0, /* rowsAffected */ + 0, /* stmtCount */ + 0, /* bulkJobId */ + execErr, + ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), + &ex.extraTxnState.hasAdminRoleCache, + ex.server.TelemetryLoggingMetrics, + stmtFingerprintID, + &topLevelQueryStats{}, + ex.statsCollector, + ) + }() + + ast := parserStmt.AST switch s := ast.(type) { case *tree.BeginTransaction: ex.incrementStartedStmtCounter(ast) @@ -2225,6 +2329,7 @@ func (ex *connExecutor) execStmtInNoTxnState( // historical timestamp even though the statement itself might contain // an AOST clause. In these cases the clause is evaluated and applied // execStmtInOpenState. + shouldLogToExecAndAudit = false noBeginStmt := (*tree.BeginTransaction)(nil) mode, sqlTs, historicalTs, err := ex.beginTransactionTimestampsAndReadMode(ctx, noBeginStmt) if err != nil { diff --git a/pkg/sql/telemetry_logging_test.go b/pkg/sql/telemetry_logging_test.go index a51e4ea76a7a..981249f6427a 100644 --- a/pkg/sql/telemetry_logging_test.go +++ b/pkg/sql/telemetry_logging_test.go @@ -1505,15 +1505,17 @@ func TestTelemetryLoggingStmtPosInTxn(t *testing.T) { st.SetTime(timeutil.FromUnixMicros(int64(1e6))) db.Exec(t, `BEGIN;`) - db.Exec(t, `SELECT 1`) st.SetTime(timeutil.FromUnixMicros(int64(2 * 1e6))) - db.Exec(t, `SELECT 2`) + db.Exec(t, `SELECT 1`) st.SetTime(timeutil.FromUnixMicros(int64(3 * 1e6))) + db.Exec(t, `SELECT 2`) + st.SetTime(timeutil.FromUnixMicros(int64(4 * 1e6))) db.Exec(t, `SELECT 3`) + st.SetTime(timeutil.FromUnixMicros(int64(5 * 1e6))) db.Exec(t, `COMMIT;`) expectedQueries := []string{ - `SELECT ‹1›`, `SELECT ‹2›`, `SELECT ‹3›`, + `BEGIN`, `SELECT ‹1›`, `SELECT ‹2›`, `SELECT ‹3›`, `COMMIT`, } log.Flush() @@ -1539,7 +1541,7 @@ func TestTelemetryLoggingStmtPosInTxn(t *testing.T) { if strings.Contains(e.Message, expected) { var sq eventpb.SampledQuery require.NoError(t, json.Unmarshal([]byte(e.Message), &sq)) - require.Equal(t, uint32(i+1), sq.StmtPosInTxn, "%s", entries) + require.Equalf(t, uint32(i), sq.StmtPosInTxn, "stmt=%s entries: %s", expected, entries) found = true break }