Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl/sqlproxyccl: frontend admitter #57849

Merged
merged 1 commit into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ require (
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645
github.com/ianlancetaylor/cgosymbolizer v0.0.0-20201002210021-dda951febc36 // indirect
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect
github.com/jackc/pgconn v1.6.1 // indirect
github.com/jackc/pgconn v1.6.1
github.com/jackc/pgproto3/v2 v2.0.4
github.com/jackc/pgx v3.6.2+incompatible
github.com/jackc/pgx/v4 v4.6.0
Expand Down
9 changes: 7 additions & 2 deletions pkg/ccl/cliccl/mtproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2
InsecureSkipVerify: true,
}
server := sqlproxyccl.NewServer(sqlproxyccl.Options{
IncomingTLSConfig: &tls.Config{
Certificates: []tls.Certificate{cer},
FrontendAdmitter: func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) {
return sqlproxyccl.FrontendAdmit(
incoming,
&tls.Config{
Certificates: []tls.Certificate{cer},
},
)
},
BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) {
params := msg.Parameters
Expand Down
3 changes: 3 additions & 0 deletions pkg/ccl/sqlproxyccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
"backend_dialer.go",
"error.go",
"errorcode_string.go",
"frontend_admitter.go",
"idle_disconnect_connection.go",
"metrics.go",
"proxy.go",
Expand All @@ -28,6 +29,7 @@ go_library(
go_test(
name = "sqlproxyccl_test",
srcs = [
"frontend_admitter_test.go",
"idle_disconnect_connection_test.go",
"main_test.go",
"proxy_test.go",
Expand All @@ -46,6 +48,7 @@ go_test(
"//pkg/util/randutil",
"//pkg/util/timeutil",
"//vendor/github.com/cockroachdb/errors",
"//vendor/github.com/jackc/pgconn",
"//vendor/github.com/jackc/pgproto3/v2:pgproto3",
"//vendor/github.com/jackc/pgx/v4:pgx",
"//vendor/github.com/stretchr/testify/require",
Expand Down
76 changes: 76 additions & 0 deletions pkg/ccl/sqlproxyccl/frontend_admitter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2020 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package sqlproxyccl

import (
"crypto/tls"
"net"

"github.com/jackc/pgproto3/v2"
)

// FrontendAdmit is the default implementation of a frontend admitter. It can
// upgrade to an optional SSL connection, and will handle and verify
// the startup message received from the PG SQL client.
func FrontendAdmit(
conn net.Conn, incomingTLSConfig *tls.Config,
) (net.Conn, *pgproto3.StartupMessage, error) {
// `conn` could be replaced by `conn` embedded in a `tls.Conn` connection,
// hence it's important to close `conn` rather than `proxyConn` since closing
// the latter will not call `Close` method of `tls.Conn`.
var sniServerName string
// If we have an incoming TLS Config, require that the client initiates
// with a TLS connection.
if incomingTLSConfig != nil {
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return nil, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message")
}
switch m.(type) {
case *pgproto3.SSLRequest:
case *pgproto3.CancelRequest:
// Ignore CancelRequest explicitly. We don't need to do this but it makes
// testing easier by avoiding a call to sendErrToClient on this path
// (which would confuse assertCtx).
return nil, nil, nil
default:
code := CodeUnexpectedInsecureStartupMessage
return nil, nil, NewErrorf(code, "unsupported startup message: %T", m)
}

_, err = conn.Write([]byte{pgAcceptSSLRequest})
if err != nil {
return nil, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err)
}

cfg := incomingTLSConfig.Clone()

cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) {
sniServerName = h.ServerName
return nil, nil
}
conn = tls.Server(conn, cfg)
}

m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return nil, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err)
}
msg, ok := m.(*pgproto3.StartupMessage)
if !ok {
return nil, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
}

// Add the sniServerName (if used) as parameter
if sniServerName != "" {
msg.Parameters["sni-server"] = sniServerName
}

return conn, msg, nil
}
149 changes: 149 additions & 0 deletions pkg/ccl/sqlproxyccl/frontend_admitter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2020 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package sqlproxyccl

import (
"context"
"crypto/tls"
"fmt"
"net"
"testing"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
)

func tlsConfig() (*tls.Config, error) {
cer, err := tls.LoadX509KeyPair("testserver.crt", "testserver.key")
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cer},
ServerName: "localhost",
}, nil
}

func TestFrontendAdmitWithClientSSLDisableAndCustomParam(t *testing.T) {
defer leaktest.AfterTest(t)()

cli, srv := net.Pipe()
require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9)))
require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9)))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go func() {
cfg, err := pgconn.ParseConfig(
"postgres://localhost?sslmode=disable&p1=a",
)
require.NoError(t, err)
require.NotNil(t, cfg)
cfg.DialFunc = func(
ctx context.Context, network, addr string,
) (net.Conn, error) {
return cli, nil
}
_, _ = pgconn.ConnectConfig(ctx, cfg)
fmt.Printf("Done\n")
}()

frontendCon, msg, err := FrontendAdmit(srv, nil)
require.NoError(t, err)
require.Equal(t, srv, frontendCon)
require.NotNil(t, msg)
require.Contains(t, msg.Parameters, "p1")
require.Equal(t, msg.Parameters["p1"], "a")
}

func TestFrontendAdmitWithClientSSLRequire(t *testing.T) {
defer leaktest.AfterTest(t)()

cli, srv := net.Pipe()
require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9)))
require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9)))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go func() {
cfg, err := pgconn.ParseConfig("postgres://localhost?sslmode=require")
require.NoError(t, err)
require.NotNil(t, cfg)
cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return cli, nil
}
_, _ = pgconn.ConnectConfig(ctx, cfg)
}()

tlsConfig, err := tlsConfig()
require.NoError(t, err)
frontendCon, msg, err := FrontendAdmit(srv, tlsConfig)
require.NoError(t, err)
require.NotEqual(t, srv, frontendCon) // The connection was replaced by SSL
require.NotNil(t, msg)
}

func TestFrontendAdmitWithCancel(t *testing.T) {
defer leaktest.AfterTest(t)()

cli, srv := net.Pipe()
require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9)))
require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9)))

go func() {
cancelRequest := pgproto3.CancelRequest{ProcessID: 1, SecretKey: 2}
_, err := cli.Write(cancelRequest.Encode([]byte{}))
require.NoError(t, err)
}()

frontendCon, msg, err := FrontendAdmit(srv, nil)
require.EqualError(t, err,
"CodeUnexpectedStartupMessage: "+
"unsupported post-TLS startup message: *pgproto3.CancelRequest",
)
require.Nil(t, frontendCon)
require.Nil(t, msg)
}

func TestFrontendAdmitWithSSLAndCancel(t *testing.T) {
defer leaktest.AfterTest(t)()

cli, srv := net.Pipe()
require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9)))
require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9)))

go func() {
sslRequest := pgproto3.SSLRequest{}
_, err := cli.Write(sslRequest.Encode([]byte{}))
require.NoError(t, err)
b := []byte{0}
n, err := cli.Read(b)
require.Equal(t, n, 1)
require.NoError(t, err)
cli = tls.Client(cli, &tls.Config{InsecureSkipVerify: true})
cancelRequest := pgproto3.CancelRequest{ProcessID: 1, SecretKey: 2}
_, err = cli.Write(cancelRequest.Encode([]byte{}))
require.NoError(t, err)
}()

tlsConfig, err := tlsConfig()
require.NoError(t, err)
frontendCon, msg, err := FrontendAdmit(srv, tlsConfig)
require.EqualError(t, err,
"CodeUnexpectedStartupMessage: "+
"unsupported post-TLS startup message: *pgproto3.CancelRequest",
)
require.Nil(t, frontendCon)
require.Nil(t, msg)
}
75 changes: 29 additions & 46 deletions pkg/ccl/sqlproxyccl/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type BackendConfig struct {

// Options are the options to the Proxy method.
type Options struct {
// Deprecated: construct FrontendAdmitter, passing this information in case
// that SSL is desired.
IncomingTLSConfig *tls.Config // config used for client -> proxy connection

// BackendFromParams returns the config to use for the proxy -> backend
Expand All @@ -70,6 +72,11 @@ type Options struct {
// The error passed to this method will contain no internal information.
OnSendErrToClient func(code ErrorCode, msg string) string

// If set, will be called immediately after a new incoming connection
// is accepted. It can optionally negotiate SSL, provide admittance control or
// other types of frontend connection filtering.
FrontendAdmitter func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error)

// If set, will be used to establish and return connection to the backend.
// If not set, the old logic will be used.
// The argument is the startup message received from the frontend. It
Expand Down Expand Up @@ -98,59 +105,35 @@ func (s *Server) Proxy(proxyConn *Conn) error {
}).Encode(nil))
}

var conn net.Conn = proxyConn
// `conn` could be replaced by `conn` embedded in a `tls.Conn` connection,
// hence it's important to close `conn` rather than `proxyConn` since closing
// the latter will not call `Close` method of `tls.Conn`.
defer func() { _ = conn.Close() }()
var sniServerName string
// If we have an incoming TLS Config, require that the client initiates
// with a TLS connection.
if s.opts.IncomingTLSConfig != nil {
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return NewErrorf(CodeClientReadFailed, "while receiving startup message")
}
switch m.(type) {
case *pgproto3.SSLRequest:
case *pgproto3.CancelRequest:
// Ignore CancelRequest explicitly. We don't need to do this but it makes
// testing easier by avoiding a call to sendErrToClient on this path
// (which would confuse assertCtx).
return nil
default:
code := CodeUnexpectedInsecureStartupMessage
sendErrToClient(conn, code, "server requires encryption")
return NewErrorf(code, "unsupported startup message: %T", m)
frontendAdmitter := s.opts.FrontendAdmitter
if frontendAdmitter == nil {
// Keep this until all clients are switched to provide FrontendAdmitter
// at what point we can also drop IncomingTLSConfig
frontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) {
return FrontendAdmit(incoming, s.opts.IncomingTLSConfig)
}

_, err = conn.Write([]byte{pgAcceptSSLRequest})
if err != nil {
return NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err)
}

cfg := s.opts.IncomingTLSConfig.Clone()

cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) {
sniServerName = h.ServerName
return nil, nil
}
conn = tls.Server(conn, cfg)
}

m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
conn, msg, err := frontendAdmitter(proxyConn)
if err != nil {
return NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err)
}
msg, ok := m.(*pgproto3.StartupMessage)
if !ok {
return NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
var codeErr *CodeError
if errors.As(err, &codeErr) && codeErr.code == CodeUnexpectedInsecureStartupMessage {
sendErrToClient(
proxyConn, // Do this on the TCP connection as it means denying SSL
CodeUnexpectedInsecureStartupMessage,
"server requires encryption",
)
}
return err
}

// Add the sniServerName (if used) as parameter
if sniServerName != "" {
msg.Parameters["sni-server"] = sniServerName
// This currently only happens for CancelRequest type of startup messages
// that we don't support
if conn == nil {
return nil

}
defer func() { _ = conn.Close() }()

backendDialer := s.opts.BackendDialer
if backendDialer == nil {
Expand Down
11 changes: 8 additions & 3 deletions pkg/ccl/sqlproxyccl/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \
`
cer, err := tls.LoadX509KeyPair("testserver.crt", "testserver.key")
require.NoError(t, err)
opts.IncomingTLSConfig = &tls.Config{
Certificates: []tls.Certificate{cer},
ServerName: "localhost",
opts.FrontendAdmitter = func(incoming net.Conn) (net.Conn, *pgproto3.StartupMessage, error) {
return FrontendAdmit(
incoming,
&tls.Config{
Certificates: []tls.Certificate{cer},
ServerName: "localhost",
},
)
}

const listenAddress = "127.0.0.1:0"
Expand Down