Skip to content

Commit

Permalink
Merge #106088
Browse files Browse the repository at this point in the history
106088: copy: avoid recopying buffer r=rafiss a=rafiss

Use a []byte instead of a bytes.Buffer, so that we only move the cursor into the buffer if we know we want to advance.

This provides a 25% reduction in allocations in BenchmarkCopyCSVEndToEnd.

```
goos: darwin
goarch: arm64
                   │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.a51e8a7806c │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.0c9d3d5efb0 │
                   │                                      sec/op                                       │                           sec/op                            vs base               │
CopyCSVEndToEnd-10                                                                          3.991 ± 3%                                                   3.883 ± 2%  -2.71% (p=0.019 n=10)

                   │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.a51e8a7806c │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.0c9d3d5efb0 │
                   │                                       B/op                                        │                              B/op                                vs base          │
CopyCSVEndToEnd-10                                                                        8.322Gi ± 2%                                                      8.304Gi ± 1%  ~ (p=0.971 n=10)

                   │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.a51e8a7806c │ /var/folders/p3/c61_z_vd3r7dr1_hnztm3ryr0000gq/T/tmp.Dm00yAlG6O/bench.0c9d3d5efb0 │
                   │                                     allocs/op                                     │                         allocs/op                          vs base                │
CopyCSVEndToEnd-10                                                                         18.30M ± 0%                                                 13.79M ± 0%  -24.62% (p=0.000 n=10)
```


Epic: None
Release note: None

Co-authored-by: Rafi Shamim <[email protected]>
  • Loading branch information
craig[bot] and rafiss committed Jul 7, 2023
2 parents 21904db + b72784a commit 81b17ca
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 76 deletions.
55 changes: 55 additions & 0 deletions pkg/sql/copy/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -770,3 +771,57 @@ func min(a, b int) int {
}
return b
}

func BenchmarkCopyCSVEndToEnd(b *testing.B) {
defer leaktest.AfterTest(b)()
defer log.Scope(b).Close(b)

ctx := context.Background()
s, db, _ := serverutils.StartServer(b, base.TestServerArgs{
DefaultTestTenant: base.TestTenantDisabled,
})
defer s.Stopper().Stop(ctx)

pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(),
"BenchmarkCopyEndToEnd", /* prefix */
url.User(username.RootUser),
)
require.NoError(b, err)
s.Stopper().AddCloser(stop.CloserFn(cleanup))

_, err = db.Exec("CREATE TABLE t (i INT PRIMARY KEY, s STRING)")
require.NoError(b, err)

conn, err := pgx.Connect(ctx, pgURL.String())
require.NoError(b, err)

rng, _ := randutil.NewTestRand()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
// Create an input of 1_000_000 rows.
buf := &bytes.Buffer{}
for j := 0; j < 1_000_000; j++ {
buf.WriteString(strconv.Itoa(j))
buf.WriteString(",")
str := randutil.RandString(rng, rng.Intn(50), "abc123\n")
buf.WriteString("\"")
buf.WriteString(str)
buf.WriteString("\"\n")
}
b.StartTimer()

// Run the COPY.
_, err = conn.PgConn().CopyFrom(ctx, buf, "COPY t FROM STDIN CSV")
require.NoError(b, err)

// Verify that the data was inserted.
b.StopTimer()
var count int
err = db.QueryRow("SELECT count(*) FROM t").Scan(&count)
require.NoError(b, err)
require.Equal(b, 1_000_000, count)
b.StartTimer()
}
}
162 changes: 86 additions & 76 deletions pkg/sql/copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type copyMachine struct {
csvReader *csv.Reader
// buf is used to parse input data into rows. It also accumulates a partial
// row between protocol messages.
buf bytes.Buffer
buf []byte
// rows accumulates a batch of rows to be eventually inserted.
rows rowcontainer.RowContainer
// insertedRows keeps track of the total number of rows inserted by the
Expand Down Expand Up @@ -616,20 +616,20 @@ const (
func (c *copyMachine) processCopyData(ctx context.Context, data string, final bool) (retErr error) {
// At the end, adjust the mem accounting to reflect what's left in the buffer.
defer func() {
if err := c.bufMemAcc.ResizeTo(ctx, int64(c.buf.Cap())); err != nil && retErr == nil {
if err := c.bufMemAcc.ResizeTo(ctx, int64(cap(c.buf))); err != nil && retErr == nil {
retErr = err
}
}()

if len(data) > (c.buf.Cap() - c.buf.Len()) {
if len(data) > (cap(c.buf) - len(c.buf)) {
// If it looks like the buffer will need to allocate to accommodate data,
// account for the memory here. This is not particularly accurate - we don't
// know how much the buffer will actually grow by.
if err := c.bufMemAcc.ResizeTo(ctx, int64(len(data))); err != nil {
if err := c.bufMemAcc.Grow(ctx, int64(len(data))); err != nil {
return err
}
}
c.buf.WriteString(data)
c.buf = append(c.buf, data...)
var readFn func(ctx context.Context, final bool) (brk bool, err error)
switch c.format {
case tree.CopyFormatText:
Expand All @@ -641,7 +641,7 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo
default:
panic("unknown copy format")
}
for c.buf.Len() > 0 {
for len(c.buf) > 0 {
brk, err := readFn(ctx, final)
if err != nil {
return err
Expand All @@ -664,11 +664,11 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo
// If we have a full batch of rows or we have exceeded maxRowMem process
// them. Only set finalBatch to true if this is the last
// CopyData segment AND we have no more data in the buffer.
if len := c.currentBatchSize(); len > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || len >= c.copyBatchRowSize || batchDone) {
if len != c.copyBatchRowSize {
log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", len, c.rowsMemAcc.Used(), c.maxRowMem)
if length := c.currentBatchSize(); length > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || length >= c.copyBatchRowSize || batchDone) {
if length != c.copyBatchRowSize {
log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", length, c.rowsMemAcc.Used(), c.maxRowMem)
}
if err := c.processRows(ctx, final && c.buf.Len() == 0); err != nil {
if err := c.processRows(ctx, final && len(c.buf) == 0); err != nil {
return err
}
}
Expand All @@ -692,24 +692,26 @@ func (c *copyMachine) currentBatchSize() int {
}

func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, err error) {
line, err := c.buf.ReadBytes(lineDelim)
if err != nil {
if err != io.EOF {
return false, err
} else if !final {
// Put the incomplete row back in the buffer, to be processed next time.
c.buf.Write(line)
idx := bytes.IndexByte(c.buf, lineDelim)
var line []byte
if idx == -1 {
if !final {
// Leave the incomplete row in the buffer, to be processed next time.
return true, nil
}
// If this is the final batch, use the whole buffer.
line = c.buf[:len(c.buf)]
c.buf = c.buf[len(c.buf):]
} else {
// Remove lineDelim from end.
line = line[:len(line)-1]
line = c.buf[:idx]
c.buf = c.buf[idx+1:]
// Remove a single '\r' at EOL, if present.
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
}
if c.buf.Len() == 0 && bytes.Equal(line, []byte(`\.`)) {
if len(c.buf) == 0 && bytes.Equal(line, []byte(`\.`)) {
return true, nil
}
err = c.readTextTuple(ctx, line)
Expand All @@ -719,29 +721,29 @@ func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, e
func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, err error) {
var fullLine []byte
quoteCharsSeen := 0
offset := 0
// Keep reading lines until we encounter a newline that is not inside a
// quoted field, and therefore signifies the end of a CSV record.
for {
line, err := c.buf.ReadBytes(lineDelim)
fullLine = append(fullLine, line...)
if err != nil {
if err == io.EOF {
if final {
// If we reached EOF and this is the final chunk of input data, then
// try to process it.
break
} else {
// If there's more CopyData, put the incomplete row back in the
// buffer, to be processed next time.
c.buf.Write(fullLine)
return true, nil
}
idx := bytes.IndexByte(c.buf[offset:], lineDelim)
if idx == -1 {
if final {
// If we reached EOF and this is the final chunk of input data, then
// try to process it.
fullLine = append(fullLine, c.buf[offset:]...)
c.buf = c.buf[len(c.buf):]
break
} else {
return false, err
// If there's more CopyData, keep the incomplete row in the
// buffer, to be processed next time.
return true, nil
}
}

// Now we need to calculate if we are have reached the end of the quote.
// Include the delimiter in the line.
line := c.buf[offset : offset+idx+1]
offset += idx + 1
fullLine = append(fullLine, line...)
// Now we need to calculate if we have reached the end of the quote.
// If so, break out.
if c.csvEscape == 0 {
// CSV escape is not specified and hence defaults to '"'.¥
Expand Down Expand Up @@ -777,6 +779,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er
}
}
if quoteCharsSeen%2 == 0 {
c.buf = c.buf[offset:]
break
}
}
Expand All @@ -792,7 +795,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er
record, err := c.csvReader.Read()
// Look for end of data before checking for errors, since a field count
// error will still return record data.
if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && c.buf.Len() == 0 {
if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && len(c.buf) == 0 {
return true, nil
}
if err != nil {
Expand Down Expand Up @@ -868,25 +871,31 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool,
}
switch c.binaryState {
case binaryStateNeedSignature:
if readSoFar, err := c.readBinarySignature(); err != nil {
n, err := c.readBinarySignature()
if err != nil {
// If this isn't the last message and we saw incomplete data, then
// put it back in the buffer to process more next time.
if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) {
c.buf.Write(readSoFar)
// leave it in the buffer to process more next time.
if !final && err == io.ErrUnexpectedEOF {
return true, nil
}
c.buf = c.buf[n:]
return false, err
}
c.buf = c.buf[n:]
return false, nil
case binaryStateRead:
if readSoFar, err := c.readBinaryTuple(ctx); err != nil {
n, err := c.readBinaryTuple(ctx)
if err != nil {
// If this isn't the last message and we saw incomplete data, then
// put it back in the buffer to process more next time.
if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) {
c.buf.Write(readSoFar)
// leave it in the buffer to process more next time.
if !final && err == io.ErrUnexpectedEOF {
return true, nil
}
c.buf = c.buf[n:]
return false, errors.Wrapf(err, "read binary tuple")
}
c.buf = c.buf[n:]
return false, nil
case binaryStateFoundTrailer:
if !final {
return false, pgerror.New(pgcode.BadCopyFileFormat,
Expand All @@ -896,45 +905,44 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool,
default:
panic("unknown binary state")
}
return false, nil
}

func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, err error) {
func (c *copyMachine) readBinaryTuple(ctx context.Context) (bytesRead int, err error) {
var fieldCount int16
var fieldCountBytes [2]byte
n, err := io.ReadFull(&c.buf, fieldCountBytes[:])
readSoFar = append(readSoFar, fieldCountBytes[:n]...)
if err != nil {
return readSoFar, err
n := copy(fieldCountBytes[:], c.buf[bytesRead:])
bytesRead += n
if n < len(fieldCountBytes) {
return bytesRead, io.ErrUnexpectedEOF
}
fieldCount = int16(binary.BigEndian.Uint16(fieldCountBytes[:]))
if fieldCount == -1 {
c.binaryState = binaryStateFoundTrailer
return nil, nil
return bytesRead, nil
}
if fieldCount < 1 {
return nil, pgerror.Newf(pgcode.BadCopyFileFormat,
return bytesRead, pgerror.Newf(pgcode.BadCopyFileFormat,
"unexpected field count: %d", fieldCount)
}
datums := make(tree.Datums, fieldCount)
var byteCount int32
var byteCountBytes [4]byte
for i := range datums {
n, err := io.ReadFull(&c.buf, byteCountBytes[:])
readSoFar = append(readSoFar, byteCountBytes[:n]...)
if err != nil {
return readSoFar, err
n := copy(byteCountBytes[:], c.buf[bytesRead:])
bytesRead += n
if n < len(byteCountBytes) {
return bytesRead, io.ErrUnexpectedEOF
}
byteCount = int32(binary.BigEndian.Uint32(byteCountBytes[:]))
if byteCount == -1 {
datums[i] = tree.DNull
continue
}
data := make([]byte, byteCount)
n, err = io.ReadFull(&c.buf, data)
readSoFar = append(readSoFar, data[:n]...)
if err != nil {
return readSoFar, err
n = copy(data, c.buf[bytesRead:])
bytesRead += n
if n < len(data) {
return bytesRead, io.ErrUnexpectedEOF
}
d, err := pgwirebase.DecodeDatum(
ctx,
Expand All @@ -944,33 +952,35 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, er
data,
)
if err != nil {
return nil, pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
return bytesRead, pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
"decode datum as %s: %s", c.resultColumns[i].Typ.SQLString(), data)
}
datums[i] = d
}
_, err = c.rows.AddRow(ctx, datums)
if err != nil {
return nil, err
return bytesRead, err
}
return nil, nil
return bytesRead, nil
}

func (c *copyMachine) readBinarySignature() ([]byte, error) {
// This is the standard 11-byte binary signature with the flags and
// header 32-bit integers appended since we only support the zero value
// of them.
const binarySignature = "PGCOPY\n\377\r\n\000" + "\x00\x00\x00\x00" + "\x00\x00\x00\x00"
var sig [11 + 8]byte
if n, err := io.ReadFull(&c.buf, sig[:]); err != nil {
return sig[:n], err
}
if !bytes.Equal(sig[:], []byte(binarySignature)) {
return sig[:], pgerror.New(pgcode.BadCopyFileFormat,
// This is the standard 11-byte binary signature with the flags and
// header 32-bit integers appended since we only support the zero value
// of them.
var copyBinarySignature = [19]byte{'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\000', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'}

func (c *copyMachine) readBinarySignature() (int, error) {
var sig [19]byte
n := copy(sig[:], c.buf)
if n < len(sig) {
return n, io.ErrUnexpectedEOF
}
if sig != copyBinarySignature {
return n, pgerror.New(pgcode.BadCopyFileFormat,
"unrecognized binary copy signature")
}
c.binaryState = binaryStateRead
return sig[:], nil
return n, nil
}

// preparePlannerForCopy resets the planner so that it can be used during
Expand Down

0 comments on commit 81b17ca

Please sign in to comment.