diff --git a/pkg/ccl/changefeedccl/cdctest/BUILD.bazel b/pkg/ccl/changefeedccl/cdctest/BUILD.bazel index f61ecd134167..fe6307fd9e57 100644 --- a/pkg/ccl/changefeedccl/cdctest/BUILD.bazel +++ b/pkg/ccl/changefeedccl/cdctest/BUILD.bazel @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/jobs", "//pkg/jobs/jobspb", + "//pkg/roachpb", "//pkg/sql", "//pkg/sql/parser", "//pkg/sql/sem/tree", diff --git a/pkg/ccl/changefeedccl/cdctest/validator.go b/pkg/ccl/changefeedccl/cdctest/validator.go index 8025e299f35d..77c7ed060a07 100644 --- a/pkg/ccl/changefeedccl/cdctest/validator.go +++ b/pkg/ccl/changefeedccl/cdctest/validator.go @@ -16,6 +16,7 @@ import ( "sort" "strings" + "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/errors" @@ -32,15 +33,83 @@ type Validator interface { Failures() []string } +// StreamClientValidatorWrapper wraps a Validator and exposes additional methods +// used by stream ingestion to check for correctness. +type StreamClientValidatorWrapper interface { + GetValuesForKeyBelowTimestamp(key string, timestamp hlc.Timestamp) ([]roachpb.KeyValue, error) + GetValidator() Validator +} + +type streamValidator struct { + Validator +} + +var _ StreamClientValidatorWrapper = &streamValidator{} + +// NewStreamClientValidatorWrapper returns a wrapped Validator, that can be used +// to validate the events emitted by the cluster to cluster streaming client. +// The wrapper currently only "wraps" an orderValidator, but can be built out +// to utilize other Validator's. +// The wrapper also allows querying the orderValidator to retrieve streamed +// events from an in-memory store. +func NewStreamClientValidatorWrapper() StreamClientValidatorWrapper { + ov := NewOrderValidator("unusedC2C") + return &streamValidator{ + ov, + } +} + +// GetValidator implements the StreamClientValidatorWrapper interface. +func (sv *streamValidator) GetValidator() Validator { + return sv.Validator +} + +// GetValuesForKeyBelowTimestamp implements the StreamClientValidatorWrapper +// interface. +// It returns the streamed KV updates for `key` with a ts less than equal to +// `timestamp`. +func (sv *streamValidator) GetValuesForKeyBelowTimestamp( + key string, timestamp hlc.Timestamp, +) ([]roachpb.KeyValue, error) { + orderValidator, ok := sv.GetValidator().(*orderValidator) + if !ok { + return nil, errors.Newf("unknown validator %T: ", sv.GetValidator()) + } + timestampValueTuples := orderValidator.keyTimestampAndValues[key] + timestampsIdx := sort.Search(len(timestampValueTuples), func(i int) bool { + return timestamp.Less(timestampValueTuples[i].ts) + }) + var kv []roachpb.KeyValue + for _, tsValue := range timestampValueTuples[:timestampsIdx] { + byteRep := []byte(key) + kv = append(kv, roachpb.KeyValue{ + Key: byteRep, + Value: roachpb.Value{ + RawBytes: []byte(tsValue.value), + Timestamp: tsValue.ts, + }, + }) + } + + return kv, nil +} + +type timestampValue struct { + ts hlc.Timestamp + value string +} + type orderValidator struct { - topic string - partitionForKey map[string]string - keyTimestamps map[string][]hlc.Timestamp - resolved map[string]hlc.Timestamp + topic string + partitionForKey map[string]string + keyTimestampAndValues map[string][]timestampValue + resolved map[string]hlc.Timestamp failures []string } +var _ Validator = &orderValidator{} + // NewOrderValidator returns a Validator that checks the row and resolved // timestamp ordering guarantees. It also asserts that keys have an affinity to // a single partition. @@ -52,17 +121,15 @@ type orderValidator struct { // lower update timestamp will be emitted on that partition. func NewOrderValidator(topic string) Validator { return &orderValidator{ - topic: topic, - partitionForKey: make(map[string]string), - keyTimestamps: make(map[string][]hlc.Timestamp), - resolved: make(map[string]hlc.Timestamp), + topic: topic, + partitionForKey: make(map[string]string), + keyTimestampAndValues: make(map[string][]timestampValue), + resolved: make(map[string]hlc.Timestamp), } } // NoteRow implements the Validator interface. -func (v *orderValidator) NoteRow( - partition string, key, ignoredValue string, updated hlc.Timestamp, -) error { +func (v *orderValidator) NoteRow(partition string, key, value string, updated hlc.Timestamp) error { if prev, ok := v.partitionForKey[key]; ok && prev != partition { v.failures = append(v.failures, fmt.Sprintf( `key [%s] received on two partitions: %s and %s`, key, prev, partition, @@ -71,17 +138,20 @@ func (v *orderValidator) NoteRow( } v.partitionForKey[key] = partition - timestamps := v.keyTimestamps[key] - timestampsIdx := sort.Search(len(timestamps), func(i int) bool { - return updated.LessEq(timestamps[i]) + timestampValueTuples := v.keyTimestampAndValues[key] + timestampsIdx := sort.Search(len(timestampValueTuples), func(i int) bool { + return updated.LessEq(timestampValueTuples[i].ts) }) - seen := timestampsIdx < len(timestamps) && timestamps[timestampsIdx] == updated + seen := timestampsIdx < len(timestampValueTuples) && + timestampValueTuples[timestampsIdx].ts == updated - if !seen && len(timestamps) > 0 && updated.Less(timestamps[len(timestamps)-1]) { + if !seen && len(timestampValueTuples) > 0 && + updated.Less(timestampValueTuples[len(timestampValueTuples)-1].ts) { v.failures = append(v.failures, fmt.Sprintf( `topic %s partition %s: saw new row timestamp %s after %s was seen`, v.topic, partition, - updated.AsOfSystemTime(), timestamps[len(timestamps)-1].AsOfSystemTime(), + updated.AsOfSystemTime(), + timestampValueTuples[len(timestampValueTuples)-1].ts.AsOfSystemTime(), )) } if !seen && updated.Less(v.resolved[partition]) { @@ -92,8 +162,12 @@ func (v *orderValidator) NoteRow( } if !seen { - v.keyTimestamps[key] = append( - append(timestamps[:timestampsIdx], updated), timestamps[timestampsIdx:]...) + v.keyTimestampAndValues[key] = append( + append(timestampValueTuples[:timestampsIdx], timestampValue{ + ts: updated, + value: value, + }), + timestampValueTuples[timestampsIdx:]...) } return nil } diff --git a/pkg/ccl/streamingccl/addresses.go b/pkg/ccl/streamingccl/addresses.go index 8153fd8d3f66..cb358eba22a0 100644 --- a/pkg/ccl/streamingccl/addresses.go +++ b/pkg/ccl/streamingccl/addresses.go @@ -25,6 +25,11 @@ func (sa StreamAddress) URL() (*url.URL, error) { // Each partition will emit events for a fixed span of keys. type PartitionAddress string +// URL parses the partition address as a URL. +func (pa PartitionAddress) URL() (*url.URL, error) { + return url.Parse(string(pa)) +} + // Topology is a configuration of stream partitions. These are particular to a // stream. It specifies the number and addresses of partitions of the stream. // diff --git a/pkg/ccl/streamingccl/streamclient/BUILD.bazel b/pkg/ccl/streamingccl/streamclient/BUILD.bazel index 33b33eae133a..fbf260b75c24 100644 --- a/pkg/ccl/streamingccl/streamclient/BUILD.bazel +++ b/pkg/ccl/streamingccl/streamclient/BUILD.bazel @@ -14,12 +14,14 @@ go_library( "//pkg/keys", "//pkg/roachpb", "//pkg/sql", + "//pkg/sql/catalog/catalogkeys", "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/systemschema", "//pkg/sql/catalog/tabledesc", "//pkg/sql/rowenc", "//pkg/sql/sem/tree", "//pkg/util/hlc", + "//pkg/util/randutil", "//pkg/util/syncutil", "//pkg/util/timeutil", ], diff --git a/pkg/ccl/streamingccl/streamclient/client_test.go b/pkg/ccl/streamingccl/streamclient/client_test.go index 4aea27ee9979..7acee6686846 100644 --- a/pkg/ccl/streamingccl/streamclient/client_test.go +++ b/pkg/ccl/streamingccl/streamclient/client_test.go @@ -38,7 +38,7 @@ func (sc testStreamClient) GetTopology( // ConsumePartition implements the Client interface. func (sc testStreamClient) ConsumePartition( - _ context.Context, _ streamingccl.PartitionAddress, _ time.Time, + _ context.Context, pa streamingccl.PartitionAddress, _ time.Time, ) (chan streamingccl.Event, error) { sampleKV := roachpb.KeyValue{ Key: []byte("key_1"), diff --git a/pkg/ccl/streamingccl/streamclient/random_stream_client.go b/pkg/ccl/streamingccl/streamclient/random_stream_client.go index 822ac6760617..41cc2a85917a 100644 --- a/pkg/ccl/streamingccl/streamclient/random_stream_client.go +++ b/pkg/ccl/streamingccl/streamclient/random_stream_client.go @@ -10,6 +10,7 @@ package streamclient import ( "context" + "fmt" "math/rand" "net/url" "strconv" @@ -19,20 +20,22 @@ import ( "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/catalogkeys" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/systemschema" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" ) const ( - // RandomStreamSchema is the schema of the KVs emitted by the random stream - // client. - RandomStreamSchema = "CREATE TABLE test (k INT PRIMARY KEY, v INT)" + // RandomStreamSchemaPlaceholder is the schema of the KVs emitted by the + // random stream client. + RandomStreamSchemaPlaceholder = "CREATE TABLE %s (k INT PRIMARY KEY, v INT)" // TestScheme is the URI scheme used to create a test load. TestScheme = "test" @@ -45,14 +48,37 @@ const ( // KVsPerCheckpoint controls approximately how many KV events should be emitted // between checkpoint events. KVsPerCheckpoint = "KVS_PER_CHECKPOINT" + // NumPartitions controls the number of partitions the client will stream data + // back on. Each partition will encompass a single table span. + NumPartitions = "NUM_PARTITIONS" + // DupProbability controls the probability with which we emit duplicate KV + // events. + DupProbability = "DUP_PROBABILITY" + // IngestionDatabaseID is the ID used in the generated table descriptor. + IngestionDatabaseID = 50 /* defaultDB */ + // IngestionTablePrefix is the prefix of the table name used in the generated + // table descriptor. + IngestionTablePrefix = "foo" ) +type interceptFn func(event streamingccl.Event, pa streamingccl.PartitionAddress) + +// InterceptableStreamClient wraps a Client, and provides a method to register +// interceptor methods that are run on every streamed Event. +type InterceptableStreamClient interface { + Client + + RegisterInterception(fn interceptFn) +} + // randomStreamConfig specifies the variables that controls the rate and type of // events that the generated stream emits. type randomStreamConfig struct { valueRange int kvFrequency time.Duration kvsPerCheckpoint int + numPartitions int + dupProbability float64 } func parseRandomStreamConfig(streamURL *url.URL) (randomStreamConfig, error) { @@ -60,6 +86,8 @@ func parseRandomStreamConfig(streamURL *url.URL) (randomStreamConfig, error) { valueRange: 100, kvFrequency: 10 * time.Microsecond, kvsPerCheckpoint: 100, + numPartitions: 1, + dupProbability: 0.5, } var err error @@ -85,81 +113,153 @@ func parseRandomStreamConfig(streamURL *url.URL) (randomStreamConfig, error) { } } + if numPartitionsStr := streamURL.Query().Get(NumPartitions); numPartitionsStr != "" { + c.numPartitions, err = strconv.Atoi(numPartitionsStr) + if err != nil { + return c, err + } + } + + if dupProbStr := streamURL.Query().Get(DupProbability); dupProbStr != "" { + c.dupProbability, err = strconv.ParseFloat(dupProbStr, 32) + if err != nil { + return c, err + } + } return c, nil } // randomStreamClient is a temporary stream client implementation that generates // random events. // -// It expects a table with the schema `RandomStreamSchema` to already exist, -// with table ID `` to be used in the URI. Opening the stream client -// on the URI 'test://' will generate random events into this table. +// The client can be configured to return more than one partition via the stream +// URL. Each partition covers a single table span. // // TODO: Move this over to a _test file in the ingestion package when there is a // real stream client implementation. type randomStreamClient struct { - baseDesc *tabledesc.Mutable - config randomStreamConfig + config randomStreamConfig // interceptors can be registered to peek at every event generated by this // client. mu struct { syncutil.Mutex - interceptors []func(streamingccl.Event) + interceptors []func(streamingccl.Event, streamingccl.PartitionAddress) } } var _ Client = &randomStreamClient{} +var _ InterceptableStreamClient = &randomStreamClient{} // newRandomStreamClient returns a stream client that generates a random set of // events on a table with an integer key and integer value for the table with // the given ID. func newRandomStreamClient(streamURL *url.URL) (Client, error) { - tableID, err := strconv.Atoi(streamURL.Host) - if err != nil { - return nil, err - } - testTable, err := sql.CreateTestTableDescriptor( - context.Background(), - 50, /* defaultdb */ - descpb.ID(tableID), - RandomStreamSchema, - systemschema.JobsTable.GetPrivileges(), - ) - if err != nil { - return nil, err - } - streamConfig, err := parseRandomStreamConfig(streamURL) if err != nil { return nil, err } return &randomStreamClient{ - baseDesc: testTable, - config: streamConfig, + config: streamConfig, }, nil } +var testTableID = 52 + +func getNextTableID() int { + ret := testTableID + testTableID++ + return ret +} + // GetTopology implements the Client interface. func (m *randomStreamClient) GetTopology( _ streamingccl.StreamAddress, ) (streamingccl.Topology, error) { - panic("not yet implemented") + topology := streamingccl.Topology{Partitions: make([]streamingccl.PartitionAddress, + 0, m.config.numPartitions)} + + // Allocate table IDs and return one per partition address in the topology. + for i := 0; i < m.config.numPartitions; i++ { + tableID := descpb.ID(getNextTableID()) + partitionURI := url.URL{ + Scheme: TestScheme, + Host: strconv.Itoa(int(tableID)), + } + topology.Partitions = append(topology.Partitions, + streamingccl.PartitionAddress(partitionURI.String())) + } + + return topology, nil +} + +// getDescriptorAndNamespaceKVForTableID returns the namespace and descriptor +// KVs for the table with tableID. +func (m *randomStreamClient) getDescriptorAndNamespaceKVForTableID( + tableID descpb.ID, +) (*tabledesc.Mutable, []roachpb.KeyValue, error) { + tableName := fmt.Sprintf("%s%d", IngestionTablePrefix, tableID) + testTable, err := sql.CreateTestTableDescriptor( + context.Background(), + IngestionDatabaseID, + tableID, + fmt.Sprintf(RandomStreamSchemaPlaceholder, tableName), + systemschema.JobsTable.GetPrivileges(), + ) + if err != nil { + return nil, nil, err + } + + // Generate namespace entry. + key := catalogkeys.NewTableKey(50, keys.PublicSchemaID, testTable.Name) + var value roachpb.Value + value.SetInt(int64(testTable.GetID())) + namespaceKV := roachpb.KeyValue{ + Key: key.Key(keys.TODOSQLCodec), + Value: value, + } + + // Generate descriptor entry. + descKey := catalogkeys.MakeDescMetadataKey(keys.TODOSQLCodec, testTable.GetID()) + descDesc := testTable.DescriptorProto() + var descValue roachpb.Value + if err := descValue.SetProto(descDesc); err != nil { + panic(err) + } + descKV := roachpb.KeyValue{ + Key: descKey, + Value: descValue, + } + + return testTable, []roachpb.KeyValue{namespaceKV, descKV}, nil } // ConsumePartition implements the Client interface. func (m *randomStreamClient) ConsumePartition( - ctx context.Context, _ streamingccl.PartitionAddress, startTime time.Time, + ctx context.Context, partitionAddress streamingccl.PartitionAddress, startTime time.Time, ) (chan streamingccl.Event, error) { eventCh := make(chan streamingccl.Event) now := timeutil.Now() if startTime.After(now) { panic("cannot start random stream client event stream in the future") } - lastResolvedTime := startTime + partitionURL, err := partitionAddress.URL() + if err != nil { + return nil, err + } + var partitionTableID int + partitionTableID, err = strconv.Atoi(partitionURL.Host) + if err != nil { + return nil, err + } + + tableDesc, systemKVs, err := m.getDescriptorAndNamespaceKVForTableID(descpb.ID(partitionTableID)) + if err != nil { + return nil, err + } go func() { defer close(eventCh) @@ -176,24 +276,40 @@ func (m *randomStreamClient) ConsumePartition( resolvedTimer.Reset(0) defer resolvedTimer.Stop() + rng, _ := randutil.NewPseudoRand() + var dupKVEvent streamingccl.Event + for { var event streamingccl.Event select { case <-kvTimer.C: kvTimer.Read = true - event = streamingccl.MakeKVEvent(m.makeRandomKey(r, lastResolvedTime)) + // If there are system KVs to emit, prioritize those. + if len(systemKVs) > 0 { + systemKV := systemKVs[0] + systemKV.Value.Timestamp = hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} + event = streamingccl.MakeKVEvent(systemKV) + systemKVs = systemKVs[1:] + } else { + // Generate a duplicate KVEvent, and update its timestamp to now(). + if rng.Float64() < m.config.dupProbability && dupKVEvent != nil { + dupKV := dupKVEvent.GetKV() + event = streamingccl.MakeKVEvent(*dupKV) + } else { + event = streamingccl.MakeKVEvent(m.makeRandomKey(r, tableDesc)) + dupKVEvent = event + } + } kvTimer.Reset(kvInterval) case <-resolvedTimer.C: resolvedTimer.Read = true resolvedTime := timeutil.Now() hlcResolvedTime := hlc.Timestamp{WallTime: resolvedTime.UnixNano()} event = streamingccl.MakeCheckpointEvent(hlcResolvedTime) - lastResolvedTime = resolvedTime resolvedTimer.Reset(resolvedInterval) + dupKVEvent = nil } - // TODO: Consider keeping an in-memory copy so that tests can verify - // that the data we've ingested is correct. select { case eventCh <- event: case <-ctx.Done(): @@ -204,7 +320,7 @@ func (m *randomStreamClient) ConsumePartition( m.mu.Lock() for _, interceptor := range m.mu.interceptors { if interceptor != nil { - interceptor(event) + interceptor(event, partitionAddress) } } m.mu.Unlock() @@ -215,9 +331,9 @@ func (m *randomStreamClient) ConsumePartition( return eventCh, nil } -func (m *randomStreamClient) makeRandomKey(r *rand.Rand, minTs time.Time) roachpb.KeyValue { - tableDesc := m.baseDesc - +func (m *randomStreamClient) makeRandomKey( + r *rand.Rand, tableDesc *tabledesc.Mutable, +) roachpb.KeyValue { // Create a key holding a random integer. k, err := rowenc.TestingMakePrimaryIndexKey(tableDesc, r.Intn(m.config.valueRange)) if err != nil { @@ -237,10 +353,7 @@ func (m *randomStreamClient) makeRandomKey(r *rand.Rand, minTs time.Time) roachp v.ClearChecksum() v.InitChecksum(k) - // Generate a timestamp between minTs and now(). - randOffset := int(timeutil.Now().UnixNano()) - int(minTs.UnixNano()) - newTimestamp := rand.Intn(randOffset) + int(minTs.UnixNano()) - v.Timestamp = hlc.Timestamp{WallTime: int64(newTimestamp)} + v.Timestamp = hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} return roachpb.KeyValue{ Key: k, @@ -249,8 +362,8 @@ func (m *randomStreamClient) makeRandomKey(r *rand.Rand, minTs time.Time) roachp } // RegisterInterception implements streamingest.interceptableStreamClient. -func (m *randomStreamClient) RegisterInterception(f func(event streamingccl.Event)) { +func (m *randomStreamClient) RegisterInterception(fn interceptFn) { m.mu.Lock() defer m.mu.Unlock() - m.mu.interceptors = append(m.mu.interceptors, f) + m.mu.interceptors = append(m.mu.interceptors, fn) } diff --git a/pkg/ccl/streamingccl/streamingest/BUILD.bazel b/pkg/ccl/streamingccl/streamingest/BUILD.bazel index 10e641f42fd3..ff47cc0255e9 100644 --- a/pkg/ccl/streamingccl/streamingest/BUILD.bazel +++ b/pkg/ccl/streamingccl/streamingest/BUILD.bazel @@ -56,6 +56,7 @@ go_test( embed = [":streamingest"], deps = [ "//pkg/base", + "//pkg/ccl/changefeedccl/cdctest", "//pkg/ccl/storageccl", "//pkg/ccl/streamingccl", "//pkg/ccl/streamingccl/streamclient", @@ -69,9 +70,11 @@ go_test( "//pkg/security/securitytest", "//pkg/server", "//pkg/settings/cluster", + "//pkg/sql", "//pkg/sql/execinfra", "//pkg/sql/execinfrapb", "//pkg/sql/sem/tree", + "//pkg/storage", "//pkg/testutils", "//pkg/testutils/distsqlutils", "//pkg/testutils/serverutils", diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_frontier_processor_test.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_frontier_processor_test.go index 7eda9458777d..0d83b7fbf6cb 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_frontier_processor_test.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_frontier_processor_test.go @@ -63,8 +63,8 @@ func TestStreamIngestionFrontierProcessor(t *testing.T) { post := execinfrapb.PostProcessSpec{} var spec execinfrapb.StreamIngestionDataSpec - pa1 := streamingccl.PartitionAddress("s3://my_streams/stream/partition1") - pa2 := streamingccl.PartitionAddress("s3://my_streams/stream/partition2") + pa1 := streamingccl.PartitionAddress("partition1") + pa2 := streamingccl.PartitionAddress("partition2") v := roachpb.MakeValueFromString("value_1") v.Timestamp = hlc.Timestamp{WallTime: 1} diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go index 05b5786100c9..e82fb62a21c1 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/kv/bulk" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/execinfra" "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" @@ -72,7 +73,6 @@ type streamIngestionProcessor struct { // partitionEvent augments a normal event with the partition it came from. type partitionEvent struct { streamingccl.Event - partition streamingccl.PartitionAddress } @@ -93,6 +93,20 @@ func newStreamIngestionDataProcessor( return nil, err } + // Check if there are any interceptor methods that need to be registered with + // the stream client. + // These methods are invoked on every emitted Event. + if knobs, ok := flowCtx.Cfg.TestingKnobs.StreamIngestionTestingKnobs.(*sql. + StreamIngestionTestingKnobs); ok { + if knobs.Interceptors != nil { + if interceptable, ok := streamClient.(streamclient.InterceptableStreamClient); ok { + for _, interceptor := range knobs.Interceptors { + interceptable.RegisterInterception(interceptor) + } + } + } + } + sip := &streamIngestionProcessor{ flowCtx: flowCtx, spec: spec, diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go index 314d0aaa942d..67b1c1451b21 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go @@ -16,15 +16,19 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdctest" "github.com/cockroachdb/cockroach/pkg/ccl/streamingccl" "github.com/cockroachdb/cockroach/pkg/ccl/streamingccl/streamclient" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" + "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/execinfra" "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/storage" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/distsqlutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" @@ -38,12 +42,6 @@ import ( "github.com/stretchr/testify/require" ) -type interceptableStreamClient interface { - streamclient.Client - - RegisterInterception(func(event streamingccl.Event)) -} - // mockStreamClient will return the slice of events associated to the stream // partition being consumed. Stream partitions are identified by unique // partition addresses. @@ -93,7 +91,6 @@ func TestStreamIngestionProcessor(t *testing.T) { defer tc.Stopper().Stop(ctx) kvDB := tc.Server(0).DB() - // Inject a mock client. v := roachpb.MakeValueFromString("value_1") v.Timestamp = hlc.Timestamp{WallTime: 1} sampleKV := roachpb.KeyValue{Key: roachpb.Key("key_1"), Value: v} @@ -112,8 +109,9 @@ func TestStreamIngestionProcessor(t *testing.T) { } startTime := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} - out, err := runStreamIngestionProcessor(ctx, t, kvDB, "some://stream", startTime, - nil /* interceptors */, mockClient) + partitionAddresses := []streamingccl.PartitionAddress{"partition1", "partition2"} + out, err := runStreamIngestionProcessor(ctx, t, kvDB, "some://stream", partitionAddresses, + startTime, nil /* interceptEvents */, mockClient) require.NoError(t, err) // Compare the set of results since the ordering is not guaranteed. @@ -142,6 +140,90 @@ func TestStreamIngestionProcessor(t *testing.T) { require.Equal(t, expectedRows, actualRows) } +func getPartitionSpanToTableID( + t *testing.T, partitionAddresses []streamingccl.PartitionAddress, +) map[string]int { + pSpanToTableID := make(map[string]int) + + // Aggregate the table IDs which should have been ingested. + for _, pa := range partitionAddresses { + pKey := roachpb.Key(pa) + pSpan := roachpb.Span{Key: pKey, EndKey: pKey.Next()} + paURL, err := pa.URL() + require.NoError(t, err) + id, err := strconv.Atoi(paURL.Host) + require.NoError(t, err) + pSpanToTableID[pSpan.String()] = id + } + return pSpanToTableID +} + +// assertEqualKVs iterates over the store in `tc` and compares the MVCC KVs +// against the in-memory copy of events stored in the `streamValidator`. This +// ensures that the stream ingestion processor ingested at least as much data as +// was streamed up until partitionTimestamp. +func assertEqualKVs( + t *testing.T, + tc *testcluster.TestCluster, + streamValidator cdctest.StreamClientValidatorWrapper, + tableID int, + partitionTimestamp hlc.Timestamp, +) { + key := keys.TODOSQLCodec.TablePrefix(uint32(tableID)) + + // Iterate over the store. + store := tc.GetFirstStoreFromServer(t, 0) + it := store.Engine().NewMVCCIterator(storage.MVCCKeyAndIntentsIterKind, storage.IterOptions{ + LowerBound: key, + UpperBound: key.PrefixEnd(), + }) + defer it.Close() + var prevKey roachpb.Key + var valueTimestampTuples []roachpb.KeyValue + var err error + for it.SeekGE(storage.MVCCKey{}); ; it.Next() { + if ok, err := it.Valid(); !ok { + if err != nil { + t.Fatal(err) + } + break + } + + // We only want to process MVCC KVs with a ts less than or equal to the max + // resolved ts for this partition. + if partitionTimestamp.Less(it.Key().Timestamp) { + continue + } + + newKey := (prevKey != nil && !it.Key().Key.Equal(prevKey)) || prevKey == nil + prevKey = it.Key().Key + + if newKey { + // All value ts should have been drained at this point, otherwise there is + // a mismatch between the streamed and ingested data. + require.Equal(t, 0, len(valueTimestampTuples)) + valueTimestampTuples, err = streamValidator.GetValuesForKeyBelowTimestamp( + string(it.Key().Key), partitionTimestamp) + require.NoError(t, err) + } + + require.Greater(t, len(valueTimestampTuples), 0) + // Since the iterator goes from latest to older versions, we compare + // starting from the end of the slice that is sorted by timestamp. + latestVersionInChain := valueTimestampTuples[len(valueTimestampTuples)-1] + require.Equal(t, roachpb.KeyValue{ + Key: it.Key().Key, + Value: roachpb.Value{ + RawBytes: it.Value(), + Timestamp: it.Key().Timestamp, + }, + }, latestVersionInChain) + // Truncate the latest version which we just checked against in preparation + // for the next iteration. + valueTimestampTuples = valueTimestampTuples[0 : len(valueTimestampTuples)-1] + } +} + // TestRandomClientGeneration tests the ingestion processor against a random // stream workload. func TestRandomClientGeneration(t *testing.T) { @@ -151,13 +233,14 @@ func TestRandomClientGeneration(t *testing.T) { ctx := context.Background() makeTestStreamURI := func( - tableID string, - valueRange, kvsPerResolved int, - kvFrequency time.Duration, + valueRange, kvsPerResolved, numPartitions int, + kvFrequency time.Duration, dupProbability float64, ) string { - return "test://" + tableID + "?VALUE_RANGE=" + strconv.Itoa(valueRange) + + return "test:///" + "?VALUE_RANGE=" + strconv.Itoa(valueRange) + "&KV_FREQUENCY=" + strconv.Itoa(int(kvFrequency)) + - "&KVS_PER_RESOLVED=" + strconv.Itoa(kvsPerResolved) + "&KVS_PER_CHECKPOINT=" + strconv.Itoa(kvsPerResolved) + + "&NUM_PARTITIONS=" + strconv.Itoa(numPartitions) + + "&DUP_PROBABILITY=" + strconv.FormatFloat(dupProbability, 'f', -1, 32) } tc := testcluster.StartTestCluster(t, 3 /* nodes */, base.TestClusterArgs{}) @@ -166,30 +249,38 @@ func TestRandomClientGeneration(t *testing.T) { conn := tc.Conns[0] sqlDB := sqlutils.MakeSQLRunner(conn) - // Create the expected table for the random stream to ingest into. - sqlDB.Exec(t, streamclient.RandomStreamSchema) - tableID := sqlDB.QueryStr(t, `SELECT id FROM system.namespace WHERE name = 'test'`)[0][0] - // TODO: Consider testing variations on these parameters. valueRange := 100 kvsPerResolved := 1_000 kvFrequency := 50 * time.Nanosecond - streamAddr := makeTestStreamURI(tableID, valueRange, kvsPerResolved, kvFrequency) + numPartitions := 4 + dupProbability := 0.2 + streamAddr := makeTestStreamURI(valueRange, kvsPerResolved, numPartitions, kvFrequency, + dupProbability) + + // The random client returns system and table data partitions. + streamClient, err := streamclient.NewStreamClient(streamingccl.StreamAddress(streamAddr)) + require.NoError(t, err) + topo, err := streamClient.GetTopology(streamingccl.StreamAddress(streamAddr)) + require.NoError(t, err) + // One system and two table data partitions. + require.Equal(t, numPartitions, len(topo.Partitions)) startTime := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} ctx, cancel := context.WithCancel(ctx) // Cancel the flow after emitting 1000 checkpoint events from the client. - cancelAfterCheckpoints := makeCheckpointEventCounter(1_000, cancel) - out, err := runStreamIngestionProcessor(ctx, t, kvDB, streamAddr, startTime, - cancelAfterCheckpoints, nil /* mockClient */) + cancelAfterCheckpoints := makeCheckpointEventCounter(1000, cancel) + streamValidator := cdctest.NewStreamClientValidatorWrapper() + validator := registerValidator(streamValidator.GetValidator()) + out, err := runStreamIngestionProcessor(ctx, t, kvDB, streamAddr, topo.Partitions, + startTime, []func(streamingccl.Event, streamingccl.PartitionAddress){cancelAfterCheckpoints, + validator}, nil /* mockClient */) require.NoError(t, err) - p1Key := roachpb.Key("partition1") - p2Key := roachpb.Key("partition2") - p1Span := roachpb.Span{Key: p1Key, EndKey: p1Key.Next()} - p2Span := roachpb.Span{Key: p2Key, EndKey: p2Key.Next()} + partitionSpanToTableID := getPartitionSpanToTableID(t, topo.Partitions) numResolvedEvents := 0 + maxResolvedTimestampPerPartition := make(map[string]hlc.Timestamp) for { row, meta := out.Next() if meta != nil { @@ -209,20 +300,38 @@ func TestRandomClientGeneration(t *testing.T) { var resolvedSpan jobspb.ResolvedSpan require.NoError(t, protoutil.Unmarshal([]byte(*protoBytes), &resolvedSpan)) - if resolvedSpan.Span.String() != p1Span.String() && resolvedSpan.Span.String() != p2Span.String() { - t.Fatalf("expected resolved span %v to be either %v or %v", resolvedSpan.Span, p1Span, p2Span) + if _, ok := partitionSpanToTableID[resolvedSpan.Span.String()]; !ok { + t.Fatalf("expected resolved span %v to be either in one of the supplied partition"+ + " addresses %v", resolvedSpan.Span, topo.Partitions) } // All resolved timestamp events should be greater than the start time. require.Less(t, startTime.WallTime, resolvedSpan.Timestamp.WallTime) + + // Track the max resolved timestamp per partition. + if ts, ok := maxResolvedTimestampPerPartition[resolvedSpan.Span.String()]; !ok || + ts.Less(resolvedSpan.Timestamp) { + maxResolvedTimestampPerPartition[resolvedSpan.Span.String()] = resolvedSpan.Timestamp + } numResolvedEvents++ } - // Check that some rows have been ingested and that we've emitted some resolved events. - numRows, err := strconv.Atoi(sqlDB.QueryStr(t, `SELECT count(*) FROM defaultdb.test`)[0][0]) - require.NoError(t, err) - require.Greater(t, numRows, 0, "at least 1 row ingested expected") + // Ensure that no errors were reported to the validator. + for _, failure := range streamValidator.GetValidator().Failures() { + t.Error(failure) + } + + for pSpan, id := range partitionSpanToTableID { + numRows, err := strconv.Atoi(sqlDB.QueryStr(t, fmt.Sprintf( + `SELECT count(*) FROM defaultdb.%s%d`, streamclient.IngestionTablePrefix, id))[0][0]) + require.NoError(t, err) + require.Greater(t, numRows, 0, "at least 1 row ingested expected") + // Scan the store for KVs ingested by this partition, and compare the MVCC + // KVs against the KVEvents streamed up to the max ingested timestamp for + // the partition. + assertEqualKVs(t, tc, streamValidator, id, maxResolvedTimestampPerPartition[pSpan]) + } require.Greater(t, numResolvedEvents, 0, "at least 1 resolved event expected") } @@ -231,8 +340,9 @@ func runStreamIngestionProcessor( t *testing.T, kvDB *kv.DB, streamAddr string, + partitionAddresses []streamingccl.PartitionAddress, startTime hlc.Timestamp, - interceptEvents func(streamingccl.Event), + interceptEvents []func(streamingccl.Event, streamingccl.PartitionAddress), mockClient streamclient.Client, ) (*distsqlutils.RowBuffer, error) { st := cluster.MakeTestingClusterSettings() @@ -249,6 +359,8 @@ func runStreamIngestionProcessor( }, EvalCtx: &evalCtx, } + flowCtx.Cfg.TestingKnobs.StreamIngestionTestingKnobs = &sql.StreamIngestionTestingKnobs{ + Interceptors: interceptEvents} out := &distsqlutils.RowBuffer{} post := execinfrapb.PostProcessSpec{} @@ -256,7 +368,7 @@ func runStreamIngestionProcessor( var spec execinfrapb.StreamIngestionDataSpec spec.StreamAddress = streamingccl.StreamAddress(streamAddr) - spec.PartitionAddresses = []streamingccl.PartitionAddress{"partition1", "partition2"} + spec.PartitionAddresses = partitionAddresses spec.StartTime = startTime processorID := int32(0) proc, err := newStreamIngestionDataProcessor(&flowCtx, processorID, spec, &post, out) @@ -270,15 +382,6 @@ func runStreamIngestionProcessor( sip.client = mockClient } - if interceptableClient, ok := sip.client.(interceptableStreamClient); ok { - interceptableClient.RegisterInterception(interceptEvents) - // TODO: Inject an interceptor here that keeps track of generated events so - // we can compare. - } else if interceptEvents != nil { - t.Fatalf("interceptor specified, but client %T does not implement interceptableStreamClient", - sip.client) - } - sip.Run(ctx) // Ensure that all the outputs are properly closed. @@ -288,11 +391,36 @@ func runStreamIngestionProcessor( return out, err } +func registerValidator( + validator cdctest.Validator, +) func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + return func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + switch event.Type() { + case streamingccl.CheckpointEvent: + resolvedTS := *event.GetResolved() + err := validator.NoteResolved(string(pa), resolvedTS) + if err != nil { + panic(err.Error()) + } + case streamingccl.KVEvent: + kv := *event.GetKV() + + err := validator.NoteRow(string(pa), string(kv.Key), string(kv.Value.RawBytes), + kv.Value.Timestamp) + if err != nil { + panic(err.Error()) + } + } + } +} + // makeCheckpointEventCounter runs f after seeing `threshold` number of // checkpoint events. -func makeCheckpointEventCounter(threshold int, f func()) func(streamingccl.Event) { +func makeCheckpointEventCounter( + threshold int, f func(), +) func(streamingccl.Event, streamingccl.PartitionAddress) { numCheckpointEventsGenerated := 0 - return func(event streamingccl.Event) { + return func(event streamingccl.Event, _ streamingccl.PartitionAddress) { switch event.Type() { case streamingccl.CheckpointEvent: numCheckpointEventsGenerated++ diff --git a/pkg/server/status_test.go b/pkg/server/status_test.go index 71bf2d570571..c1fd1e409e2e 100644 --- a/pkg/server/status_test.go +++ b/pkg/server/status_test.go @@ -1160,6 +1160,44 @@ func TestStatusVars(t *testing.T) { } } +// TestStatusVarsTxnMetrics verifies that the metrics from the /_status/vars +// endpoint for txns and the special cockroach_restart savepoint are correct. +func TestStatusVarsTxnMetrics(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer db.Close() + defer s.Stopper().Stop(context.Background()) + + if _, err := db.Exec("BEGIN;" + + "SAVEPOINT cockroach_restart;" + + "SELECT 1;" + + "RELEASE SAVEPOINT cockroach_restart;" + + "ROLLBACK;"); err != nil { + t.Fatal(err) + } + + body, err := getText(s, s.AdminURL()+statusPrefix+"vars") + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(body, []byte("sql_txn_begin_count 1")) { + t.Errorf("expected `sql_txn_begin_count 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_restart_savepoint_count 1")) { + t.Errorf("expected `sql_restart_savepoint_count 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_restart_savepoint_release_count 1")) { + t.Errorf("expected `sql_restart_savepoint_release_count 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_txn_commit_count 1")) { + t.Errorf("expected `sql_txn_commit_count 1`, got: %s", body) + } + if !bytes.Contains(body, []byte("sql_txn_rollback_count 0")) { + t.Errorf("expected `sql_txn_rollback_count 0`, got: %s", body) + } +} + func TestSpanStatsResponse(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index abe914d5d6f3..6393275322eb 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -230,6 +230,7 @@ go_library( deps = [ "//pkg/base", "//pkg/build", + "//pkg/ccl/streamingccl", "//pkg/clusterversion", "//pkg/config", "//pkg/config/zonepb", diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index d338b7f23818..b8ae839749dc 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -2713,7 +2713,13 @@ func (sc *StatementCounters) incrementCount(ex *connExecutor, stmt tree.Statemen case *tree.CommitTransaction: sc.TxnCommitCount.Inc() case *tree.RollbackTransaction: - sc.TxnRollbackCount.Inc() + // The CommitWait state means that the transaction has already committed + // after a specially handled `RELEASE SAVEPOINT cockroach_restart` command. + if ex.getTransactionState() == CommitWaitStateStr { + sc.TxnCommitCount.Inc() + } else { + sc.TxnRollbackCount.Inc() + } case *tree.Savepoint: if ex.isCommitOnReleaseSavepoint(t.Name) { sc.RestartSavepointCount.Inc() diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index b0d605499b75..c787e3d9044e 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/apd/v2" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/streamingccl" "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/config" "github.com/cockroachdb/cockroach/pkg/config/zonepb" @@ -982,6 +983,16 @@ var _ base.ModuleTestingKnobs = &BackupRestoreTestingKnobs{} // ModuleTestingKnobs implements the base.ModuleTestingKnobs interface. func (*BackupRestoreTestingKnobs) ModuleTestingKnobs() {} +// StreamIngestionTestingKnobs contains knobs for stream ingestion behavior. +type StreamIngestionTestingKnobs struct { + Interceptors []func(event streamingccl.Event, pa streamingccl.PartitionAddress) +} + +var _ base.ModuleTestingKnobs = &StreamIngestionTestingKnobs{} + +// ModuleTestingKnobs implements the base.ModuleTestingKnobs interface. +func (*StreamIngestionTestingKnobs) ModuleTestingKnobs() {} + func shouldDistributeGivenRecAndMode( rec distRecommendation, mode sessiondata.DistSQLExecMode, ) bool { diff --git a/pkg/sql/execinfra/server_config.go b/pkg/sql/execinfra/server_config.go index 3475952bd460..ea760c5885ed 100644 --- a/pkg/sql/execinfra/server_config.go +++ b/pkg/sql/execinfra/server_config.go @@ -232,6 +232,9 @@ type TestingKnobs struct { // BackupRestoreTestingKnobs are backup and restore specific testing knobs. BackupRestoreTestingKnobs base.ModuleTestingKnobs + + // BackupRestoreTestingKnobs are stream ingestion specific testing knobs. + StreamIngestionTestingKnobs base.ModuleTestingKnobs } // MetadataTestLevel represents the types of queries where metadata test diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index 7f7ecb6720cc..acb14f50ea9c 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -520,32 +520,41 @@ func client(ctx context.Context, serverAddr net.Addr, wg *sync.WaitGroup) error // waitForClientConn blocks until a client connects and performs the pgwire // handshake. This emulates what pgwire.Server does. func waitForClientConn(ln net.Listener) (*conn, error) { - conn, err := ln.Accept() + conn, _, err := getSessionArgs(ln, false /* trustRemoteAddr */) if err != nil { return nil, err } + metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval) + pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil) + return pgwireConn, nil +} + +// getSessionArgs blocks until a client connects and returns the connection +// together with session arguments or an error. +func getSessionArgs(ln net.Listener, trustRemoteAddr bool) (net.Conn, sql.SessionArgs, error) { + conn, err := ln.Accept() + if err != nil { + return nil, sql.SessionArgs{}, err + } + buf := pgwirebase.MakeReadBuffer() _, err = buf.ReadUntypedMsg(conn) if err != nil { - return nil, err + return nil, sql.SessionArgs{}, err } version, err := buf.GetUint32() if err != nil { - return nil, err + return nil, sql.SessionArgs{}, err } if version != version30 { - return nil, errors.Errorf("unexpected protocol version: %d", version) - } - - // Consume the connection options. - if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf, conn.RemoteAddr(), false /* trustRemoteAddr */); err != nil { - return nil, err + return nil, sql.SessionArgs{}, errors.Errorf("unexpected protocol version: %d", version) } - metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval) - pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil) - return pgwireConn, nil + args, err := parseClientProvidedSessionParameters( + context.Background(), nil, &buf, conn.RemoteAddr(), trustRemoteAddr, + ) + return conn, args, err } func makeTestingConvCfg() (sessiondatapb.DataConversionConfig, *time.Location) { @@ -1252,3 +1261,263 @@ func TestConnCloseCancelsAuth(t *testing.T) { // Check that the auth process indeed noticed the cancelation. <-authBlocked } + +func TestParseClientProvidedSessionParameters(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // The test server is used only incidentally by this test: this is not the + // server that the client will connect to; we just use it on the side to + // execute some metadata queries that pgx sends whenever it opens a + // connection. + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true, UseDatabase: "system"}) + defer s.Stopper().Stop(context.Background()) + + // Start a pgwire "server". + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverAddr := ln.Addr() + log.Infof(context.Background(), "started listener on %s", serverAddr) + testCases := []struct { + desc string + query string + assert func(t *testing.T, args sql.SessionArgs, err error) + }{ + { + desc: "user is set from query", + query: "user=root", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "root", args.User.Normalized()) + }, + }, + { + desc: "user is ignored in options", + query: "user=root&options=-c%20user=test_user_from_options", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "root", args.User.Normalized()) + _, ok := args.SessionDefaults["user"] + require.False(t, ok) + }, + }, + { + desc: "results_buffer_size is not configurable from options", + query: "user=root&options=-c%20results_buffer_size=42", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "options: parameter \"results_buffer_size\" cannot be changed", err) + }, + }, + { + desc: "crdb:remote_addr is ignored in options", + query: "user=root&options=-c%20crdb%3Aremote_addr=2.3.4.5%3A5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.NotEqual(t, "2.3.4.5:5432", args.RemoteAddr.String()) + }, + }, + { + desc: "more keys than values in options error", + query: "user=root&options=-c%20search_path==public,test,default", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path==public,test,default\" is invalid, check '='", err) + }, + }, + { + desc: "more values than keys in options error", + query: "user=root&options=-c%20search_path", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path\" is invalid, check '='", err) + }, + }, + { + desc: "success parsing encoded options", + query: "user=root&options=-c%20search_path%3ddefault%2Ctest", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + }, + }, + { + desc: "success parsing options with no space after '-c'", + query: "user=root&options=-csearch_path=default,test -coptimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"]) + }, + }, + { + desc: "error when no leading '-c'", + query: "user=root&options=search_path=default", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path=default\" is invalid, must have prefix '-c' or '--'", err) + }, + }, + { + desc: "'-c' with no leading space belongs to prev value", + query: "user=root&options=-c search_path=default-c", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default-c", args.SessionDefaults["search_path"]) + }, + }, + { + desc: "fail to parse '-c' with no leading space", + query: "user=root&options=-c search_path=default-c optimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"optimizer_use_multicol_stats=true\" is invalid, must have prefix '-c' or '--'", err) + }, + }, + { + desc: "parse multiple options successfully", + query: "user=root&options=-c%20search_path=default,test%20-c%20optimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"]) + }, + }, + { + desc: "success parsing option with space in value", + query: "user=root&options=-c default_transaction_isolation=READ\\ UNCOMMITTED", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "READ UNCOMMITTED", args.SessionDefaults["default_transaction_isolation"]) + }, + }, + { + desc: "remote_addr missing port", + query: "user=root&crdb:remote_addr=5.4.3.2", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "invalid address format: address 5.4.3.2: missing port in address", err) + }, + }, + { + desc: "remote_addr port must be numeric", + query: "user=root&crdb:remote_addr=5.4.3.2:port", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "remote port is not numeric", err) + }, + }, + { + desc: "remote_addr host must be numeric", + query: "user=root&crdb:remote_addr=ip:5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "remote address is not numeric", err) + }, + }, + { + desc: "success setting remote address from query", + query: "user=root&crdb:remote_addr=2.3.4.5:5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "2.3.4.5:5432", args.RemoteAddr.String()) + }, + }, + } + + baseURL := fmt.Sprintf("postgres://%s/system?sslmode=disable", serverAddr) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + go func() { + url := fmt.Sprintf("%s&%s", baseURL, tc.query) + c, err := gosql.Open("postgres", url) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + // ignore the error because there is no answer from the server, we are + // interested in parsing session arguments only + _ = c.PingContext(ctx) + // closing connection immediately, since getSessionArgs is blocking + _ = c.Close() + }() + + // Wait for the client to connect and perform the handshake. + _, args, err := getSessionArgs(ln, true /* trustRemoteAddr */) + tc.assert(t, args, err) + }) + } +} + +func TestSetSessionArguments(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + ctx := context.Background() + defer s.Stopper().Stop(ctx) + + pgURL, cleanupFunc := sqlutils.PGUrl( + t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser), + ) + defer cleanupFunc() + + q := pgURL.Query() + q.Add("options", " --user=test -c search_path=public,testsp %20 --default-transaction-isolation=read\\ uncommitted -capplication_name=test --datestyle=iso\\ ,\\ mdy\\ ") + pgURL.RawQuery = q.Encode() + noBufferDB, err := gosql.Open("postgres", pgURL.String()) + + if err != nil { + t.Fatal(err) + } + defer noBufferDB.Close() + + pgxConfig, err := pgx.ParseConnectionString(pgURL.String()) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(pgxConfig) + if err != nil { + t.Fatal(err) + } + + rows, err := conn.Query("show all") + if err != nil { + t.Fatal(err) + } + + expectedOptions := map[string]string{ + "search_path": "public,testsp", + // setting an isolation level is a noop: + // all transactions execute with serializable isolation. + "default_transaction_isolation": "serializable", + "application_name": "test", + "datestyle": "ISO, MDY", + } + expectedFoundOptions := len(expectedOptions) + + var foundOptions int + var variable, value string + for rows.Next() { + err = rows.Scan(&variable, &value) + if err != nil { + t.Fatal(err) + } + if v, ok := expectedOptions[variable]; ok { + foundOptions++ + if v != value { + t.Fatalf("option %q expected value %q, actual %q", variable, v, value) + } + } + } + require.Equal(t, expectedFoundOptions, foundOptions) + + if err := conn.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index e95d437c0aa8..8b702e50d4a0 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "io" "net" + "net/url" "strconv" "strings" "sync/atomic" @@ -755,23 +756,21 @@ func parseClientProvidedSessionParameters( } args.RemoteAddr = &net.TCPAddr{IP: ip, Port: port} - default: - exists, configurable := sql.IsSessionVariableConfigurable(key) - - switch { - case exists && configurable: - args.SessionDefaults[key] = value - - case !exists: - if _, ok := sql.UnsupportedVars[key]; ok { - counter := sqltelemetry.UnimplementedClientStatusParameterCounter(key) - telemetry.Inc(counter) + case "options": + opts, err := parseOptions(value) + if err != nil { + return sql.SessionArgs{}, err + } + for _, opt := range opts { + err = loadParameter(ctx, opt.key, opt.value, &args) + if err != nil { + return sql.SessionArgs{}, pgerror.Wrapf(err, pgerror.GetPGCode(err), "options") } - log.Warningf(ctx, "unknown configuration parameter: %q", key) - - case !configurable: - return sql.SessionArgs{}, pgerror.Newf(pgcode.CantChangeRuntimeParam, - "parameter %q cannot be changed", key) + } + default: + err = loadParameter(ctx, key, value, &args) + if err != nil { + return sql.SessionArgs{}, err } } } @@ -790,6 +789,133 @@ func parseClientProvidedSessionParameters( return args, nil } +func loadParameter(ctx context.Context, key, value string, args *sql.SessionArgs) error { + exists, configurable := sql.IsSessionVariableConfigurable(key) + + switch { + case exists && configurable: + args.SessionDefaults[key] = value + + case !exists: + if _, ok := sql.UnsupportedVars[key]; ok { + counter := sqltelemetry.UnimplementedClientStatusParameterCounter(key) + telemetry.Inc(counter) + } + log.Warningf(ctx, "unknown configuration parameter: %q", key) + + case !configurable: + return pgerror.Newf(pgcode.CantChangeRuntimeParam, + "parameter %q cannot be changed", key) + } + return nil +} + +// option represents an option argument passed in the connection URL. +type option struct { + key string + value string +} + +// parseOptions parses the given string into the options. The options must be +// separated by space and have one of the following patterns: +// '-c key=value', '-ckey=value', '--key=value' +func parseOptions(optionsString string) ([]option, error) { + var res []option + optionsRaw, err := url.QueryUnescape(optionsString) + if err != nil { + return nil, pgerror.Newf(pgcode.ProtocolViolation, "failed to unescape options %q", optionsString) + } + + lastWasDashC := false + opts := splitOptions(optionsRaw) + + for i := 0; i < len(opts); i++ { + prefix := "" + if len(opts[i]) > 1 { + prefix = opts[i][:2] + } + + switch { + case opts[i] == "-c": + lastWasDashC = true + continue + case lastWasDashC: + lastWasDashC = false + // if the last option was '-c' parse current option with no regard to + // the prefix + prefix = "" + case prefix == "--" || prefix == "-c": + lastWasDashC = false + default: + return nil, pgerror.Newf(pgcode.ProtocolViolation, + "option %q is invalid, must have prefix '-c' or '--'", opts[i]) + } + + opt, err := splitOption(opts[i], prefix) + if err != nil { + return nil, err + } + res = append(res, opt) + } + return res, nil +} + +// splitOptions slices the given string into substrings separated by space +// unless the space is escaped using backslashes '\\'. It also skips multiple +// subsequent spaces. +func splitOptions(options string) []string { + var res []string + var sb strings.Builder + i := 0 + for i < len(options) { + sb.Reset() + // skip leading space + for i < len(options) && options[i] == ' ' { + i++ + } + if i == len(options) { + break + } + + lastWasEscape := false + + for i < len(options) { + if options[i] == ' ' && !lastWasEscape { + break + } + if !lastWasEscape && options[i] == '\\' { + lastWasEscape = true + } else { + lastWasEscape = false + sb.WriteByte(options[i]) + } + i++ + } + + res = append(res, sb.String()) + } + + return res +} + +// splitOption splits the given opt argument into substrings separated by '='. +// It returns an error if the given option does not comply with the pattern +// "key=value" and the number of elements in the result is not two. +// splitOption removes the prefix from the key and replaces '-' with '_' so +// "--option-name=value" becomes [option_name, value]. +func splitOption(opt, prefix string) (option, error) { + kv := strings.Split(opt, "=") + + if len(kv) != 2 { + return option{}, pgerror.Newf(pgcode.ProtocolViolation, + "option %q is invalid, check '='", opt) + } + + kv[0] = strings.TrimPrefix(kv[0], prefix) + + return option{key: strings.ReplaceAll(kv[0], "-", "_"), value: kv[1]}, nil +} + // Note: Usage of an env var here makes it possible to unconditionally // enable this feature when cluster settings do not work reliably, // e.g. in multi-tenant setups in v20.2. This override mechanism can