diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 7e14780c180f..4ba294554af5 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -406,6 +406,7 @@ go_test( "//pkg/util/log/logconfig", "//pkg/util/log/logpb", "//pkg/util/protoutil", + "//pkg/util/randutil", "//pkg/util/stop", "//pkg/util/timeutil", "//pkg/util/tracing", diff --git a/pkg/cli/nodelocal.go b/pkg/cli/nodelocal.go index 612f838dd805..b785efa3b7a0 100644 --- a/pkg/cli/nodelocal.go +++ b/pkg/cli/nodelocal.go @@ -11,7 +11,6 @@ package cli import ( - "bytes" "context" "fmt" "io" @@ -27,7 +26,7 @@ import ( "github.com/spf13/cobra" ) -const chunkSize = 4 * 1024 +var chunkSize = 4 * 1024 var nodeLocalUploadCmd = &cobra.Command{ Use: "upload ", @@ -72,11 +71,61 @@ func openSourceFile(source string) (io.ReadCloser, error) { return f, nil } +// escapingReader is an io.Reader that escapes characters from the +// underlying reader for processing by the pgwire COPY protocol. +// +// TODO(ssd): Can we replace this with something that uses the binary +// COPY format. +type escapingReader struct { + r io.Reader + + readChunk []byte + buf []byte +} + +func (er *escapingReader) copyBufferTo(b []byte) (int, error) { + end := len(b) + if end > len(er.buf) { + end = len(er.buf) + } + n := copy(b, er.buf[:end]) + er.buf = er.buf[end:] + return n, nil +} + +func (er *escapingReader) Read(b []byte) (int, error) { + // If we have anything left in the buffer from last time, + // return it now. + if len(er.buf) > 0 { + return er.copyBufferTo(b) + } + if (er.readChunk) == nil { + er.readChunk = make([]byte, chunkSize) + } + for { + n, err := er.r.Read(er.readChunk) + if n > 0 { + er.buf = appendEscapedText(er.buf, er.readChunk[:n]) + if len(er.buf)+chunkSize > len(b) { + break + } + } else if err == io.EOF { + break + } else if err != nil { + return 0, err + } + } + if len(er.buf) > 0 { + return er.copyBufferTo(b) + } + return 0, io.EOF +} + // appendEscapedText escapes the input text for processing by the pgwire COPY // protocol. The result is appended to the []byte given by buf. // This implementation is copied from lib/pq. // https://github.com/lib/pq/blob/8c6de565f76fb5cd40a5c1b8ce583fbc3ba1bd0e/encode.go#L138 -func appendEscapedText(buf []byte, text string) []byte { +func appendEscapedText(buf []byte, text []byte) []byte { escapeNeeded := false startPos := 0 var c byte @@ -130,21 +179,7 @@ func uploadFile( Path: destination, } stmt := sql.CopyInFileStmt(nodelocalURL.String(), sql.CrdbInternalName, sql.NodelocalFileUploadTable) - - send := make([]byte, 0) - tmp := make([]byte, chunkSize) - for { - n, err := reader.Read(tmp) - if n > 0 { - send = appendEscapedText(send, string(tmp[:n])) - } else if err == io.EOF { - break - } else if err != nil { - return err - } - } - - if _, err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { + if _, err := ex.CopyFrom(ctx, &escapingReader{r: reader}, stmt); err != nil { return err } diff --git a/pkg/cli/nodelocal_test.go b/pkg/cli/nodelocal_test.go index 274c9123b77f..b597a2c9c3b1 100644 --- a/pkg/cli/nodelocal_test.go +++ b/pkg/cli/nodelocal_test.go @@ -13,12 +13,14 @@ package cli import ( "bytes" "fmt" + "io" "os" "path/filepath" "testing" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/stretchr/testify/require" ) func Example_nodelocal() { @@ -137,3 +139,86 @@ func createTestFile(name, content string) (string, func()) { _ = os.RemoveAll(tmpDir) } } + +func TestEscapingReader(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("escapes newlines", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\n34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', 'n', '3', '4'}, buf[:n]) + }) + t.Run("escapes carriage returns", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\r34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', 'r', '3', '4'}, buf[:n]) + }) + t.Run("escapes tabs", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\t34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', 't', '3', '4'}, buf[:n]) + }) + t.Run("escapes backslashes", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\\34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', '\\', '3', '4'}, buf[:n]) + }) + t.Run("correctly returns escaped characters that overflow buffer", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\n3"))} + buf := make([]byte, 2) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{'1', '\\'}, buf[:n]) + n, err = er.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{'n', '3'}, buf[:n]) + }) + t.Run("correctly returns remainer when buffer is larger", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\n3"))} + buf := make([]byte, 2) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{'1', '\\'}, buf[:n]) + buf = make([]byte, 8) + n, err = er.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte{'n', '3'}, buf[:n]) + }) + t.Run("correctly returns remainder when chunksize is small", func(t *testing.T) { + oldChunkSize := chunkSize + defer func() { chunkSize = oldChunkSize }() + chunkSize = 1 + er := escapingReader{r: bytes.NewReader([]byte("1\n34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', 'n', '3', '4'}, buf[:n]) + }) + t.Run("correctly returns EOF when underlying reader returns it", func(t *testing.T) { + er := escapingReader{r: bytes.NewReader([]byte("1\n34"))} + buf := make([]byte, 5) + n, err := er.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, []byte{'1', '\\', 'n', '3', '4'}, buf[:n]) + _, err = er.Read(buf) + require.ErrorIs(t, err, io.EOF) + }) +} diff --git a/pkg/cli/testutils.go b/pkg/cli/testutils.go index 8b1d52c3649e..ced60ae33e35 100644 --- a/pkg/cli/testutils.go +++ b/pkg/cli/testutils.go @@ -58,7 +58,7 @@ type TestCLI struct { // t is the testing.T instance used for this test. // Example_xxx tests may have this set to nil. - t *testing.T + t testing.TB // logScope binds the lifetime of the log files to this test, when t // is not nil. logScope *log.TestLogScope @@ -72,7 +72,7 @@ type TestCLI struct { // TestCLIParams contains parameters used by TestCLI. type TestCLIParams struct { - T *testing.T + T testing.TB Insecure bool // NoServer, if true, starts the test without a DB server. NoServer bool diff --git a/pkg/cli/userfile.go b/pkg/cli/userfile.go index 153449e5b0d4..170b016def1f 100644 --- a/pkg/cli/userfile.go +++ b/pkg/cli/userfile.go @@ -11,7 +11,6 @@ package cli import ( - "bytes" "context" "fmt" "io" @@ -609,20 +608,7 @@ func uploadUserFile( } stmt := sql.CopyInFileStmt(unescapedUserfileURL, sql.CrdbInternalName, sql.UserFileUploadTable) - send := make([]byte, 0) - tmp := make([]byte, chunkSize) - for { - n, err := reader.Read(tmp) - if n > 0 { - send = appendEscapedText(send, string(tmp[:n])) - } else if err == io.EOF { - break - } else if err != nil { - return "", err - } - } - - if _, err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { + if _, err := ex.CopyFrom(ctx, &escapingReader{r: reader}, stmt); err != nil { return "", err } diff --git a/pkg/cli/userfiletable_test.go b/pkg/cli/userfiletable_test.go index 1da9a5b5009e..fe0a3a78eab8 100644 --- a/pkg/cli/userfiletable_test.go +++ b/pkg/cli/userfiletable_test.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/ioctx" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -357,13 +358,13 @@ func Example_userfile_upload_recursive() { func checkUserFileContent( ctx context.Context, - t *testing.T, - execcCfg interface{}, + t testing.TB, + execCfg interface{}, user username.SQLUsername, userfileURI string, expectedContent []byte, ) { - store, err := execcCfg.(sql.ExecutorConfig).DistSQLSrv.ExternalStorageFromURI(ctx, + store, err := execCfg.(sql.ExecutorConfig).DistSQLSrv.ExternalStorageFromURI(ctx, userfileURI, user) require.NoError(t, err) reader, _, err := store.ReadFile(ctx, "", cloud.ReadOptions{NoFileSize: true}) @@ -373,6 +374,30 @@ func checkUserFileContent( require.True(t, bytes.Equal(got, expectedContent)) } +func BenchmarkUserfileUpload(b *testing.B) { + c := NewCLITest(TestCLIParams{T: b}) + defer c.Cleanup() + + dir, cleanFn := testutils.TempDir(b) + defer cleanFn() + + dataSize := 64 << 20 + rnd, _ := randutil.NewTestRand() + content := randutil.RandBytes(rnd, dataSize) + + filePath := filepath.Join(dir, "testfile") + err := os.WriteFile(filePath, content, 0666) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + b.SetBytes(int64(dataSize)) + for n := 0; n < b.N; n++ { + _, err = c.RunWithCapture(fmt.Sprintf("userfile upload %s %s", filePath, fmt.Sprintf("%s-%d", filePath, n))) + require.NoError(b, err) + } +} + func TestUserFileUploadRecursive(t *testing.T) { defer leaktest.AfterTest(t)()