-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
proxy.go
199 lines (173 loc) · 6.65 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Copyright 2020 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
package sqlproxyccl
import (
"crypto/tls"
"encoding/binary"
"io"
"net"
"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/admitter"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/jackc/pgproto3/v2"
)
const pgAcceptSSLRequest = 'S'
// See https://www.postgresql.org/docs/9.1/protocol-message-formats.html.
var pgSSLRequest = []int32{8, 80877103}
// Options are the options to the Proxy method.
type Options struct {
IncomingTLSConfig *tls.Config // config used for client -> proxy connection
OutgoingTLSConfig *tls.Config // config used for proxy -> backend connection
// TODO(tbg): this is unimplemented and exists only to check which clients
// allow use of SNI. Should always return ("", nil).
OutgoingAddrFromSNI func(serverName string) (addr string, clientErr error)
OutgoingAddrFromParams func(map[string]string) (addr string, clientErr error)
// If set, consulted to decorate an error message to be sent to the client.
// The error passed to this method will contain no internal information.
OnSendErrToClient func(code ErrorCode, msg string) string
admitter admitter.Service
}
// Proxy takes an incoming client connection and relays it to a backend SQL
// server.
func (s *Server) Proxy(conn net.Conn) error {
sendErrToClient := func(conn net.Conn, code ErrorCode, msg string) {
if s.opts.OnSendErrToClient != nil {
msg = s.opts.OnSendErrToClient(code, msg)
}
_, _ = conn.Write((&pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: msg,
}).Encode(nil))
}
if s.admitter != nil {
// TODO(spaskob): check for previous successful connection from the same IP
// in which case allow connection.
if err := s.admitter.AllowRequest(conn.RemoteAddr().String(), timeutil.Now()); err != nil {
s.metrics.RefusedConnCount.Inc(1)
return newErrorf(CodeProxyRefusedConnection, "too many connection attempts")
}
}
{
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return newErrorf(CodeClientReadFailed, "while receiving startup message")
}
switch m.(type) {
case *pgproto3.SSLRequest:
case *pgproto3.CancelRequest:
// Ignore CancelRequest explicitly. We don't need to do this but it makes
// testing easier by avoiding a call to sendErrToClient on this path
// (which would confuse assertCtx).
return nil
default:
code := CodeUnexpectedInsecureStartupMessage
sendErrToClient(conn, code, "server requires encryption")
return newErrorf(code, "unsupported startup message: %T", m)
}
_, err = conn.Write([]byte{pgAcceptSSLRequest})
if err != nil {
return newErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err)
}
cfg := s.opts.IncomingTLSConfig.Clone()
var sniServerName string
cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) {
sniServerName = h.ServerName
return nil, nil
}
if s.opts.OutgoingAddrFromSNI != nil {
addr, clientErr := s.opts.OutgoingAddrFromSNI(sniServerName)
if clientErr != nil {
code := CodeSNIRoutingFailed
sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients
return newErrorf(code, "rejected by OutgoingAddrFromSNI")
}
if addr != "" {
return newErrorf(CodeSNIRoutingFailed, "OutgoingAddrFromSNI is unimplemented")
}
}
conn = tls.Server(conn, cfg)
}
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return newErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err)
}
msg, ok := m.(*pgproto3.StartupMessage)
if !ok {
return newErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
}
outgoingAddr, clientErr := s.opts.OutgoingAddrFromParams(msg.Parameters)
if clientErr != nil {
s.metrics.RoutingErrCount.Inc(1)
code := CodeParamsRoutingFailed
sendErrToClient(conn, code, clientErr.Error())
return newErrorf(code, "rejected by OutgoingAddrFromParams: %v", clientErr)
}
crdbConn, err := net.Dial("tcp", outgoingAddr)
if err != nil {
s.metrics.BackendDownCount.Inc(1)
code := CodeBackendDown
sendErrToClient(conn, code, "unable to reach backend SQL server")
return newErrorf(code, "dialing backend server: %v", err)
}
// Send SSLRequest.
if err := binary.Write(crdbConn, binary.BigEndian, pgSSLRequest); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "sending SSLRequest to target server: %v", err)
}
if s.admitter != nil {
s.admitter.RequestSuccess(conn.RemoteAddr().String())
}
response := make([]byte, 1)
if _, err = io.ReadFull(crdbConn, response); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "reading response to SSLRequest")
}
if response[0] != pgAcceptSSLRequest {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendRefusedTLS, "target server refused TLS connection")
}
outCfg := s.opts.OutgoingTLSConfig.Clone()
outCfg.ServerName = outgoingAddr
crdbConn = tls.Client(crdbConn, outCfg)
if _, err := crdbConn.Write(msg.Encode(nil)); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", outgoingAddr, err)
}
// These channels are buffered because we'll only consume one of them.
errOutgoing := make(chan error, 1)
errIncoming := make(chan error, 1)
go func() {
_, err := io.Copy(crdbConn, conn)
errOutgoing <- err
}()
go func() {
_, err := io.Copy(conn, crdbConn)
errIncoming <- err
}()
select {
// NB: when using pgx, we see a nil errIncoming first on clean connection
// termination. Using psql I see a nil errOutgoing first. I think the PG
// protocol stipulates sending a message to the server at which point
// the server closes the connection (errIncoming), but presumably the
// client gets to close the connection once it's sent that message,
// meaning either case is possible.
case err := <-errIncoming:
if err != nil {
s.metrics.BackendDisconnectCount.Inc(1)
return newErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err)
}
return nil
case err := <-errOutgoing:
// The incoming connection got closed.
if err != nil {
s.metrics.ClientDisconnectCount.Inc(1)
return newErrorf(CodeClientDisconnected, "copying from target server to client: %v", err)
}
return nil
}
}