Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl/sqlproxyccl: cleanup PG interceptor APIs #76613

Merged
merged 3 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 8 additions & 24 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@ import (
// BackendInterceptor is a server int/erceptor for the Postgres backend protocol.
type BackendInterceptor pgInterceptor

// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least
// the size of a pgwire message header.
func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) {
pgi, err := newPgInterceptor(src, dst, bufSize)
if err != nil {
return nil, err
}
return (*BackendInterceptor)(pgi), nil
// NewBackendInterceptor creates a BackendInterceptor using the default buffer
// size of 8K bytes.
func NewBackendInterceptor(src io.Reader) *BackendInterceptor {
return (*BackendInterceptor)(newPgInterceptor(src, defaultBufferSize))
}

// PeekMsg returns the header of the current pgwire message without advancing
Expand All @@ -37,13 +33,6 @@ func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size
return pgwirebase.ClientMessageType(byteType), size, err
}

// WriteMsg writes the given bytes to the writer dst.
//
// See pgInterceptor.WriteMsg for more information.
func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) {
return (*pgInterceptor)(bi).WriteMsg(data.Encode(nil))
}

// ReadMsg decodes the current pgwire message and returns a FrontendMessage.
// This also advances the interceptor to the next message.
//
Expand All @@ -53,19 +42,14 @@ func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error
if err != nil {
return nil, err
}
// errPanicWriter is used here because Receive must not Write.
return pgproto3.NewBackend(newChunkReader(msgBytes), &errPanicWriter{}).Receive()
// errWriter is used here because Receive must not Write.
return pgproto3.NewBackend(newChunkReader(msgBytes), &errWriter{}).Receive()
}

// ForwardMsg sends the current pgwire message to the destination without any
// decoding, and advances the interceptor to the next message.
//
// See pgInterceptor.ForwardMsg for more information.
func (bi *BackendInterceptor) ForwardMsg() (n int, err error) {
return (*pgInterceptor)(bi).ForwardMsg()
}

// Close closes the interceptor, and prevents further operations on it.
func (bi *BackendInterceptor) Close() {
(*pgInterceptor)(bi).Close()
func (bi *BackendInterceptor) ForwardMsg(dst io.Writer) (n int, err error) {
return (*pgInterceptor)(bi).ForwardMsg(dst)
}
68 changes: 13 additions & 55 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,91 +27,49 @@ func TestBackendInterceptor(t *testing.T) {

q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil)

t.Run("bufSize too small", func(t *testing.T) {
bi, err := interceptor.NewBackendInterceptor(nil /* src */, nil /* dst */, 1)
require.Error(t, err)
require.Nil(t, bi)
})

t.Run("PeekMsg returns the right message type", func(t *testing.T) {
buildSrc := func(t *testing.T, count int) *bytes.Buffer {
t.Helper()
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)
return src
}

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
t.Run("PeekMsg returns the right message type", func(t *testing.T) {
src := buildSrc(t, 1)

bi := interceptor.NewBackendInterceptor(src)
require.NotNil(t, bi)

typ, size, err := bi.PeekMsg()
require.NoError(t, err)
require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ)
require.Equal(t, 14, size)

bi.Close()
typ, size, err = bi.PeekMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, pgwirebase.ClientMessageType(0), typ)
require.Equal(t, 0, size)
})

t.Run("WriteMsg writes data to dst", func(t *testing.T) {
dst := new(bytes.Buffer)
bi, err := interceptor.NewBackendInterceptor(nil /* src */, dst, 10)
require.NoError(t, err)
require.NotNil(t, bi)

// This is a backend interceptor, so writing goes to the server.
toSend := &pgproto3.Query{String: "SELECT 1"}
n, err := bi.WriteMsg(toSend)
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.WriteMsg(toSend)
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})

t.Run("ReadMsg decodes the message correctly", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)
src := buildSrc(t, 1)

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(src)
require.NotNil(t, bi)

msg, err := bi.ReadMsg()
require.NoError(t, err)
rmsg, ok := msg.(*pgproto3.Query)
require.True(t, ok)
require.Equal(t, "SELECT 1", rmsg.String)

bi.Close()
msg, err = bi.ReadMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Nil(t, msg)
})

t.Run("ForwardMsg forwards data to dst", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)
src := buildSrc(t, 1)
dst := new(bytes.Buffer)

bi, err := interceptor.NewBackendInterceptor(src, dst, 16)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(src)
require.NotNil(t, bi)

n, err := bi.ForwardMsg()
n, err := bi.ForwardMsg(dst)
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.ForwardMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})
}
Loading