diff --git a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go index 74866d3b68de..0611b267a43e 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go +++ b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go @@ -18,14 +18,10 @@ import ( // BackendInterceptor is a server int/erceptor for the Postgres backend protocol. type BackendInterceptor pgInterceptor -// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least -// the size of a pgwire message header. -func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) { - pgi, err := newPgInterceptor(src, dst, bufSize) - if err != nil { - return nil, err - } - return (*BackendInterceptor)(pgi), nil +// NewBackendInterceptor creates a BackendInterceptor using the default buffer +// size of 8K bytes. +func NewBackendInterceptor(src io.Reader) *BackendInterceptor { + return (*BackendInterceptor)(newPgInterceptor(src, defaultBufferSize)) } // PeekMsg returns the header of the current pgwire message without advancing @@ -37,13 +33,6 @@ func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size return pgwirebase.ClientMessageType(byteType), size, err } -// WriteMsg writes the given bytes to the writer dst. -// -// See pgInterceptor.WriteMsg for more information. -func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) { - return (*pgInterceptor)(bi).WriteMsg(data.Encode(nil)) -} - // ReadMsg decodes the current pgwire message and returns a FrontendMessage. // This also advances the interceptor to the next message. // @@ -53,19 +42,14 @@ func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error if err != nil { return nil, err } - // errPanicWriter is used here because Receive must not Write. - return pgproto3.NewBackend(newChunkReader(msgBytes), &errPanicWriter{}).Receive() + // errWriter is used here because Receive must not Write. + return pgproto3.NewBackend(newChunkReader(msgBytes), &errWriter{}).Receive() } // ForwardMsg sends the current pgwire message to the destination without any // decoding, and advances the interceptor to the next message. // // See pgInterceptor.ForwardMsg for more information. -func (bi *BackendInterceptor) ForwardMsg() (n int, err error) { - return (*pgInterceptor)(bi).ForwardMsg() -} - -// Close closes the interceptor, and prevents further operations on it. -func (bi *BackendInterceptor) Close() { - (*pgInterceptor)(bi).Close() +func (bi *BackendInterceptor) ForwardMsg(dst io.Writer) (n int, err error) { + return (*pgInterceptor)(bi).ForwardMsg(dst) } diff --git a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go index 0e442b9f1e88..50fba567d0a0 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go @@ -27,59 +27,30 @@ func TestBackendInterceptor(t *testing.T) { q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - t.Run("bufSize too small", func(t *testing.T) { - bi, err := interceptor.NewBackendInterceptor(nil /* src */, nil /* dst */, 1) - require.Error(t, err) - require.Nil(t, bi) - }) - - t.Run("PeekMsg returns the right message type", func(t *testing.T) { + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() src := new(bytes.Buffer) _, err := src.Write(q) require.NoError(t, err) + return src + } - bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16) - require.NoError(t, err) + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := buildSrc(t, 1) + + bi := interceptor.NewBackendInterceptor(src) require.NotNil(t, bi) typ, size, err := bi.PeekMsg() require.NoError(t, err) require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ) require.Equal(t, 14, size) - - bi.Close() - typ, size, err = bi.PeekMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, pgwirebase.ClientMessageType(0), typ) - require.Equal(t, 0, size) - }) - - t.Run("WriteMsg writes data to dst", func(t *testing.T) { - dst := new(bytes.Buffer) - bi, err := interceptor.NewBackendInterceptor(nil /* src */, dst, 10) - require.NoError(t, err) - require.NotNil(t, bi) - - // This is a backend interceptor, so writing goes to the server. - toSend := &pgproto3.Query{String: "SELECT 1"} - n, err := bi.WriteMsg(toSend) - require.NoError(t, err) - require.Equal(t, 14, n) - require.Equal(t, 14, dst.Len()) - - bi.Close() - n, err = bi.WriteMsg(toSend) - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, 0, n) }) t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { - src := new(bytes.Buffer) - _, err := src.Write(q) - require.NoError(t, err) + src := buildSrc(t, 1) - bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(src) require.NotNil(t, bi) msg, err := bi.ReadMsg() @@ -87,31 +58,18 @@ func TestBackendInterceptor(t *testing.T) { rmsg, ok := msg.(*pgproto3.Query) require.True(t, ok) require.Equal(t, "SELECT 1", rmsg.String) - - bi.Close() - msg, err = bi.ReadMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Nil(t, msg) }) t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { - src := new(bytes.Buffer) - _, err := src.Write(q) - require.NoError(t, err) + src := buildSrc(t, 1) dst := new(bytes.Buffer) - bi, err := interceptor.NewBackendInterceptor(src, dst, 16) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(src) require.NotNil(t, bi) - n, err := bi.ForwardMsg() + n, err := bi.ForwardMsg(dst) require.NoError(t, err) require.Equal(t, 14, n) require.Equal(t, 14, dst.Len()) - - bi.Close() - n, err = bi.ForwardMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, 0, n) }) } diff --git a/pkg/ccl/sqlproxyccl/interceptor/base.go b/pkg/ccl/sqlproxyccl/interceptor/base.go index c933b270ed2e..c59945cb0b3c 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base.go @@ -22,14 +22,11 @@ import ( // length itself). const pgHeaderSizeBytes = 5 -// ErrSmallBuffer indicates that the requested buffer for the interceptor is -// too small. -var ErrSmallBuffer = errors.New("buffer is too small") - -// ErrInterceptorClosed is the returned error whenever the intercept is closed. -// When this happens, the caller should terminate both dst and src to guarantee -// correctness. -var ErrInterceptorClosed = errors.New("interceptor is closed") +// defaultBufferSize is the default buffer size for the interceptor. 8K was +// chosen to match Postgres' send and receive buffer sizes. +// +// See: https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134-L135. +const defaultBufferSize = 2 << 13 // 8K // ErrProtocolError indicates that the packets are malformed, and are not as // expected. @@ -43,7 +40,6 @@ type pgInterceptor struct { _ util.NoCopy src io.Reader - dst io.Writer // buf stores buffered bytes from src. This may contain one or more pgwire // messages, and messages may be partially buffered. @@ -61,27 +57,21 @@ type pgInterceptor struct { // was partially buffered, the interceptor will handle that case before // resetting readPos and writePos to 0. readPos, writePos int - - // closed indicates that the interceptor is closed. This will be set to - // true whenever there's an error within one of the interceptor's operations - // leading to an ambiguity. Once an interceptor is closed, all subsequent - // method calls on the interceptor will return ErrInterceptorClosed. - closed bool } // newPgInterceptor creates a new instance of the interceptor with an internal -// buffer of bufSize bytes. bufSize must be at least the size of a pgwire -// message header. -func newPgInterceptor(src io.Reader, dst io.Writer, bufSize int) (*pgInterceptor, error) { - // The internal buffer must be able to fit the header. +// buffer of bufSize bytes. If bufSize is smaller than 5 bytes, the interceptor +// will default to an 8K buffer size. +func newPgInterceptor(src io.Reader, bufSize int) *pgInterceptor { + // The internal buffer must be able to fit the header. If bufSize is smaller + // than 5 bytes, just default to 8K, or else the interceptor is unusable. if bufSize < pgHeaderSizeBytes { - return nil, ErrSmallBuffer + bufSize = defaultBufferSize } return &pgInterceptor{ src: src, - dst: dst, buf: make([]byte, bufSize), - }, nil + } } // PeekMsg returns the header of the current pgwire message without advancing @@ -90,12 +80,11 @@ func newPgInterceptor(src io.Reader, dst io.Writer, bufSize int) (*pgInterceptor // includes the header type and body length. This will return ErrProtocolError // if the packets are malformed. // -// If the interceptor is closed, PeekMsg returns ErrInterceptorClosed. +// If err != nil, we are safe to reuse the interceptor. In the case of +// ErrProtocolError, the interceptor is still usable, though calls to ReadMsg +// and ForwardMsg will return an error. The bytes are still in the buffer, so +// the only way is to abandon the interceptor. func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { - if p.closed { - return 0, 0, ErrInterceptorClosed - } - if err := p.ensureNextNBytes(pgHeaderSizeBytes); err != nil { // Possibly due to a timeout or context cancellation. return 0, 0, err @@ -116,43 +105,23 @@ func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { return typ, size + 1, nil } -// WriteMsg writes the given bytes to the writer dst. If err != nil and a Write -// was attempted, the interceptor will be closed. -// -// If the interceptor is closed, WriteMsg returns ErrInterceptorClosed. -func (p *pgInterceptor) WriteMsg(data []byte) (n int, err error) { - if p.closed { - return 0, ErrInterceptorClosed - } - defer func() { - // Close the interceptor if there was an error. Theoretically, we only - // need to close here if n > 0, but for consistency with the other - // methods, we will do that here too. - if err != nil { - p.Close() - } - }() - return p.dst.Write(data) -} - // ReadMsg returns the current pgwire message in bytes. It also advances the // interceptor to the next message. On return, the msg field is valid if and -// only if err == nil. If err != nil and a Read was attempted because the buffer -// did not have enough bytes, the interceptor will be closed. +// only if err == nil. // // The interceptor retains ownership of all the memory returned by ReadMsg; the // caller is allowed to hold on to this memory *until* the next moment other // methods on the interceptor are called. The data will only be valid until then -// as well. +// as well. This may allocate if the message does not fit into the internal +// buffer, so use with care. If we are using this with the intention of sending +// it to another connection, we should use ForwardMsg, which does not allocate. // -// If the interceptor is closed, ReadMsg returns ErrInterceptorClosed. +// WARNING: If err != nil, the caller should abandon the interceptor, as we may +// be in a corrupted state. This invokes PeekMsg under the hood to know the +// message length. One optimization that could be done is to invoke PeekMsg +// manually first before calling this to ensure that we do not return errors +// when peeking during ReadMsg. func (p *pgInterceptor) ReadMsg() (msg []byte, err error) { - // Technically this is redundant since PeekMsg will do the same thing, but - // we do so here for clarity. - if p.closed { - return nil, ErrInterceptorClosed - } - // Peek header of the current message for message size. _, size, err := p.PeekMsg() if err != nil { @@ -179,19 +148,6 @@ func (p *pgInterceptor) ReadMsg() (msg []byte, err error) { n := copy(msg, p.buf[p.readPos:p.writePos]) p.readPos += n - defer func() { - // Close the interceptor because we read the data (both buffered and - // possibly newer ones) into msg, which is larger than the buffer's - // size, and there's no easy way to recover. We could technically fix - // some of the situations, especially when no bytes were read, but at - // this point, it's likely that the one end of the interceptor is - // already gone, or the proxy is shutting down, so there's no point - // trying to save a disconnect. - if err != nil { - p.Close() - } - }() - // Read more bytes. if _, err := io.ReadFull(p.src, msg[n:]); err != nil { return nil, err @@ -200,20 +156,17 @@ func (p *pgInterceptor) ReadMsg() (msg []byte, err error) { return msg, nil } -// ForwardMsg sends the current pgwire message to the destination, and advances -// the interceptor to the next message. On return, n == pgwire message size if -// and only if err == nil. If err != nil and a Write was attempted, the -// interceptor will be closed. +// ForwardMsg sends the current pgwire message to dst, and advances the +// interceptor to the next message. On return, n == pgwire message size if +// and only if err == nil. // -// If the interceptor is closed, ForwardMsg returns ErrInterceptorClosed. -func (p *pgInterceptor) ForwardMsg() (n int, err error) { - // Technically this is redundant since PeekMsg will do the same thing, but - // we do so here for clarity. - if p.closed { - return 0, ErrInterceptorClosed - } - - // Retrieve header of the current message for message size. +// WARNING: If err != nil, the caller should abandon the interceptor or dst, as +// we may be in a corrupted state. This invokes PeekMsg under the hood to know +// the message length. One optimization that could be done is to invoke PeekMsg +// manually first before calling this to ensure that we do not return errors +// when peeking during ForwardMsg. +func (p *pgInterceptor) ForwardMsg(dst io.Writer) (n int, err error) { + // Retrieve header of the current message for body size. _, size, err := p.PeekMsg() if err != nil { return 0, err @@ -229,21 +182,8 @@ func (p *pgInterceptor) ForwardMsg() (n int, err error) { } p.readPos = endPos - defer func() { - // State may be invalid depending on whether bytes have been written. - // To reduce complexity, we'll just close the interceptor, and the - // caller should just terminate both ends. - // - // If src has been closed, the dst state may be invalid. If dst has been - // closed, buffered bytes no longer represent the protocol correctly - // even if we slurped the remaining bytes for the current message. - if err != nil { - p.Close() - } - }() - // Forward the message to the destination. - n, err = p.dst.Write(p.buf[startPos:endPos]) + n, err = dst.Write(p.buf[startPos:endPos]) if err != nil { return n, err } @@ -256,7 +196,7 @@ func (p *pgInterceptor) ForwardMsg() (n int, err error) { // Message was partially buffered, so copy the remaining. if remainingBytes > 0 { - m, err := io.CopyN(p.dst, p.src, int64(remainingBytes)) + m, err := io.CopyN(dst, p.src, int64(remainingBytes)) n += int(m) if err != nil { return n, err @@ -271,27 +211,14 @@ func (p *pgInterceptor) ForwardMsg() (n int, err error) { return n, nil } -// Close closes the interceptor, and prevents further operations on it. -func (p *pgInterceptor) Close() { - p.closed = true -} - -// readSize returns the number of bytes read by the interceptor. If the -// interceptor is closed, this will return 0. +// readSize returns the number of bytes read by the interceptor. func (p *pgInterceptor) readSize() int { - if p.closed { - return 0 - } return p.writePos - p.readPos } // writeSize returns the remaining number of bytes that could fit into the -// internal buffer before needing to be re-aligned. If the interceptor is -// closed, this will return 0. +// internal buffer before needing to be re-aligned. func (p *pgInterceptor) writeSize() int { - if p.closed { - return 0 - } return len(p.buf) - p.writePos } @@ -319,12 +246,14 @@ func (p *pgInterceptor) ensureNextNBytes(n int) error { return err } -var _ io.Writer = &errPanicWriter{} +var _ io.Writer = &errWriter{} -// errPanicWriter is an io.Writer that panics whenever a Write call is made. -type errPanicWriter struct{} +// errWriter is an io.Writer that fails whenever a Write call is made. This is +// used within ReadMsg for both BackendInterceptor and FrontendInterceptor. +// Since it's just a Read, Write calls should not be made. +type errWriter struct{} // Write implements the io.Writer interface. -func (w *errPanicWriter) Write(p []byte) (int, error) { - panic("unexpected Write call") +func (w *errWriter) Write(p []byte) (int, error) { + return 0, errors.AssertionFailedf("unexpected Write call") } diff --git a/pkg/ccl/sqlproxyccl/interceptor/base_test.go b/pkg/ccl/sqlproxyccl/interceptor/base_test.go index 295d5b87861c..aca49e46f97b 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base_test.go @@ -27,52 +27,31 @@ import ( func TestNewPgInterceptor(t *testing.T) { defer leaktest.AfterTest(t)() - reader, writer := io.Pipe() - - // Negative buffer size. - pgi, err := newPgInterceptor(reader, writer, -1) - require.EqualError(t, err, ErrSmallBuffer.Error()) - require.Nil(t, pgi) - - // Small buffer size. - pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes-1) - require.EqualError(t, err, ErrSmallBuffer.Error()) - require.Nil(t, pgi) - - // Buffer that fits the header exactly. - pgi, err = newPgInterceptor(reader, writer, pgHeaderSizeBytes) - require.NoError(t, err) - require.NotNil(t, pgi) - require.Len(t, pgi.buf, pgHeaderSizeBytes) - - // Normal buffer size. - pgi, err = newPgInterceptor(reader, writer, 1024) - require.NoError(t, err) - require.NotNil(t, pgi) - require.Len(t, pgi.buf, 1024) - require.Equal(t, reader, pgi.src) - require.Equal(t, writer, pgi.dst) + reader, _ := io.Pipe() + + for _, tc := range []struct { + bufSize int + normalizedBufSize int + }{ + {-1, defaultBufferSize}, + {pgHeaderSizeBytes - 1, defaultBufferSize}, + {pgHeaderSizeBytes, pgHeaderSizeBytes}, + {1024, 1024}, + } { + pgi := newPgInterceptor(reader, tc.bufSize) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, tc.normalizedBufSize) + require.Equal(t, reader, pgi.src) + } } func TestPGInterceptor_PeekMsg(t *testing.T) { defer leaktest.AfterTest(t)() - t.Run("interceptor is closed", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) - require.NoError(t, err) - pgi.Close() - - typ, size, err := pgi.PeekMsg() - require.EqualError(t, err, ErrInterceptorClosed.Error()) - require.Equal(t, byte(0), typ) - require.Equal(t, 0, size) - }) - - t.Run("read error", func(t *testing.T) { + t.Run("read_error", func(t *testing.T) { r := iotest.ErrReader(errors.New("read error")) - pgi, err := newPgInterceptor(r, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(r, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, "read error") @@ -80,15 +59,14 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { require.Equal(t, 0, size) }) - t.Run("protocol error/size=0", func(t *testing.T) { + t.Run("protocol_error/length=0", func(t *testing.T) { var data [10]byte buf := new(bytes.Buffer) _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -96,7 +74,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { require.Equal(t, 0, size) }) - t.Run("protocol error/size=3", func(t *testing.T) { + t.Run("protocol_error/length=3", func(t *testing.T) { var data [5]byte binary.BigEndian.PutUint32(data[1:5], uint32(3)) @@ -104,8 +82,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -113,7 +90,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { require.Equal(t, 0, size) }) - t.Run("protocol error/size=math.MaxInt32", func(t *testing.T) { + t.Run("protocol_error/length=math.MaxInt32", func(t *testing.T) { var data [5]byte binary.BigEndian.PutUint32(data[1:5], uint32(math.MaxInt32)) @@ -121,8 +98,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -130,8 +106,8 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { require.Equal(t, 0, size) }) - t.Run("successful without body", func(t *testing.T) { - // Use 4 bytes to indicate no body. + t.Run("successful/length=4", func(t *testing.T) { + // Only write 5 bytes (without body) var data [5]byte data[0] = 'A' binary.BigEndian.PutUint32(data[1:5], uint32(4)) @@ -140,8 +116,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 5 /* bufSize */) typ, size, err := pgi.PeekMsg() require.NoError(t, err) @@ -151,105 +126,76 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { }) t.Run("successful", func(t *testing.T) { - buf := new(bytes.Buffer) - msgBytes := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - _, err := buf.Write(msgBytes) - require.NoError(t, err) + buf := buildSrc(t, 1) - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.NoError(t, err) require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) - require.Equal(t, len(msgBytes), size) + require.Equal(t, len(testSelect1Bytes), size) require.Equal(t, 4, buf.Len()) // Invoking Peek should not advance the interceptor. typ, size, err = pgi.PeekMsg() require.NoError(t, err) require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) - require.Equal(t, len(msgBytes), size) + require.Equal(t, len(testSelect1Bytes), size) require.Equal(t, 4, buf.Len()) }) } -func TestPGInterceptor_WriteMsg(t *testing.T) { +func TestPGInterceptor_ReadMsg(t *testing.T) { defer leaktest.AfterTest(t)() - t.Run("interceptor is closed", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) - require.NoError(t, err) - pgi.Close() + t.Run("read_error/msg_fits", func(t *testing.T) { + buf := buildSrc(t, 1) - n, err := pgi.WriteMsg([]byte{}) - require.EqualError(t, err, ErrInterceptorClosed.Error()) - require.Equal(t, 0, n) - }) + // Use a LimitReader to allow PeekMsg to read 5 bytes, then update src + // back to the original version. + src := &errReadWriter{r: buf, count: 2} + pgi := newPgInterceptor(io.LimitReader(src, 5), 32 /* bufSize */) - t.Run("write error", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, &errReadWriter{w: io.Discard}, 10) + // Call PeekMsg here to populate internal buffer with header. + typ, size, err := pgi.PeekMsg() require.NoError(t, err) + require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) + require.Equal(t, len(testSelect1Bytes), size) + require.Equal(t, 9, buf.Len()) + + // Update src back to the non LimitReader version. + pgi.src = src - n, err := pgi.WriteMsg([]byte{}) + // Now call ReadMsg. + msg, err := pgi.ReadMsg() require.EqualError(t, err, io.ErrClosedPipe.Error()) - require.Equal(t, 0, n) - require.True(t, pgi.closed) + require.Nil(t, msg) + require.Equal(t, 9, buf.Len()) }) - t.Run("successful", func(t *testing.T) { - buf := new(bytes.Buffer) - pgi, err := newPgInterceptor(nil /* src */, buf, 10) - require.NoError(t, err) - - n, err := pgi.WriteMsg([]byte("hello")) - require.NoError(t, err) - require.Equal(t, 5, n) - require.False(t, pgi.closed) - require.Equal(t, "hello", buf.String()) - }) -} + // When we overflow, ReadMsg will allocate. + t.Run("read_error/msg_overflows", func(t *testing.T) { + buf := buildSrc(t, 1) -func TestPGInterceptor_ReadMsg(t *testing.T) { - defer leaktest.AfterTest(t)() - - t.Run("interceptor is closed", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) - require.NoError(t, err) - pgi.Close() + // testSelect1Bytes has 14 bytes, but only 6 bytes within internal + // buffer, so overflow. + src := &errReadWriter{r: buf, count: 2} + pgi := newPgInterceptor(src, 6 /* bufSize */) msg, err := pgi.ReadMsg() - require.EqualError(t, err, ErrInterceptorClosed.Error()) + require.EqualError(t, err, io.ErrClosedPipe.Error()) require.Nil(t, msg) + require.Equal(t, 8, buf.Len()) }) - q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - - buildSrc := func(t *testing.T, count int) *bytes.Buffer { - t.Helper() - src := new(bytes.Buffer) - for i := 0; i < count; i++ { - // Alternate between SELECT 1 and 2 to ensure correctness. - if i%2 == 0 { - q[12] = '1' - } else { - q[12] = '2' - } - _, err := src.Write(q) - require.NoError(t, err) - } - return src - } - - t.Run("message fits", func(t *testing.T) { + t.Run("successful/msg_fits", func(t *testing.T) { const count = 101 // Inclusive of warm-up run in AllocsPerRun. buf := buildSrc(t, count) // Set buffer's size to be a multiple of the message so that we'll // always hit the case where the message fits. - pgi, err := newPgInterceptor(buf, nil /* dst */, len(q)*3) - require.NoError(t, err) + pgi := newPgInterceptor(buf, len(testSelect1Bytes)*3) c := 0 n := testing.AllocsPerRun(count-1, func() { @@ -275,15 +221,15 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { require.Equal(t, 0, buf.Len()) }) - t.Run("message overflows", func(t *testing.T) { + // When we overflow, ReadMsg will allocate. + t.Run("successful/msg_overflows", func(t *testing.T) { const count = 101 // Inclusive of warm-up run in AllocsPerRun. buf := buildSrc(t, count) // Set the buffer to be large enough to fit more bytes than the header, // but not the entire message. - pgi, err := newPgInterceptor(buf, nil /* dst */, 7) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 7 /* bufSize */) c := 0 n := testing.AllocsPerRun(count-1, func() { @@ -312,75 +258,45 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { require.Equal(t, float64(1), n) require.Equal(t, 0, buf.Len()) }) - - t.Run("read error after allocate", func(t *testing.T) { - q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - buf := new(bytes.Buffer) - _, err := buf.Write(q) - require.NoError(t, err) - - src := &errReadWriter{r: buf, count: 2} - pgi, err := newPgInterceptor(src, nil /* dst */, 6) - require.NoError(t, err) - - msg, err := pgi.ReadMsg() - require.EqualError(t, err, io.ErrClosedPipe.Error()) - require.Nil(t, msg) - - // Ensure that interceptor is closed. - require.True(t, pgi.closed) - require.Equal(t, 8, buf.Len()) - }) } func TestPGInterceptor_ForwardMsg(t *testing.T) { defer leaktest.AfterTest(t)() - t.Run("interceptor is closed", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) - require.NoError(t, err) - pgi.Close() + t.Run("write_error/fully_buffered", func(t *testing.T) { + src := buildSrc(t, 1) + dst := new(bytes.Buffer) + dstWriter := &errReadWriter{w: dst, count: 1} - n, err := pgi.ForwardMsg() - require.EqualError(t, err, ErrInterceptorClosed.Error()) + pgi := newPgInterceptor(src, 32 /* bufSize */) + + n, err := pgi.ForwardMsg(dstWriter) + require.EqualError(t, err, io.ErrClosedPipe.Error()) require.Equal(t, 0, n) + + // Managed to read everything, but could not write to dst. + require.Equal(t, 0, src.Len()) + require.Equal(t, 0, dst.Len()) + require.Equal(t, 0, pgi.readSize()) }) - q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - - buildSrc := func(t *testing.T, count int) *bytes.Buffer { - t.Helper() - src := new(bytes.Buffer) - for i := 0; i < count; i++ { - // Alternate between SELECT 1 and 2 to ensure correctness. - if i%2 == 0 { - q[12] = '1' - } else { - q[12] = '2' - } - _, err := src.Write(q) - require.NoError(t, err) - } - return src - } + t.Run("write_error/partially_buffered", func(t *testing.T) { + src := buildSrc(t, 1) + dst := new(bytes.Buffer) + dstWriter := &errReadWriter{w: dst, count: 2} - validateDst := func(t *testing.T, dst io.Reader, count int) { - t.Helper() - backend := pgproto3.NewBackend(pgproto3.NewChunkReader(dst), nil /* w */) - for i := 0; i < count; i++ { - msg, err := backend.Receive() - require.NoError(t, err) - q := msg.(*pgproto3.Query) + // testSelect1Bytes has 14 bytes, but only 6 bytes within internal + // buffer, so partially buffered. + pgi := newPgInterceptor(src, 6 /* bufSize */) - expectedStr := "SELECT 1" - if i%2 == 1 { - expectedStr = "SELECT 2" - } - require.Equal(t, expectedStr, q.String) - } - } + n, err := pgi.ForwardMsg(dstWriter) + require.EqualError(t, err, io.ErrClosedPipe.Error()) + require.Equal(t, 6, n) - t.Run("message fits", func(t *testing.T) { + require.Equal(t, 6, dst.Len()) + }) + + t.Run("successful/fully_buffered", func(t *testing.T) { const count = 101 // Inclusive of warm-up run in AllocsPerRun. src := buildSrc(t, count) @@ -388,12 +304,11 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { // Set buffer's size to be a multiple of the message so that we'll // always hit the case where the message fits. - pgi, err := newPgInterceptor(src, dst, len(q)*3) - require.NoError(t, err) + pgi := newPgInterceptor(src, len(testSelect1Bytes)*3) // Forward all the messages, and ensure 0 allocations. n := testing.AllocsPerRun(count-1, func() { - n, err := pgi.ForwardMsg() + n, err := pgi.ForwardMsg(dst) require.NoError(t, err) require.Equal(t, 14, n) }) @@ -405,7 +320,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { require.Equal(t, 0, dst.Len()) }) - t.Run("message overflows", func(t *testing.T) { + t.Run("successful/partially_buffered", func(t *testing.T) { const count = 151 // Inclusive of warm-up run in AllocsPerRun. src := buildSrc(t, count) @@ -413,11 +328,10 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { // Set the buffer to be large enough to fit more bytes than the header, // but not the entire message. - pgi, err := newPgInterceptor(src, dst, 7) - require.NoError(t, err) + pgi := newPgInterceptor(src, 7 /* bufSize */) n := testing.AllocsPerRun(count-1, func() { - n, err := pgi.ForwardMsg() + n, err := pgi.ForwardMsg(dst) require.NoError(t, err) require.Equal(t, 14, n) }) @@ -434,105 +348,49 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { validateDst(t, dst, count) require.Equal(t, 0, dst.Len()) }) - - t.Run("write error", func(t *testing.T) { - q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) - src := new(bytes.Buffer) - _, err := src.Write(q) - require.NoError(t, err) - dst := new(bytes.Buffer) - - pgi, err := newPgInterceptor(src, &errReadWriter{w: dst, count: 2}, 6) - require.NoError(t, err) - - n, err := pgi.ForwardMsg() - require.EqualError(t, err, io.ErrClosedPipe.Error()) - require.Equal(t, 6, n) - - // Ensure that interceptor is closed. - require.True(t, pgi.closed) - require.Equal(t, 6, dst.Len()) - }) } -func TestPGInterceptor_Close(t *testing.T) { +func TestPGInterceptor_readSize(t *testing.T) { defer leaktest.AfterTest(t)() - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 10) - require.NoError(t, err) - require.False(t, pgi.closed) - pgi.Close() - require.True(t, pgi.closed) -} - -func TestPGInterceptor_ReadSize(t *testing.T) { - defer leaktest.AfterTest(t)() - - t.Run("interceptor is closed", func(t *testing.T) { - buf := bytes.NewBufferString("foobarbaz") - pgi, err := newPgInterceptor(buf, nil /* dst */, 9) - require.NoError(t, err) - require.NoError(t, pgi.ensureNextNBytes(1)) + buf := bytes.NewBufferString("foobarbazz") + pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) - require.Equal(t, 9, pgi.readSize()) - pgi.Close() - require.Equal(t, 0, pgi.readSize()) - }) + // No reads to internal buffer. + require.Equal(t, 0, pgi.readSize()) - t.Run("valid", func(t *testing.T) { - buf := bytes.NewBufferString("foobarbazz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) - require.NoError(t, err) - - // No reads to internal buffer. - require.Equal(t, 0, pgi.readSize()) + // Attempt reads to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 3, pgi.readSize()) - // Attempt reads to buffer. - require.NoError(t, pgi.ensureNextNBytes(3)) - require.Equal(t, 3, pgi.readSize()) - - // Read until buffer is full. - require.NoError(t, pgi.ensureNextNBytes(10)) - require.Equal(t, 10, pgi.readSize()) - }) + // Read until buffer is full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 10, pgi.readSize()) } -func TestPGInterceptor_WriteSize(t *testing.T) { +func TestPGInterceptor_writeSize(t *testing.T) { defer leaktest.AfterTest(t)() - t.Run("interceptor is closed", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 9) - require.NoError(t, err) - - require.Equal(t, 9, pgi.writeSize()) - pgi.Close() - require.Equal(t, 0, pgi.writeSize()) - }) - - t.Run("valid", func(t *testing.T) { - buf := bytes.NewBufferString("foobarbazz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 10) - require.NoError(t, err) + buf := bytes.NewBufferString("foobarbazz") + pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) - // No writes to internal buffer. - require.Equal(t, 10, pgi.writeSize()) + // No writes to internal buffer. + require.Equal(t, 10, pgi.writeSize()) - // Attempt writes to buffer. - require.NoError(t, pgi.ensureNextNBytes(3)) - require.Equal(t, 7, pgi.writeSize()) + // Attempt writes to buffer. + require.NoError(t, pgi.ensureNextNBytes(3)) + require.Equal(t, 7, pgi.writeSize()) - // Attempt more writes to buffer until full. - require.NoError(t, pgi.ensureNextNBytes(10)) - require.Equal(t, 0, pgi.writeSize()) - }) + // Attempt more writes to buffer until full. + require.NoError(t, pgi.ensureNextNBytes(10)) + require.Equal(t, 0, pgi.writeSize()) } func TestPGInterceptor_ensureNextNBytes(t *testing.T) { defer leaktest.AfterTest(t)() t.Run("invalid n", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, nil /* dst */, 8) - require.NoError(t, err) + pgi := newPgInterceptor(nil /* src */, 8 /* bufSize */) require.EqualError(t, pgi.ensureNextNBytes(-1), "invalid number of bytes -1 for buffer size 8") @@ -543,8 +401,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { t.Run("buffer already has n bytes", func(t *testing.T) { buf := bytes.NewBufferString("foobarbaz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 8) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 8 /* bufSize */) // Read "foo" into buffer". require.NoError(t, pgi.ensureNextNBytes(3)) @@ -565,8 +422,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { t.Run("bytes are realigned", func(t *testing.T) { buf := bytes.NewBufferString("foobarbazcar") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), nil /* dst */, 9) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 9 /* bufSize */) // Read "foobarb" into buffer. require.NoError(t, pgi.ensureNextNBytes(7)) @@ -587,8 +443,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { // if there was a Read call. buf := bytes.NewBufferString("foobarbaz") - pgi, err := newPgInterceptor(buf, nil /* dst */, 10) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) // Request for only 1 byte. require.NoError(t, pgi.ensureNextNBytes(1)) @@ -597,7 +452,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) // Should be a no-op. - _, err = buf.WriteString("car") + _, err := buf.WriteString("car") require.NoError(t, err) require.NoError(t, pgi.ensureNextNBytes(9)) require.Equal(t, 3, buf.Len()) @@ -632,3 +487,51 @@ func (rw *errReadWriter) Write(p []byte) (int, error) { } return rw.w.Write(p) } + +// testSelect1Bytes represents the bytes for a SELECT 1 query. This will always +// be 14 bytes (5 (header) + 8 (query) + 1 (null terminator)). +var testSelect1Bytes = (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + +// buildSrc generates a buffer with count test queries which alternates between +// SELECT 1 and SELECT 2. +func buildSrc(t *testing.T, count int) *bytes.Buffer { + t.Helper() + + // Reset bytes back to SELECT 1. + defer func() { + testSelect1Bytes[12] = '1' + }() + + // Generate buffer. + src := new(bytes.Buffer) + for i := 0; i < count; i++ { + // Alternate between SELECT 1 and 2 to ensure correctness. + if i%2 == 0 { + testSelect1Bytes[12] = '1' + } else { + testSelect1Bytes[12] = '2' + } + _, err := src.Write(testSelect1Bytes) + require.NoError(t, err) + } + return src +} + +// validateDst ensures that we have the right sequence of test queries in dst. +// There should be count queries that alternate between SELECT 1 and SELECT 2. +// Use buildSrc to generate the sender's buffer. +func validateDst(t *testing.T, dst io.Reader, count int) { + t.Helper() + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(dst), nil /* w */) + for i := 0; i < count; i++ { + msg, err := backend.Receive() + require.NoError(t, err) + q := msg.(*pgproto3.Query) + + expectedStr := "SELECT 1" + if i%2 == 1 { + expectedStr = "SELECT 2" + } + require.Equal(t, expectedStr, q.String) + } +} diff --git a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go index cea2ddf1f550..cd4a5ca9caab 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go +++ b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go @@ -15,20 +15,13 @@ import ( "github.com/jackc/pgproto3/v2" ) -// FrontendInterceptor is a client interceptor for the Postgres frontend -// protocol. +// FrontendInterceptor is a client interceptor for the Postgres frontend protocol. type FrontendInterceptor pgInterceptor -// NewFrontendInterceptor creates a FrontendInterceptor. bufSize must be at -// least the size of a pgwire message header. -func NewFrontendInterceptor( - src io.Reader, dst io.Writer, bufSize int, -) (*FrontendInterceptor, error) { - pgi, err := newPgInterceptor(src, dst, bufSize) - if err != nil { - return nil, err - } - return (*FrontendInterceptor)(pgi), nil +// NewFrontendInterceptor creates a FrontendInterceptor using the default buffer +// size of 8K bytes. +func NewFrontendInterceptor(src io.Reader) *FrontendInterceptor { + return (*FrontendInterceptor)(newPgInterceptor(src, defaultBufferSize)) } // PeekMsg returns the header of the current pgwire message without advancing @@ -40,13 +33,6 @@ func (fi *FrontendInterceptor) PeekMsg() (typ pgwirebase.ServerMessageType, size return pgwirebase.ServerMessageType(byteType), size, err } -// WriteMsg writes the given bytes to the writer dst. -// -// See pgInterceptor.WriteMsg for more information. -func (fi *FrontendInterceptor) WriteMsg(data pgproto3.BackendMessage) (n int, err error) { - return (*pgInterceptor)(fi).WriteMsg(data.Encode(nil)) -} - // ReadMsg decodes the current pgwire message and returns a BackendMessage. // This also advances the interceptor to the next message. // @@ -56,19 +42,14 @@ func (fi *FrontendInterceptor) ReadMsg() (msg pgproto3.BackendMessage, err error if err != nil { return nil, err } - // errPanicWriter is used here because Receive must not Write. - return pgproto3.NewFrontend(newChunkReader(msgBytes), &errPanicWriter{}).Receive() + // errWriter is used here because Receive must not Write. + return pgproto3.NewFrontend(newChunkReader(msgBytes), &errWriter{}).Receive() } // ForwardMsg sends the current pgwire message to the destination without any // decoding, and advances the interceptor to the next message. // // See pgInterceptor.ForwardMsg for more information. -func (fi *FrontendInterceptor) ForwardMsg() (n int, err error) { - return (*pgInterceptor)(fi).ForwardMsg() -} - -// Close closes the interceptor, and prevents further operations on it. -func (fi *FrontendInterceptor) Close() { - (*pgInterceptor)(fi).Close() +func (fi *FrontendInterceptor) ForwardMsg(dst io.Writer) (n int, err error) { + return (*pgInterceptor)(fi).ForwardMsg(dst) } diff --git a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go index 55f168b21015..7e1b03f50d05 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go @@ -27,59 +27,30 @@ func TestFrontendInterceptor(t *testing.T) { q := (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil) - t.Run("bufSize too small", func(t *testing.T) { - fi, err := interceptor.NewFrontendInterceptor(nil /* src */, nil /* dst */, 1) - require.Error(t, err) - require.Nil(t, fi) - }) - - t.Run("PeekMsg returns the right message type", func(t *testing.T) { + buildSrc := func(t *testing.T, count int) *bytes.Buffer { + t.Helper() src := new(bytes.Buffer) _, err := src.Write(q) require.NoError(t, err) + return src + } - fi, err := interceptor.NewFrontendInterceptor(src, nil /* dst */, 16) - require.NoError(t, err) + t.Run("PeekMsg returns the right message type", func(t *testing.T) { + src := buildSrc(t, 1) + + fi := interceptor.NewFrontendInterceptor(src) require.NotNil(t, fi) typ, size, err := fi.PeekMsg() require.NoError(t, err) require.Equal(t, pgwirebase.ServerMsgReady, typ) require.Equal(t, 6, size) - - fi.Close() - typ, size, err = fi.PeekMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, pgwirebase.ServerMessageType(0), typ) - require.Equal(t, 0, size) - }) - - t.Run("WriteMsg writes data to dst", func(t *testing.T) { - dst := new(bytes.Buffer) - fi, err := interceptor.NewFrontendInterceptor(nil /* src */, dst, 10) - require.NoError(t, err) - require.NotNil(t, fi) - - // This is a frontend interceptor, so writing goes to the client. - toSend := &pgproto3.ReadyForQuery{TxStatus: 'I'} - n, err := fi.WriteMsg(toSend) - require.NoError(t, err) - require.Equal(t, 6, n) - require.Equal(t, 6, dst.Len()) - - fi.Close() - n, err = fi.WriteMsg(toSend) - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, 0, n) }) t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { - src := new(bytes.Buffer) - _, err := src.Write(q) - require.NoError(t, err) + src := buildSrc(t, 1) - fi, err := interceptor.NewFrontendInterceptor(src, nil /* dst */, 16) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(src) require.NotNil(t, fi) msg, err := fi.ReadMsg() @@ -87,31 +58,18 @@ func TestFrontendInterceptor(t *testing.T) { rmsg, ok := msg.(*pgproto3.ReadyForQuery) require.True(t, ok) require.Equal(t, byte('I'), rmsg.TxStatus) - - fi.Close() - msg, err = fi.ReadMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Nil(t, msg) }) t.Run("ForwardMsg forwards data to dst", func(t *testing.T) { - src := new(bytes.Buffer) - _, err := src.Write(q) - require.NoError(t, err) + src := buildSrc(t, 1) dst := new(bytes.Buffer) - fi, err := interceptor.NewFrontendInterceptor(src, dst, 16) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(src) require.NotNil(t, fi) - n, err := fi.ForwardMsg() + n, err := fi.ForwardMsg(dst) require.NoError(t, err) require.Equal(t, 6, n) require.Equal(t, 6, dst.Len()) - - fi.Close() - n, err = fi.ForwardMsg() - require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error()) - require.Equal(t, 0, n) }) } diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go index 4f00d8e7216d..bb91118dc9cf 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go @@ -24,8 +24,6 @@ import ( func TestSimpleProxy(t *testing.T) { defer leaktest.AfterTest(t)() - const bufferSize = 16 - // These represents connections for client<->proxy and proxy<->server. fromClient := new(bytes.Buffer) toClient := new(bytes.Buffer) @@ -33,10 +31,8 @@ func TestSimpleProxy(t *testing.T) { toServer := new(bytes.Buffer) // Create client and server interceptors. - clientInt, err := interceptor.NewBackendInterceptor(fromClient, toServer, bufferSize) - require.NoError(t, err) - serverInt, err := interceptor.NewFrontendInterceptor(fromServer, toClient, bufferSize) - require.NoError(t, err) + clientInt := interceptor.NewBackendInterceptor(fromClient) + serverInt := interceptor.NewFrontendInterceptor(fromServer) t.Run("client to server", func(t *testing.T) { // Client sends a list of SQL queries. @@ -61,13 +57,13 @@ func TestSimpleProxy(t *testing.T) { require.NoError(t, err) // Forward message to server. - _, err = clientInt.ForwardMsg() + _, err = clientInt.ForwardMsg(toServer) require.NoError(t, err) if typ == pgwirebase.ClientMsgTerminate { // Right before we terminate, we could also craft a custom // message, and send it to the server. - _, err := clientInt.WriteMsg(customQuery) + _, err := toServer.Write(customQuery.Encode(nil)) require.NoError(t, err) break } @@ -103,7 +99,7 @@ func TestSimpleProxy(t *testing.T) { // Assuming that we're only interested in small messages, then // we could skip all the large ones. if size > 12 { - _, err := serverInt.ForwardMsg() + _, err := serverInt.ForwardMsg(toClient) require.NoError(t, err) continue } @@ -128,11 +124,11 @@ func TestSimpleProxy(t *testing.T) { // the client. dmsg.SecretKey = 100 - _, err = serverInt.WriteMsg(dmsg) + _, err = toClient.Write(dmsg.Encode(nil)) require.NoError(t, err) default: // Forward message that we're not interested to the client. - _, err := serverInt.ForwardMsg() + _, err := serverInt.ForwardMsg(toClient) require.NoError(t, err) }