Skip to content

Commit

Permalink
feat: callback notify function when connection is refused (#2308)
Browse files Browse the repository at this point in the history
Add option for adding a callback function for connection
refused events due to max connection limit.

---------

Co-authored-by: Che Lin <[email protected]>
Co-authored-by: Jack Wotherspoon <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 9be642b commit 9309b84
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 27 deletions.
8 changes: 8 additions & 0 deletions cmd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ func WithLazyRefresh() Option {
c.conf.LazyRefresh = true
}
}

// WithConnRefuseNotify configures the Proxy to call the provided function when
// a connection is refused. The notification function is run in a goroutine.
func WithConnRefuseNotify(n func()) Option {
return func(c *Command) {
c.connRefuseNotify = n
}
}
11 changes: 6 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ func Execute() {
// Command represents an invocation of the Cloud SQL Auth Proxy.
type Command struct {
*cobra.Command
conf *proxy.Config
logger cloudsql.Logger
dialer cloudsql.Dialer
cleanup func() error
conf *proxy.Config
logger cloudsql.Logger
dialer cloudsql.Dialer
cleanup func() error
connRefuseNotify func()
}

var longHelp = `
Expand Down Expand Up @@ -1025,7 +1026,7 @@ func runSignalWrapper(cmd *Command) (err error) {
startCh := make(chan *proxy.Client)
go func() {
defer close(startCh)
p, err := proxy.NewClient(ctx, cmd.dialer, cmd.logger, cmd.conf)
p, err := proxy.NewClient(ctx, cmd.dialer, cmd.logger, cmd.conf, cmd.connRefuseNotify)
if err != nil {
cmd.logger.Debugf("Error starting proxy: %v", err)
shutdownCh <- fmt.Errorf("unable to start: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion internal/healthcheck/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer, i
Instances: instances,
MaxConnections: maxConns,
}
p, err := proxy.NewClient(context.Background(), dialer, logger, c)
p, err := proxy.NewClient(context.Background(), dialer, logger, c, nil)
if err != nil {
t.Fatalf("proxy.NewClient: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/fuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string)
conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir}

// This context is only used to call the Cloud SQL API
c, err := proxy.NewClient(context.Background(), d, testLogger, conf)
c, err := proxy.NewClient(context.Background(), d, testLogger, conf, nil)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
Expand Down Expand Up @@ -424,7 +424,7 @@ func TestFUSEWithBadDir(t *testing.T) {
t.Skip("skipping fuse tests in short mode.")
}
conf := &proxy.Config{FUSEDir: "/not/a/dir", FUSETempDir: randTmpDir(t)}
_, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, conf)
_, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, conf, nil)
if err == nil {
t.Fatal("proxy client should fail with bad dir")
}
Expand Down
14 changes: 10 additions & 4 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,14 @@ type Client struct {

logger cloudsql.Logger

connRefuseNotify func()

fuseMount
}

// NewClient completes the initial setup required to get the proxy to a "steady"
// state.
func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config) (*Client, error) {
func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config, connRefuseNotify func()) (*Client, error) {
// Check if the caller has configured a dialer.
// Otherwise, initialize a new one.
if d == nil {
Expand All @@ -523,9 +525,10 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *
}

c := &Client{
logger: l,
dialer: d,
conf: conf,
logger: l,
dialer: d,
connRefuseNotify: connRefuseNotify,
conf: conf,
}

if conf.FUSEDir != "" {
Expand Down Expand Up @@ -753,6 +756,9 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error {

if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections {
c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections)
if c.connRefuseNotify != nil {
go c.connRefuseNotify()
}
_ = cConn.Close()
return
}
Expand Down
3 changes: 2 additions & 1 deletion internal/proxy/proxy_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func TestFuseClosesGracefully(t *testing.T) {
FUSEDir: t.TempDir(),
FUSETempDir: t.TempDir(),
Token: "mytoken",
})
},
nil)
if err != nil {
t.Fatal(err)
}
Expand Down
38 changes: 24 additions & 14 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func TestClientInitialization(t *testing.T) {

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in)
c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
Expand Down Expand Up @@ -370,7 +370,13 @@ func TestClientLimitsMaxConnections(t *testing.T) {
},
MaxConnections: 1,
}
c, err := proxy.NewClient(context.Background(), d, testLogger, in)
callbackGot := 0
connRefuseNotify := func() {
d.mu.Lock()
defer d.mu.Unlock()
callbackGot++
}
c, err := proxy.NewClient(context.Background(), d, testLogger, in, connRefuseNotify)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand Down Expand Up @@ -410,6 +416,10 @@ func TestClientLimitsMaxConnections(t *testing.T) {
if got := d.dialAttempts(); got != want {
t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got)
}

if callbackGot == 0 {
t.Fatal("connRefuseNotifyCallback is not called")
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
Expand Down Expand Up @@ -442,7 +452,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) {
},
WaitOnClose: 1 * time.Second,
}
c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in)
c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand All @@ -464,7 +474,7 @@ func TestClientClosesCleanly(t *testing.T) {
{Name: "proj:reg:inst"},
},
}
c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in)
c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
Expand All @@ -486,7 +496,7 @@ func TestClosesWithError(t *testing.T) {
{Name: "proj:reg:inst"},
},
}
c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in)
c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
Expand Down Expand Up @@ -542,13 +552,13 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) {
},
}

c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in)
c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
c.Close()

c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in)
c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
Expand All @@ -562,7 +572,7 @@ func TestClientNotifiesCallerOnServe(t *testing.T) {
{Name: "proj:region:pg"},
},
}
c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in)
c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}
Expand Down Expand Up @@ -595,7 +605,7 @@ func TestClientConnCount(t *testing.T) {
MaxConnections: 10,
}

c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in)
c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand Down Expand Up @@ -636,7 +646,7 @@ func TestCheckConnections(t *testing.T) {
},
}
d := &fakeDialer{}
c, err := proxy.NewClient(context.Background(), d, testLogger, in)
c, err := proxy.NewClient(context.Background(), d, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand Down Expand Up @@ -664,7 +674,7 @@ func TestCheckConnections(t *testing.T) {
},
}
ed := &errorDialer{}
c, err = proxy.NewClient(context.Background(), ed, testLogger, in)
c, err = proxy.NewClient(context.Background(), ed, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand All @@ -690,7 +700,7 @@ func TestRunConnectionCheck(t *testing.T) {
RunConnectionTest: true,
}
d := &fakeDialer{}
c, err := proxy.NewClient(context.Background(), d, testLogger, in)
c, err := proxy.NewClient(context.Background(), d, testLogger, in, nil)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
Expand Down Expand Up @@ -757,7 +767,7 @@ func TestProxyInitializationWithFailedUnixSocket(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in)
_, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil)
if err == nil {
t.Fatalf("want non nil error, got = %v", err)
}
Expand Down Expand Up @@ -801,7 +811,7 @@ func TestProxyMultiInstances(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in)
_, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil)
if tc.wantSuccess != (err == nil) {
t.Fatalf("want return = %v, got = %v", tc.wantSuccess, err == nil)
}
Expand Down

0 comments on commit 9309b84

Please sign in to comment.