-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
proxy.go
229 lines (201 loc) · 7.51 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
// Copyright 2020 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
package sqlproxyccl
import (
"crypto/tls"
"encoding/binary"
"io"
"net"
"github.com/jackc/pgproto3/v2"
)
const pgAcceptSSLRequest = 'S'
// See https://www.postgresql.org/docs/9.1/protocol-message-formats.html.
var pgSSLRequest = []int32{8, 80877103}
// BackendConfig contains the configuration of a backend connection that is
// being proxied.
type BackendConfig struct {
OutgoingAddress string
TLSConf *tls.Config
RefuseConn bool
OnConnectionSuccess func() error
}
// Options are the options to the Proxy method.
type Options struct {
IncomingTLSConfig *tls.Config // config used for client -> proxy connection
// TODO(tbg): this is unimplemented and exists only to check which clients
// allow use of SNI. Should always return ("", nil).
BackendConfigFromSNI func(serverName string) (config *BackendConfig, clientErr error)
// BackendFromParams returns the config to use for the proxy -> backend
// connection. The TLS config is in it and it must have an appropriate
// ServerName for the remote backend.
BackendConfigFromParams func(
params map[string]string, ipAddress string,
) (config *BackendConfig, clientErr error)
// If set, consulted to modify the parameters set by the frontend before
// forwarding them to the backend during startup.
ModifyRequestParams func(map[string]string)
// If set, consulted to decorate an error message to be sent to the client.
// The error passed to this method will contain no internal information.
OnSendErrToClient func(code ErrorCode, msg string) string
}
// Proxy takes an incoming client connection and relays it to a backend SQL
// server.
func (s *Server) Proxy(conn net.Conn) error {
sendErrToClient := func(conn net.Conn, code ErrorCode, msg string) {
if s.opts.OnSendErrToClient != nil {
msg = s.opts.OnSendErrToClient(code, msg)
}
_, _ = conn.Write((&pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: msg,
}).Encode(nil))
}
{
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return newErrorf(CodeClientReadFailed, "while receiving startup message")
}
switch m.(type) {
case *pgproto3.SSLRequest:
case *pgproto3.CancelRequest:
// Ignore CancelRequest explicitly. We don't need to do this but it makes
// testing easier by avoiding a call to sendErrToClient on this path
// (which would confuse assertCtx).
return nil
default:
code := CodeUnexpectedInsecureStartupMessage
sendErrToClient(conn, code, "server requires encryption")
return newErrorf(code, "unsupported startup message: %T", m)
}
_, err = conn.Write([]byte{pgAcceptSSLRequest})
if err != nil {
return newErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err)
}
cfg := s.opts.IncomingTLSConfig.Clone()
var sniServerName string
cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) {
sniServerName = h.ServerName
return nil, nil
}
if s.opts.BackendConfigFromSNI != nil {
cfg, clientErr := s.opts.BackendConfigFromSNI(sniServerName)
if clientErr != nil {
code := CodeSNIRoutingFailed
sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients
return newErrorf(code, "rejected by OutgoingAddrFromSNI")
}
if cfg.OutgoingAddress != "" {
return newErrorf(CodeSNIRoutingFailed, "BackendConfigFromSNI is unimplemented")
}
}
conn = tls.Server(conn, cfg)
}
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return newErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err)
}
msg, ok := m.(*pgproto3.StartupMessage)
if !ok {
return newErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
}
var backendConfig *BackendConfig
{
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return newErrorf(
CodeParamsRoutingFailed, "could not parse address %s: %v",
conn.RemoteAddr().String(), err)
}
var clientErr error
backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, ip)
if clientErr != nil {
s.metrics.RoutingErrCount.Inc(1)
code := CodeParamsRoutingFailed
sendErrToClient(conn, code, clientErr.Error())
return newErrorf(code, "rejected by BackendConfigFromParams: %v", clientErr)
}
}
if backendConfig.RefuseConn {
s.metrics.RefusedConnCount.Inc(1)
code := CodeProxyRefusedConnection
sendErrToClient(conn, code, "backend refused to admit")
return newErrorf(code, "backend refused to admit")
}
crdbConn, err := net.Dial("tcp", backendConfig.OutgoingAddress)
if err != nil {
s.metrics.BackendDownCount.Inc(1)
code := CodeBackendDown
sendErrToClient(conn, code, "unable to reach backend SQL server")
return newErrorf(code, "dialing backend server: %v", err)
}
// Send SSLRequest.
if err := binary.Write(crdbConn, binary.BigEndian, pgSSLRequest); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "sending SSLRequest to target server: %v", err)
}
response := make([]byte, 1)
if _, err = io.ReadFull(crdbConn, response); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "reading response to SSLRequest")
}
if response[0] != pgAcceptSSLRequest {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendRefusedTLS, "target server refused TLS connection")
}
outCfg := backendConfig.TLSConf.Clone()
crdbConn = tls.Client(crdbConn, outCfg)
if s.opts.ModifyRequestParams != nil {
s.opts.ModifyRequestParams(msg.Parameters)
}
if _, err := crdbConn.Write(msg.Encode(nil)); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v",
backendConfig.OutgoingAddress, err)
}
if backendConfig.OnConnectionSuccess != nil {
if err := backendConfig.OnConnectionSuccess(); err != nil {
code := CodeBackendDown
sendErrToClient(conn, code, err.Error())
s.metrics.BackendDownCount.Inc(1)
return newErrorf(code, "recording connection success: %v", err)
}
}
// These channels are buffered because we'll only consume one of them.
errOutgoing := make(chan error, 1)
errIncoming := make(chan error, 1)
go func() {
_, err := io.Copy(crdbConn, conn)
errOutgoing <- err
}()
go func() {
_, err := io.Copy(conn, crdbConn)
errIncoming <- err
}()
select {
// NB: when using pgx, we see a nil errIncoming first on clean connection
// termination. Using psql I see a nil errOutgoing first. I think the PG
// protocol stipulates sending a message to the server at which point
// the server closes the connection (errIncoming), but presumably the
// client gets to close the connection once it's sent that message,
// meaning either case is possible.
case err := <-errIncoming:
if err != nil {
s.metrics.BackendDisconnectCount.Inc(1)
return newErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err)
}
return nil
case err := <-errOutgoing:
// The incoming connection got closed.
if err != nil {
s.metrics.ClientDisconnectCount.Inc(1)
return newErrorf(CodeClientDisconnected, "copying from target server to client: %v", err)
}
return nil
}
}