Skip to content

Commit

Permalink
copy: new copy IO implementation
Browse files Browse the repository at this point in the history
Refactor COPY so that all the buffer reading takes place in a separate
implementation of the io.Reader interface. This does two things, it
enables the COPY implementation to efficiently handle small CopyData
frames by eliminating extra buffering and exposes the COPY bytes as
a pure stream a bytes which makes retry easier. It also cleans up the
COPY code that handles CopyData segments straddling line boundaries, now
we can just let the CSV reader do its thing.

The old implementation would read from a pgwire BufferedReader (copy 1)
into a pgwire "ReadBuffer" (copy 2) and then push those segments into a
bytes.Buffer "buf" (copy 3). The text and binary readers would read
right from buf the CSV reader has its own buffer and we would read
lines from buf and write them into the CSV reader's buffer (copy 4).

The new approach does away with all this and the text and binary formats
read directly from a bufio.Reader (copy 1) stacked on the copy.Reader
(no buffering) stacked on the pgwire BufferedReader (copy 2). For CSV
the CSVReader reads directly from the copy.Reader since it has its
own buffer so again only two copies off the wire.

This doesn't seem to affect performance much but it gives the GC a nice
break and sets up a clean solution for cockroachdb#99327.

When encountering a memory usage error we used to try to let the encoder
finish the row but with the more efficient buffering this started
succeeds where it always failed before. Now we just don't do the hail
mary and if we hit the limit we bail and return immediately, this is
more OOM safe and simpler.

Fixes: cockroachdb#93156
Informs: cockroachdb#99327
Release note: none
Epic: CRDB-25321
  • Loading branch information
cucaroach committed Jun 2, 2023
1 parent db8257a commit 4af728a
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 277 deletions.
1 change: 1 addition & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ GO_TARGETS = [
"//pkg/sql/contention:contention_test",
"//pkg/sql/contentionpb:contentionpb",
"//pkg/sql/contentionpb:contentionpb_test",
"//pkg/sql/copy:copy",
"//pkg/sql/copy:copy_test",
"//pkg/sql/covering:covering",
"//pkg/sql/covering:covering_test",
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ go_library(
"//pkg/sql/contention",
"//pkg/sql/contention/txnidcache",
"//pkg/sql/contentionpb",
"//pkg/sql/copy",
"//pkg/sql/covering",
"//pkg/sql/decodeusername",
"//pkg/sql/delegate",
Expand Down
16 changes: 15 additions & 1 deletion pkg/sql/copy/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_test(
name = "copy_test",
Expand Down Expand Up @@ -56,4 +56,18 @@ go_test(
],
)

go_library(
name = "copy",
srcs = ["reader.go"],
importpath = "github.com/cockroachdb/cockroach/pkg/sql/copy",
visibility = ["//visibility:public"],
deps = [
"//pkg/settings",
"//pkg/sql/pgwire/pgcode",
"//pkg/sql/pgwire/pgerror",
"//pkg/sql/pgwire/pgwirebase",
"@com_github_cockroachdb_errors//:errors",
],
)

get_x_data(name = "get_x_data")
2 changes: 1 addition & 1 deletion pkg/sql/copy/copy_in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package copy
package copy_test

import (
"context"
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/copy/copy_out_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package copy
package copy_test

import (
"bytes"
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/copy/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package copy
package copy_test

import (
"bytes"
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/copy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package copy
package copy_test

import (
"os"
Expand Down
161 changes: 161 additions & 0 deletions pkg/sql/copy/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package copy

import (
"io"

"github.com/cockroachdb/cockroach/pkg/settings"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
"github.com/cockroachdb/errors"
)

// Reader is an io.Reader that reads the COPY protocol from the underlying pgwire
// buffer but removes the protocol bits and just returns the raw text. There's
// no buffering, there's a buffer below it on top of the raw connection and
// any buffering required above it (for line reading etc) is either
// implemented in COPY or in the CSV reader.
type Reader struct {
// We don't use this buffer but we use some of its helper methods.
pgwirebase.ReadBuffer
// pgr is the underlying pgwire reader we're reading from
pgr pgwirebase.BufferedReader
// Scratch space for Drain
scratch [128]byte
remainder int
done bool
}

var _ io.Reader = &Reader{}

// Init sets up the reader, it requires access to cluster settings for max
// message size.
func (c *Reader) Init(pgr pgwirebase.BufferedReader, sv *settings.Values) {
c.pgr = pgr
c.SetOption(pgwirebase.ReadBufferOptionWithClusterSettings(sv))
}

// Read implements the io.Reader interface.
func (c *Reader) Read(p []byte) (int, error) {
// The CSV reader can eat an EOF and will come back for another read in some
// case's. Keep giving it EOF's in that case.
if c.done {
return 0, io.EOF
}
// If we had a short read, finish it.
if c.remainder > 0 {
// We never want to overread from the wire, we might read past COPY data
// segments so limit p to remainder bytes.
if c.remainder < len(p) {
p = p[:c.remainder]
}
n, err := c.pgr.Read(p)
if err != nil {
return 0, err
}
c.remainder -= n
return n, nil
}
// Go to pgwire to get next segment.
size, err := c.readTypedMessage()
if err != nil {
return 0, err
}
// We never want to overread from the wire, we might read past COPY data
// segments so limit p to size bytes.
if size < len(p) {
p = p[:size]
}
n, err := c.pgr.Read(p)
if err != nil {
return 0, err
}
if n < size {
c.remainder = size - n
}
return n, nil
}

func (c *Reader) readTypedMessage() (size int, err error) {
b, err := c.pgr.ReadByte()
if err != nil {
return 0, err
}
typ := pgwirebase.ClientMessageType(b)
_, size, err = c.ReadUntypedMsgSize(c.pgr)
if err != nil {
if pgwirebase.IsMessageTooBigError(err) && typ == pgwirebase.ClientMsgCopyData {
// Slurp the remaining bytes.
_, slurpErr := c.SlurpBytes(c.pgr, pgwirebase.GetMessageTooBigSize(err))
if slurpErr != nil {
return 0, errors.CombineErrors(err, errors.Wrapf(slurpErr, "error slurping remaining bytes in COPY"))
}

// As per the pgwire spec, we must continue reading until we encounter
// CopyDone or CopyFail. We don't support COPY in the extended
// protocol, so we don't need to look for Sync messages. See
// https://www.postgresql.org/docs/13/protocol-flow.html#PROTOCOL-COPY
for {
typ, _, slurpErr = c.ReadTypedMsg(c.pgr)
if typ == pgwirebase.ClientMsgCopyDone || typ == pgwirebase.ClientMsgCopyFail {
break
}
if slurpErr != nil && !pgwirebase.IsMessageTooBigError(slurpErr) {
return 0, errors.CombineErrors(err, errors.Wrapf(slurpErr, "error slurping remaining bytes in COPY"))
}

_, slurpErr = c.SlurpBytes(c.pgr, pgwirebase.GetMessageTooBigSize(slurpErr))
if slurpErr != nil {
return 0, errors.CombineErrors(err, errors.Wrapf(slurpErr, "error slurping remaining bytes in COPY"))
}
}
}
return 0, err
}
switch typ {
case pgwirebase.ClientMsgCopyData:
// Just return size.
case pgwirebase.ClientMsgCopyDone:
c.done = true
return 0, io.EOF
case pgwirebase.ClientMsgCopyFail:
msg := make([]byte, size)
if _, err := io.ReadFull(c.pgr, msg); err != nil {
return 0, err
}
return 0, pgerror.Newf(pgcode.QueryCanceled, "COPY from stdin failed: %s", string(msg))
case pgwirebase.ClientMsgFlush, pgwirebase.ClientMsgSync:
// Spec says to "ignore Flush and Sync messages received during copy-in mode".
default:
// In order to gracefully handle bogus protocol, ie back to back copies, we have to
// slurp these bytes.
msg := make([]byte, size)
if _, err := io.ReadFull(c.pgr, msg); err != nil {
return 0, err
}
return 0, pgwirebase.NewUnrecognizedMsgTypeErr(typ)
}
return size, nil
}

// Drain will discard any bytes we haven't read yet.
func (c *Reader) Drain() error {
for c.remainder > 0 {
n, err := c.pgr.Read(c.scratch[:])
if err != nil {
return err
}
c.remainder -= n
}
return nil
}
Loading

0 comments on commit 4af728a

Please sign in to comment.