Skip to content

Commit

Permalink
Merge pull request #105 from pdecat/master
Browse files Browse the repository at this point in the history
Add -max_connections flag to put a hard limit on the number of connections
  • Loading branch information
AthenaShi authored Sep 29, 2017
2 parents 473b68c + 92b6fc6 commit b4646d0
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 370 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tests/cloud_sql_proxy
6 changes: 5 additions & 1 deletion cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ directory at 'dir' must be empty before this program is started.`)
fuseTmp = flag.String("fuse_tmp", defaultTmp, `Used as a temporary directory if -fuse is set. Note that files in this directory
can be removed automatically by this program.`)

// Settings for limits
maxConnections = flag.Uint64("max_connections", 0, `If provided, the maximum number of connections to establish before refusing new connections. Defaults to 0 (no limit)`)

// Settings for authentication.
token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.")
tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials.
Expand Down Expand Up @@ -461,7 +464,8 @@ func main() {
logging.Infof("Ready for new connections")

(&proxy.Client{
Port: port,
Port: port,
MaxConnections: *maxConnections,
Certs: certs.NewCertSourceOpts(client, certs.RemoteOpts{
APIBasePath: host,
IgnoreRegion: !*checkRegion,
Expand Down
23 changes: 23 additions & 0 deletions proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"

"github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
Expand Down Expand Up @@ -80,6 +81,13 @@ type Client struct {
// protected by cfgL.
cfgCache map[string]cacheEntry
cfgL sync.RWMutex

// MaxConnections is the maximum number of connections to establish
// before refusing new connections. 0 means no limit.
MaxConnections uint64

// ConnectionsCounter is used to enforce the optional maxConnections limit
ConnectionsCounter uint64
}

type cacheEntry struct {
Expand All @@ -103,6 +111,20 @@ func (c *Client) Run(connSrc <-chan Conn) {
}

func (c *Client) handleConn(conn Conn) {
// Track connections count only if a maximum connections limit is set to avoid useless overhead
if c.MaxConnections > 0 {
active := atomic.AddUint64(&c.ConnectionsCounter, 1)

// Deferred decrement of ConnectionsCounter upon connection closing
defer atomic.AddUint64(&c.ConnectionsCounter, ^uint64(0))

if active > c.MaxConnections {
logging.Errorf("too many open connections (max %d)", c.MaxConnections)
conn.Conn.Close()
return
}
}

server, err := c.Dial(conn.Instance)
if err != nil {
logging.Errorf("couldn't connect to %q: %v", conn.Instance, err)
Expand All @@ -118,6 +140,7 @@ func (c *Client) handleConn(conn Conn) {

c.Conns.Add(conn.Instance, conn.Conn)
copyThenClose(server, conn.Conn, conn.Instance, "local connection on "+conn.Conn.LocalAddr().String())

if err := c.Conns.Remove(conn.Instance, conn.Conn); err != nil {
logging.Errorf("%s", err)
}
Expand Down
61 changes: 61 additions & 0 deletions proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -122,3 +123,63 @@ func TestConcurrentRefresh(t *testing.T) {
}
b.Unlock()
}

func TestMaximumConnectionsCount(t *testing.T) {
const maxConnections = 10
const numConnections = maxConnections + 1
var dials uint64 = 0

b := &fakeCerts{}
certSource := blockingCertSource{
map[string]*fakeCerts{}}
firstDialExited := make(chan struct{})
c := &Client{
Certs: &certSource,
Dialer: func(string, string) (net.Conn, error) {
atomic.AddUint64(&dials, 1)

// Wait until the first dial fails to ensure the max connections count is reached by a concurrent dialer
<-firstDialExited

return nil, errFakeDial
},
MaxConnections: maxConnections,
}

// Build certSource.values before creating goroutines to avoid concurrent map read and map write
instanceNames := make([]string, numConnections)
for i := 0; i < numConnections; i++ {
// Vary instance name to bypass config cache and avoid second call to Client.tryConnect() in Client.Dial()
instanceName := fmt.Sprintf("%s-%d", instance, i)
certSource.values[instanceName] = b
instanceNames[i] = instanceName
}

var wg sync.WaitGroup
var firstDialOnce sync.Once
for _, instanceName := range instanceNames {
wg.Add(1)
go func(instanceName string) {
defer wg.Done()

conn := Conn{
Instance: instanceName,
Conn: &dummyConn{},
}
c.handleConn(conn)

firstDialOnce.Do(func() { close(firstDialExited) })
}(instanceName)
}

wg.Wait()

switch {
case dials > maxConnections:
t.Errorf("client should have refused to dial new connection on %dth attempt when the maximum of %d connections was reached (%d dials)", numConnections, maxConnections, dials)
case dials == maxConnections:
t.Logf("client has correctly refused to dial new connection on %dth attempt when the maximum of %d connections was reached (%d dials)\n", numConnections, maxConnections, dials)
case dials < maxConnections:
t.Errorf("client should have dialed exactly the maximum of %d connections (%d connections, %d dials)", maxConnections, numConnections, dials)
}
}
4 changes: 4 additions & 0 deletions proxy/proxy/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ var c1, c2, c3 = &dummyConn{}, &dummyConn{}, &dummyConn{}

type dummyConn struct{ net.Conn }

func (c dummyConn) Close() error {
return nil
}

func TestConnSetAdd(t *testing.T) {
s := NewConnSet()

Expand Down
Loading

0 comments on commit b4646d0

Please sign in to comment.