-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
frontend_admitter.go
74 lines (64 loc) · 2.42 KB
/
frontend_admitter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Copyright 2020 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
package sqlproxyccl
import (
"crypto/tls"
"net"
"github.com/jackc/pgproto3/v2"
)
// FrontendAdmit is an example frontend admitter
func FrontendAdmit(
conn net.Conn, incomingTLSConfig *tls.Config,
) (net.Conn, *pgproto3.StartupMessage, error) {
// `conn` could be replaced by `conn` embedded in a `tls.Conn` connection,
// hence it's important to close `conn` rather than `proxyConn` since closing
// the latter will not call `Close` method of `tls.Conn`.
var sniServerName string
// If we have an incoming TLS Config, require that the client initiates
// with a TLS connection.
if incomingTLSConfig != nil {
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return nil, nil, NewErrorf(CodeClientReadFailed, "while receiving startup message")
}
switch m.(type) {
case *pgproto3.SSLRequest:
case *pgproto3.CancelRequest:
// Ignore CancelRequest explicitly. We don't need to do this but it makes
// testing easier by avoiding a call to sendErrToClient on this path
// (which would confuse assertCtx).
return nil, nil, nil
default:
code := CodeUnexpectedInsecureStartupMessage
return nil, nil, NewErrorf(code, "unsupported startup message: %T", m)
}
_, err = conn.Write([]byte{pgAcceptSSLRequest})
if err != nil {
return nil, nil, NewErrorf(CodeClientWriteFailed, "acking SSLRequest: %v", err)
}
cfg := incomingTLSConfig.Clone()
cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) {
sniServerName = h.ServerName
return nil, nil
}
conn = tls.Server(conn, cfg)
}
m, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn).ReceiveStartupMessage()
if err != nil {
return nil, nil, NewErrorf(CodeClientReadFailed, "receiving post-TLS startup message: %v", err)
}
msg, ok := m.(*pgproto3.StartupMessage)
if !ok {
return nil, nil, NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
}
// Add the sniServerName (if used) as parameter
if sniServerName != "" {
msg.Parameters["sni-server"] = sniServerName
}
return conn, msg, nil
}