diff --git a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go index 5258760ebc31..0e442b9f1e88 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go @@ -45,7 +45,7 @@ func TestBackendInterceptor(t *testing.T) { typ, size, err := bi.PeekMsg() require.NoError(t, err) require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ) - require.Equal(t, 9, size) + require.Equal(t, 14, size) bi.Close() typ, size, err = bi.PeekMsg() diff --git a/pkg/ccl/sqlproxyccl/interceptor/base.go b/pkg/ccl/sqlproxyccl/interceptor/base.go index f54daf579b58..1925840dbf67 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base.go @@ -85,9 +85,9 @@ func newPgInterceptor(src io.Reader, dst io.Writer, bufSize int) (*pgInterceptor // PeekMsg returns the header of the current pgwire message without advancing // the interceptor. On return, err == nil if and only if the entire header can -// be read. Note that size corresponds to the body size, and does not account -// for the size field itself. This will return ErrProtocolError if the packets -// are malformed. +// be read. The returned size corresponds to the entire message size, which +// includes the header type and body length. This will return ErrProtocolError +// if the packets are malformed. // // If the interceptor is closed, PeekMsg returns ErrInterceptorClosed. func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { @@ -108,7 +108,9 @@ func (p *pgInterceptor) PeekMsg() (typ byte, size int, err error) { return 0, 0, ErrProtocolError } - return typ, size - 4, nil + // Add 1 to size to account for type. We don't need to add 4 (int length) to + // it because size is already inclusive of that. + return typ, size + 1, nil } // WriteMsg writes the given bytes to the writer dst. If err != nil and a Write @@ -148,28 +150,27 @@ func (p *pgInterceptor) ReadMsg() (msg []byte, err error) { return nil, ErrInterceptorClosed } - // Peek header of the current message for body size. + // Peek header of the current message for message size. _, size, err := p.PeekMsg() if err != nil { return nil, err } - msgSizeBytes := pgHeaderSizeBytes + size // Can the entire message fit into the buffer? - if msgSizeBytes <= len(p.buf) { - if err := p.ensureNextNBytes(msgSizeBytes); err != nil { + if size <= len(p.buf) { + if err := p.ensureNextNBytes(size); err != nil { // Possibly due to a timeout or context cancellation. return nil, err } // Return a slice to the internal buffer to avoid an allocation here. - retBuf := p.buf[p.readPos : p.readPos+msgSizeBytes] - p.readPos += msgSizeBytes + retBuf := p.buf[p.readPos : p.readPos+size] + p.readPos += size return retBuf, nil } // Message cannot fit, so we will have to allocate. - msg = make([]byte, msgSizeBytes) + msg = make([]byte, size) // Copy bytes which have already been read. n := copy(msg, p.buf[p.readPos:p.writePos]) @@ -209,7 +210,7 @@ func (p *pgInterceptor) ForwardMsg() (n int, err error) { return 0, ErrInterceptorClosed } - // Retrieve header of the current message for body size. + // Retrieve header of the current message for message size. _, size, err := p.PeekMsg() if err != nil { return 0, err @@ -217,7 +218,7 @@ func (p *pgInterceptor) ForwardMsg() (n int, err error) { // Handle overflows as current message may not fit in the current buffer. startPos := p.readPos - endPos := startPos + pgHeaderSizeBytes + size + endPos := startPos + size remainingBytes := 0 if endPos > p.writePos { remainingBytes = endPos - p.writePos diff --git a/pkg/ccl/sqlproxyccl/interceptor/base_test.go b/pkg/ccl/sqlproxyccl/interceptor/base_test.go index 7fc55673b66d..1654b9b34ecd 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base_test.go @@ -95,7 +95,8 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { t.Run("successful", func(t *testing.T) { buf := new(bytes.Buffer) - _, err := buf.Write((&pgproto3.Query{String: "SELECT 1"}).Encode(nil)) + msgBytes := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil) + _, err := buf.Write(msgBytes) require.NoError(t, err) pgi, err := newPgInterceptor(buf, nil /* dst */, 10) @@ -104,14 +105,14 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { typ, size, err := pgi.PeekMsg() require.NoError(t, err) require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) - require.Equal(t, 9, size) + require.Equal(t, len(msgBytes), size) require.Equal(t, 4, buf.Len()) // Invoking Peek should not advance the interceptor. typ, size, err = pgi.PeekMsg() require.NoError(t, err) require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), typ) - require.Equal(t, 9, size) + require.Equal(t, len(msgBytes), size) require.Equal(t, 4, buf.Len()) }) } diff --git a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go index 45704877ce28..55f168b21015 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go @@ -45,7 +45,7 @@ func TestFrontendInterceptor(t *testing.T) { typ, size, err := fi.PeekMsg() require.NoError(t, err) require.Equal(t, pgwirebase.ServerMsgReady, typ) - require.Equal(t, 1, size) + require.Equal(t, 6, size) fi.Close() typ, size, err = fi.PeekMsg()