Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl/sqlproxyccl: add postgres interceptors for message forwarding #76006

Merged
merged 1 commit into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ALL_TESTS = [
"//pkg/ccl/spanconfigccl/spanconfigsqlwatcherccl:spanconfigsqlwatcherccl_test",
"//pkg/ccl/sqlproxyccl/denylist:denylist_test",
"//pkg/ccl/sqlproxyccl/idle:idle_test",
"//pkg/ccl/sqlproxyccl/interceptor:interceptor_test",
"//pkg/ccl/sqlproxyccl/tenant:tenant_test",
"//pkg/ccl/sqlproxyccl/throttler:throttler_test",
"//pkg/ccl/sqlproxyccl:sqlproxyccl_test",
Expand Down
38 changes: 38 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "interceptor",
srcs = [
"backend_interceptor.go",
"base.go",
"chunkreader.go",
"frontend_interceptor.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor",
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/pgwire/pgwirebase",
"//pkg/util",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgproto3_v2//:pgproto3",
],
)

go_test(
name = "interceptor_test",
srcs = [
"backend_interceptor_test.go",
"base_test.go",
"chunkreader_test.go",
"frontend_interceptor_test.go",
"interceptor_test.go",
],
embed = [":interceptor"],
deps = [
"//pkg/sql/pgwire/pgwirebase",
"//pkg/util/leaktest",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgproto3_v2//:pgproto3",
"@com_github_stretchr_testify//require",
],
)
71 changes: 71 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2022 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package interceptor

import (
"io"

"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
"github.com/jackc/pgproto3/v2"
)

// BackendInterceptor is a server int/erceptor for the Postgres backend protocol.
type BackendInterceptor pgInterceptor

// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least
// the size of a pgwire message header.
func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) {
pgi, err := newPgInterceptor(src, dst, bufSize)
if err != nil {
return nil, err
}
return (*BackendInterceptor)(pgi), nil
}

// PeekMsg returns the header of the current pgwire message without advancing
// the interceptor.
//
// See pgInterceptor.PeekMsg for more information.
func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size int, err error) {
byteType, size, err := (*pgInterceptor)(bi).PeekMsg()
return pgwirebase.ClientMessageType(byteType), size, err
}

// WriteMsg writes the given bytes to the writer dst.
//
// See pgInterceptor.WriteMsg for more information.
func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) {
return (*pgInterceptor)(bi).WriteMsg(data.Encode(nil))
}

// ReadMsg decodes the current pgwire message and returns a FrontendMessage.
// This also advances the interceptor to the next message.
//
// See pgInterceptor.ReadMsg for more information.
func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error) {
msgBytes, err := (*pgInterceptor)(bi).ReadMsg()
if err != nil {
return nil, err
}
// errPanicWriter is used here because Receive must not Write.
return pgproto3.NewBackend(newChunkReader(msgBytes), &errPanicWriter{}).Receive()
}

// ForwardMsg sends the current pgwire message to the destination without any
// decoding, and advances the interceptor to the next message.
//
// See pgInterceptor.ForwardMsg for more information.
func (bi *BackendInterceptor) ForwardMsg() (n int, err error) {
return (*pgInterceptor)(bi).ForwardMsg()
}

// Close closes the interceptor, and prevents further operations on it.
func (bi *BackendInterceptor) Close() {
(*pgInterceptor)(bi).Close()
}
117 changes: 117 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2022 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package interceptor_test

import (
"bytes"
"testing"

"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
)

// TestBackendInterceptor tests the BackendInterceptor. Note that the tests
// here are shallow. For detailed ones, see the tests for the internal
// interceptor in base_test.go.
func TestBackendInterceptor(t *testing.T) {
defer leaktest.AfterTest(t)()

q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil)

t.Run("bufSize too small", func(t *testing.T) {
bi, err := interceptor.NewBackendInterceptor(nil /* src */, nil /* dst */, 1)
require.Error(t, err)
require.Nil(t, bi)
})

t.Run("PeekMsg returns the right message type", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
require.NotNil(t, bi)

typ, size, err := bi.PeekMsg()
require.NoError(t, err)
require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ)
require.Equal(t, 9, size)

bi.Close()
typ, size, err = bi.PeekMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, pgwirebase.ClientMessageType(0), typ)
require.Equal(t, 0, size)
})

t.Run("WriteMsg writes data to dst", func(t *testing.T) {
dst := new(bytes.Buffer)
bi, err := interceptor.NewBackendInterceptor(nil /* src */, dst, 10)
require.NoError(t, err)
require.NotNil(t, bi)

// This is a backend interceptor, so writing goes to the server.
toSend := &pgproto3.Query{String: "SELECT 1"}
n, err := bi.WriteMsg(toSend)
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.WriteMsg(toSend)
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})

t.Run("ReadMsg decodes the message correctly", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
require.NotNil(t, bi)

msg, err := bi.ReadMsg()
require.NoError(t, err)
rmsg, ok := msg.(*pgproto3.Query)
require.True(t, ok)
require.Equal(t, "SELECT 1", rmsg.String)

bi.Close()
msg, err = bi.ReadMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Nil(t, msg)
})

t.Run("ForwardMsg forwards data to dst", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)
dst := new(bytes.Buffer)

bi, err := interceptor.NewBackendInterceptor(src, dst, 16)
require.NoError(t, err)
require.NotNil(t, bi)

n, err := bi.ForwardMsg()
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.ForwardMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})
}
Loading