diff --git a/pkg/sql/copy/copy_test.go b/pkg/sql/copy/copy_test.go index be864cd43393..33edb995dcac 100644 --- a/pkg/sql/copy/copy_test.go +++ b/pkg/sql/copy/copy_test.go @@ -48,6 +48,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/datadriven" "github.com/cockroachdb/errors" "github.com/jackc/pgconn" @@ -770,3 +771,57 @@ func min(a, b int) int { } return b } + +func BenchmarkCopyCSVEndToEnd(b *testing.B) { + defer leaktest.AfterTest(b)() + defer log.Scope(b).Close(b) + + ctx := context.Background() + s, db, _ := serverutils.StartServer(b, base.TestServerArgs{ + DefaultTestTenant: base.TestTenantDisabled, + }) + defer s.Stopper().Stop(ctx) + + pgURL, cleanup, err := sqlutils.PGUrlE( + s.ServingSQLAddr(), + "BenchmarkCopyEndToEnd", /* prefix */ + url.User(username.RootUser), + ) + require.NoError(b, err) + s.Stopper().AddCloser(stop.CloserFn(cleanup)) + + _, err = db.Exec("CREATE TABLE t (i INT PRIMARY KEY, s STRING)") + require.NoError(b, err) + + conn, err := pgx.Connect(ctx, pgURL.String()) + require.NoError(b, err) + + rng, _ := randutil.NewTestRand() + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + // Create an input of 1_000_000 rows. + buf := &bytes.Buffer{} + for j := 0; j < 1_000_000; j++ { + buf.WriteString(strconv.Itoa(j)) + buf.WriteString(",") + str := randutil.RandString(rng, rng.Intn(50), "abc123\n") + buf.WriteString("\"") + buf.WriteString(str) + buf.WriteString("\"\n") + } + b.StartTimer() + + // Run the COPY. + _, err = conn.PgConn().CopyFrom(ctx, buf, "COPY t FROM STDIN CSV") + require.NoError(b, err) + + // Verify that the data was inserted. + b.StopTimer() + var count int + err = db.QueryRow("SELECT count(*) FROM t").Scan(&count) + require.NoError(b, err) + require.Equal(b, 1_000_000, count) + b.StartTimer() + } +} diff --git a/pkg/sql/copy_from.go b/pkg/sql/copy_from.go index f955b4df07a9..13a6702c3d60 100644 --- a/pkg/sql/copy_from.go +++ b/pkg/sql/copy_from.go @@ -218,7 +218,7 @@ type copyMachine struct { csvReader *csv.Reader // buf is used to parse input data into rows. It also accumulates a partial // row between protocol messages. - buf bytes.Buffer + buf []byte // rows accumulates a batch of rows to be eventually inserted. rows rowcontainer.RowContainer // insertedRows keeps track of the total number of rows inserted by the @@ -616,20 +616,20 @@ const ( func (c *copyMachine) processCopyData(ctx context.Context, data string, final bool) (retErr error) { // At the end, adjust the mem accounting to reflect what's left in the buffer. defer func() { - if err := c.bufMemAcc.ResizeTo(ctx, int64(c.buf.Cap())); err != nil && retErr == nil { + if err := c.bufMemAcc.ResizeTo(ctx, int64(cap(c.buf))); err != nil && retErr == nil { retErr = err } }() - if len(data) > (c.buf.Cap() - c.buf.Len()) { + if len(data) > (cap(c.buf) - len(c.buf)) { // If it looks like the buffer will need to allocate to accommodate data, // account for the memory here. This is not particularly accurate - we don't // know how much the buffer will actually grow by. - if err := c.bufMemAcc.ResizeTo(ctx, int64(len(data))); err != nil { + if err := c.bufMemAcc.Grow(ctx, int64(len(data))); err != nil { return err } } - c.buf.WriteString(data) + c.buf = append(c.buf, data...) var readFn func(ctx context.Context, final bool) (brk bool, err error) switch c.format { case tree.CopyFormatText: @@ -641,7 +641,7 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo default: panic("unknown copy format") } - for c.buf.Len() > 0 { + for len(c.buf) > 0 { brk, err := readFn(ctx, final) if err != nil { return err @@ -664,11 +664,11 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo // If we have a full batch of rows or we have exceeded maxRowMem process // them. Only set finalBatch to true if this is the last // CopyData segment AND we have no more data in the buffer. - if len := c.currentBatchSize(); len > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || len >= c.copyBatchRowSize || batchDone) { - if len != c.copyBatchRowSize { - log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", len, c.rowsMemAcc.Used(), c.maxRowMem) + if length := c.currentBatchSize(); length > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || length >= c.copyBatchRowSize || batchDone) { + if length != c.copyBatchRowSize { + log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", length, c.rowsMemAcc.Used(), c.maxRowMem) } - if err := c.processRows(ctx, final && c.buf.Len() == 0); err != nil { + if err := c.processRows(ctx, final && len(c.buf) == 0); err != nil { return err } } @@ -692,24 +692,26 @@ func (c *copyMachine) currentBatchSize() int { } func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, err error) { - line, err := c.buf.ReadBytes(lineDelim) - if err != nil { - if err != io.EOF { - return false, err - } else if !final { - // Put the incomplete row back in the buffer, to be processed next time. - c.buf.Write(line) + idx := bytes.IndexByte(c.buf, lineDelim) + var line []byte + if idx == -1 { + if !final { + // Leave the incomplete row in the buffer, to be processed next time. return true, nil } + // If this is the final batch, use the whole buffer. + line = c.buf[:len(c.buf)] + c.buf = c.buf[len(c.buf):] } else { // Remove lineDelim from end. - line = line[:len(line)-1] + line = c.buf[:idx] + c.buf = c.buf[idx+1:] // Remove a single '\r' at EOL, if present. if len(line) > 0 && line[len(line)-1] == '\r' { line = line[:len(line)-1] } } - if c.buf.Len() == 0 && bytes.Equal(line, []byte(`\.`)) { + if len(c.buf) == 0 && bytes.Equal(line, []byte(`\.`)) { return true, nil } err = c.readTextTuple(ctx, line) @@ -719,29 +721,29 @@ func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, e func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, err error) { var fullLine []byte quoteCharsSeen := 0 + offset := 0 // Keep reading lines until we encounter a newline that is not inside a // quoted field, and therefore signifies the end of a CSV record. for { - line, err := c.buf.ReadBytes(lineDelim) - fullLine = append(fullLine, line...) - if err != nil { - if err == io.EOF { - if final { - // If we reached EOF and this is the final chunk of input data, then - // try to process it. - break - } else { - // If there's more CopyData, put the incomplete row back in the - // buffer, to be processed next time. - c.buf.Write(fullLine) - return true, nil - } + idx := bytes.IndexByte(c.buf[offset:], lineDelim) + if idx == -1 { + if final { + // If we reached EOF and this is the final chunk of input data, then + // try to process it. + fullLine = append(fullLine, c.buf[offset:]...) + c.buf = c.buf[len(c.buf):] + break } else { - return false, err + // If there's more CopyData, keep the incomplete row in the + // buffer, to be processed next time. + return true, nil } } - - // Now we need to calculate if we are have reached the end of the quote. + // Include the delimiter in the line. + line := c.buf[offset : offset+idx+1] + offset += idx + 1 + fullLine = append(fullLine, line...) + // Now we need to calculate if we have reached the end of the quote. // If so, break out. if c.csvEscape == 0 { // CSV escape is not specified and hence defaults to '"'.¥ @@ -777,6 +779,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er } } if quoteCharsSeen%2 == 0 { + c.buf = c.buf[offset:] break } } @@ -792,7 +795,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er record, err := c.csvReader.Read() // Look for end of data before checking for errors, since a field count // error will still return record data. - if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && c.buf.Len() == 0 { + if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && len(c.buf) == 0 { return true, nil } if err != nil { @@ -868,25 +871,31 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool, } switch c.binaryState { case binaryStateNeedSignature: - if readSoFar, err := c.readBinarySignature(); err != nil { + n, err := c.readBinarySignature() + if err != nil { // If this isn't the last message and we saw incomplete data, then - // put it back in the buffer to process more next time. - if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) { - c.buf.Write(readSoFar) + // leave it in the buffer to process more next time. + if !final && err == io.ErrUnexpectedEOF { return true, nil } + c.buf = c.buf[n:] return false, err } + c.buf = c.buf[n:] + return false, nil case binaryStateRead: - if readSoFar, err := c.readBinaryTuple(ctx); err != nil { + n, err := c.readBinaryTuple(ctx) + if err != nil { // If this isn't the last message and we saw incomplete data, then - // put it back in the buffer to process more next time. - if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) { - c.buf.Write(readSoFar) + // leave it in the buffer to process more next time. + if !final && err == io.ErrUnexpectedEOF { return true, nil } + c.buf = c.buf[n:] return false, errors.Wrapf(err, "read binary tuple") } + c.buf = c.buf[n:] + return false, nil case binaryStateFoundTrailer: if !final { return false, pgerror.New(pgcode.BadCopyFileFormat, @@ -896,34 +905,33 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool, default: panic("unknown binary state") } - return false, nil } -func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, err error) { +func (c *copyMachine) readBinaryTuple(ctx context.Context) (bytesRead int, err error) { var fieldCount int16 var fieldCountBytes [2]byte - n, err := io.ReadFull(&c.buf, fieldCountBytes[:]) - readSoFar = append(readSoFar, fieldCountBytes[:n]...) - if err != nil { - return readSoFar, err + n := copy(fieldCountBytes[:], c.buf[bytesRead:]) + bytesRead += n + if n < len(fieldCountBytes) { + return bytesRead, io.ErrUnexpectedEOF } fieldCount = int16(binary.BigEndian.Uint16(fieldCountBytes[:])) if fieldCount == -1 { c.binaryState = binaryStateFoundTrailer - return nil, nil + return bytesRead, nil } if fieldCount < 1 { - return nil, pgerror.Newf(pgcode.BadCopyFileFormat, + return bytesRead, pgerror.Newf(pgcode.BadCopyFileFormat, "unexpected field count: %d", fieldCount) } datums := make(tree.Datums, fieldCount) var byteCount int32 var byteCountBytes [4]byte for i := range datums { - n, err := io.ReadFull(&c.buf, byteCountBytes[:]) - readSoFar = append(readSoFar, byteCountBytes[:n]...) - if err != nil { - return readSoFar, err + n := copy(byteCountBytes[:], c.buf[bytesRead:]) + bytesRead += n + if n < len(byteCountBytes) { + return bytesRead, io.ErrUnexpectedEOF } byteCount = int32(binary.BigEndian.Uint32(byteCountBytes[:])) if byteCount == -1 { @@ -931,10 +939,10 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, er continue } data := make([]byte, byteCount) - n, err = io.ReadFull(&c.buf, data) - readSoFar = append(readSoFar, data[:n]...) - if err != nil { - return readSoFar, err + n = copy(data, c.buf[bytesRead:]) + bytesRead += n + if n < len(data) { + return bytesRead, io.ErrUnexpectedEOF } d, err := pgwirebase.DecodeDatum( ctx, @@ -944,33 +952,35 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, er data, ) if err != nil { - return nil, pgerror.Wrapf(err, pgcode.BadCopyFileFormat, + return bytesRead, pgerror.Wrapf(err, pgcode.BadCopyFileFormat, "decode datum as %s: %s", c.resultColumns[i].Typ.SQLString(), data) } datums[i] = d } _, err = c.rows.AddRow(ctx, datums) if err != nil { - return nil, err + return bytesRead, err } - return nil, nil + return bytesRead, nil } -func (c *copyMachine) readBinarySignature() ([]byte, error) { - // This is the standard 11-byte binary signature with the flags and - // header 32-bit integers appended since we only support the zero value - // of them. - const binarySignature = "PGCOPY\n\377\r\n\000" + "\x00\x00\x00\x00" + "\x00\x00\x00\x00" - var sig [11 + 8]byte - if n, err := io.ReadFull(&c.buf, sig[:]); err != nil { - return sig[:n], err - } - if !bytes.Equal(sig[:], []byte(binarySignature)) { - return sig[:], pgerror.New(pgcode.BadCopyFileFormat, +// This is the standard 11-byte binary signature with the flags and +// header 32-bit integers appended since we only support the zero value +// of them. +var copyBinarySignature = [19]byte{'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\000', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'} + +func (c *copyMachine) readBinarySignature() (int, error) { + var sig [19]byte + n := copy(sig[:], c.buf) + if n < len(sig) { + return n, io.ErrUnexpectedEOF + } + if sig != copyBinarySignature { + return n, pgerror.New(pgcode.BadCopyFileFormat, "unrecognized binary copy signature") } c.binaryState = binaryStateRead - return sig[:], nil + return n, nil } // preparePlannerForCopy resets the planner so that it can be used during