diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 335c1540c75a..3b0d5fc60377 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -170,6 +170,7 @@ ALL_TESTS = [ "//pkg/sql/catalog/typedesc:typedesc_test", "//pkg/sql/catalog:catalog_test", "//pkg/sql/colcontainer:colcontainer_test", + "//pkg/sql/colconv:colconv_test", "//pkg/sql/colencoding:colencoding_test", "//pkg/sql/colexec/colbuilder:colbuilder_test", "//pkg/sql/colexec/colexecagg:colexecagg_test", diff --git a/pkg/ccl/changefeedccl/changefeeddist/distflow.go b/pkg/ccl/changefeedccl/changefeeddist/distflow.go index 4b13401a2e43..d72e48c8d249 100644 --- a/pkg/ccl/changefeedccl/changefeeddist/distflow.go +++ b/pkg/ccl/changefeedccl/changefeeddist/distflow.go @@ -136,7 +136,7 @@ func StartDistChangefeed( return resultRows.Err() } -// changefeedResultWriter implements the `rowexec.resultWriter` that sends +// changefeedResultWriter implements the `sql.rowResultWriter` that sends // the received rows back over the given channel. type changefeedResultWriter struct { rowsCh chan<- tree.Datums diff --git a/pkg/ccl/importccl/import_processor_test.go b/pkg/ccl/importccl/import_processor_test.go index c46ecca3cbf8..78acfc2f88a8 100644 --- a/pkg/ccl/importccl/import_processor_test.go +++ b/pkg/ccl/importccl/import_processor_test.go @@ -37,7 +37,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/rowexec" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/storage/cloud" "github.com/cockroachdb/cockroach/pkg/storage/cloud/nodelocal" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -190,9 +189,6 @@ func (r *errorReportingRowReceiver) Push( } func (r *errorReportingRowReceiver) ProducerDone() {} -func (r *errorReportingRowReceiver) Types() []*types.T { - return nil -} // A do nothing bulk adder implementation. type doNothingKeyAdder struct { diff --git a/pkg/col/coldata/BUILD.bazel b/pkg/col/coldata/BUILD.bazel index c0f09ab8bc2a..d4d820373a58 100644 --- a/pkg/col/coldata/BUILD.bazel +++ b/pkg/col/coldata/BUILD.bazel @@ -43,6 +43,7 @@ go_test( embed = [":coldata"], deps = [ "//pkg/col/coldatatestutils", + "//pkg/sql/colconv", "//pkg/sql/types", "//pkg/testutils/buildutil", "//pkg/util/leaktest", diff --git a/pkg/col/coldata/batch.go b/pkg/col/coldata/batch.go index d898eb1ca7f2..1e1377fbef24 100644 --- a/pkg/col/coldata/batch.go +++ b/pkg/col/coldata/batch.go @@ -333,16 +333,17 @@ func (m *MemBatch) String() string { if m.Length() == 0 { return "[zero-length batch]" } - var builder strings.Builder - strs := make([]string, len(m.ColVecs())) - for i := 0; i < m.Length(); i++ { - builder.WriteString("\n[") - for colIdx, v := range m.ColVecs() { - strs[colIdx] = fmt.Sprintf("%v", GetValueAt(v, i)) - } - builder.WriteString(strings.Join(strs, ", ")) - builder.WriteString("]") + if VecsToStringWithRowPrefix == nil { + panic("need to inject the implementation from sql/colconv package") } - builder.WriteString("\n") - return builder.String() + return strings.Join(VecsToStringWithRowPrefix(m.ColVecs(), m.Length(), m.Selection(), "" /* prefix */), "\n") } + +// VecsToStringWithRowPrefix returns a pretty representation of the vectors. +// This method will convert all vectors to datums in order to print everything +// in the same manner as the tree.Datum representation does. Each row is printed +// in a separate string. +// +// The implementation lives in colconv package and is injected during the +// initialization. +var VecsToStringWithRowPrefix func(vecs []Vec, length int, sel []int, prefix string) []string diff --git a/pkg/col/coldata/batch_test.go b/pkg/col/coldata/batch_test.go index c5dddcbe3e7a..cf57451e795f 100644 --- a/pkg/col/coldata/batch_test.go +++ b/pkg/col/coldata/batch_test.go @@ -16,6 +16,7 @@ import ( "unsafe" "github.com/cockroachdb/cockroach/pkg/col/coldata" + "github.com/cockroachdb/cockroach/pkg/sql/colconv" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/stretchr/testify/assert" @@ -142,3 +143,43 @@ func TestBatchWithBytesAndNulls(t *testing.T) { assert.True(t, len(vec.Get(idx)) == 0) } } + +// Import colconv package in order to inject the implementation of +// coldata.VecsToStringWithRowPrefix. +var _ colconv.VecToDatumConverter + +func TestBatchString(t *testing.T) { + defer leaktest.AfterTest(t)() + + b := coldata.NewMemBatch([]*types.T{types.String}, coldata.StandardColumnFactory) + input := []string{"one", "two", "three"} + for i := range input { + b.ColVec(0).Bytes().Set(i, []byte(input[i])) + } + getExpected := func(length int, sel []int) string { + var result string + for i := 0; i < length; i++ { + if i > 0 { + result += "\n" + } + rowIdx := i + if sel != nil { + rowIdx = sel[i] + } + result += "['" + input[rowIdx] + "']" + } + return result + } + for _, tc := range []struct { + length int + sel []int + }{ + {length: 3}, + {length: 2, sel: []int{0, 2}}, + } { + b.SetSelection(tc.sel != nil) + copy(b.Selection(), tc.sel) + b.SetLength(tc.length) + assert.Equal(t, getExpected(tc.length, tc.sel), b.String()) + } +} diff --git a/pkg/sql/colconv/BUILD.bazel b/pkg/sql/colconv/BUILD.bazel index 44d4aaacf31b..7c0b7af019d3 100644 --- a/pkg/sql/colconv/BUILD.bazel +++ b/pkg/sql/colconv/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") # TODO(irfansharif): The dependency tree for *.eg.go needs # sorting out. It depends on execgen+templates from elsewhere. Look towards @@ -9,6 +9,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "colconv", srcs = [ + "batch.go", # keep "datum_to_vec.eg.go", "vec_to_datum.eg.go", ], @@ -32,3 +33,15 @@ go_library( "@com_github_lib_pq//oid", ], ) + +go_test( + name = "colconv_test", + srcs = ["batch_test.go"], + embed = [":colconv"], + deps = [ + "//pkg/col/coldata", + "//pkg/sql/types", + "//pkg/util/leaktest", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/sql/colconv/batch.go b/pkg/sql/colconv/batch.go new file mode 100644 index 000000000000..7d095b1dc246 --- /dev/null +++ b/pkg/sql/colconv/batch.go @@ -0,0 +1,47 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package colconv + +import ( + "strings" + + "github.com/cockroachdb/cockroach/pkg/col/coldata" +) + +func init() { + coldata.VecsToStringWithRowPrefix = vecsToStringWithRowPrefix +} + +// vecsToStringWithRowPrefix returns a pretty representation of the vectors with +// each row being in a separate string. +func vecsToStringWithRowPrefix(vecs []coldata.Vec, length int, sel []int, prefix string) []string { + var builder strings.Builder + converter := NewAllVecToDatumConverter(len(vecs)) + defer converter.Release() + converter.ConvertVecs(vecs, length, sel) + result := make([]string, length) + strs := make([]string, len(vecs)) + for i := 0; i < length; i++ { + builder.Reset() + rowIdx := i + if sel != nil { + rowIdx = sel[i] + } + builder.WriteString(prefix + "[") + for colIdx := 0; colIdx < len(vecs); colIdx++ { + strs[colIdx] = converter.GetDatumColumn(colIdx)[rowIdx].String() + } + builder.WriteString(strings.Join(strs, " ")) + builder.WriteString("]") + result[i] = builder.String() + } + return result +} diff --git a/pkg/sql/colconv/batch_test.go b/pkg/sql/colconv/batch_test.go new file mode 100644 index 000000000000..bd15ef8701cd --- /dev/null +++ b/pkg/sql/colconv/batch_test.go @@ -0,0 +1,53 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package colconv + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/col/coldata" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/stretchr/testify/require" +) + +func TestVecsToStringWithRowPrefix(t *testing.T) { + defer leaktest.AfterTest(t)() + + vec := coldata.NewMemColumn(types.String, coldata.BatchSize(), coldata.StandardColumnFactory) + input := []string{"one", "two", "three"} + for i := range input { + vec.Bytes().Set(i, []byte(input[i])) + } + getExpected := func(length int, sel []int, prefix string) []string { + result := make([]string, length) + for i := 0; i < length; i++ { + rowIdx := i + if sel != nil { + rowIdx = sel[i] + } + result[i] = prefix + "['" + input[rowIdx] + "']" + } + return result + } + for _, tc := range []struct { + length int + sel []int + prefix string + }{ + {length: 3}, + {length: 2, sel: []int{0, 2}}, + {length: 3, prefix: "row: "}, + {length: 2, sel: []int{0, 2}, prefix: "row: "}, + } { + require.Equal(t, getExpected(tc.length, tc.sel, tc.prefix), vecsToStringWithRowPrefix([]coldata.Vec{vec}, tc.length, tc.sel, tc.prefix)) + } +} diff --git a/pkg/sql/colexec/colexecbase/simple_project.go b/pkg/sql/colexec/colexecbase/simple_project.go index a737fed2b5f7..6dbe70726c19 100644 --- a/pkg/sql/colexec/colexecbase/simple_project.go +++ b/pkg/sql/colexec/colexecbase/simple_project.go @@ -12,6 +12,7 @@ package colexecbase import ( "context" + "strings" "github.com/cockroachdb/cockroach/pkg/col/coldata" "github.com/cockroachdb/cockroach/pkg/sql/colexecop" @@ -86,6 +87,10 @@ func (b *projectingBatch) ReplaceCol(col coldata.Vec, idx int) { b.Batch.ReplaceCol(col, int(b.projection[idx])) } +func (b *projectingBatch) String() string { + return strings.Join(coldata.VecsToStringWithRowPrefix(b.ColVecs(), b.Length(), b.Selection(), "" /* prefix */), "\n") +} + // NewSimpleProjectOp returns a new simpleProjectOp that applies a simple // projection on the columns in its input batch, returning a new batch with // only the columns in the projection slice, in order. In a degenerate case diff --git a/pkg/sql/colflow/flow_coordinator.go b/pkg/sql/colflow/flow_coordinator.go index 1b2adb06426b..952d088bf484 100644 --- a/pkg/sql/colflow/flow_coordinator.go +++ b/pkg/sql/colflow/flow_coordinator.go @@ -36,10 +36,8 @@ type FlowCoordinator struct { row rowenc.EncDatumRow meta *execinfrapb.ProducerMetadata - // cancelFlow will return a function to cancel the context of the flow. It - // is a function in order to be lazily evaluated, since the context - // cancellation function is only available after the flow is Start()'ed. - cancelFlow func() context.CancelFunc + // cancelFlow cancels the context of the flow. + cancelFlow context.CancelFunc } var flowCoordinatorPool = sync.Pool{ @@ -57,7 +55,7 @@ func NewFlowCoordinator( processorID int32, input execinfra.RowSource, output execinfra.RowReceiver, - cancelFlow func() context.CancelFunc, + cancelFlow context.CancelFunc, ) *FlowCoordinator { f := flowCoordinatorPool.Get().(*FlowCoordinator) f.input = input @@ -155,7 +153,7 @@ func (f *FlowCoordinator) Next() (rowenc.EncDatumRow, *execinfrapb.ProducerMetad func (f *FlowCoordinator) close() { if f.InternalClose() { - f.cancelFlow()() + f.cancelFlow() } } diff --git a/pkg/sql/colflow/vectorized_flow.go b/pkg/sql/colflow/vectorized_flow.go index 03b1cb1ceff1..fdd31af59da6 100644 --- a/pkg/sql/colflow/vectorized_flow.go +++ b/pkg/sql/colflow/vectorized_flow.go @@ -1009,7 +1009,7 @@ func (s *vectorizedFlowCreator) setupOutput( pspec.ProcessorID, input, s.syncFlowConsumer, - s.getCancelFlowFn, + s.getCancelFlowFn(), ) // The flow coordinator is a root of its operator chain. s.opChains = append(s.opChains, f) diff --git a/pkg/sql/colflow/vectorized_flow_shutdown_test.go b/pkg/sql/colflow/vectorized_flow_shutdown_test.go index 755121574b81..6f9931360e9d 100644 --- a/pkg/sql/colflow/vectorized_flow_shutdown_test.go +++ b/pkg/sql/colflow/vectorized_flow_shutdown_test.go @@ -66,9 +66,9 @@ func (c callbackCloser) Close() error { return c.closeCb() } -// TestVectorizedFlowShutdown tests that closing the materializer correctly +// TestVectorizedFlowShutdown tests that closing the FlowCoordinator correctly // closes all the infrastructure corresponding to the flow ending in that -// materializer. Namely: +// FlowCoordinator. Namely: // - on a remote node, it creates a colflow.HashRouter with 3 outputs (with a // corresponding to each colrpc.Outbox) as well as 3 standalone Outboxes; // - on a local node, it creates 6 colrpc.Inboxes that feed into an unordered @@ -366,7 +366,7 @@ func TestVectorizedFlowShutdown(t *testing.T) { 1, /* processorID */ materializer, nil, /* output */ - func() context.CancelFunc { return cancelLocal }, + cancelLocal, ) coordinator.Start(ctxLocal) diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index 157ccaba116a..96abaa1f4c1c 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -737,8 +737,8 @@ type DescribeResult interface { // SetInferredTypes tells the client about the inferred placeholder types. SetInferredTypes([]oid.Oid) - // SetNoDataDescription is used to tell the client that the prepared statement - // or portal produces no rows. + // SetNoDataRowDescription is used to tell the client that the prepared + // statement or portal produces no rows. SetNoDataRowDescription() // SetPrepStmtOutput tells the client about the results schema of a prepared // statement. diff --git a/pkg/sql/distsql_running.go b/pkg/sql/distsql_running.go index 1883c93088e4..5ba273ddaac6 100644 --- a/pkg/sql/distsql_running.go +++ b/pkg/sql/distsql_running.go @@ -611,91 +611,133 @@ func (r *DistSQLReceiver) SetError(err error) { } } -// Push is part of the RowReceiver interface. -func (r *DistSQLReceiver) Push( - row rowenc.EncDatumRow, meta *execinfrapb.ProducerMetadata, -) execinfra.ConsumerStatus { - if r.testingKnobs.pushCallback != nil { - r.testingKnobs.pushCallback(row, meta) +// pushMeta takes in non-empty metadata object and pushes it to the result +// writer. Possibly updated status is returned. +func (r *DistSQLReceiver) pushMeta(meta *execinfrapb.ProducerMetadata) execinfra.ConsumerStatus { + if metaWriter, ok := r.resultWriter.(MetadataResultWriter); ok { + metaWriter.AddMeta(r.ctx, meta) } - if meta != nil { - if metaWriter, ok := r.resultWriter.(MetadataResultWriter); ok { - metaWriter.AddMeta(r.ctx, meta) - } - if meta.LeafTxnFinalState != nil { - if r.txn != nil { - if r.txn.ID() == meta.LeafTxnFinalState.Txn.ID { - if err := r.txn.UpdateRootWithLeafFinalState(r.ctx, meta.LeafTxnFinalState); err != nil { - r.SetError(err) - } + if meta.LeafTxnFinalState != nil { + if r.txn != nil { + if r.txn.ID() == meta.LeafTxnFinalState.Txn.ID { + if err := r.txn.UpdateRootWithLeafFinalState(r.ctx, meta.LeafTxnFinalState); err != nil { + r.SetError(err) } - } else { - r.SetError( - errors.Errorf("received a leaf final state (%s); but have no root", meta.LeafTxnFinalState)) } + } else { + r.SetError( + errors.Errorf("received a leaf final state (%s); but have no root", meta.LeafTxnFinalState)) } - if meta.Err != nil { - // Check if the error we just received should take precedence over a - // previous error (if any). - if roachpb.ErrPriority(meta.Err) > roachpb.ErrPriority(r.resultWriter.Err()) { - if r.txn != nil { - if retryErr := (*roachpb.UnhandledRetryableError)(nil); errors.As(meta.Err, &retryErr) { - // Update the txn in response to remote errors. In the non-DistSQL - // world, the TxnCoordSender handles "unhandled" retryable errors, - // but this one is coming from a distributed SQL node, which has - // left the handling up to the root transaction. - meta.Err = r.txn.UpdateStateOnRemoteRetryableErr(r.ctx, &retryErr.PErr) - // Update the clock with information from the error. On non-DistSQL - // code paths, the DistSender does this. - // TODO(andrei): We don't propagate clock signals on success cases - // through DistSQL; we should. We also don't propagate them through - // non-retryable errors; we also should. - if r.clockUpdater != nil { - r.clockUpdater.Update(retryErr.PErr.Now) - } + } + if meta.Err != nil { + // Check if the error we just received should take precedence over a + // previous error (if any). + if roachpb.ErrPriority(meta.Err) > roachpb.ErrPriority(r.resultWriter.Err()) { + if r.txn != nil { + if retryErr := (*roachpb.UnhandledRetryableError)(nil); errors.As(meta.Err, &retryErr) { + // Update the txn in response to remote errors. In the non-DistSQL + // world, the TxnCoordSender handles "unhandled" retryable errors, + // but this one is coming from a distributed SQL node, which has + // left the handling up to the root transaction. + meta.Err = r.txn.UpdateStateOnRemoteRetryableErr(r.ctx, &retryErr.PErr) + // Update the clock with information from the error. On non-DistSQL + // code paths, the DistSender does this. + // TODO(andrei): We don't propagate clock signals on success cases + // through DistSQL; we should. We also don't propagate them through + // non-retryable errors; we also should. + if r.clockUpdater != nil { + r.clockUpdater.Update(retryErr.PErr.Now) } } - r.SetError(meta.Err) } + r.SetError(meta.Err) } - if len(meta.Ranges) > 0 { - r.rangeCache.Insert(r.ctx, meta.Ranges...) + } + if len(meta.Ranges) > 0 { + r.rangeCache.Insert(r.ctx, meta.Ranges...) + } + if len(meta.TraceData) > 0 { + if span := tracing.SpanFromContext(r.ctx); span != nil { + span.ImportRemoteSpans(meta.TraceData) } - if len(meta.TraceData) > 0 { - if span := tracing.SpanFromContext(r.ctx); span != nil { - span.ImportRemoteSpans(meta.TraceData) - } - var ev roachpb.ContentionEvent - for i := range meta.TraceData { - meta.TraceData[i].Structured(func(any *pbtypes.Any) { - if !pbtypes.Is(any, &ev) { - return - } - if err := pbtypes.UnmarshalAny(any, &ev); err != nil { - return - } - if r.contendedQueryMetric != nil { - // Increment the contended query metric at most once - // if the query sees at least one contention event. - r.contendedQueryMetric.Inc(1) - r.contendedQueryMetric = nil - } - r.contentionRegistry.AddContentionEvent(ev) - }) - } + var ev roachpb.ContentionEvent + for i := range meta.TraceData { + meta.TraceData[i].Structured(func(any *pbtypes.Any) { + if !pbtypes.Is(any, &ev) { + return + } + if err := pbtypes.UnmarshalAny(any, &ev); err != nil { + return + } + if r.contendedQueryMetric != nil { + // Increment the contended query metric at most once + // if the query sees at least one contention event. + r.contendedQueryMetric.Inc(1) + r.contendedQueryMetric = nil + } + r.contentionRegistry.AddContentionEvent(ev) + }) } - if meta.Metrics != nil { - r.stats.bytesRead += meta.Metrics.BytesRead - r.stats.rowsRead += meta.Metrics.RowsRead - if r.progressAtomic != nil && r.expectedRowsRead != 0 { - progress := float64(r.stats.rowsRead) / float64(r.expectedRowsRead) - atomic.StoreUint64(r.progressAtomic, math.Float64bits(progress)) - } - meta.Metrics.Release() + } + if meta.Metrics != nil { + r.stats.bytesRead += meta.Metrics.BytesRead + r.stats.rowsRead += meta.Metrics.RowsRead + if r.progressAtomic != nil && r.expectedRowsRead != 0 { + progress := float64(r.stats.rowsRead) / float64(r.expectedRowsRead) + atomic.StoreUint64(r.progressAtomic, math.Float64bits(progress)) } - // Release the meta object. It is unsafe for use after this call. - meta.Release() - return r.status + meta.Metrics.Release() + } + // Release the meta object. It is unsafe for use after this call. + meta.Release() + return r.status +} + +// handleCommErr handles the communication error (the one returned when +// attempting to add data to the result writer). +func (r *DistSQLReceiver) handleCommErr(commErr error) { + // ErrLimitedResultClosed and errIEResultChannelClosed are not real + // errors, it is a signal to stop distsql and return success to the + // client (that's why we don't set the error on the resultWriter). + if errors.Is(commErr, ErrLimitedResultClosed) { + log.VEvent(r.ctx, 1, "encountered ErrLimitedResultClosed (transitioning to draining)") + r.status = execinfra.DrainRequested + } else if errors.Is(commErr, errIEResultChannelClosed) { + log.VEvent(r.ctx, 1, "encountered errIEResultChannelClosed (transitioning to draining)") + r.status = execinfra.DrainRequested + } else { + // Set the error on the resultWriter to notify the consumer about + // it. Most clients don't care to differentiate between + // communication errors and query execution errors, so they can + // simply inspect resultWriter.Err(). + r.SetError(commErr) + + // The only client that needs to know that a communication error and + // not a query execution error has occurred is + // connExecutor.execWithDistSQLEngine which will inspect r.commErr + // on its own and will shut down the connection. + // + // We don't need to shut down the connection if there's a + // portal-related error. This is definitely a layering violation, + // but is part of some accepted technical debt (see comments on + // sql/pgwire.limitedCommandResult.moreResultsNeeded). Instead of + // changing the signature of AddRow, we have a sentinel error that + // is handled specially here. + if !errors.Is(commErr, ErrLimitedResultNotSupported) { + r.commErr = commErr + } + } +} + +// Push is part of the execinfra.RowReceiver interface. +func (r *DistSQLReceiver) Push( + row rowenc.EncDatumRow, meta *execinfrapb.ProducerMetadata, +) execinfra.ConsumerStatus { + if r.testingKnobs.pushCallback != nil { + r.testingKnobs.pushCallback(row, meta) + } + if meta != nil { + return r.pushMeta(meta) } if r.resultWriter.Err() == nil && r.ctx.Err() != nil { r.SetError(r.ctx.Err()) @@ -738,37 +780,7 @@ func (r *DistSQLReceiver) Push( } r.tracing.TraceExecRowsResult(r.ctx, r.row) if commErr := r.resultWriter.AddRow(r.ctx, r.row); commErr != nil { - // ErrLimitedResultClosed and errIEResultChannelClosed are not real - // errors, it is a signal to stop distsql and return success to the - // client (that's why we don't set the error on the resultWriter). - if errors.Is(commErr, ErrLimitedResultClosed) { - log.VEvent(r.ctx, 1, "encountered ErrLimitedResultClosed (transitioning to draining)") - r.status = execinfra.DrainRequested - } else if errors.Is(commErr, errIEResultChannelClosed) { - log.VEvent(r.ctx, 1, "encountered errIEResultChannelClosed (transitioning to draining)") - r.status = execinfra.DrainRequested - } else { - // Set the error on the resultWriter to notify the consumer about - // it. Most clients don't care to differentiate between - // communication errors and query execution errors, so they can - // simply inspect resultWriter.Err(). - r.SetError(commErr) - - // The only client that needs to know that a communication error and - // not a query execution error has occurred is - // connExecutor.execWithDistSQLEngine which will inspect r.commErr - // on its own and will shut down the connection. - // - // We don't need to shut down the connection if there's a - // portal-related error. This is definitely a layering violation, - // but is part of some accepted technical debt (see comments on - // sql/pgwire.limitedCommandResult.moreResultsNeeded). Instead of - // changing the signature of AddRow, we have a sentinel error that - // is handled specially here. - if !errors.Is(commErr, ErrLimitedResultNotSupported) { - r.commErr = commErr - } - } + r.handleCommErr(commErr) } return r.status } @@ -782,7 +794,7 @@ var ( ErrLimitedResultClosed = errors.New("row count limit closed") ) -// ProducerDone is part of the RowReceiver interface. +// ProducerDone is part of the execinfra.RowReceiver interface. func (r *DistSQLReceiver) ProducerDone() { if r.closed { panic("double close") @@ -790,11 +802,6 @@ func (r *DistSQLReceiver) ProducerDone() { r.closed = true } -// Types is part of the RowReceiver interface. -func (r *DistSQLReceiver) Types() []*types.T { - return r.outputTypes -} - // PlanAndRunSubqueries returns false if an error was encountered and sets that // error in the provided receiver. Note that if false is returned, then this // function will have closed all the subquery plans because it assumes that the diff --git a/pkg/sql/distsql_running_test.go b/pkg/sql/distsql_running_test.go index 98913ef53b8a..092f13820b47 100644 --- a/pkg/sql/distsql_running_test.go +++ b/pkg/sql/distsql_running_test.go @@ -205,8 +205,7 @@ func TestDistSQLReceiverErrorRanking(t *testing.T) { txn := kv.NewTxn(ctx, db, s.NodeID()) - // We're going to use a rowResultWriter to which only errors will be passed. - rw := newCallbackResultWriter(nil /* fn */) + rw := &errOnlyResultWriter{} recv := MakeDistSQLReceiver( ctx, rw, diff --git a/pkg/sql/execinfra/base.go b/pkg/sql/execinfra/base.go index 48b8df07b638..38a88ef7468d 100644 --- a/pkg/sql/execinfra/base.go +++ b/pkg/sql/execinfra/base.go @@ -75,10 +75,6 @@ type RowReceiver interface { // Implementations of Push() must be thread-safe. Push(row rowenc.EncDatumRow, meta *execinfrapb.ProducerMetadata) ConsumerStatus - // Types returns the types of the EncDatumRow that this RowReceiver expects - // to be pushed. - Types() []*types.T - // ProducerDone is called when the producer has pushed all the rows and // metadata; it causes the RowReceiver to process all rows and clean up. // @@ -524,11 +520,6 @@ func (rc *RowChannel) ConsumerClosed() { } } -// Types is part of the RowReceiver interface. -func (rc *RowChannel) Types() []*types.T { - return rc.types -} - // DoesNotUseTxn implements the DoesNotUseTxn interface. Since the RowChannel's // input is run in a different goroutine, the flow will check the RowChannel's // input separately. diff --git a/pkg/sql/flowinfra/outbox.go b/pkg/sql/flowinfra/outbox.go index 5585f079bcdb..b27668b3fd9a 100644 --- a/pkg/sql/flowinfra/outbox.go +++ b/pkg/sql/flowinfra/outbox.go @@ -438,7 +438,7 @@ func (m *Outbox) run(ctx context.Context, wg *sync.WaitGroup) { // Start starts the outbox. func (m *Outbox) Start(ctx context.Context, wg *sync.WaitGroup, flowCtxCancel context.CancelFunc) { - if m.Types() == nil { + if m.OutputTypes() == nil { panic("outbox not initialized") } if wg != nil { diff --git a/pkg/sql/pgwire/BUILD.bazel b/pkg/sql/pgwire/BUILD.bazel index e4edace217b1..45d9f18acbc0 100644 --- a/pkg/sql/pgwire/BUILD.bazel +++ b/pkg/sql/pgwire/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "//pkg/util/errorutil/unimplemented", "//pkg/util/humanizeutil", "//pkg/util/ipaddr", + "//pkg/util/json", "//pkg/util/log", "//pkg/util/log/eventpb", "//pkg/util/metric", @@ -52,6 +53,8 @@ go_library( "//pkg/util/timeofday", "//pkg/util/timetz", "//pkg/util/timeutil", + "//pkg/util/timeutil/pgdate", + "//pkg/util/uuid", "@com_github_cockroachdb_apd_v2//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index e887be3064f6..9faaf47cef9a 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -107,7 +107,7 @@ type paramStatusUpdate struct { var _ sql.CommandResult = &commandResult{} -// Close is part of the CommandResult interface. +// Close is part of the sql.RestrictedCommandResult interface. func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) { r.assertNotReleased() defer r.release() @@ -167,19 +167,19 @@ func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndica } } -// Discard is part of the CommandResult interface. +// Discard is part of the sql.RestrictedCommandResult interface. func (r *commandResult) Discard() { r.assertNotReleased() defer r.release() } -// Err is part of the CommandResult interface. +// Err is part of the sql.RestrictedCommandResult interface. func (r *commandResult) Err() error { r.assertNotReleased() return r.err } -// SetError is part of the CommandResult interface. +// SetError is part of the sql.RestrictedCommandResult interface. // // We're not going to write any bytes to the buffer in order to support future // SetError() calls. The error will only be serialized at Close() time. @@ -188,7 +188,7 @@ func (r *commandResult) SetError(err error) { r.err = err } -// AddRow is part of the CommandResult interface. +// AddRow is part of the sql.RestrictedCommandResult interface. func (r *commandResult) AddRow(ctx context.Context, row tree.Datums) error { r.assertNotReleased() if r.err != nil { @@ -214,13 +214,13 @@ func (r *commandResult) AddRow(ctx context.Context, row tree.Datums) error { return err } -// DisableBuffering is part of the CommandResult interface. +// DisableBuffering is part of the sql.RestrictedCommandResult interface. func (r *commandResult) DisableBuffering() { r.assertNotReleased() r.bufferingDisabled = true } -// BufferParamStatusUpdate is part of the CommandResult interface. +// BufferParamStatusUpdate is part of the sql.RestrictedCommandResult interface. func (r *commandResult) BufferParamStatusUpdate(param string, val string) { r.buffer.paramStatusUpdates = append( r.buffer.paramStatusUpdates, @@ -228,12 +228,12 @@ func (r *commandResult) BufferParamStatusUpdate(param string, val string) { ) } -// BufferNotice is part of the CommandResult interface. +// BufferNotice is part of the sql.RestrictedCommandResult interface. func (r *commandResult) BufferNotice(notice pgnotice.Notice) { r.buffer.notices = append(r.buffer.notices, notice) } -// SetColumns is part of the CommandResult interface. +// SetColumns is part of the sql.RestrictedCommandResult interface. func (r *commandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) { r.assertNotReleased() r.conn.writerState.fi.registerCmd(r.pos) @@ -246,28 +246,28 @@ func (r *commandResult) SetColumns(ctx context.Context, cols colinfo.ResultColum } } -// SetInferredTypes is part of the DescribeResult interface. +// SetInferredTypes is part of the sql.DescribeResult interface. func (r *commandResult) SetInferredTypes(types []oid.Oid) { r.assertNotReleased() r.conn.writerState.fi.registerCmd(r.pos) r.conn.bufferParamDesc(types) } -// SetNoDataRowDescription is part of the DescribeResult interface. +// SetNoDataRowDescription is part of the sql.DescribeResult interface. func (r *commandResult) SetNoDataRowDescription() { r.assertNotReleased() r.conn.writerState.fi.registerCmd(r.pos) r.conn.bufferNoDataMsg() } -// SetPrepStmtOutput is part of the DescribeResult interface. +// SetPrepStmtOutput is part of the sql.DescribeResult interface. func (r *commandResult) SetPrepStmtOutput(ctx context.Context, cols colinfo.ResultColumns) { r.assertNotReleased() r.conn.writerState.fi.registerCmd(r.pos) _ /* err */ = r.conn.writeRowDescription(ctx, cols, nil /* formatCodes */, &r.conn.writerState.buf) } -// SetPortalOutput is part of the DescribeResult interface. +// SetPortalOutput is part of the sql.DescribeResult interface. func (r *commandResult) SetPortalOutput( ctx context.Context, cols colinfo.ResultColumns, formatCodes []pgwirebase.FormatCode, ) { @@ -276,19 +276,19 @@ func (r *commandResult) SetPortalOutput( _ /* err */ = r.conn.writeRowDescription(ctx, cols, formatCodes, &r.conn.writerState.buf) } -// IncrementRowsAffected is part of the CommandResult interface. +// IncrementRowsAffected is part of the sql.RestrictedCommandResult interface. func (r *commandResult) IncrementRowsAffected(ctx context.Context, n int) { r.assertNotReleased() r.rowsAffected += n } -// RowsAffected is part of the CommandResult interface. +// RowsAffected is part of the sql.RestrictedCommandResult interface. func (r *commandResult) RowsAffected() int { r.assertNotReleased() return r.rowsAffected } -// ResetStmtType is part of the CommandResult interface. +// ResetStmtType is part of the sql.RestrictedCommandResult interface. func (r *commandResult) ResetStmtType(stmt tree.Statement) { r.assertNotReleased() r.stmtType = stmt.StatementReturnType() @@ -398,7 +398,7 @@ type limitedCommandResult struct { limit int } -// AddRow is part of the CommandResult interface. +// AddRow is part of the sql.RestrictedCommandResult interface. func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error { if err := r.commandResult.AddRow(ctx, row); err != nil { return err diff --git a/pkg/sql/pgwire/types.go b/pkg/sql/pgwire/types.go index f1c67f685cda..f98378682de2 100644 --- a/pkg/sql/pgwire/types.go +++ b/pkg/sql/pgwire/types.go @@ -31,9 +31,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/duration" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/ipaddr" + "github.com/cockroachdb/cockroach/pkg/util/json" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/timeofday" "github.com/cockroachdb/cockroach/pkg/util/timetz" + "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -64,6 +67,67 @@ func pgTypeForParserType(t *types.T) pgType { } } +func writeTextBool(b *writeBuffer, v bool) { + b.putInt32(1) + b.writeByte(tree.PgwireFormatBool(v)) +} + +func writeTextInt64(b *writeBuffer, v int64) { + // Start at offset 4 because `putInt32` clobbers the first 4 bytes. + s := strconv.AppendInt(b.putbuf[4:4], v, 10) + b.putInt32(int32(len(s))) + b.write(s) +} + +func writeTextFloat64(b *writeBuffer, fl float64, conv sessiondatapb.DataConversionConfig) { + var s []byte + // PostgreSQL supports 'Inf' as a valid literal for the floating point + // special value Infinity, therefore handling the special cases for them. + // (https://github.com/cockroachdb/cockroach/issues/62601) + if math.IsInf(fl, 1) { + s = []byte("Infinity") + } else if math.IsInf(fl, -1) { + s = []byte("-Infinity") + } else { + // Start at offset 4 because `putInt32` clobbers the first 4 bytes. + s = strconv.AppendFloat(b.putbuf[4:4], fl, 'g', conv.GetFloatPrec(), 64) + } + b.putInt32(int32(len(s))) + b.write(s) +} + +func writeTextBytes(b *writeBuffer, v string, conv sessiondatapb.DataConversionConfig) { + result := lex.EncodeByteArrayToRawBytes(v, conv.BytesEncodeFormat, false /* skipHexPrefix */) + b.putInt32(int32(len(result))) + b.write([]byte(result)) +} + +func writeTextUUID(b *writeBuffer, v uuid.UUID) { + // Start at offset 4 because `putInt32` clobbers the first 4 bytes. + s := b.putbuf[4 : 4+36] + v.StringBytes(s) + b.putInt32(int32(len(s))) + b.write(s) +} + +func writeTextString(b *writeBuffer, v string, t *types.T) { + b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(v, t)) +} + +func writeTextTimestamp(b *writeBuffer, v time.Time) { + // Start at offset 4 because `putInt32` clobbers the first 4 bytes. + s := formatTs(v, nil, b.putbuf[4:4]) + b.putInt32(int32(len(s))) + b.write(s) +} + +func writeTextTimestampTZ(b *writeBuffer, v time.Time, sessionLoc *time.Location) { + // Start at offset 4 because `putInt32` clobbers the first 4 bytes. + s := formatTs(v, sessionLoc, b.putbuf[4:4]) + b.putInt32(int32(len(s))) + b.write(s) +} + // writeTextDatum writes d to the buffer. Type t must be specified for types // that have various width encodings and therefore need padding (chars). // It is ignored (and can be nil) for types which do not need padding. @@ -82,59 +146,49 @@ func (b *writeBuffer) writeTextDatum( b.putInt32(-1) return } + writeTextDatumNotNull(b, d, conv, sessionLoc, t) +} + +// writeTextDatumNotNull writes d to the buffer when d is not null. Type t must +// be specified for types that have various width encodings and therefore need +// padding (chars). It is ignored (and can be nil) for types which do not need +// padding. +func writeTextDatumNotNull( + b *writeBuffer, + d tree.Datum, + conv sessiondatapb.DataConversionConfig, + sessionLoc *time.Location, + t *types.T, +) { switch v := tree.UnwrapDatum(nil, d).(type) { case *tree.DBitArray: b.textFormatter.FormatNode(v) b.writeFromFmtCtx(b.textFormatter) case *tree.DBool: - b.textFormatter.FormatNode(v) - b.writeFromFmtCtx(b.textFormatter) + writeTextBool(b, bool(*v)) case *tree.DInt: - // Start at offset 4 because `putInt32` clobbers the first 4 bytes. - s := strconv.AppendInt(b.putbuf[4:4], int64(*v), 10) - b.putInt32(int32(len(s))) - b.write(s) + writeTextInt64(b, int64(*v)) case *tree.DFloat: fl := float64(*v) - var s []byte - // PostgreSQL supports 'Inf' as a valid literal for the floating point - // special value Infinity, therefore handling the special cases for them. - // (https://github.com/cockroachdb/cockroach/issues/62601) - if math.IsInf(fl, 1) { - s = []byte("Infinity") - } else if math.IsInf(fl, -1) { - s = []byte("-Infinity") - } else { - // Start at offset 4 because `putInt32` clobbers the first 4 bytes. - s = strconv.AppendFloat(b.putbuf[4:4], fl, 'g', conv.GetFloatPrec(), 64) - } - b.putInt32(int32(len(s))) - b.write(s) + writeTextFloat64(b, fl, conv) case *tree.DDecimal: b.writeLengthPrefixedDatum(v) case *tree.DBytes: - result := lex.EncodeByteArrayToRawBytes( - string(*v), conv.BytesEncodeFormat, false /* skipHexPrefix */) - b.putInt32(int32(len(result))) - b.write([]byte(result)) + writeTextBytes(b, string(*v), conv) case *tree.DUuid: - // Start at offset 4 because `putInt32` clobbers the first 4 bytes. - s := b.putbuf[4 : 4+36] - v.UUID.StringBytes(s) - b.putInt32(int32(len(s))) - b.write(s) + writeTextUUID(b, v.UUID) case *tree.DIPAddr: b.writeLengthPrefixedString(v.IPAddr.String()) case *tree.DString: - b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(string(*v), t)) + writeTextString(b, string(*v), t) case *tree.DCollatedString: b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(v.Contents, t)) @@ -171,16 +225,10 @@ func (b *writeBuffer) writeTextDatum( b.write([]byte(s)) case *tree.DTimestamp: - // Start at offset 4 because `putInt32` clobbers the first 4 bytes. - s := formatTs(v.Time, nil, b.putbuf[4:4]) - b.putInt32(int32(len(s))) - b.write(s) + writeTextTimestamp(b, v.Time) case *tree.DTimestampTZ: - // Start at offset 4 because `putInt32` clobbers the first 4 bytes. - s := formatTs(v.Time, sessionLoc, b.putbuf[4:4]) - b.putInt32(int32(len(s))) - b.write(s) + writeTextTimestampTZ(b, v.Time, sessionLoc) case *tree.DInterval: b.textFormatter.FormatNode(v) @@ -210,6 +258,173 @@ func (b *writeBuffer) writeTextDatum( } } +func writeBinaryBool(b *writeBuffer, v bool) { + b.putInt32(1) + if v { + b.writeByte(1) + } else { + b.writeByte(0) + } +} + +func writeBinaryInt(b *writeBuffer, v int64, t *types.T) { + switch t.Oid() { + case oid.T_int2: + b.putInt32(2) + b.putInt16(int16(v)) + case oid.T_int4: + b.putInt32(4) + b.putInt32(int32(v)) + case oid.T_int8: + b.putInt32(8) + b.putInt64(v) + default: + b.setError(errors.Errorf("unsupported int oid: %v", t.Oid())) + } +} + +func writeBinaryFloat(b *writeBuffer, v float64, t *types.T) { + switch t.Oid() { + case oid.T_float4: + b.putInt32(4) + b.putInt32(int32(math.Float32bits(float32(v)))) + case oid.T_float8: + b.putInt32(8) + b.putInt64(int64(math.Float64bits(v))) + default: + b.setError(errors.Errorf("unsupported float oid: %v", t.Oid())) + } +} + +func writeBinaryDecimal(b *writeBuffer, v *apd.Decimal) { + if v.Form != apd.Finite { + b.putInt32(8) + // 0 digits. + b.putInt32(0) + // https://github.com/postgres/postgres/blob/ffa4cbd623dd69f9fa99e5e92426928a5782cf1a/src/backend/utils/adt/numeric.c#L169 + b.write([]byte{0xc0, 0, 0, 0}) + + if v.Form == apd.Infinite { + // TODO(mjibson): #32489 + // The above encoding is not correct for Infinity, but since that encoding + // doesn't exist in postgres, it's unclear what to do. For now use the NaN + // encoding and count it to see if anyone even needs this. + telemetry.Inc(sqltelemetry.BinaryDecimalInfinityCounter) + } + + return + } + + alloc := struct { + pgNum pgwirebase.PGNumeric + + bigI big.Int + }{ + pgNum: pgwirebase.PGNumeric{ + // Since we use 2000 as the exponent limits in tree.DecimalCtx, this + // conversion should not overflow. + Dscale: int16(-v.Exponent), + }, + } + + if v.Sign() >= 0 { + alloc.pgNum.Sign = pgwirebase.PGNumericPos + } else { + alloc.pgNum.Sign = pgwirebase.PGNumericNeg + } + + isZero := func(r rune) bool { + return r == '0' + } + + // Mostly cribbed from libpqtypes' str2num. + digits := strings.TrimLeftFunc(alloc.bigI.Abs(&v.Coeff).String(), isZero) + dweight := len(digits) - int(alloc.pgNum.Dscale) - 1 + digits = strings.TrimRightFunc(digits, isZero) + + if dweight >= 0 { + alloc.pgNum.Weight = int16((dweight+1+pgwirebase.PGDecDigits-1)/pgwirebase.PGDecDigits - 1) + } else { + alloc.pgNum.Weight = int16(-((-dweight-1)/pgwirebase.PGDecDigits + 1)) + } + offset := (int(alloc.pgNum.Weight)+1)*pgwirebase.PGDecDigits - (dweight + 1) + alloc.pgNum.Ndigits = int16((len(digits) + offset + pgwirebase.PGDecDigits - 1) / pgwirebase.PGDecDigits) + + if len(digits) == 0 { + offset = 0 + alloc.pgNum.Ndigits = 0 + alloc.pgNum.Weight = 0 + } + + digitIdx := -offset + + nextDigit := func() int16 { + var ndigit int16 + for nextDigitIdx := digitIdx + pgwirebase.PGDecDigits; digitIdx < nextDigitIdx; digitIdx++ { + ndigit *= 10 + if digitIdx >= 0 && digitIdx < len(digits) { + ndigit += int16(digits[digitIdx] - '0') + } + } + return ndigit + } + + // The dscale is defined as number of digits (in base 10) visible + // after the decimal separator, so it can't be negative. + if alloc.pgNum.Dscale < 0 { + alloc.pgNum.Dscale = 0 + } + + b.putInt32(int32(2 * (4 + alloc.pgNum.Ndigits))) + b.putInt16(alloc.pgNum.Ndigits) + b.putInt16(alloc.pgNum.Weight) + b.putInt16(int16(alloc.pgNum.Sign)) + b.putInt16(alloc.pgNum.Dscale) + + for digitIdx < len(digits) { + b.putInt16(nextDigit()) + } +} + +func writeBinaryBytes(b *writeBuffer, v []byte) { + b.putInt32(int32(len(v))) + b.write(v) +} + +func writeBinaryString(b *writeBuffer, v string, t *types.T) { + b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(v, t)) +} + +func writeBinaryTimestamp(b *writeBuffer, v time.Time) { + b.putInt32(8) + b.putInt64(timeToPgBinary(v, nil)) +} + +func writeBinaryTimestampTZ(b *writeBuffer, v time.Time, sessionLoc *time.Location) { + b.putInt32(8) + b.putInt64(timeToPgBinary(v, sessionLoc)) +} + +func writeBinaryDate(b *writeBuffer, v pgdate.Date) { + b.putInt32(4) + b.putInt32(v.PGEpochDays()) +} + +func writeBinaryInterval(b *writeBuffer, v duration.Duration) { + b.putInt32(16) + b.putInt64(v.Nanos() / int64(time.Microsecond/time.Nanosecond)) + b.putInt32(int32(v.Days)) + b.putInt32(int32(v.Months)) +} + +func writeBinaryJSON(b *writeBuffer, v json.JSON) { + s := v.String() + b.putInt32(int32(len(s) + 1)) + // Postgres version number, as of writing, `1` is the only valid value. + b.writeByte(1) + b.writeString(s) +} + // writeBinaryDatum writes d to the buffer. Type t must be specified for types // that have various width encodings (floats, ints, chars). It is ignored // (and can be nil) for types with a 1:1 datum:type mapping. @@ -224,6 +439,16 @@ func (b *writeBuffer) writeBinaryDatum( b.putInt32(-1) return } + writeBinaryDatumNotNull(ctx, b, d, sessionLoc, t) +} + +// writeBinaryDatumNotNull writes d to the buffer when d is not null. Type t +// must be specified for types that have various width encodings (floats, ints, +// chars). It is ignored (and can be nil) for types with a 1:1 datum:type +// mapping. +func writeBinaryDatumNotNull( + ctx context.Context, b *writeBuffer, d tree.Datum, sessionLoc *time.Location, t *types.T, +) { switch v := tree.UnwrapDatum(nil, d).(type) { case *tree.DBitArray: words, lastBitsUsed := v.EncodingParts() @@ -257,136 +482,22 @@ func (b *writeBuffer) writeBinaryDatum( } case *tree.DBool: - b.putInt32(1) - if *v { - b.writeByte(1) - } else { - b.writeByte(0) - } + writeBinaryBool(b, bool(*v)) case *tree.DInt: - switch t.Oid() { - case oid.T_int2: - b.putInt32(2) - b.putInt16(int16(*v)) - case oid.T_int4: - b.putInt32(4) - b.putInt32(int32(*v)) - case oid.T_int8: - b.putInt32(8) - b.putInt64(int64(*v)) - default: - b.setError(errors.Errorf("unsupported int oid: %v", t.Oid())) - } + writeBinaryInt(b, int64(*v), t) case *tree.DFloat: - switch t.Oid() { - case oid.T_float4: - b.putInt32(4) - b.putInt32(int32(math.Float32bits(float32(*v)))) - case oid.T_float8: - b.putInt32(8) - b.putInt64(int64(math.Float64bits(float64(*v)))) - default: - b.setError(errors.Errorf("unsupported float oid: %v", t.Oid())) - } + writeBinaryFloat(b, float64(*v), t) case *tree.DDecimal: - if v.Form != apd.Finite { - b.putInt32(8) - // 0 digits. - b.putInt32(0) - // https://github.com/postgres/postgres/blob/ffa4cbd623dd69f9fa99e5e92426928a5782cf1a/src/backend/utils/adt/numeric.c#L169 - b.write([]byte{0xc0, 0, 0, 0}) - - if v.Form == apd.Infinite { - // TODO(mjibson): #32489 - // The above encoding is not correct for Infinity, but since that encoding - // doesn't exist in postgres, it's unclear what to do. For now use the NaN - // encoding and count it to see if anyone even needs this. - telemetry.Inc(sqltelemetry.BinaryDecimalInfinityCounter) - } - - return - } - - alloc := struct { - pgNum pgwirebase.PGNumeric - - bigI big.Int - }{ - pgNum: pgwirebase.PGNumeric{ - // Since we use 2000 as the exponent limits in tree.DecimalCtx, this - // conversion should not overflow. - Dscale: int16(-v.Exponent), - }, - } - - if v.Sign() >= 0 { - alloc.pgNum.Sign = pgwirebase.PGNumericPos - } else { - alloc.pgNum.Sign = pgwirebase.PGNumericNeg - } - - isZero := func(r rune) bool { - return r == '0' - } - - // Mostly cribbed from libpqtypes' str2num. - digits := strings.TrimLeftFunc(alloc.bigI.Abs(&v.Coeff).String(), isZero) - dweight := len(digits) - int(alloc.pgNum.Dscale) - 1 - digits = strings.TrimRightFunc(digits, isZero) - - if dweight >= 0 { - alloc.pgNum.Weight = int16((dweight+1+pgwirebase.PGDecDigits-1)/pgwirebase.PGDecDigits - 1) - } else { - alloc.pgNum.Weight = int16(-((-dweight-1)/pgwirebase.PGDecDigits + 1)) - } - offset := (int(alloc.pgNum.Weight)+1)*pgwirebase.PGDecDigits - (dweight + 1) - alloc.pgNum.Ndigits = int16((len(digits) + offset + pgwirebase.PGDecDigits - 1) / pgwirebase.PGDecDigits) - - if len(digits) == 0 { - offset = 0 - alloc.pgNum.Ndigits = 0 - alloc.pgNum.Weight = 0 - } - - digitIdx := -offset - - nextDigit := func() int16 { - var ndigit int16 - for nextDigitIdx := digitIdx + pgwirebase.PGDecDigits; digitIdx < nextDigitIdx; digitIdx++ { - ndigit *= 10 - if digitIdx >= 0 && digitIdx < len(digits) { - ndigit += int16(digits[digitIdx] - '0') - } - } - return ndigit - } - - // The dscale is defined as number of digits (in base 10) visible - // after the decimal separator, so it can't be negative. - if alloc.pgNum.Dscale < 0 { - alloc.pgNum.Dscale = 0 - } - - b.putInt32(int32(2 * (4 + alloc.pgNum.Ndigits))) - b.putInt16(alloc.pgNum.Ndigits) - b.putInt16(alloc.pgNum.Weight) - b.putInt16(int16(alloc.pgNum.Sign)) - b.putInt16(alloc.pgNum.Dscale) - - for digitIdx < len(digits) { - b.putInt16(nextDigit()) - } + writeBinaryDecimal(b, &v.Decimal) case *tree.DBytes: - b.putInt32(int32(len(*v))) - b.write([]byte(*v)) + writeBinaryBytes(b, []byte(*v)) case *tree.DUuid: - b.putInt32(16) - b.write(v.GetBytes()) + writeBinaryBytes(b, v.GetBytes()) case *tree.DIPAddr: // We calculate the Postgres binary format for an IPAddr. For the spec see, @@ -427,22 +538,19 @@ func (b *writeBuffer) writeBinaryDatum( b.writeLengthPrefixedString(v.LogicalRep) case *tree.DString: - b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(string(*v), t)) + writeBinaryString(b, string(*v), t) case *tree.DCollatedString: b.writeLengthPrefixedString(tree.ResolveBlankPaddedChar(v.Contents, t)) case *tree.DTimestamp: - b.putInt32(8) - b.putInt64(timeToPgBinary(v.Time, nil)) + writeBinaryTimestamp(b, v.Time) case *tree.DTimestampTZ: - b.putInt32(8) - b.putInt64(timeToPgBinary(v.Time, sessionLoc)) + writeBinaryTimestampTZ(b, v.Time, sessionLoc) case *tree.DDate: - b.putInt32(4) - b.putInt32(v.PGEpochDays()) + writeBinaryDate(b, v.Date) case *tree.DTime: b.putInt32(8) @@ -454,10 +562,7 @@ func (b *writeBuffer) writeBinaryDatum( b.putInt32(v.OffsetSecs) case *tree.DInterval: - b.putInt32(16) - b.putInt64(v.Nanos() / int64(time.Microsecond/time.Nanosecond)) - b.putInt32(int32(v.Days)) - b.putInt32(int32(v.Months)) + writeBinaryInterval(b, v.Duration) case *tree.DTuple: // TODO(andrei): We shouldn't be allocating a new buffer for every array. @@ -517,12 +622,10 @@ func (b *writeBuffer) writeBinaryDatum( } } b.writeLengthPrefixedBuffer(&subWriter.wrapped) + case *tree.DJSON: - s := v.JSON.String() - b.putInt32(int32(len(s) + 1)) - // Postgres version number, as of writing, `1` is the only valid value. - b.writeByte(1) - b.writeString(s) + writeBinaryJSON(b, v.JSON) + case *tree.DOid: b.putInt32(4) b.putInt32(int32(v.DInt)) diff --git a/pkg/sql/rowexec/utils_test.go b/pkg/sql/rowexec/utils_test.go index 77a74878830b..aa6d0b1288bb 100644 --- a/pkg/sql/rowexec/utils_test.go +++ b/pkg/sql/rowexec/utils_test.go @@ -170,11 +170,6 @@ func (r *rowDisposer) Push( // ProducerDone is part of the execinfra.RowReceiver interface. func (r *rowDisposer) ProducerDone() {} -// Types is part of the execinfra.RowReceiver interface. -func (r *rowDisposer) Types() []*types.T { - return nil -} - func (r *rowDisposer) ResetNumRowsDisposed() { r.numRowsDisposed = 0 } diff --git a/pkg/sql/rowflow/routers.go b/pkg/sql/rowflow/routers.go index b75d53adc32b..a165d763b6f2 100644 --- a/pkg/sql/rowflow/routers.go +++ b/pkg/sql/rowflow/routers.go @@ -414,10 +414,6 @@ func (rb *routerBase) ProducerDone() { } } -func (rb *routerBase) Types() []*types.T { - return rb.types -} - // updateStreamState updates the status of one stream and, if this was the last // open stream, it also updates rb.aggregatedStatus. func (rb *routerBase) updateStreamState( diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 09fb8c307b33..9f67c3388d91 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -455,14 +455,19 @@ func (d *DBool) Max(_ *EvalContext) (Datum, bool) { // AmbiguousFormat implements the Datum interface. func (*DBool) AmbiguousFormat() bool { return false } +// PgwireFormatBool returns a single byte representing a boolean according to +// pgwire encoding. +func PgwireFormatBool(d bool) byte { + if d { + return 't' + } + return 'f' +} + // Format implements the NodeFormatter interface. func (d *DBool) Format(ctx *FmtCtx) { if ctx.HasFlags(fmtPgwireFormat) { - if bool(*d) { - ctx.WriteByte('t') - } else { - ctx.WriteByte('f') - } + ctx.WriteByte(PgwireFormatBool(bool(*d))) return } ctx.WriteString(strconv.FormatBool(bool(*d))) @@ -1887,19 +1892,24 @@ func (d *DDate) Min(_ *EvalContext) (Datum, bool) { // AmbiguousFormat implements the Datum interface. func (*DDate) AmbiguousFormat() bool { return true } -// Format implements the NodeFormatter interface. -func (d *DDate) Format(ctx *FmtCtx) { +// FormatDate writes d into ctx according to the format flags. +func FormatDate(d pgdate.Date, ctx *FmtCtx) { f := ctx.flags bareStrings := f.HasFlags(FmtFlags(lexbase.EncBareStrings)) if !bareStrings { ctx.WriteByte('\'') } - d.Date.Format(&ctx.Buffer) + d.Format(&ctx.Buffer) if !bareStrings { ctx.WriteByte('\'') } } +// Format implements the NodeFormatter interface. +func (d *DDate) Format(ctx *FmtCtx) { + FormatDate(d.Date, ctx) +} + // Size implements the Datum interface. func (d *DDate) Size() uintptr { return unsafe.Sizeof(*d) @@ -2771,19 +2781,24 @@ func (d *DInterval) ValueAsString() string { // AmbiguousFormat implements the Datum interface. func (*DInterval) AmbiguousFormat() bool { return true } -// Format implements the NodeFormatter interface. -func (d *DInterval) Format(ctx *FmtCtx) { +// FormatDuration writes d into ctx according to the format flags. +func FormatDuration(d duration.Duration, ctx *FmtCtx) { f := ctx.flags bareStrings := f.HasFlags(FmtFlags(lexbase.EncBareStrings)) if !bareStrings { ctx.WriteByte('\'') } - d.Duration.Format(&ctx.Buffer) + d.Format(&ctx.Buffer) if !bareStrings { ctx.WriteByte('\'') } } +// Format implements the NodeFormatter interface. +func (d *DInterval) Format(ctx *FmtCtx) { + FormatDuration(d.Duration, ctx) +} + // Size implements the Datum interface. func (d *DInterval) Size() uintptr { return unsafe.Sizeof(*d) diff --git a/pkg/testutils/distsqlutils/row_buffer.go b/pkg/testutils/distsqlutils/row_buffer.go index 76a6d12ddce9..cf16809dd6b5 100644 --- a/pkg/testutils/distsqlutils/row_buffer.go +++ b/pkg/testutils/distsqlutils/row_buffer.go @@ -155,11 +155,6 @@ func (rb *RowBuffer) ProducerDone() { rb.Mu.producerClosed = true } -// Types is part of the RowReceiver interface. -func (rb *RowBuffer) Types() []*types.T { - return rb.types -} - // OutputTypes is part of the RowSource interface. func (rb *RowBuffer) OutputTypes() []*types.T { if rb.types == nil {