diff --git a/pkg/sql/distsqlrun/sample_aggregator.go b/pkg/sql/distsqlrun/sample_aggregator.go index bcc53de6bc23..fe0f685feae2 100644 --- a/pkg/sql/distsqlrun/sample_aggregator.go +++ b/pkg/sql/distsqlrun/sample_aggregator.go @@ -229,7 +229,7 @@ func (s *sampleAggregator) mainLoop(ctx context.Context) (earlyExit bool, err er return false, errors.NewAssertionErrorWithWrappedErrf(err, "decoding rank column") } // Retain the rows with the top ranks. - if err := s.sr.SampleRow(ctx, row[:s.rankCol], uint64(rank)); err != nil { + if err := s.sr.SampleRow(ctx, s.evalCtx, row[:s.rankCol], uint64(rank)); err != nil { return false, err } continue diff --git a/pkg/sql/distsqlrun/sampler.go b/pkg/sql/distsqlrun/sampler.go index 5dfe9636aa71..c7b725391773 100644 --- a/pkg/sql/distsqlrun/sampler.go +++ b/pkg/sql/distsqlrun/sampler.go @@ -310,7 +310,7 @@ func (s *samplerProcessor) mainLoop(ctx context.Context) (earlyExit bool, err er // Use Int63 so we don't have headaches converting to DInt. rank := uint64(rng.Int63()) - if err := s.sr.SampleRow(ctx, row, rank); err != nil { + if err := s.sr.SampleRow(ctx, s.evalCtx, row, rank); err != nil { return false, err } } diff --git a/pkg/sql/stats/row_sampling.go b/pkg/sql/stats/row_sampling.go index e47926db621a..986ad2ca42df 100644 --- a/pkg/sql/stats/row_sampling.go +++ b/pkg/sql/stats/row_sampling.go @@ -14,6 +14,7 @@ import ( "container/heap" "context" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/mon" @@ -81,7 +82,7 @@ func (sr *SampleReservoir) Pop() interface{} { panic("unimplemented") } // SampleRow looks at a row and either drops it or adds it to the reservoir. func (sr *SampleReservoir) SampleRow( - ctx context.Context, row sqlbase.EncDatumRow, rank uint64, + ctx context.Context, evalCtx *tree.EvalContext, row sqlbase.EncDatumRow, rank uint64, ) error { if len(sr.samples) < cap(sr.samples) { // We haven't accumulated enough rows yet, just append. @@ -94,7 +95,7 @@ func (sr *SampleReservoir) SampleRow( return err } } - if err := sr.copyRow(ctx, rowCopy, row); err != nil { + if err := sr.copyRow(ctx, evalCtx, rowCopy, row); err != nil { return err } sr.samples = append(sr.samples, SampledRow{Row: rowCopy, Rank: rank}) @@ -106,7 +107,7 @@ func (sr *SampleReservoir) SampleRow( } // Replace the max rank if ours is smaller. if len(sr.samples) > 0 && rank < sr.samples[0].Rank { - if err := sr.copyRow(ctx, sr.samples[0].Row, row); err != nil { + if err := sr.copyRow(ctx, evalCtx, sr.samples[0].Row, row); err != nil { return err } sr.samples[0].Rank = rank @@ -120,7 +121,9 @@ func (sr *SampleReservoir) Get() []SampledRow { return sr.samples } -func (sr *SampleReservoir) copyRow(ctx context.Context, dst, src sqlbase.EncDatumRow) error { +func (sr *SampleReservoir) copyRow( + ctx context.Context, evalCtx *tree.EvalContext, dst, src sqlbase.EncDatumRow, +) error { for i := range src { // Copy only the decoded datum to ensure that we remove any reference to // the encoded bytes. The encoded bytes would have been scanned in a batch @@ -131,8 +134,14 @@ func (sr *SampleReservoir) copyRow(ctx context.Context, dst, src sqlbase.EncDatu } beforeSize := dst[i].Size() dst[i] = sqlbase.DatumToEncDatum(&sr.colTypes[i], src[i].Datum) + afterSize := dst[i].Size() + if afterSize > uintptr(maxBytesPerSample) { + dst[i].Datum = truncateDatum(evalCtx, dst[i].Datum, maxBytesPerSample) + afterSize = dst[i].Size() + } + // Perform memory accounting. - if afterSize := dst[i].Size(); sr.memAcc != nil && afterSize > beforeSize { + if sr.memAcc != nil && afterSize > beforeSize { if err := sr.memAcc.Grow(ctx, int64(afterSize-beforeSize)); err != nil { return err } @@ -141,3 +150,63 @@ func (sr *SampleReservoir) copyRow(ctx context.Context, dst, src sqlbase.EncDatu } return nil } + +const maxBytesPerSample = 400 + +// truncateDatum truncates large datums to avoid using excessive memory or disk +// space. It performs a best-effort attempt to return a datum that is similar +// to d using at most maxBytes bytes. +// +// For example, if maxBytes=10, "Cockroach Labs" would be truncated to +// "Cockroach ". +func truncateDatum(evalCtx *tree.EvalContext, d tree.Datum, maxBytes int) tree.Datum { + switch t := d.(type) { + case *tree.DBitArray: + b := tree.DBitArray{BitArray: t.ToWidth(uint(maxBytes * 8))} + return &b + + case *tree.DBytes: + // Make a copy so the memory from the original byte string can be garbage + // collected. + b := make([]byte, maxBytes) + copy(b, *t) + return tree.NewDBytes(tree.DBytes(b)) + + case *tree.DString: + return tree.NewDString(truncateString(string(*t), maxBytes)) + + case *tree.DCollatedString: + contents := truncateString(t.Contents, maxBytes) + + // Note: this will end up being larger than maxBytes due to the key and + // locale, so this is just a best-effort attempt to limit the size. + return tree.NewDCollatedString(contents, t.Locale, &evalCtx.CollationEnv) + + default: + // It's not easy to truncate other types (e.g. Decimal). + // TODO(rytaft): If the total memory limit is exceeded then the histogram + // should not be constructed. + return d + } +} + +// truncateString truncates long strings to the longest valid substring that is +// less than maxBytes bytes. It is rune-aware so it does not cut unicode +// characters in half. +func truncateString(s string, maxBytes int) string { + last := 0 + // For strings, range skips from rune to rune and i is the byte index of + // the current rune. + for i := range s { + if i > maxBytes { + break + } + last = i + } + + // Copy the truncated string so that the memory from the longer string can + // be garbage collected. + b := make([]byte, last) + copy(b, s) + return string(b) +} diff --git a/pkg/sql/stats/row_sampling_test.go b/pkg/sql/stats/row_sampling_test.go index dc77a93034fe..41117793465a 100644 --- a/pkg/sql/stats/row_sampling_test.go +++ b/pkg/sql/stats/row_sampling_test.go @@ -17,6 +17,7 @@ import ( "sort" "testing" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/sql/types" @@ -25,13 +26,13 @@ import ( // runSampleTest feeds rows with the given ranks through a reservoir // of a given size and verifies the results are correct. -func runSampleTest(t *testing.T, numSamples int, ranks []int) { +func runSampleTest(t *testing.T, evalCtx *tree.EvalContext, numSamples int, ranks []int) { ctx := context.Background() var sr SampleReservoir sr.Init(numSamples, []types.T{*types.Int}, nil /* memAcc */) for _, r := range ranks { d := sqlbase.DatumToEncDatum(types.Int, tree.NewDInt(tree.DInt(r))) - if err := sr.SampleRow(ctx, sqlbase.EncDatumRow{d}, uint64(r)); err != nil { + if err := sr.SampleRow(ctx, evalCtx, sqlbase.EncDatumRow{d}, uint64(r)); err != nil { t.Errorf("%v", err) } } @@ -62,6 +63,7 @@ func runSampleTest(t *testing.T, numSamples int, ranks []int) { } func TestSampleReservoir(t *testing.T) { + evalCtx := tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()) for _, n := range []int{10, 100, 1000, 10000} { rng, _ := randutil.NewPseudoRand() ranks := make([]int, n) @@ -70,8 +72,44 @@ func TestSampleReservoir(t *testing.T) { } for _, k := range []int{1, 5, 10, 100} { t.Run(fmt.Sprintf("%d/%d", n, k), func(t *testing.T) { - runSampleTest(t, k, ranks) + runSampleTest(t, &evalCtx, k, ranks) }) } } } + +func TestTruncateDatum(t *testing.T) { + evalCtx := tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()) + runTest := func(d, expected tree.Datum) { + actual := truncateDatum(&evalCtx, d, 10 /* maxBytes */) + if actual.Compare(&evalCtx, expected) != 0 { + t.Fatalf("expected %s but found %s", expected.String(), actual.String()) + } + } + + original1, err := tree.ParseDBitArray("0110110101111100001100110110101111100001100110110101111" + + "10000110011011010111110000110011011010111110000110011011010111110000110") + if err != nil { + t.Fatal(err) + } + expected1, err := tree.ParseDBitArray("0110110101111100001100110110101111100001100110110101111" + + "1000011001101101011111000") + if err != nil { + t.Fatal(err) + } + runTest(original1, expected1) + + original2 := tree.DBytes("deadbeef1234567890") + expected2 := tree.DBytes("deadbeef12") + runTest(&original2, &expected2) + + original3 := tree.DString("Hello 世界") + expected3 := tree.DString("Hello 世") + runTest(&original3, &expected3) + + original4 := tree.NewDCollatedString(`IT was lovely summer weather in the country, and the golden +corn, the green oats, and the haystacks piled up in the meadows looked beautiful`, + "en_US", &tree.CollationEnvironment{}) + expected4 := tree.NewDCollatedString("IT was lov", "en_US", &tree.CollationEnvironment{}) + runTest(original4, expected4) +}