-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
mt_proxy.go
186 lines (165 loc) · 5.56 KB
/
mt_proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// Copyright 2022 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
package cliccl
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"time"
"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl"
"github.com/cockroachdb/cockroach/pkg/cli"
"github.com/cockroachdb/cockroach/pkg/cli/clierrorplus"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/log/severity"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/spf13/cobra"
)
const (
// shutdownConnectionTimeout is the maximum amount of time we will wait
// for all connections to be closed before forcefully closing them by
// shutting down the server
shutdownConnectionTimeout = time.Minute * 59
)
var mtStartSQLProxyCmd = &cobra.Command{
Use: "start-proxy",
Short: "start a sql proxy",
Long: `Starts a SQL proxy.
This proxy accepts incoming connections and relays them to a backend server
determined by the arguments used.
`,
RunE: clierrorplus.MaybeDecorateError(runStartSQLProxy),
Args: cobra.NoArgs,
}
func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) {
// Initialize logging, stopper and context that can be canceled
ctx, stopper, err := initLogging(cmd)
if err != nil {
return err
}
defer stopper.Stop(ctx)
log.Infof(ctx, "New proxy with opts: %+v", proxyContext)
proxyLn, err := net.Listen("tcp", proxyContext.ListenAddr)
if err != nil {
return err
}
metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress)
if err != nil {
return err
}
stopper.AddCloser(stop.CloserFn(func() { _ = metricsLn.Close() }))
server, err := sqlproxyccl.NewServer(ctx, stopper, proxyContext)
if err != nil {
return err
}
errChan := make(chan error, 1)
if err := stopper.RunAsyncTask(ctx, "serve-http", func(ctx context.Context) {
log.Infof(ctx, "HTTP metrics server listening at %s", metricsLn.Addr())
if err := server.ServeHTTP(ctx, metricsLn); err != nil {
errChan <- err
}
}); err != nil {
return err
}
if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) {
log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr())
if err := server.Serve(ctx, proxyLn); err != nil {
errChan <- err
}
}); err != nil {
return err
}
return waitForSignals(ctx, server, stopper, proxyLn, errChan)
}
func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) {
// Remove the default store, which avoids using it to set up logging.
// Instead, we'll default to logging to stderr unless --log-dir is
// specified. This makes sense since the standalone SQL server is
// at the time of writing stateless and may not be provisioned with
// suitable storage.
ctx = context.Background()
stopper, err = cli.ClearStoresAndSetupLoggingForMTCommands(cmd, ctx)
if err != nil {
return ctx, nil, err
}
ctx, _ = stopper.WithCancelOnQuiesce(ctx)
return ctx, stopper, err
}
func waitForSignals(
ctx context.Context,
server *sqlproxyccl.Server,
stopper *stop.Stopper,
proxyLn net.Listener,
errChan chan error,
) (returnErr error) {
// Need to alias the signals if this has to run on non-unix OSes too.
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, cli.DrainSignals...)
select {
case err := <-errChan:
log.StartAlwaysFlush()
return err
case <-stopper.ShouldQuiesce():
// Stop has been requested through the stopper's Stop
<-stopper.IsStopped()
// StartAlwaysFlush both flushes and ensures that subsequent log
// writes are flushed too.
log.StartAlwaysFlush()
case sig := <-signalCh: // INT or TERM
log.StartAlwaysFlush() // In case the caller follows up with KILL
log.Ops.Infof(ctx, "received signal '%s'", sig)
if sig == os.Interrupt {
returnErr = errors.New("interrupted")
}
go func() {
// Begin shutdown by:
// 1. Stopping the TCP listener so no new connections can be established
// 2. Waiting for all connections to close "naturally" or
// waiting for "shutdownConnectionTimeout" to elapse after which
// open TCP connections will be forcefully closed so the server can stop
log.Infof(ctx, "stopping tcp listener")
_ = proxyLn.Close()
select {
case <-server.AwaitNoConnections(ctx):
case <-time.After(shutdownConnectionTimeout):
}
log.Infof(ctx, "server stopping")
stopper.Stop(ctx)
}()
case <-log.FatalChan():
stopper.Stop(ctx)
select {} // Block and wait for logging go routine to shut down the process
}
// K8s will send two SIGTERM signals (one in preStop hook and one afterwards)
// and we do not want to force shutdown until the third signal
// TODO(pjtatlow): remove this once we can do graceful restarts with externalNetworkPolicy=local
// https://github.com/kubernetes/enhancements/issues/1669
numInterrupts := 0
for {
select {
case sig := <-signalCh:
if numInterrupts == 0 {
numInterrupts++
log.Ops.Infof(ctx, "received additional signal '%s'; continuing graceful shutdown. Next signal will force shutdown.", sig)
continue
}
log.Ops.Shoutf(ctx, severity.ERROR,
"received signal '%s' during shutdown, initiating hard shutdown", redact.Safe(sig))
panic("terminate")
case <-stopper.IsStopped():
const msgDone = "server shutdown completed"
log.Ops.Infof(ctx, msgDone)
fmt.Fprintln(os.Stdout, msgDone)
}
break
}
return returnErr
}