diff --git a/pkg/server/api_v2_sql.go b/pkg/server/api_v2_sql.go index 4e4e2783128d..efd9047dd65b 100644 --- a/pkg/server/api_v2_sql.go +++ b/pkg/server/api_v2_sql.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -360,13 +361,27 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { // runner is the function that will execute all the statements as a group. // If there's just one statement, we execute them with an implicit, // auto-commit transaction. - runner := func(ctx context.Context, fn func(context.Context, *kv.Txn) error) error { return fn(ctx, nil) } + + type ( + txnFunc = func(context.Context, *kv.Txn, sqlutil.InternalExecutor) error + runnerFunc = func(ctx context.Context, fn txnFunc) error + ) + var runner runnerFunc if len(requestPayload.Statements) > 1 { // We need a transaction to group the statements together. // We use TxnWithSteppingEnabled here even though we don't // use stepping below, because that buys us admission control. - runner = func(ctx context.Context, fn func(context.Context, *kv.Txn) error) error { - return a.admin.server.db.TxnWithSteppingEnabled(ctx, sessiondatapb.Normal, fn) + cf := a.admin.server.sqlServer.execCfg.CollectionFactory + runner = func(ctx context.Context, fn txnFunc) error { + return cf.TxnWithExecutor(ctx, a.admin.server.db, nil, func( + ctx context.Context, txn *kv.Txn, _ *descs.Collection, ie sqlutil.InternalExecutor, + ) error { + return fn(ctx, txn, ie) + }, descs.SteppingEnabled()) + } + } else { + runner = func(ctx context.Context, fn func(context.Context, *kv.Txn, sqlutil.InternalExecutor) error) error { + return fn(ctx, nil, a.admin.ie) } } @@ -376,7 +391,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { err = contextutil.RunWithTimeout(ctx, "run-sql-via-api", timeout, func(ctx context.Context) error { retryNum := 0 - return runner(ctx, func(ctx context.Context, txn *kv.Txn) error { + return runner(ctx, func(ctx context.Context, txn *kv.Txn, ie sqlutil.InternalExecutor) error { result.Execution.TxnResults = result.Execution.TxnResults[:0] result.Execution.Retries = retryNum retryNum++ @@ -413,7 +428,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) { } }() - it, err := a.admin.ie.QueryIteratorEx(ctx, "run-query-via-api", txn, + it, err := ie.QueryIteratorEx(ctx, "run-query-via-api", txn, sessiondata.InternalExecutorOverride{ User: username, Database: requestPayload.Database, diff --git a/pkg/server/testdata/api_v2_sql b/pkg/server/testdata/api_v2_sql index e7fad39181ac..c1117e61117a 100644 --- a/pkg/server/testdata/api_v2_sql +++ b/pkg/server/testdata/api_v2_sql @@ -271,3 +271,186 @@ sql admin expect-error } ---- XXUUU|parsing statement 1: expected 1 placeholder(s), got 2 + +sql admin +{ + "database": "mydb", + "execute": true, + "statements": [{"sql": "CREATE TABLE foo (i INT PRIMARY KEY, j INT UNIQUE)"}] +} +---- +{ + "execution": { + "txn_results": [ + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 1, + "tag": "CREATE TABLE" + } + ] + }, + "num_statements": 1 +} + +sql admin +{ + "database": "mydb", + "execute": true, + "statements": [ + {"sql": "ALTER TABLE foo RENAME TO bar"}, + {"sql": "INSERT INTO bar (i) VALUES (1), (2)"}, + {"sql": "ALTER TABLE bar DROP COLUMN j"}, + {"sql": "ALTER TABLE bar ADD COLUMN k INT DEFAULT 42"} + ] +} +---- +{ + "execution": { + "txn_results": [ + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 1, + "tag": "ALTER TABLE" + }, + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 2, + "start": "1970-01-01T00:00:00Z", + "statement": 2, + "tag": "INSERT" + }, + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 3, + "tag": "ALTER TABLE" + }, + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 4, + "tag": "ALTER TABLE" + } + ] + }, + "num_statements": 4 +} + +sql admin +{ + "database": "mydb", + "execute": true, + "statements": [ + {"sql": "SELECT * FROM bar"} + ] +} +---- +{ + "execution": { + "txn_results": [ + { + "columns": [ + { + "name": "i", + "oid": 20, + "type": "INT8" + }, + { + "name": "k", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows": [ + { + "i": 1, + "k": 42 + }, + { + "i": 2, + "k": 42 + } + ], + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 1, + "tag": "SELECT" + } + ] + }, + "num_statements": 1 +} + + +sql admin +{ + "database": "mydb", + "execute": true, + "statements": [ + {"sql": "DROP TABLE bar"} + ] +} +---- +{ + "execution": { + "txn_results": [ + { + "columns": [ + { + "name": "rows_affected", + "oid": 20, + "type": "INT8" + } + ], + "end": "1970-01-01T00:00:00Z", + "rows_affected": 0, + "start": "1970-01-01T00:00:00Z", + "statement": 1, + "tag": "DROP TABLE" + } + ] + }, + "num_statements": 1 +} diff --git a/pkg/sql/catalog/descs/BUILD.bazel b/pkg/sql/catalog/descs/BUILD.bazel index a8f7c33d0446..5eb90c141119 100644 --- a/pkg/sql/catalog/descs/BUILD.bazel +++ b/pkg/sql/catalog/descs/BUILD.bazel @@ -60,6 +60,7 @@ go_library( "//pkg/sql/sem/catconstants", "//pkg/sql/sem/tree", "//pkg/sql/sessiondata", + "//pkg/sql/sessiondatapb", "//pkg/sql/sqlerrors", "//pkg/sql/sqlliveness", "//pkg/sql/sqlutil", @@ -85,6 +86,7 @@ go_test( "errors_test.go", "helpers_test.go", "main_test.go", + "txn_external_test.go", "txn_with_executor_datadriven_test.go", ], data = glob(["testdata/**"]), @@ -127,6 +129,7 @@ go_test( "//pkg/util/mon", "//pkg/util/randutil", "@com_github_cockroachdb_datadriven//:datadriven", + "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", "@com_github_lib_pq//oid", "@com_github_stretchr_testify//assert", diff --git a/pkg/sql/catalog/descs/txn.go b/pkg/sql/catalog/descs/txn.go index 4d2de8394083..0b14da74bf07 100644 --- a/pkg/sql/catalog/descs/txn.go +++ b/pkg/sql/catalog/descs/txn.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/retry" @@ -43,14 +44,51 @@ func (cf *CollectionFactory) Txn( ctx context.Context, db *kv.DB, f func(ctx context.Context, txn *kv.Txn, descriptors *Collection) error, + opts ...TxnOption, ) error { return cf.TxnWithExecutor(ctx, db, nil /* sessionData */, func( ctx context.Context, txn *kv.Txn, descriptors *Collection, _ sqlutil.InternalExecutor, ) error { return f(ctx, txn, descriptors) - }) + }, opts...) } +// TxnOption is used to configure a Txn or TxnWithExecutor. +type TxnOption interface { + apply(*txnConfig) +} + +type txnConfig struct { + steppingEnabled bool +} + +type txnOptionFn func(options *txnConfig) + +func (f txnOptionFn) apply(options *txnConfig) { f(options) } + +var steppingEnabled = txnOptionFn(func(o *txnConfig) { + o.steppingEnabled = true +}) + +// SteppingEnabled creates a TxnOption to determine whether the underlying +// transaction should have stepping enabled. If stepping is enabled, the +// transaction will implicitly use lower admission priority. However, the +// user will need to remember to Step the Txn to make writes visible. The +// InternalExecutor will automatically (for better or for worse) step the +// transaction when executing each statement. +func SteppingEnabled() TxnOption { + return steppingEnabled +} + +// TxnWithExecutorFunc is used to run a transaction in the context of a +// Collection and an InternalExecutor. +type TxnWithExecutorFunc = func( + ctx context.Context, + txn *kv.Txn, + descriptors *Collection, + ie sqlutil.InternalExecutor, +) error + // TxnWithExecutor enables callers to run transactions with a *Collection such that all // retrieved immutable descriptors are properly leased and all mutable // descriptors are handled. The function deals with verifying the two version @@ -66,8 +104,21 @@ func (cf *CollectionFactory) TxnWithExecutor( ctx context.Context, db *kv.DB, sd *sessiondata.SessionData, - f func(ctx context.Context, txn *kv.Txn, descriptors *Collection, ie sqlutil.InternalExecutor) error, + f TxnWithExecutorFunc, + opts ...TxnOption, ) error { + var config txnConfig + for _, opt := range opts { + opt.apply(&config) + } + run := db.Txn + if config.steppingEnabled { + type kvTxnFunc = func(context.Context, *kv.Txn) error + run = func(ctx context.Context, f kvTxnFunc) error { + return db.TxnWithSteppingEnabled(ctx, sessiondatapb.Normal, f) + } + } + // Waits for descriptors that were modified, skipping // over ones that had their descriptor wiped. waitForDescriptors := func(modifiedDescriptors []lease.IDVersion, deletedDescs catalog.DescriptorIDSet) error { @@ -97,7 +148,7 @@ func (cf *CollectionFactory) TxnWithExecutor( for { var modifiedDescriptors []lease.IDVersion var deletedDescs catalog.DescriptorIDSet - if err := db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { + if err := run(ctx, func(ctx context.Context, txn *kv.Txn) error { modifiedDescriptors, deletedDescs = nil, catalog.DescriptorIDSet{} descsCol := cf.NewCollection( ctx, nil, /* temporarySchemaProvider */ diff --git a/pkg/sql/catalog/descs/txn_external_test.go b/pkg/sql/catalog/descs/txn_external_test.go new file mode 100644 index 000000000000..941610269f51 --- /dev/null +++ b/pkg/sql/catalog/descs/txn_external_test.go @@ -0,0 +1,71 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package descs_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +// TestTxnWithStepping tests that if the user opts into stepping, they +// get stepping. +func TestTxnWithStepping(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, _, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + cf := s.CollectionFactory().(*descs.CollectionFactory) + scratchKey, err := s.ScratchRange() + require.NoError(t, err) + // Write a key, read in the transaction without stepping, ensure we + // do not see the value, step the transaction, then ensure that we do. + require.NoError(t, cf.Txn(ctx, kvDB, func( + ctx context.Context, txn *kv.Txn, descriptors *descs.Collection, + ) error { + if err := txn.Put(ctx, scratchKey, 1); err != nil { + return err + } + { + got, err := txn.Get(ctx, scratchKey) + if err != nil { + return err + } + if got.Exists() { + return errors.AssertionFailedf("expected no value, got %v", got) + } + } + if err := txn.Step(ctx); err != nil { + return err + } + { + got, err := txn.Get(ctx, scratchKey) + if err != nil { + return err + } + if got.ValueInt() != 1 { + return errors.AssertionFailedf("expected 1, got %v", got) + } + } + return nil + }, descs.SteppingEnabled())) +}