forked from Venafi/vcert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
listener.go
141 lines (129 loc) · 3.35 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package vcert
import (
"crypto/tls"
"crypto/x509/pkix"
"fmt"
"log"
"net"
"time"
"github.com/Venafi/vcert/v5/pkg/certificate"
"github.com/Venafi/vcert/v5/pkg/endpoint"
"github.com/Venafi/vcert/v5/pkg/util"
)
// NewListener returns a net.Listener that listens on the first port
// specified in domains list (like "example.com:8443") or on default
// (443) port on all interfaces and returns *tls.Conn connections with
// certificates enrolled via Venafi for the provided domain.
//
// It enables one-line HTTPS servers:
//
// log.Fatal(http.Serve(vcert.NewListener("example.com"), handler))
//
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS
// handshake has completed.
func (cfg *Config) NewListener(domains ...string) net.Listener {
l := listener{}
conn, err := cfg.NewClient()
if err != nil {
l.e = err
return &l
}
certs := make([]tls.Certificate, len(domains))
certsMap := make(map[string]*tls.Certificate)
port := ""
for i, d := range domains {
parsedHost, parsedPort, err := net.SplitHostPort(d)
if err == nil {
if port != "" && parsedPort != port {
l.e = fmt.Errorf("ports conflict: %v and %v", parsedPort, port)
return &l
}
port = parsedPort
d = parsedHost
}
log.Println("Retrieving certificate for domain", d)
cert, err := getSimpleCertificate(conn, d)
if err != nil {
l.e = err
return &l
}
certs[i] = cert
certsMap[d] = &certs[i]
}
if port == "" {
port = "443"
}
/* #nosec */
l.conf = &tls.Config{
Certificates: certs,
NameToCertificate: certsMap,
}
l.Listener, l.e = net.Listen("tcp", ":"+port)
log.Println("Starting server on port", port)
return &l
}
func getSimpleCertificate(conn endpoint.Connector, cn string) (tls.Certificate, error) {
req := certificate.Request{Subject: pkix.Name{CommonName: cn}, DNSNames: []string{cn}, CsrOrigin: certificate.LocalGeneratedCSR}
zc, err := conn.ReadZoneConfiguration()
if err != nil {
return tls.Certificate{}, err
}
err = conn.GenerateRequest(zc, &req)
if err != nil {
return tls.Certificate{}, err
}
requestID, err := conn.RequestCertificate(&req)
if err != nil {
return tls.Certificate{}, err
}
req.PickupID = requestID
req.Timeout = time.Minute
certCollection, err := conn.RetrieveCertificate(&req)
if err != nil {
return tls.Certificate{}, err
}
err = certCollection.AddPrivateKey(req.PrivateKey, nil)
if err != nil {
return tls.Certificate{}, err
}
privKey, err := util.DecryptPkcs8PrivateKey(certCollection.PrivateKey, "")
if err != nil {
return tls.Certificate{}, err
}
certCollection.PrivateKey = privKey
return certCollection.ToTLSCertificate(), err
}
type listener struct {
net.Listener
conf *tls.Config
e error
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.e != nil {
return nil, ln.e
}
conn, err := ln.Listener.Accept()
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
err = tcpConn.SetKeepAlive(true)
if err != nil {
return nil, err
}
err = tcpConn.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
return nil, err
}
return tls.Server(tcpConn, ln.conf), nil
}
func (ln *listener) Close() error {
if ln.e != nil {
return ln.e
}
return ln.Listener.Close()
}