diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index b44808bc5222..ad1e228b3f5c 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -41,6 +41,7 @@ ALL_TESTS = [ "//pkg/ccl/spanconfigccl/spanconfigsqlwatcherccl:spanconfigsqlwatcherccl_test", "//pkg/ccl/sqlproxyccl/denylist:denylist_test", "//pkg/ccl/sqlproxyccl/idle:idle_test", + "//pkg/ccl/sqlproxyccl/interceptor:interceptor_test", "//pkg/ccl/sqlproxyccl/tenant:tenant_test", "//pkg/ccl/sqlproxyccl/throttler:throttler_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", diff --git a/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel b/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel new file mode 100644 index 000000000000..d272bf3ebe7d --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "interceptor", + srcs = ["interceptor.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgwirebase", + "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgproto3_v2//:pgproto3", + ], +) + +go_test( + name = "interceptor_test", + srcs = ["interceptor_test.go"], + embed = [":interceptor"], + deps = [ + "//pkg/sql/pgwire/pgwirebase", + "//pkg/util/leaktest", + "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgproto3_v2//:pgproto3", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor.go new file mode 100644 index 000000000000..ed7a48aeef55 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor.go @@ -0,0 +1,461 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package interceptor + +import ( + "encoding/binary" + "io" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" +) + +// pgHeaderSizeBytes represents the number of bytes of a pgwire message header +// (i.e. one byte for type, and an int for the body size, inclusive of the +// length itself). +const pgHeaderSizeBytes = 5 + +// ErrSmallBuffer indicates that the requested buffer for the interceptor is +// too small. +var ErrSmallBuffer = errors.New("buffer is too small") + +// ErrInterceptorClosed is the returned error whenever the intercept is closed. +// When this happens, the caller should terminate both dst and src to guarantee +// correctness. +var ErrInterceptorClosed = errors.New("interceptor is closed") + +// ErrProtocolError indicates that the packets are malformed, and are not as +// expected. +var ErrProtocolError = errors.New("protocol error") + +// pgInterceptor provides a convenient way to read and forward Postgres +// messages, while minimizing IO reads and memory allocations. +// +// NOTE: Methods on the interceptor are not thread-safe. +type pgInterceptor struct { + src io.Reader + dst io.Writer + + // buf stores bytes which have been read, but have not been processed yet. + buf []byte + // readPos and writePos indicates the read and write pointers for bytes in + // the buffer buf. + readPos, writePos int + + // closed indicates that the interceptor is closed. This will be set to + // true whenever there's an error within one of the interceptor's operations + // leading to an ambiguity. Once an interceptor is closed, all subsequent + // method calls on the interceptor will return ErrInterceptorClosed. + closed bool +} + +// newPgInterceptor creates a new instance of the interceptor with an internal +// buffer of bufSize bytes. bufSize must be at least the size of a pgwire +// message header. +func newPgInterceptor(src io.Reader, dst io.Writer, bufSize int) (*pgInterceptor, error) { + // The internal buffer must be able to fit the header. + if bufSize < pgHeaderSizeBytes { + return nil, ErrSmallBuffer + } + return &pgInterceptor{ + src: src, + dst: dst, + buf: make([]byte, bufSize), + }, nil +} + +// ensureNextNBytes blocks on IO reads until the buffer has at least n bytes. +func (p *pgInterceptor) ensureNextNBytes(n int) error { + if n < 0 || n > len(p.buf) { + return errors.AssertionFailedf( + "invalid number of bytes %d for buffer size %d", n, len(p.buf)) + } + + // Buffer already has n bytes. + if p.ReadSize() >= n { + return nil + } + + // Not enough empty slots to fit the unread bytes, so re-align bytes. + minReadCount := n - p.ReadSize() + if p.WriteSize() < minReadCount { + p.writePos = copy(p.buf, p.buf[p.readPos:p.writePos]) + p.readPos = 0 + } + + c, err := io.ReadAtLeast(p.src, p.buf[p.writePos:], minReadCount) + p.writePos += c + return err +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. On return, err == nil if and only if the entire header can +// be read. Note that size corresponds to the body size, and does not account +// for the size field itself. This will return ErrProtocolError if the packets +// are malformed. +// +// If the interceptor is closed, PeekMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { + if p.closed { + return 0, 0, ErrInterceptorClosed + } + + if err := p.ensureNextNBytes(pgHeaderSizeBytes); err != nil { + // Possibly due to a timeout or context cancellation. + return 0, 0, err + } + + typ = p.buf[p.readPos] + size = int(binary.BigEndian.Uint32(p.buf[p.readPos+1:])) - 4 + + // Size has to be at least itself based on pgwire's protocol. + if size < 0 { + return 0, 0, ErrProtocolError + } + + return typ, size, nil +} + +// WriteMsg writes the given bytes to the writer dst. If err != nil and a Write +// was attempted, the interceptor will be closed. +// +// If the interceptor is closed, WriteMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) WriteMsg(data []byte) (n int, err error) { + if p.closed { + return 0, ErrInterceptorClosed + } + defer func() { + // Close the interceptor if there was an error. Theoretically, we only + // need to close here if n > 0, but for consistency with the other + // methods, we will do that here too. + if err != nil { + p.Close() + } + }() + return p.dst.Write(data) +} + +// ReadMsg returns the current pgwire message data in bytes and its type. It +// also advances the interceptor to the next message. On return, the type and +// body fields are valid if and only if err == nil. If err != nil and a Read +// was attempted because the buffer did not have enough bytes, the interceptor +// will be closed. +// +// The interceptor retains ownership of all the memory returned by ReadMsg; the +// caller is allowed to hold on to this memory *until* the next moment other +// methods on the interceptor are called. The data will only be valid until then +// as well. +// +// If the interceptor is closed, ReadMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) ReadMsg() (typ byte, body []byte, err error) { + // Technically this is redundant since PeekMsg will do the same thing, but + // we do so here for clarity. + if p.closed { + return 0, nil, ErrInterceptorClosed + } + + // Read header of the current message for body size. + typ, size, err := p.PeekMsg() + if err != nil { + return 0, nil, err + } + p.readPos += pgHeaderSizeBytes + + // Can the entire message fit into the buffer? + if size <= len(p.buf) { + if err := p.ensureNextNBytes(size); err != nil { + // Possibly due to a timeout or context cancellation. + return 0, nil, err + } + + // Return a slice to the internal buffer to avoid an allocation here. + retBuf := p.buf[p.readPos : p.readPos+size] + p.readPos += size + return typ, retBuf, nil + } + + // Message cannot fit, so we will have to allocate. + body = make([]byte, size) + + // Copy bytes which have already been read. + toCopy := size + if p.ReadSize() < size { + toCopy = p.ReadSize() + } + n := copy(body, p.buf[p.readPos:p.readPos+toCopy]) + p.readPos += n // toCopy has to be the same as n here. + + defer func() { + // Close the interceptor because we read the data (both buffered and + // possibly newer ones) into body, which is larger than the buffer's + // size, and there's no easy way to recover. We could technically fix + // some of the situations, especially when no bytes were read, but at + // this point, it's likely that the one end of the interceptor is + // already gone, or the proxy is shutting down, so there's no point + // trying to save a disconnect. + if err != nil { + p.Close() + } + }() + + // Read more bytes. + if _, err := io.ReadFull(p.src, body[n:]); err != nil { + return 0, nil, err + } + + return typ, body, nil +} + +// ForwardMsg sends the current pgwire message to the destination, and advances +// the interceptor to the next message. On return, n == pgwire message size if +// and only if err == nil. If err != nil and a Write was attempted, the +// interceptor will be closed. +// +// If the interceptor is closed, ForwardMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) ForwardMsg() (n int, err error) { + // Technically this is redundant since PeekMsg will do the same thing, but + // we do so here for clarity. + if p.closed { + return 0, ErrInterceptorClosed + } + + // Retrieve header of the current message for body size. + _, size, err := p.PeekMsg() + if err != nil { + return 0, err + } + + // Handle overflows as current message may not fit in the current buffer. + startPos := p.readPos + endPos := startPos + pgHeaderSizeBytes + size + remainingBytes := 0 + if endPos > p.writePos { + remainingBytes = endPos - p.writePos + endPos = p.writePos + } + p.readPos = endPos + + defer func() { + // State may be invalid depending on whether bytes have been written. + // To reduce complexity, we'll just close the interceptor, and the + // caller should just terminate both ends. + // + // If src has been closed, the dst state may be invalid. If dst has been + // closed, buffered bytes no longer represent the protocol correctly + // even if we slurped the remaining bytes for the current message. + if err != nil { + p.Close() + } + }() + + // Forward the message to the destination. + n, err = p.dst.Write(p.buf[startPos:endPos]) + if err != nil { + return n, err + } + // n shouldn't be larger than the size of the buffer unless the + // implementation of Write for dst is incorrect. This shouldn't be the case + // if we're using a TCP connection here. + if n < endPos-startPos { + return n, io.ErrShortWrite + } + + // Message was partially buffered, so copy the remaining. + if remainingBytes > 0 { + m, err := io.CopyN(p.dst, p.src, int64(remainingBytes)) + n += int(m) + if err != nil { + return n, err + } + // n shouldn't be larger than remainingBytes unless the internal Read + // and Write calls for either of src or dst are incorrect. This + // shouldn't be the case if we're using a TCP connection here. + if int(m) < remainingBytes { + return n, io.ErrShortWrite + } + } + return n, nil +} + +// ReadSize returns the number of bytes read by the interceptor. If the +// interceptor is closed, this will return 0. +func (p *pgInterceptor) ReadSize() int { + if p.closed { + return 0 + } + return p.writePos - p.readPos +} + +// WriteSize returns the remaining number of bytes that could fit into the +// internal buffer before needing to be re-aligned. If the interceptor is +// closed, this will return 0. +func (p *pgInterceptor) WriteSize() int { + if p.closed { + return 0 + } + return len(p.buf) - p.writePos +} + +// Close closes the interceptor, and prevents further operations on it. +func (p *pgInterceptor) Close() { + p.closed = true +} + +var errInvalidRead = errors.New("invalid read in chunkReader") + +var _ pgproto3.ChunkReader = &chunkReader{} + +// chunkReader is a wrapper on a single Postgres message, and is meant to be +// used with the Receive method on pgproto3.{NewFrontend, NewBackend}. +type chunkReader struct { + header [5]byte + body []byte + pos int +} + +func newChunkReader(typ byte, body []byte) pgproto3.ChunkReader { + cr := &chunkReader{body: body} + cr.header[0] = typ + binary.BigEndian.PutUint32(cr.header[1:], uint32(len(body)+4)) + return cr +} + +// Next implements the pgproto3.ChunkReader interface. This implements a tiny +// state machine where Next has to be called in the following order: n=5 to +// read headers, followed by n=len(body) to read the message body. An io.EOF +// will be returned once the entire message has been read. If the state machine +// protocol isn't obeyed, an errInvalidRead will be returned. +func (cr *chunkReader) Next(n int) (buf []byte, err error) { + switch cr.pos { + case 0: + // Only the headers can be requested. + if n != 5 { + return nil, errInvalidRead + } + cr.pos += n + return cr.header[:], nil + case 5: + // After header is read, only the body can be requested. + if n != len(cr.body) { + return nil, errInvalidRead + } + cr.pos += n + return cr.body, nil + default: + return nil, io.EOF + } +} + +// FrontendInterceptor is a client interceptor for the Postgres frontend +// protocol. +type FrontendInterceptor struct { + p *pgInterceptor +} + +// NewFrontendInterceptor creates a FrontendInterceptor. bufSize must be at +// least the size of a pgwire message header. +func NewFrontendInterceptor( + src io.Reader, dst io.Writer, bufSize int, +) (*FrontendInterceptor, error) { + pgi, err := newPgInterceptor(src, dst, bufSize) + if err != nil { + return nil, err + } + return &FrontendInterceptor{p: pgi}, nil +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. See pgInterceptor.PeekMsg for more information. +func (fi *FrontendInterceptor) PeekMsg() (typ pgwirebase.ServerMessageType, size int, err error) { + byteType, size, err := fi.p.PeekMsg() + return pgwirebase.ServerMessageType(byteType), size, err +} + +// WriteMsg writes the given bytes to the writer dst. See pgInterceptor.WriteMsg +// for more information. +func (fi *FrontendInterceptor) WriteMsg(data pgproto3.BackendMessage) (n int, err error) { + return fi.p.WriteMsg(data.Encode(nil)) +} + +// ReadMsg decodes the current pgwire message and returns a BackendMessage. +// This also advances the interceptor to the next message. See +// pgInterceptor.ReadMsg for more information. +func (fi *FrontendInterceptor) ReadMsg() (msg pgproto3.BackendMessage, err error) { + typ, body, err := fi.p.ReadMsg() + if err != nil { + return nil, err + } + // errPanicWriter is used here because Receive must not Write. + return pgproto3.NewFrontend(newChunkReader(typ, body), &errPanicWriter{}).Receive() +} + +// ForwardMsg sends the current pgwire message to the destination without any +// decoding, and advances the interceptor to the next message. See +// pgInterceptor.ForwardMsg for more information. +func (fi *FrontendInterceptor) ForwardMsg() (n int, err error) { + return fi.p.ForwardMsg() +} + +// BackendInterceptor is a server interceptor for the Postgres backend protocol. +type BackendInterceptor struct { + p *pgInterceptor +} + +// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least +// the size of a pgwire message header. +func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) { + pgi, err := newPgInterceptor(src, dst, bufSize) + if err != nil { + return nil, err + } + return &BackendInterceptor{p: pgi}, nil +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. See pgInterceptor.PeekMsg for more information. +func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size int, err error) { + byteType, size, err := bi.p.PeekMsg() + return pgwirebase.ClientMessageType(byteType), size, err +} + +// WriteMsg writes the given bytes to the writer dst. See pgInterceptor.WriteMsg +// for more information. +func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) { + return bi.p.WriteMsg(data.Encode(nil)) +} + +// ReadMsg decodes the current pgwire message and returns a FrontendMessage. +// This also advances the interceptor to the next message. See +// pgInterceptor.ReadMsg for more information. +func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error) { + typ, body, err := bi.p.ReadMsg() + if err != nil { + return nil, err + } + // errPanicWriter is used here because Receive must not Write. + return pgproto3.NewBackend(newChunkReader(typ, body), &errPanicWriter{}).Receive() +} + +// ForwardMsg sends the current pgwire message to the destination without any +// decoding, and advances the interceptor to the next message. See +// pgInterceptor.ForwardMsg for more information. +func (bi *BackendInterceptor) ForwardMsg() (n int, err error) { + return bi.p.ForwardMsg() +} + +var _ io.Writer = &errPanicWriter{} + +// errPanicWriter is an io.Writer that panics whenever a Write call is made. +type errPanicWriter struct{} + +// Write implements the io.Writer interface. +func (w *errPanicWriter) Write(p []byte) (int, error) { + panic("unexpected Write call") +} diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go new file mode 100644 index 000000000000..83a71d99f9b6 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go @@ -0,0 +1,847 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package interceptor + +import ( + "bytes" + "io" + "testing" + "testing/iotest" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestNewPgInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + reader, writer := io.Pipe() + + // Negative buffer size. + pgi, err := newPgInterceptor(reader, writer, -1) + require.EqualError(t, err, ErrSmallBuffer.Error()) + require.Nil(t, pgi) + + // Small buffer size. + pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes-1) + require.EqualError(t, err, ErrSmallBuffer.Error()) + require.Nil(t, pgi) + + // Buffer that fits the header exactly. + pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes) + require.NoError(t, err) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, pgHeaderSizeBytes) + + // Normal buffer size. + pgi, err = newPgInterceptor(reader, writer, 1024) + require.NoError(t, err) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, 1024) + require.Equal(t, reader, pgi.src) + require.Equal(t, writer, pgi.dst) +} + +func TestPGInterceptor_ensureNextNBytes(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("invalid n", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 8) + require.NoError(t, err) + + require.EqualError(t, pgi.ensureNextNBytes(-1), + "invalid number of bytes -1 for buffer size 8") + require.EqualError(t, pgi.ensureNextNBytes(9), + "invalid number of bytes 9 for buffer size 8") + }) + + t.Run("buffer already has n bytes", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 8) + require.NoError(t, err) + + // Read "foo" into buffer". + require.NoError(t, pgi.ensureNextNBytes(3)) + + // These should not read anything since we expect the buffer to + // have three bytes. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 6, buf.Len()) + require.NoError(t, pgi.ensureNextNBytes(0)) + require.Equal(t, 6, buf.Len()) + require.NoError(t, pgi.ensureNextNBytes(1)) + require.Equal(t, 6, buf.Len()) + + // Verify that buf actually has "foo". + require.Equal(t, "foo", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) + + t.Run("bytes are realigned", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazcar") + + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 9) + require.NoError(t, err) + + // Read "foobarb" into buffer. + require.NoError(t, pgi.ensureNextNBytes(7)) + + // Assume "foobar" is read. + pgi.readPos += 6 + + // Now ensure that we have 6 bytes. + require.NoError(t, pgi.ensureNextNBytes(6)) + require.Equal(t, 0, buf.Len()) + + // Verify that buf has "bazcar". + require.Equal(t, "bazcar", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) + + t.Run("bytes are read greedily", func(t *testing.T) { + // This tests that we read as much as we can into the internal buffer + // if there was a Read call. + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + // Request for only 1 byte. + require.NoError(t, pgi.ensureNextNBytes(1)) + + // Verify that buf has "foobarbaz". + require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) + + // Should be a no-op. + _, err = buf.WriteString("car") + require.NoError(t, err) + require.NoError(t, pgi.ensureNextNBytes(9)) + require.Equal(t, 3, buf.Len()) + require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) +} + +func TestPGInterceptor_PeekMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("read error", func(t *testing.T) { + r := iotest.ErrReader(errors.New("read error")) + + pgi, err := newPgInterceptor(r, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, "read error") + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("protocol error", func(t *testing.T) { + data := make([]byte, 10) + buf := new(bytes.Buffer) + _, err := buf.Write(data) + require.NoError(t, err) + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, ErrProtocolError.Error()) + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("successful", func(t *testing.T) { + buf := new(bytes.Buffer) + _, err := buf.Write((&pgproto3.Query{String: "SELECT 1"}).Encode(nil)) + require.NoError(t, err) + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + require.Equal(t, 9, size) + require.Equal(t, 4, buf.Len()) + + // Invoking Peek should not advance the interceptor. + typ, size, err = pgi.PeekMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + require.Equal(t, 9, size) + require.Equal(t, 4, buf.Len()) + }) +} + +func TestPGInterceptor_WriteMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + n, err := pgi.WriteMsg([]byte{}) + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, 0, n) + }) + + t.Run("write error", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, &errReadWriter{w: io.Discard}, 10) + require.NoError(t, err) + + n, err := pgi.WriteMsg([]byte{}) + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, 0, n) + require.True(t, pgi.closed) + }) + + t.Run("successful", func(t *testing.T) { + buf := new(bytes.Buffer) + pgi, err := newPgInterceptor(nil /* src */, buf, 10) + require.NoError(t, err) + + n, err := pgi.WriteMsg([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.False(t, pgi.closed) + require.Equal(t, "hello", buf.String()) + }) +} + +func TestPGInterceptor_ReadMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + typ, body, err := pgi.ReadMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, byte(0), typ) + require.Nil(t, body) + }) + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() + src := new(bytes.Buffer) + for i := 0; i < count; i++ { + // Alternate between SELECT 1 and 2 to ensure correctness. + if i%2 == 0 { + q[12] = '1' + } else { + q[12] = '2' + } + _, err := src.Write(q) + require.NoError(t, err) + } + return src + } + + t.Run("message fits", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + buf := buildSrc(t, count) + + // Set buffer's size to be a multiple of the message so that we'll + // always hit the case where the message fits. + pgi, err := newPgInterceptor(buf, nil /* dst */, len(q)*3) + require.NoError(t, err) + + c := 0 + n := testing.AllocsPerRun(count-1, func() { + typ, body, err := pgi.ReadMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + + expectedStr := "SELECT 1\x00" + if c%2 == 1 { + expectedStr = "SELECT 2\x00" + } + + // Using require.Equal here will result in 2 allocs. + if string(body) != expectedStr { + t.Fatalf(`expected %q, got: %q`, expectedStr, string(body)) + } + c++ + }) + require.Equal(t, float64(0), n, "should not allocate") + require.Equal(t, 0, buf.Len()) + }) + + t.Run("message overflows", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + buf := buildSrc(t, count) + + // Set the buffer to be large enough to fit more bytes than the header, + // but not the entire message. + pgi, err := newPgInterceptor(buf, nil /* dst */, 7) + require.NoError(t, err) + + c := 0 + n := testing.AllocsPerRun(count-1, func() { + typ, body, err := pgi.ReadMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + + expectedStr := "SELECT 1\x00" + if c%2 == 1 { + expectedStr = "SELECT 2\x00" + } + + // Using require.Equal here will result in 2 allocs. + if string(body) != expectedStr { + t.Fatalf(`expected %q, got: %q`, expectedStr, string(body)) + } + c++ + }) + // Ensure that we only have 1 allocation. We could technically improve + // this by ensuring that one pool of memory is used to reduce the number + // of allocations, but ReadMsg is only called during a transfer session, + // so there's very little benefit to optimizing for that. + require.Equal(t, float64(1), n) + require.Equal(t, 0, buf.Len()) + }) + + t.Run("read error after allocate", func(t *testing.T) { + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + buf := new(bytes.Buffer) + _, err := buf.Write(q) + require.NoError(t, err) + + src := &errReadWriter{r: buf, count: 2} + pgi, err := newPgInterceptor(src, nil /* dst */, 6) + require.NoError(t, err) + + typ, body, err := pgi.ReadMsg() + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, byte(0), typ) + require.Nil(t, body) + + // Ensure that interceptor is closed. + require.True(t, pgi.closed) + require.Equal(t, 8, buf.Len()) + }) +} + +func TestPGInterceptor_ForwardMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + n, err := pgi.ForwardMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, 0, n) + }) + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() + src := new(bytes.Buffer) + for i := 0; i < count; i++ { + // Alternate between SELECT 1 and 2 to ensure correctness. + if i%2 == 0 { + q[12] = '1' + } else { + q[12] = '2' + } + _, err := src.Write(q) + require.NoError(t, err) + } + return src + } + + validateDst := func(t *testing.T, dst io.Reader, count int) { + t.Helper() + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(dst), nil /* w */) + for i := 0; i < count; i++ { + msg, err := backend.Receive() + require.NoError(t, err) + q := msg.(*pgproto3.Query) + + expectedStr := "SELECT 1" + if i%2 == 1 { + expectedStr = "SELECT 2" + } + require.Equal(t, expectedStr, q.String) + } + } + + t.Run("message fits", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + src := buildSrc(t, count) + dst := new(bytes.Buffer) + + // Set buffer's size to be a multiple of the message so that we'll + // always hit the case where the message fits. + pgi, err := newPgInterceptor(src, dst, len(q)*3) + require.NoError(t, err) + + // Forward all the messages, and ensure 0 allocations. + n := testing.AllocsPerRun(count-1, func() { + n, err := pgi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + }) + require.Equal(t, float64(0), n, "should not allocate") + require.Equal(t, 0, src.Len()) + + // Validate messages. + validateDst(t, dst, count) + require.Equal(t, 0, dst.Len()) + }) + + t.Run("message overflows", func(t *testing.T) { + const count = 151 // Inclusive of warm-up run in AllocsPerRun. + + src := buildSrc(t, count) + dst := new(bytes.Buffer) + + // Set the buffer to be large enough to fit more bytes than the header, + // but not the entire message. + pgi, err := newPgInterceptor(src, dst, 7) + require.NoError(t, err) + + n := testing.AllocsPerRun(count-1, func() { + n, err := pgi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + }) + // NOTE: This allocation is benign, and is due to the fact that io.CopyN + // allocates an internal buffer in copyBuffer. This wouldn't happen if + // a TCP connection is used as the destination since there's a fast-path + // that prevents that. + // + // See: https://cs.opensource.google/go/go/+/refs/tags/go1.17.6:src/io/io.go;l=402-410;drc=refs%2Ftags%2Fgo1.17.6 + require.Equal(t, float64(1), n) + require.Equal(t, 0, src.Len()) + + // Validate messages. + validateDst(t, dst, count) + require.Equal(t, 0, dst.Len()) + }) + + t.Run("write error", func(t *testing.T) { + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + pgi, err := newPgInterceptor(src, &errReadWriter{w: dst, count: 2}, 6) + require.NoError(t, err) + + n, err := pgi.ForwardMsg() + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, 6, n) + + // Ensure that interceptor is closed. + require.True(t, pgi.closed) + require.Equal(t, 6, dst.Len()) + }) +} + +func TestPGInterceptor_ReadSize(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(buf, nil /* dst */, 9) + require.NoError(t, err) + require.NoError(t, pgi.ensureNextNBytes(1)) + + require.Equal(t, 9, pgi.ReadSize()) + pgi.Close() + require.Equal(t, 0, pgi.ReadSize()) + }) + + t.Run("valid", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazz") + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) + require.NoError(t, err) + + // No reads to internal buffer. + require.Equal(t, 0, pgi.ReadSize()) + + // Attempt reads to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 3, pgi.ReadSize()) + + // Read until buffer is full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 10, pgi.ReadSize()) + }) +} + +func TestPGInterceptor_WriteSize(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 9) + require.NoError(t, err) + + require.Equal(t, 9, pgi.WriteSize()) + pgi.Close() + require.Equal(t, 0, pgi.WriteSize()) + }) + + t.Run("valid", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazz") + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) + require.NoError(t, err) + + // No writes to internal buffer. + require.Equal(t, 10, pgi.WriteSize()) + + // Attempt writes to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 7, pgi.WriteSize()) + + // Attempt more writes to buffer until full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 0, pgi.WriteSize()) + }) +} + +func TestPGInterceptor_Close(t *testing.T) { + defer leaktest.AfterTest(t)() + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + require.False(t, pgi.closed) + pgi.Close() + require.True(t, pgi.closed) +} + +// TestFrontendInterceptor tests the FrontendInterceptor. Note that the tests +// here are shallow. For detailed ones, see the tests for the internal +// interceptor. +func TestFrontendInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + q := (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil) + + t.Run("bufSize too small", func(t *testing.T) { + fi, err := NewFrontendInterceptor(nil /* src */, nil /* dst */, 1) + require.Error(t, err) + require.Nil(t, fi) + }) + + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + fi, err := NewFrontendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + typ, size, err := fi.PeekMsg() + require.NoError(t, err) + require.Equal(t, pgwirebase.ServerMsgReady, typ) + require.Equal(t, 1, size) + }) + + t.Run("WriteMsg writes data to dst", func(t *testing.T) { + dst := new(bytes.Buffer) + fi, err := NewFrontendInterceptor(nil /* src */, dst, 10) + require.NoError(t, err) + require.NotNil(t, fi) + + // This is a frontend interceptor, so writing goes to the client. + n, err := fi.WriteMsg(&pgproto3.ReadyForQuery{TxStatus: 'I'}) + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, 6, dst.Len()) + }) + + t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + fi, err := NewFrontendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + msg, err := fi.ReadMsg() + require.NoError(t, err) + rmsg, ok := msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + require.Equal(t, byte('I'), rmsg.TxStatus) + }) + + t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + fi, err := NewFrontendInterceptor(src, dst, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + n, err := fi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, 6, dst.Len()) + }) +} + +// TestBackendInterceptor tests the BackendInterceptor. Note that the tests +// here are shallow. For detailed ones, see the tests for the internal +// interceptor. +func TestBackendInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + t.Run("bufSize too small", func(t *testing.T) { + bi, err := NewBackendInterceptor(nil /* src */, nil /* dst */, 1) + require.Error(t, err) + require.Nil(t, bi) + }) + + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + bi, err := NewBackendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + typ, size, err := bi.PeekMsg() + require.NoError(t, err) + require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ) + require.Equal(t, 9, size) + }) + + t.Run("WriteMsg writes data to dst", func(t *testing.T) { + dst := new(bytes.Buffer) + bi, err := NewBackendInterceptor(nil /* src */, dst, 10) + require.NoError(t, err) + require.NotNil(t, bi) + + // This is a backend interceptor, so writing goes to the server. + n, err := bi.WriteMsg(&pgproto3.Query{String: "SELECT 1"}) + require.NoError(t, err) + require.Equal(t, 14, n) + require.Equal(t, 14, dst.Len()) + }) + + t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + bi, err := NewBackendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + msg, err := bi.ReadMsg() + require.NoError(t, err) + rmsg, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SELECT 1", rmsg.String) + }) + + t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + bi, err := NewBackendInterceptor(src, dst, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + n, err := bi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + require.Equal(t, 14, dst.Len()) + }) +} + +// TestSimpleProxy illustrates how the frontend and backend interceptors can be +// used as a proxy. +func TestSimpleProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + const bufferSize = 16 + + // These represents connections for client<->proxy and proxy<->server. + fromClient := new(bytes.Buffer) + toClient := new(bytes.Buffer) + fromServer := new(bytes.Buffer) + toServer := new(bytes.Buffer) + + // Create client and server interceptors. + clientInt, err := NewBackendInterceptor(fromClient, toServer, bufferSize) + require.NoError(t, err) + serverInt, err := NewFrontendInterceptor(fromServer, toClient, bufferSize) + require.NoError(t, err) + + t.Run("client to server", func(t *testing.T) { + // Client sends a list of SQL queries. + queries := []pgproto3.FrontendMessage{ + &pgproto3.Query{String: "SELECT 1"}, + &pgproto3.Query{String: "SELECT * FROM foo.bar"}, + &pgproto3.Query{String: "UPDATE foo SET x = 42"}, + &pgproto3.Sync{}, + &pgproto3.Terminate{}, + } + for _, msg := range queries { + _, err := fromClient.Write(msg.Encode(nil)) + require.NoError(t, err) + } + totalBytes := fromClient.Len() + + customQuery := &pgproto3.Query{ + String: "SELECT * FROM crdb_internal.serialize_session()"} + + for { + typ, _, err := clientInt.PeekMsg() + require.NoError(t, err) + + // Forward message to server. + _, err = clientInt.ForwardMsg() + require.NoError(t, err) + + if typ == pgwirebase.ClientMsgTerminate { + // Right before we terminate, we could also craft a custom + // message, and send it to the server. + _, err := clientInt.WriteMsg(customQuery) + require.NoError(t, err) + break + } + } + require.Equal(t, 0, fromClient.Len()) + require.Equal(t, totalBytes+len(customQuery.Encode(nil)), toServer.Len()) + }) + + t.Run("server to client", func(t *testing.T) { + // Server sends back responses. + queries := []pgproto3.BackendMessage{ + // Forward these back to the client. + &pgproto3.CommandComplete{CommandTag: []byte("averylongstring")}, + &pgproto3.BackendKeyData{ProcessID: 100, SecretKey: 42}, + // Do not forward back to the client. + &pgproto3.CommandComplete{CommandTag: []byte("short")}, + // Terminator. + &pgproto3.ReadyForQuery{}, + } + for _, msg := range queries { + _, err := fromServer.Write(msg.Encode(nil)) + require.NoError(t, err) + } + // Exclude bytes from second message. + totalBytes := fromServer.Len() - len(queries[2].Encode(nil)) + + for { + typ, size, err := serverInt.PeekMsg() + require.NoError(t, err) + + switch typ { + case pgwirebase.ServerMsgCommandComplete: + // Assuming that we're only interested in small messages, then + // we could skip all the large ones. + if size > 12 { + _, err := serverInt.ForwardMsg() + require.NoError(t, err) + continue + } + + // Decode message. + msg, err := serverInt.ReadMsg() + require.NoError(t, err) + + // Once we've decoded the message, we could store the message + // somewhere, and not forward it back to the client. + dmsg, ok := msg.(*pgproto3.CommandComplete) + require.True(t, ok) + require.Equal(t, "short", string(dmsg.CommandTag)) + case pgwirebase.ServerMsgBackendKeyData: + msg, err := serverInt.ReadMsg() + require.NoError(t, err) + + dmsg, ok := msg.(*pgproto3.BackendKeyData) + require.True(t, ok) + + // We could even rewrite the message before sending it back to + // the client. + dmsg.SecretKey = 100 + + _, err = serverInt.WriteMsg(dmsg) + require.NoError(t, err) + default: + // Forward message that we're not interested to the client. + _, err := serverInt.ForwardMsg() + require.NoError(t, err) + } + + if typ == pgwirebase.ServerMsgReady { + break + } + } + require.Equal(t, 0, fromServer.Len()) + require.Equal(t, totalBytes, toClient.Len()) + }) +} + +var _ io.Reader = &errReadWriter{} +var _ io.Writer = &errReadWriter{} + +// errReadWriter returns io.ErrClosedPipe after count reads or writes in total. +type errReadWriter struct { + r io.Reader + w io.Writer + count int +} + +// Read implements the io.Reader interface. +func (rw *errReadWriter) Read(p []byte) (int, error) { + rw.count-- + if rw.count <= 0 { + return 0, io.ErrClosedPipe + } + return rw.r.Read(p) +} + +// Write implements the io.Writer interface. +func (rw *errReadWriter) Write(p []byte) (int, error) { + rw.count-- + if rw.count <= 0 { + return 0, io.ErrClosedPipe + } + return rw.w.Write(p) +}