From e02a13d516efe18a8a1f52bd2bea43a4e4f8f9a7 Mon Sep 17 00:00:00 2001 From: Jay Date: Fri, 4 Feb 2022 00:37:57 -0500 Subject: [PATCH] ccl/sqlproxyccl: add postgres interceptors for message forwarding Informs #76000. This commit implements postgres interceptors, namely FrontendInterceptor and BackendInterceptor, as described in the sqlproxy connection migration RFC. These interceptors will be used as building blocks for the forwarder component that we will be adding in a later PR. Since the forwarder component has not been added, a simple proxy test (i.e. TestSimpleProxy) has been added to illustrate how the frontend and backend interceptors can be used within the proxy. Release note: None --- pkg/BUILD.bazel | 1 + pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel | 26 + .../sqlproxyccl/interceptor/interceptor.go | 461 ++++++++++ .../interceptor/interceptor_test.go | 847 ++++++++++++++++++ 4 files changed, 1335 insertions(+) create mode 100644 pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel create mode 100644 pkg/ccl/sqlproxyccl/interceptor/interceptor.go create mode 100644 pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index b44808bc5222..ad1e228b3f5c 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -41,6 +41,7 @@ ALL_TESTS = [ "//pkg/ccl/spanconfigccl/spanconfigsqlwatcherccl:spanconfigsqlwatcherccl_test", "//pkg/ccl/sqlproxyccl/denylist:denylist_test", "//pkg/ccl/sqlproxyccl/idle:idle_test", + "//pkg/ccl/sqlproxyccl/interceptor:interceptor_test", "//pkg/ccl/sqlproxyccl/tenant:tenant_test", "//pkg/ccl/sqlproxyccl/throttler:throttler_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", diff --git a/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel b/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel new file mode 100644 index 000000000000..d272bf3ebe7d --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "interceptor", + srcs = ["interceptor.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgwirebase", + "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgproto3_v2//:pgproto3", + ], +) + +go_test( + name = "interceptor_test", + srcs = ["interceptor_test.go"], + embed = [":interceptor"], + deps = [ + "//pkg/sql/pgwire/pgwirebase", + "//pkg/util/leaktest", + "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgproto3_v2//:pgproto3", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor.go new file mode 100644 index 000000000000..ed7a48aeef55 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor.go @@ -0,0 +1,461 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package interceptor + +import ( + "encoding/binary" + "io" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" +) + +// pgHeaderSizeBytes represents the number of bytes of a pgwire message header +// (i.e. one byte for type, and an int for the body size, inclusive of the +// length itself). +const pgHeaderSizeBytes = 5 + +// ErrSmallBuffer indicates that the requested buffer for the interceptor is +// too small. +var ErrSmallBuffer = errors.New("buffer is too small") + +// ErrInterceptorClosed is the returned error whenever the intercept is closed. +// When this happens, the caller should terminate both dst and src to guarantee +// correctness. +var ErrInterceptorClosed = errors.New("interceptor is closed") + +// ErrProtocolError indicates that the packets are malformed, and are not as +// expected. +var ErrProtocolError = errors.New("protocol error") + +// pgInterceptor provides a convenient way to read and forward Postgres +// messages, while minimizing IO reads and memory allocations. +// +// NOTE: Methods on the interceptor are not thread-safe. +type pgInterceptor struct { + src io.Reader + dst io.Writer + + // buf stores bytes which have been read, but have not been processed yet. + buf []byte + // readPos and writePos indicates the read and write pointers for bytes in + // the buffer buf. + readPos, writePos int + + // closed indicates that the interceptor is closed. This will be set to + // true whenever there's an error within one of the interceptor's operations + // leading to an ambiguity. Once an interceptor is closed, all subsequent + // method calls on the interceptor will return ErrInterceptorClosed. + closed bool +} + +// newPgInterceptor creates a new instance of the interceptor with an internal +// buffer of bufSize bytes. bufSize must be at least the size of a pgwire +// message header. +func newPgInterceptor(src io.Reader, dst io.Writer, bufSize int) (*pgInterceptor, error) { + // The internal buffer must be able to fit the header. + if bufSize < pgHeaderSizeBytes { + return nil, ErrSmallBuffer + } + return &pgInterceptor{ + src: src, + dst: dst, + buf: make([]byte, bufSize), + }, nil +} + +// ensureNextNBytes blocks on IO reads until the buffer has at least n bytes. +func (p *pgInterceptor) ensureNextNBytes(n int) error { + if n < 0 || n > len(p.buf) { + return errors.AssertionFailedf( + "invalid number of bytes %d for buffer size %d", n, len(p.buf)) + } + + // Buffer already has n bytes. + if p.ReadSize() >= n { + return nil + } + + // Not enough empty slots to fit the unread bytes, so re-align bytes. + minReadCount := n - p.ReadSize() + if p.WriteSize() < minReadCount { + p.writePos = copy(p.buf, p.buf[p.readPos:p.writePos]) + p.readPos = 0 + } + + c, err := io.ReadAtLeast(p.src, p.buf[p.writePos:], minReadCount) + p.writePos += c + return err +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. On return, err == nil if and only if the entire header can +// be read. Note that size corresponds to the body size, and does not account +// for the size field itself. This will return ErrProtocolError if the packets +// are malformed. +// +// If the interceptor is closed, PeekMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { + if p.closed { + return 0, 0, ErrInterceptorClosed + } + + if err := p.ensureNextNBytes(pgHeaderSizeBytes); err != nil { + // Possibly due to a timeout or context cancellation. + return 0, 0, err + } + + typ = p.buf[p.readPos] + size = int(binary.BigEndian.Uint32(p.buf[p.readPos+1:])) - 4 + + // Size has to be at least itself based on pgwire's protocol. + if size < 0 { + return 0, 0, ErrProtocolError + } + + return typ, size, nil +} + +// WriteMsg writes the given bytes to the writer dst. If err != nil and a Write +// was attempted, the interceptor will be closed. +// +// If the interceptor is closed, WriteMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) WriteMsg(data []byte) (n int, err error) { + if p.closed { + return 0, ErrInterceptorClosed + } + defer func() { + // Close the interceptor if there was an error. Theoretically, we only + // need to close here if n > 0, but for consistency with the other + // methods, we will do that here too. + if err != nil { + p.Close() + } + }() + return p.dst.Write(data) +} + +// ReadMsg returns the current pgwire message data in bytes and its type. It +// also advances the interceptor to the next message. On return, the type and +// body fields are valid if and only if err == nil. If err != nil and a Read +// was attempted because the buffer did not have enough bytes, the interceptor +// will be closed. +// +// The interceptor retains ownership of all the memory returned by ReadMsg; the +// caller is allowed to hold on to this memory *until* the next moment other +// methods on the interceptor are called. The data will only be valid until then +// as well. +// +// If the interceptor is closed, ReadMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) ReadMsg() (typ byte, body []byte, err error) { + // Technically this is redundant since PeekMsg will do the same thing, but + // we do so here for clarity. + if p.closed { + return 0, nil, ErrInterceptorClosed + } + + // Read header of the current message for body size. + typ, size, err := p.PeekMsg() + if err != nil { + return 0, nil, err + } + p.readPos += pgHeaderSizeBytes + + // Can the entire message fit into the buffer? + if size <= len(p.buf) { + if err := p.ensureNextNBytes(size); err != nil { + // Possibly due to a timeout or context cancellation. + return 0, nil, err + } + + // Return a slice to the internal buffer to avoid an allocation here. + retBuf := p.buf[p.readPos : p.readPos+size] + p.readPos += size + return typ, retBuf, nil + } + + // Message cannot fit, so we will have to allocate. + body = make([]byte, size) + + // Copy bytes which have already been read. + toCopy := size + if p.ReadSize() < size { + toCopy = p.ReadSize() + } + n := copy(body, p.buf[p.readPos:p.readPos+toCopy]) + p.readPos += n // toCopy has to be the same as n here. + + defer func() { + // Close the interceptor because we read the data (both buffered and + // possibly newer ones) into body, which is larger than the buffer's + // size, and there's no easy way to recover. We could technically fix + // some of the situations, especially when no bytes were read, but at + // this point, it's likely that the one end of the interceptor is + // already gone, or the proxy is shutting down, so there's no point + // trying to save a disconnect. + if err != nil { + p.Close() + } + }() + + // Read more bytes. + if _, err := io.ReadFull(p.src, body[n:]); err != nil { + return 0, nil, err + } + + return typ, body, nil +} + +// ForwardMsg sends the current pgwire message to the destination, and advances +// the interceptor to the next message. On return, n == pgwire message size if +// and only if err == nil. If err != nil and a Write was attempted, the +// interceptor will be closed. +// +// If the interceptor is closed, ForwardMsg returns ErrInterceptorClosed. +func (p *pgInterceptor) ForwardMsg() (n int, err error) { + // Technically this is redundant since PeekMsg will do the same thing, but + // we do so here for clarity. + if p.closed { + return 0, ErrInterceptorClosed + } + + // Retrieve header of the current message for body size. + _, size, err := p.PeekMsg() + if err != nil { + return 0, err + } + + // Handle overflows as current message may not fit in the current buffer. + startPos := p.readPos + endPos := startPos + pgHeaderSizeBytes + size + remainingBytes := 0 + if endPos > p.writePos { + remainingBytes = endPos - p.writePos + endPos = p.writePos + } + p.readPos = endPos + + defer func() { + // State may be invalid depending on whether bytes have been written. + // To reduce complexity, we'll just close the interceptor, and the + // caller should just terminate both ends. + // + // If src has been closed, the dst state may be invalid. If dst has been + // closed, buffered bytes no longer represent the protocol correctly + // even if we slurped the remaining bytes for the current message. + if err != nil { + p.Close() + } + }() + + // Forward the message to the destination. + n, err = p.dst.Write(p.buf[startPos:endPos]) + if err != nil { + return n, err + } + // n shouldn't be larger than the size of the buffer unless the + // implementation of Write for dst is incorrect. This shouldn't be the case + // if we're using a TCP connection here. + if n < endPos-startPos { + return n, io.ErrShortWrite + } + + // Message was partially buffered, so copy the remaining. + if remainingBytes > 0 { + m, err := io.CopyN(p.dst, p.src, int64(remainingBytes)) + n += int(m) + if err != nil { + return n, err + } + // n shouldn't be larger than remainingBytes unless the internal Read + // and Write calls for either of src or dst are incorrect. This + // shouldn't be the case if we're using a TCP connection here. + if int(m) < remainingBytes { + return n, io.ErrShortWrite + } + } + return n, nil +} + +// ReadSize returns the number of bytes read by the interceptor. If the +// interceptor is closed, this will return 0. +func (p *pgInterceptor) ReadSize() int { + if p.closed { + return 0 + } + return p.writePos - p.readPos +} + +// WriteSize returns the remaining number of bytes that could fit into the +// internal buffer before needing to be re-aligned. If the interceptor is +// closed, this will return 0. +func (p *pgInterceptor) WriteSize() int { + if p.closed { + return 0 + } + return len(p.buf) - p.writePos +} + +// Close closes the interceptor, and prevents further operations on it. +func (p *pgInterceptor) Close() { + p.closed = true +} + +var errInvalidRead = errors.New("invalid read in chunkReader") + +var _ pgproto3.ChunkReader = &chunkReader{} + +// chunkReader is a wrapper on a single Postgres message, and is meant to be +// used with the Receive method on pgproto3.{NewFrontend, NewBackend}. +type chunkReader struct { + header [5]byte + body []byte + pos int +} + +func newChunkReader(typ byte, body []byte) pgproto3.ChunkReader { + cr := &chunkReader{body: body} + cr.header[0] = typ + binary.BigEndian.PutUint32(cr.header[1:], uint32(len(body)+4)) + return cr +} + +// Next implements the pgproto3.ChunkReader interface. This implements a tiny +// state machine where Next has to be called in the following order: n=5 to +// read headers, followed by n=len(body) to read the message body. An io.EOF +// will be returned once the entire message has been read. If the state machine +// protocol isn't obeyed, an errInvalidRead will be returned. +func (cr *chunkReader) Next(n int) (buf []byte, err error) { + switch cr.pos { + case 0: + // Only the headers can be requested. + if n != 5 { + return nil, errInvalidRead + } + cr.pos += n + return cr.header[:], nil + case 5: + // After header is read, only the body can be requested. + if n != len(cr.body) { + return nil, errInvalidRead + } + cr.pos += n + return cr.body, nil + default: + return nil, io.EOF + } +} + +// FrontendInterceptor is a client interceptor for the Postgres frontend +// protocol. +type FrontendInterceptor struct { + p *pgInterceptor +} + +// NewFrontendInterceptor creates a FrontendInterceptor. bufSize must be at +// least the size of a pgwire message header. +func NewFrontendInterceptor( + src io.Reader, dst io.Writer, bufSize int, +) (*FrontendInterceptor, error) { + pgi, err := newPgInterceptor(src, dst, bufSize) + if err != nil { + return nil, err + } + return &FrontendInterceptor{p: pgi}, nil +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. See pgInterceptor.PeekMsg for more information. +func (fi *FrontendInterceptor) PeekMsg() (typ pgwirebase.ServerMessageType, size int, err error) { + byteType, size, err := fi.p.PeekMsg() + return pgwirebase.ServerMessageType(byteType), size, err +} + +// WriteMsg writes the given bytes to the writer dst. See pgInterceptor.WriteMsg +// for more information. +func (fi *FrontendInterceptor) WriteMsg(data pgproto3.BackendMessage) (n int, err error) { + return fi.p.WriteMsg(data.Encode(nil)) +} + +// ReadMsg decodes the current pgwire message and returns a BackendMessage. +// This also advances the interceptor to the next message. See +// pgInterceptor.ReadMsg for more information. +func (fi *FrontendInterceptor) ReadMsg() (msg pgproto3.BackendMessage, err error) { + typ, body, err := fi.p.ReadMsg() + if err != nil { + return nil, err + } + // errPanicWriter is used here because Receive must not Write. + return pgproto3.NewFrontend(newChunkReader(typ, body), &errPanicWriter{}).Receive() +} + +// ForwardMsg sends the current pgwire message to the destination without any +// decoding, and advances the interceptor to the next message. See +// pgInterceptor.ForwardMsg for more information. +func (fi *FrontendInterceptor) ForwardMsg() (n int, err error) { + return fi.p.ForwardMsg() +} + +// BackendInterceptor is a server interceptor for the Postgres backend protocol. +type BackendInterceptor struct { + p *pgInterceptor +} + +// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least +// the size of a pgwire message header. +func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) { + pgi, err := newPgInterceptor(src, dst, bufSize) + if err != nil { + return nil, err + } + return &BackendInterceptor{p: pgi}, nil +} + +// PeekMsg returns the header of the current pgwire message without advancing +// the interceptor. See pgInterceptor.PeekMsg for more information. +func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size int, err error) { + byteType, size, err := bi.p.PeekMsg() + return pgwirebase.ClientMessageType(byteType), size, err +} + +// WriteMsg writes the given bytes to the writer dst. See pgInterceptor.WriteMsg +// for more information. +func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) { + return bi.p.WriteMsg(data.Encode(nil)) +} + +// ReadMsg decodes the current pgwire message and returns a FrontendMessage. +// This also advances the interceptor to the next message. See +// pgInterceptor.ReadMsg for more information. +func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error) { + typ, body, err := bi.p.ReadMsg() + if err != nil { + return nil, err + } + // errPanicWriter is used here because Receive must not Write. + return pgproto3.NewBackend(newChunkReader(typ, body), &errPanicWriter{}).Receive() +} + +// ForwardMsg sends the current pgwire message to the destination without any +// decoding, and advances the interceptor to the next message. See +// pgInterceptor.ForwardMsg for more information. +func (bi *BackendInterceptor) ForwardMsg() (n int, err error) { + return bi.p.ForwardMsg() +} + +var _ io.Writer = &errPanicWriter{} + +// errPanicWriter is an io.Writer that panics whenever a Write call is made. +type errPanicWriter struct{} + +// Write implements the io.Writer interface. +func (w *errPanicWriter) Write(p []byte) (int, error) { + panic("unexpected Write call") +} diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go new file mode 100644 index 000000000000..83a71d99f9b6 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go @@ -0,0 +1,847 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package interceptor + +import ( + "bytes" + "io" + "testing" + "testing/iotest" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestNewPgInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + reader, writer := io.Pipe() + + // Negative buffer size. + pgi, err := newPgInterceptor(reader, writer, -1) + require.EqualError(t, err, ErrSmallBuffer.Error()) + require.Nil(t, pgi) + + // Small buffer size. + pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes-1) + require.EqualError(t, err, ErrSmallBuffer.Error()) + require.Nil(t, pgi) + + // Buffer that fits the header exactly. + pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes) + require.NoError(t, err) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, pgHeaderSizeBytes) + + // Normal buffer size. + pgi, err = newPgInterceptor(reader, writer, 1024) + require.NoError(t, err) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, 1024) + require.Equal(t, reader, pgi.src) + require.Equal(t, writer, pgi.dst) +} + +func TestPGInterceptor_ensureNextNBytes(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("invalid n", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 8) + require.NoError(t, err) + + require.EqualError(t, pgi.ensureNextNBytes(-1), + "invalid number of bytes -1 for buffer size 8") + require.EqualError(t, pgi.ensureNextNBytes(9), + "invalid number of bytes 9 for buffer size 8") + }) + + t.Run("buffer already has n bytes", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 8) + require.NoError(t, err) + + // Read "foo" into buffer". + require.NoError(t, pgi.ensureNextNBytes(3)) + + // These should not read anything since we expect the buffer to + // have three bytes. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 6, buf.Len()) + require.NoError(t, pgi.ensureNextNBytes(0)) + require.Equal(t, 6, buf.Len()) + require.NoError(t, pgi.ensureNextNBytes(1)) + require.Equal(t, 6, buf.Len()) + + // Verify that buf actually has "foo". + require.Equal(t, "foo", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) + + t.Run("bytes are realigned", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazcar") + + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 9) + require.NoError(t, err) + + // Read "foobarb" into buffer. + require.NoError(t, pgi.ensureNextNBytes(7)) + + // Assume "foobar" is read. + pgi.readPos += 6 + + // Now ensure that we have 6 bytes. + require.NoError(t, pgi.ensureNextNBytes(6)) + require.Equal(t, 0, buf.Len()) + + // Verify that buf has "bazcar". + require.Equal(t, "bazcar", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) + + t.Run("bytes are read greedily", func(t *testing.T) { + // This tests that we read as much as we can into the internal buffer + // if there was a Read call. + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + // Request for only 1 byte. + require.NoError(t, pgi.ensureNextNBytes(1)) + + // Verify that buf has "foobarbaz". + require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) + + // Should be a no-op. + _, err = buf.WriteString("car") + require.NoError(t, err) + require.NoError(t, pgi.ensureNextNBytes(9)) + require.Equal(t, 3, buf.Len()) + require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) + }) +} + +func TestPGInterceptor_PeekMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("read error", func(t *testing.T) { + r := iotest.ErrReader(errors.New("read error")) + + pgi, err := newPgInterceptor(r, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, "read error") + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("protocol error", func(t *testing.T) { + data := make([]byte, 10) + buf := new(bytes.Buffer) + _, err := buf.Write(data) + require.NoError(t, err) + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.EqualError(t, err, ErrProtocolError.Error()) + require.Equal(t, byte(0), typ) + require.Equal(t, 0, size) + }) + + t.Run("successful", func(t *testing.T) { + buf := new(bytes.Buffer) + _, err := buf.Write((&pgproto3.Query{String: "SELECT 1"}).Encode(nil)) + require.NoError(t, err) + + pgi, err := newPgInterceptor(buf, nil /* dst */, 10) + require.NoError(t, err) + + typ, size, err := pgi.PeekMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + require.Equal(t, 9, size) + require.Equal(t, 4, buf.Len()) + + // Invoking Peek should not advance the interceptor. + typ, size, err = pgi.PeekMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + require.Equal(t, 9, size) + require.Equal(t, 4, buf.Len()) + }) +} + +func TestPGInterceptor_WriteMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + n, err := pgi.WriteMsg([]byte{}) + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, 0, n) + }) + + t.Run("write error", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, &errReadWriter{w: io.Discard}, 10) + require.NoError(t, err) + + n, err := pgi.WriteMsg([]byte{}) + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, 0, n) + require.True(t, pgi.closed) + }) + + t.Run("successful", func(t *testing.T) { + buf := new(bytes.Buffer) + pgi, err := newPgInterceptor(nil /* src */, buf, 10) + require.NoError(t, err) + + n, err := pgi.WriteMsg([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.False(t, pgi.closed) + require.Equal(t, "hello", buf.String()) + }) +} + +func TestPGInterceptor_ReadMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + typ, body, err := pgi.ReadMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, byte(0), typ) + require.Nil(t, body) + }) + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() + src := new(bytes.Buffer) + for i := 0; i < count; i++ { + // Alternate between SELECT 1 and 2 to ensure correctness. + if i%2 == 0 { + q[12] = '1' + } else { + q[12] = '2' + } + _, err := src.Write(q) + require.NoError(t, err) + } + return src + } + + t.Run("message fits", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + buf := buildSrc(t, count) + + // Set buffer's size to be a multiple of the message so that we'll + // always hit the case where the message fits. + pgi, err := newPgInterceptor(buf, nil /* dst */, len(q)*3) + require.NoError(t, err) + + c := 0 + n := testing.AllocsPerRun(count-1, func() { + typ, body, err := pgi.ReadMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + + expectedStr := "SELECT 1\x00" + if c%2 == 1 { + expectedStr = "SELECT 2\x00" + } + + // Using require.Equal here will result in 2 allocs. + if string(body) != expectedStr { + t.Fatalf(`expected %q, got: %q`, expectedStr, string(body)) + } + c++ + }) + require.Equal(t, float64(0), n, "should not allocate") + require.Equal(t, 0, buf.Len()) + }) + + t.Run("message overflows", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + buf := buildSrc(t, count) + + // Set the buffer to be large enough to fit more bytes than the header, + // but not the entire message. + pgi, err := newPgInterceptor(buf, nil /* dst */, 7) + require.NoError(t, err) + + c := 0 + n := testing.AllocsPerRun(count-1, func() { + typ, body, err := pgi.ReadMsg() + require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + + expectedStr := "SELECT 1\x00" + if c%2 == 1 { + expectedStr = "SELECT 2\x00" + } + + // Using require.Equal here will result in 2 allocs. + if string(body) != expectedStr { + t.Fatalf(`expected %q, got: %q`, expectedStr, string(body)) + } + c++ + }) + // Ensure that we only have 1 allocation. We could technically improve + // this by ensuring that one pool of memory is used to reduce the number + // of allocations, but ReadMsg is only called during a transfer session, + // so there's very little benefit to optimizing for that. + require.Equal(t, float64(1), n) + require.Equal(t, 0, buf.Len()) + }) + + t.Run("read error after allocate", func(t *testing.T) { + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + buf := new(bytes.Buffer) + _, err := buf.Write(q) + require.NoError(t, err) + + src := &errReadWriter{r: buf, count: 2} + pgi, err := newPgInterceptor(src, nil /* dst */, 6) + require.NoError(t, err) + + typ, body, err := pgi.ReadMsg() + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, byte(0), typ) + require.Nil(t, body) + + // Ensure that interceptor is closed. + require.True(t, pgi.closed) + require.Equal(t, 8, buf.Len()) + }) +} + +func TestPGInterceptor_ForwardMsg(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + pgi.Close() + + n, err := pgi.ForwardMsg() + require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.Equal(t, 0, n) + }) + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() + src := new(bytes.Buffer) + for i := 0; i < count; i++ { + // Alternate between SELECT 1 and 2 to ensure correctness. + if i%2 == 0 { + q[12] = '1' + } else { + q[12] = '2' + } + _, err := src.Write(q) + require.NoError(t, err) + } + return src + } + + validateDst := func(t *testing.T, dst io.Reader, count int) { + t.Helper() + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(dst), nil /* w */) + for i := 0; i < count; i++ { + msg, err := backend.Receive() + require.NoError(t, err) + q := msg.(*pgproto3.Query) + + expectedStr := "SELECT 1" + if i%2 == 1 { + expectedStr = "SELECT 2" + } + require.Equal(t, expectedStr, q.String) + } + } + + t.Run("message fits", func(t *testing.T) { + const count = 101 // Inclusive of warm-up run in AllocsPerRun. + + src := buildSrc(t, count) + dst := new(bytes.Buffer) + + // Set buffer's size to be a multiple of the message so that we'll + // always hit the case where the message fits. + pgi, err := newPgInterceptor(src, dst, len(q)*3) + require.NoError(t, err) + + // Forward all the messages, and ensure 0 allocations. + n := testing.AllocsPerRun(count-1, func() { + n, err := pgi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + }) + require.Equal(t, float64(0), n, "should not allocate") + require.Equal(t, 0, src.Len()) + + // Validate messages. + validateDst(t, dst, count) + require.Equal(t, 0, dst.Len()) + }) + + t.Run("message overflows", func(t *testing.T) { + const count = 151 // Inclusive of warm-up run in AllocsPerRun. + + src := buildSrc(t, count) + dst := new(bytes.Buffer) + + // Set the buffer to be large enough to fit more bytes than the header, + // but not the entire message. + pgi, err := newPgInterceptor(src, dst, 7) + require.NoError(t, err) + + n := testing.AllocsPerRun(count-1, func() { + n, err := pgi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + }) + // NOTE: This allocation is benign, and is due to the fact that io.CopyN + // allocates an internal buffer in copyBuffer. This wouldn't happen if + // a TCP connection is used as the destination since there's a fast-path + // that prevents that. + // + // See: https://cs.opensource.google/go/go/+/refs/tags/go1.17.6:src/io/io.go;l=402-410;drc=refs%2Ftags%2Fgo1.17.6 + require.Equal(t, float64(1), n) + require.Equal(t, 0, src.Len()) + + // Validate messages. + validateDst(t, dst, count) + require.Equal(t, 0, dst.Len()) + }) + + t.Run("write error", func(t *testing.T) { + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + pgi, err := newPgInterceptor(src, &errReadWriter{w: dst, count: 2}, 6) + require.NoError(t, err) + + n, err := pgi.ForwardMsg() + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, 6, n) + + // Ensure that interceptor is closed. + require.True(t, pgi.closed) + require.Equal(t, 6, dst.Len()) + }) +} + +func TestPGInterceptor_ReadSize(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbaz") + + pgi, err := newPgInterceptor(buf, nil /* dst */, 9) + require.NoError(t, err) + require.NoError(t, pgi.ensureNextNBytes(1)) + + require.Equal(t, 9, pgi.ReadSize()) + pgi.Close() + require.Equal(t, 0, pgi.ReadSize()) + }) + + t.Run("valid", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazz") + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) + require.NoError(t, err) + + // No reads to internal buffer. + require.Equal(t, 0, pgi.ReadSize()) + + // Attempt reads to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 3, pgi.ReadSize()) + + // Read until buffer is full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 10, pgi.ReadSize()) + }) +} + +func TestPGInterceptor_WriteSize(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("interceptor is closed", func(t *testing.T) { + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 9) + require.NoError(t, err) + + require.Equal(t, 9, pgi.WriteSize()) + pgi.Close() + require.Equal(t, 0, pgi.WriteSize()) + }) + + t.Run("valid", func(t *testing.T) { + buf := bytes.NewBufferString("foobarbazz") + pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) + require.NoError(t, err) + + // No writes to internal buffer. + require.Equal(t, 10, pgi.WriteSize()) + + // Attempt writes to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 7, pgi.WriteSize()) + + // Attempt more writes to buffer until full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 0, pgi.WriteSize()) + }) +} + +func TestPGInterceptor_Close(t *testing.T) { + defer leaktest.AfterTest(t)() + pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) + require.NoError(t, err) + require.False(t, pgi.closed) + pgi.Close() + require.True(t, pgi.closed) +} + +// TestFrontendInterceptor tests the FrontendInterceptor. Note that the tests +// here are shallow. For detailed ones, see the tests for the internal +// interceptor. +func TestFrontendInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + q := (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil) + + t.Run("bufSize too small", func(t *testing.T) { + fi, err := NewFrontendInterceptor(nil /* src */, nil /* dst */, 1) + require.Error(t, err) + require.Nil(t, fi) + }) + + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + fi, err := NewFrontendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + typ, size, err := fi.PeekMsg() + require.NoError(t, err) + require.Equal(t, pgwirebase.ServerMsgReady, typ) + require.Equal(t, 1, size) + }) + + t.Run("WriteMsg writes data to dst", func(t *testing.T) { + dst := new(bytes.Buffer) + fi, err := NewFrontendInterceptor(nil /* src */, dst, 10) + require.NoError(t, err) + require.NotNil(t, fi) + + // This is a frontend interceptor, so writing goes to the client. + n, err := fi.WriteMsg(&pgproto3.ReadyForQuery{TxStatus: 'I'}) + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, 6, dst.Len()) + }) + + t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + fi, err := NewFrontendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + msg, err := fi.ReadMsg() + require.NoError(t, err) + rmsg, ok := msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + require.Equal(t, byte('I'), rmsg.TxStatus) + }) + + t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + fi, err := NewFrontendInterceptor(src, dst, 16) + require.NoError(t, err) + require.NotNil(t, fi) + + n, err := fi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 6, n) + require.Equal(t, 6, dst.Len()) + }) +} + +// TestBackendInterceptor tests the BackendInterceptor. Note that the tests +// here are shallow. For detailed ones, see the tests for the internal +// interceptor. +func TestBackendInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + + q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + + t.Run("bufSize too small", func(t *testing.T) { + bi, err := NewBackendInterceptor(nil /* src */, nil /* dst */, 1) + require.Error(t, err) + require.Nil(t, bi) + }) + + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + bi, err := NewBackendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + typ, size, err := bi.PeekMsg() + require.NoError(t, err) + require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ) + require.Equal(t, 9, size) + }) + + t.Run("WriteMsg writes data to dst", func(t *testing.T) { + dst := new(bytes.Buffer) + bi, err := NewBackendInterceptor(nil /* src */, dst, 10) + require.NoError(t, err) + require.NotNil(t, bi) + + // This is a backend interceptor, so writing goes to the server. + n, err := bi.WriteMsg(&pgproto3.Query{String: "SELECT 1"}) + require.NoError(t, err) + require.Equal(t, 14, n) + require.Equal(t, 14, dst.Len()) + }) + + t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + + bi, err := NewBackendInterceptor(src, nil /* dst */, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + msg, err := bi.ReadMsg() + require.NoError(t, err) + rmsg, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SELECT 1", rmsg.String) + }) + + t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { + src := new(bytes.Buffer) + _, err := src.Write(q) + require.NoError(t, err) + dst := new(bytes.Buffer) + + bi, err := NewBackendInterceptor(src, dst, 16) + require.NoError(t, err) + require.NotNil(t, bi) + + n, err := bi.ForwardMsg() + require.NoError(t, err) + require.Equal(t, 14, n) + require.Equal(t, 14, dst.Len()) + }) +} + +// TestSimpleProxy illustrates how the frontend and backend interceptors can be +// used as a proxy. +func TestSimpleProxy(t *testing.T) { + defer leaktest.AfterTest(t)() + + const bufferSize = 16 + + // These represents connections for client<->proxy and proxy<->server. + fromClient := new(bytes.Buffer) + toClient := new(bytes.Buffer) + fromServer := new(bytes.Buffer) + toServer := new(bytes.Buffer) + + // Create client and server interceptors. + clientInt, err := NewBackendInterceptor(fromClient, toServer, bufferSize) + require.NoError(t, err) + serverInt, err := NewFrontendInterceptor(fromServer, toClient, bufferSize) + require.NoError(t, err) + + t.Run("client to server", func(t *testing.T) { + // Client sends a list of SQL queries. + queries := []pgproto3.FrontendMessage{ + &pgproto3.Query{String: "SELECT 1"}, + &pgproto3.Query{String: "SELECT * FROM foo.bar"}, + &pgproto3.Query{String: "UPDATE foo SET x = 42"}, + &pgproto3.Sync{}, + &pgproto3.Terminate{}, + } + for _, msg := range queries { + _, err := fromClient.Write(msg.Encode(nil)) + require.NoError(t, err) + } + totalBytes := fromClient.Len() + + customQuery := &pgproto3.Query{ + String: "SELECT * FROM crdb_internal.serialize_session()"} + + for { + typ, _, err := clientInt.PeekMsg() + require.NoError(t, err) + + // Forward message to server. + _, err = clientInt.ForwardMsg() + require.NoError(t, err) + + if typ == pgwirebase.ClientMsgTerminate { + // Right before we terminate, we could also craft a custom + // message, and send it to the server. + _, err := clientInt.WriteMsg(customQuery) + require.NoError(t, err) + break + } + } + require.Equal(t, 0, fromClient.Len()) + require.Equal(t, totalBytes+len(customQuery.Encode(nil)), toServer.Len()) + }) + + t.Run("server to client", func(t *testing.T) { + // Server sends back responses. + queries := []pgproto3.BackendMessage{ + // Forward these back to the client. + &pgproto3.CommandComplete{CommandTag: []byte("averylongstring")}, + &pgproto3.BackendKeyData{ProcessID: 100, SecretKey: 42}, + // Do not forward back to the client. + &pgproto3.CommandComplete{CommandTag: []byte("short")}, + // Terminator. + &pgproto3.ReadyForQuery{}, + } + for _, msg := range queries { + _, err := fromServer.Write(msg.Encode(nil)) + require.NoError(t, err) + } + // Exclude bytes from second message. + totalBytes := fromServer.Len() - len(queries[2].Encode(nil)) + + for { + typ, size, err := serverInt.PeekMsg() + require.NoError(t, err) + + switch typ { + case pgwirebase.ServerMsgCommandComplete: + // Assuming that we're only interested in small messages, then + // we could skip all the large ones. + if size > 12 { + _, err := serverInt.ForwardMsg() + require.NoError(t, err) + continue + } + + // Decode message. + msg, err := serverInt.ReadMsg() + require.NoError(t, err) + + // Once we've decoded the message, we could store the message + // somewhere, and not forward it back to the client. + dmsg, ok := msg.(*pgproto3.CommandComplete) + require.True(t, ok) + require.Equal(t, "short", string(dmsg.CommandTag)) + case pgwirebase.ServerMsgBackendKeyData: + msg, err := serverInt.ReadMsg() + require.NoError(t, err) + + dmsg, ok := msg.(*pgproto3.BackendKeyData) + require.True(t, ok) + + // We could even rewrite the message before sending it back to + // the client. + dmsg.SecretKey = 100 + + _, err = serverInt.WriteMsg(dmsg) + require.NoError(t, err) + default: + // Forward message that we're not interested to the client. + _, err := serverInt.ForwardMsg() + require.NoError(t, err) + } + + if typ == pgwirebase.ServerMsgReady { + break + } + } + require.Equal(t, 0, fromServer.Len()) + require.Equal(t, totalBytes, toClient.Len()) + }) +} + +var _ io.Reader = &errReadWriter{} +var _ io.Writer = &errReadWriter{} + +// errReadWriter returns io.ErrClosedPipe after count reads or writes in total. +type errReadWriter struct { + r io.Reader + w io.Writer + count int +} + +// Read implements the io.Reader interface. +func (rw *errReadWriter) Read(p []byte) (int, error) { + rw.count-- + if rw.count <= 0 { + return 0, io.ErrClosedPipe + } + return rw.r.Read(p) +} + +// Write implements the io.Writer interface. +func (rw *errReadWriter) Write(p []byte) (int, error) { + rw.count-- + if rw.count <= 0 { + return 0, io.ErrClosedPipe + } + return rw.w.Write(p) +}