Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copy: avoid recopying buffer #106088

Merged
merged 2 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions pkg/sql/copy/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -770,3 +771,57 @@ func min(a, b int) int {
}
return b
}

func BenchmarkCopyCSVEndToEnd(b *testing.B) {
defer leaktest.AfterTest(b)()
defer log.Scope(b).Close(b)

ctx := context.Background()
s, db, _ := serverutils.StartServer(b, base.TestServerArgs{
DefaultTestTenant: base.TestTenantDisabled,
})
defer s.Stopper().Stop(ctx)

pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(),
"BenchmarkCopyEndToEnd", /* prefix */
url.User(username.RootUser),
)
require.NoError(b, err)
s.Stopper().AddCloser(stop.CloserFn(cleanup))

_, err = db.Exec("CREATE TABLE t (i INT PRIMARY KEY, s STRING)")
require.NoError(b, err)

conn, err := pgx.Connect(ctx, pgURL.String())
require.NoError(b, err)

rng, _ := randutil.NewTestRand()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
// Create an input of 1_000_000 rows.
buf := &bytes.Buffer{}
for j := 0; j < 1_000_000; j++ {
buf.WriteString(strconv.Itoa(j))
buf.WriteString(",")
str := randutil.RandString(rng, rng.Intn(50), "abc123\n")
buf.WriteString("\"")
buf.WriteString(str)
buf.WriteString("\"\n")
}
b.StartTimer()

// Run the COPY.
_, err = conn.PgConn().CopyFrom(ctx, buf, "COPY t FROM STDIN CSV")
require.NoError(b, err)

// Verify that the data was inserted.
b.StopTimer()
var count int
err = db.QueryRow("SELECT count(*) FROM t").Scan(&count)
require.NoError(b, err)
require.Equal(b, 1_000_000, count)
b.StartTimer()
}
}
162 changes: 86 additions & 76 deletions pkg/sql/copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type copyMachine struct {
csvReader *csv.Reader
// buf is used to parse input data into rows. It also accumulates a partial
// row between protocol messages.
buf bytes.Buffer
buf []byte
// rows accumulates a batch of rows to be eventually inserted.
rows rowcontainer.RowContainer
// insertedRows keeps track of the total number of rows inserted by the
Expand Down Expand Up @@ -616,20 +616,20 @@ const (
func (c *copyMachine) processCopyData(ctx context.Context, data string, final bool) (retErr error) {
// At the end, adjust the mem accounting to reflect what's left in the buffer.
defer func() {
if err := c.bufMemAcc.ResizeTo(ctx, int64(c.buf.Cap())); err != nil && retErr == nil {
if err := c.bufMemAcc.ResizeTo(ctx, int64(cap(c.buf))); err != nil && retErr == nil {
retErr = err
}
}()

if len(data) > (c.buf.Cap() - c.buf.Len()) {
if len(data) > (cap(c.buf) - len(c.buf)) {
// If it looks like the buffer will need to allocate to accommodate data,
// account for the memory here. This is not particularly accurate - we don't
// know how much the buffer will actually grow by.
if err := c.bufMemAcc.ResizeTo(ctx, int64(len(data))); err != nil {
if err := c.bufMemAcc.Grow(ctx, int64(len(data))); err != nil {
return err
}
}
c.buf.WriteString(data)
c.buf = append(c.buf, data...)
var readFn func(ctx context.Context, final bool) (brk bool, err error)
switch c.format {
case tree.CopyFormatText:
Expand All @@ -641,7 +641,7 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo
default:
panic("unknown copy format")
}
for c.buf.Len() > 0 {
for len(c.buf) > 0 {
brk, err := readFn(ctx, final)
if err != nil {
return err
Expand All @@ -664,11 +664,11 @@ func (c *copyMachine) processCopyData(ctx context.Context, data string, final bo
// If we have a full batch of rows or we have exceeded maxRowMem process
// them. Only set finalBatch to true if this is the last
// CopyData segment AND we have no more data in the buffer.
if len := c.currentBatchSize(); len > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || len >= c.copyBatchRowSize || batchDone) {
if len != c.copyBatchRowSize {
log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", len, c.rowsMemAcc.Used(), c.maxRowMem)
if length := c.currentBatchSize(); length > 0 && (c.rowsMemAcc.Used() > c.maxRowMem || length >= c.copyBatchRowSize || batchDone) {
if length != c.copyBatchRowSize {
log.VEventf(ctx, 2, "copy batch of %d rows flushing due to memory usage %d > %d", length, c.rowsMemAcc.Used(), c.maxRowMem)
}
if err := c.processRows(ctx, final && c.buf.Len() == 0); err != nil {
if err := c.processRows(ctx, final && len(c.buf) == 0); err != nil {
return err
}
}
Expand All @@ -692,24 +692,26 @@ func (c *copyMachine) currentBatchSize() int {
}

func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, err error) {
line, err := c.buf.ReadBytes(lineDelim)
if err != nil {
if err != io.EOF {
return false, err
} else if !final {
// Put the incomplete row back in the buffer, to be processed next time.
c.buf.Write(line)
idx := bytes.IndexByte(c.buf, lineDelim)
var line []byte
if idx == -1 {
if !final {
// Leave the incomplete row in the buffer, to be processed next time.
return true, nil
}
// If this is the final batch, use the whole buffer.
line = c.buf[:len(c.buf)]
c.buf = c.buf[len(c.buf):]
} else {
// Remove lineDelim from end.
line = line[:len(line)-1]
line = c.buf[:idx]
c.buf = c.buf[idx+1:]
// Remove a single '\r' at EOL, if present.
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
}
if c.buf.Len() == 0 && bytes.Equal(line, []byte(`\.`)) {
if len(c.buf) == 0 && bytes.Equal(line, []byte(`\.`)) {
return true, nil
}
err = c.readTextTuple(ctx, line)
Expand All @@ -719,29 +721,29 @@ func (c *copyMachine) readTextData(ctx context.Context, final bool) (brk bool, e
func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, err error) {
var fullLine []byte
quoteCharsSeen := 0
offset := 0
// Keep reading lines until we encounter a newline that is not inside a
// quoted field, and therefore signifies the end of a CSV record.
for {
line, err := c.buf.ReadBytes(lineDelim)
fullLine = append(fullLine, line...)
if err != nil {
if err == io.EOF {
if final {
// If we reached EOF and this is the final chunk of input data, then
// try to process it.
break
} else {
// If there's more CopyData, put the incomplete row back in the
// buffer, to be processed next time.
c.buf.Write(fullLine)
return true, nil
}
idx := bytes.IndexByte(c.buf[offset:], lineDelim)
if idx == -1 {
if final {
// If we reached EOF and this is the final chunk of input data, then
// try to process it.
fullLine = append(fullLine, c.buf[offset:]...)
c.buf = c.buf[len(c.buf):]
break
} else {
return false, err
// If there's more CopyData, keep the incomplete row in the
// buffer, to be processed next time.
return true, nil
}
}

// Now we need to calculate if we are have reached the end of the quote.
// Include the delimiter in the line.
line := c.buf[offset : offset+idx+1]
offset += idx + 1
fullLine = append(fullLine, line...)
// Now we need to calculate if we have reached the end of the quote.
// If so, break out.
if c.csvEscape == 0 {
// CSV escape is not specified and hence defaults to '"'.¥
Expand Down Expand Up @@ -777,6 +779,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er
}
}
if quoteCharsSeen%2 == 0 {
c.buf = c.buf[offset:]
break
}
}
Expand All @@ -792,7 +795,7 @@ func (c *copyMachine) readCSVData(ctx context.Context, final bool) (brk bool, er
record, err := c.csvReader.Read()
// Look for end of data before checking for errors, since a field count
// error will still return record data.
if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && c.buf.Len() == 0 {
if len(record) == 1 && !record[0].Quoted && record[0].Val == endOfData && len(c.buf) == 0 {
return true, nil
}
if err != nil {
Expand Down Expand Up @@ -868,25 +871,31 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool,
}
switch c.binaryState {
case binaryStateNeedSignature:
if readSoFar, err := c.readBinarySignature(); err != nil {
n, err := c.readBinarySignature()
if err != nil {
// If this isn't the last message and we saw incomplete data, then
// put it back in the buffer to process more next time.
if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) {
c.buf.Write(readSoFar)
// leave it in the buffer to process more next time.
if !final && err == io.ErrUnexpectedEOF {
return true, nil
}
c.buf = c.buf[n:]
return false, err
}
c.buf = c.buf[n:]
return false, nil
case binaryStateRead:
if readSoFar, err := c.readBinaryTuple(ctx); err != nil {
n, err := c.readBinaryTuple(ctx)
if err != nil {
// If this isn't the last message and we saw incomplete data, then
// put it back in the buffer to process more next time.
if !final && (err == io.EOF || err == io.ErrUnexpectedEOF) {
c.buf.Write(readSoFar)
// leave it in the buffer to process more next time.
if !final && err == io.ErrUnexpectedEOF {
return true, nil
}
c.buf = c.buf[n:]
return false, errors.Wrapf(err, "read binary tuple")
}
c.buf = c.buf[n:]
return false, nil
case binaryStateFoundTrailer:
if !final {
return false, pgerror.New(pgcode.BadCopyFileFormat,
Expand All @@ -896,45 +905,44 @@ func (c *copyMachine) readBinaryData(ctx context.Context, final bool) (brk bool,
default:
panic("unknown binary state")
}
return false, nil
}

func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, err error) {
func (c *copyMachine) readBinaryTuple(ctx context.Context) (bytesRead int, err error) {
var fieldCount int16
var fieldCountBytes [2]byte
n, err := io.ReadFull(&c.buf, fieldCountBytes[:])
readSoFar = append(readSoFar, fieldCountBytes[:n]...)
if err != nil {
return readSoFar, err
n := copy(fieldCountBytes[:], c.buf[bytesRead:])
bytesRead += n
if n < len(fieldCountBytes) {
return bytesRead, io.ErrUnexpectedEOF
}
fieldCount = int16(binary.BigEndian.Uint16(fieldCountBytes[:]))
if fieldCount == -1 {
c.binaryState = binaryStateFoundTrailer
return nil, nil
return bytesRead, nil
}
if fieldCount < 1 {
return nil, pgerror.Newf(pgcode.BadCopyFileFormat,
return bytesRead, pgerror.Newf(pgcode.BadCopyFileFormat,
"unexpected field count: %d", fieldCount)
}
datums := make(tree.Datums, fieldCount)
var byteCount int32
var byteCountBytes [4]byte
for i := range datums {
n, err := io.ReadFull(&c.buf, byteCountBytes[:])
readSoFar = append(readSoFar, byteCountBytes[:n]...)
if err != nil {
return readSoFar, err
n := copy(byteCountBytes[:], c.buf[bytesRead:])
bytesRead += n
if n < len(byteCountBytes) {
return bytesRead, io.ErrUnexpectedEOF
}
byteCount = int32(binary.BigEndian.Uint32(byteCountBytes[:]))
if byteCount == -1 {
datums[i] = tree.DNull
continue
}
data := make([]byte, byteCount)
n, err = io.ReadFull(&c.buf, data)
readSoFar = append(readSoFar, data[:n]...)
if err != nil {
return readSoFar, err
n = copy(data, c.buf[bytesRead:])
bytesRead += n
if n < len(data) {
return bytesRead, io.ErrUnexpectedEOF
}
d, err := pgwirebase.DecodeDatum(
ctx,
Expand All @@ -944,33 +952,35 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) (readSoFar []byte, er
data,
)
if err != nil {
return nil, pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
return bytesRead, pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
"decode datum as %s: %s", c.resultColumns[i].Typ.SQLString(), data)
}
datums[i] = d
}
_, err = c.rows.AddRow(ctx, datums)
if err != nil {
return nil, err
return bytesRead, err
}
return nil, nil
return bytesRead, nil
}

func (c *copyMachine) readBinarySignature() ([]byte, error) {
// This is the standard 11-byte binary signature with the flags and
// header 32-bit integers appended since we only support the zero value
// of them.
const binarySignature = "PGCOPY\n\377\r\n\000" + "\x00\x00\x00\x00" + "\x00\x00\x00\x00"
var sig [11 + 8]byte
if n, err := io.ReadFull(&c.buf, sig[:]); err != nil {
return sig[:n], err
}
if !bytes.Equal(sig[:], []byte(binarySignature)) {
return sig[:], pgerror.New(pgcode.BadCopyFileFormat,
// This is the standard 11-byte binary signature with the flags and
// header 32-bit integers appended since we only support the zero value
// of them.
var copyBinarySignature = [19]byte{'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\000', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'}

func (c *copyMachine) readBinarySignature() (int, error) {
var sig [19]byte
n := copy(sig[:], c.buf)
if n < len(sig) {
return n, io.ErrUnexpectedEOF
}
if sig != copyBinarySignature {
return n, pgerror.New(pgcode.BadCopyFileFormat,
"unrecognized binary copy signature")
}
c.binaryState = binaryStateRead
return sig[:], nil
return n, nil
}

// preparePlannerForCopy resets the planner so that it can be used during
Expand Down