Skip to content

Commit

Permalink
fix: pass dial options to FUSE mounts
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Apr 11, 2023
1 parent 83c8a64 commit fc11523
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 21 deletions.
28 changes: 11 additions & 17 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,14 @@ type Client struct {
// all Cloud SQL instances.
connCount uint64

// maxConns is the maximum number of allowed connections tracked by
// connCount. If not set, there is no limit.
maxConns uint64
// conf is the configuration used to initialize the Client.
conf *Config

dialer cloudsql.Dialer

// mnts is a list of all mounted sockets for this client
mnts []*socketMount

// waitOnClose is the maximum duration to wait for open connections to close
// when shutting down.
waitOnClose time.Duration

logger cloudsql.Logger

fuseMount
Expand All @@ -478,10 +473,9 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *
}

c := &Client{
logger: l,
dialer: d,
maxConns: conf.MaxConnections,
waitOnClose: conf.WaitOnClose,
logger: l,
dialer: d,
conf: conf,
}

if conf.FUSEDir != "" {
Expand Down Expand Up @@ -564,7 +558,7 @@ func (c *Client) CheckConnections(ctx context.Context) (int, error) {
// ConnCount returns the number of open connections and the maximum allowed
// connections. Returns 0 when the maximum allowed connections have not been set.
func (c *Client) ConnCount() (uint64, uint64) {
return atomic.LoadUint64(&c.connCount), c.maxConns
return atomic.LoadUint64(&c.connCount), c.conf.MaxConnections
}

// Serve starts proxying connections for all configured instances using the
Expand Down Expand Up @@ -643,13 +637,13 @@ func (c *Client) Close() error {
if cErr != nil {
mErr = append(mErr, cErr)
}
if c.waitOnClose == 0 {
if c.conf.WaitOnClose == 0 {
if len(mErr) > 0 {
return mErr
}
return nil
}
timeout := time.After(c.waitOnClose)
timeout := time.After(c.conf.WaitOnClose)
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
Expand All @@ -664,7 +658,7 @@ func (c *Client) Close() error {
}
open := atomic.LoadUint64(&c.connCount)
if open > 0 {
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose))
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.conf.WaitOnClose))
}
if len(mErr) > 0 {
return mErr
Expand Down Expand Up @@ -697,8 +691,8 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error {
count := atomic.AddUint64(&c.connCount, 1)
defer atomic.AddUint64(&c.connCount, ^uint64(0))

if c.maxConns > 0 && count > c.maxConns {
c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.maxConns)
if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections {
c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections)
_ = cConn.Close()
return
}
Expand Down
7 changes: 6 additions & 1 deletion internal/proxy/proxy_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut)
}

s, err := c.newSocketMount(
ctx, &Config{UnixSocket: c.fuseTempDir},
ctx, withUnixSocket(*c.conf, c.fuseTempDir),
nil, InstanceConnConfig{Name: instance},
)
if err != nil {
Expand Down Expand Up @@ -147,6 +147,11 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut)
), fs.OK
}

func withUnixSocket(c Config, tmpDir string) *Config {
c.UnixSocket = tmpDir
return &c
}

func (c *Client) serveFuse(ctx context.Context, notify func()) error {
srv, err := fs.Mount(c.fuseDir, c, &fs.Options{
MountOptions: fuse.MountOptions{AllowOther: true},
Expand Down
12 changes: 9 additions & 3 deletions tests/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ func keyfile(t *testing.T) string {
}
return string(creds)
}

// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
func proxyConnTestWithReady(t *testing.T, args []string, driver, dsn string, ready func() error) {
ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout)
defer cancel()
// Start the proxy
Expand All @@ -81,6 +79,9 @@ func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
if err != nil {
t.Fatalf("unable to verify proxy was serving: %s \n %s", err, output)
}
if err := ready(); err != nil {
t.Fatalf("proxy was not ready: %v", err)
}

// Connect to the instance
db, err := sql.Open(driver, dsn)
Expand All @@ -94,6 +95,11 @@ func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
}
}

// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
proxyConnTestWithReady(t, args, driver, dsn, func() error { return nil })
}

// testHealthCheck verifies that when a proxy client serves the given instance,
// the readiness endpoint serves http.StatusOK.
func testHealthCheck(t *testing.T, connName string) {
Expand Down
82 changes: 82 additions & 0 deletions tests/fuse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !windows && !darwin

package tests

import (
"fmt"
"os"
"testing"
"time"

"github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/proxy"
)

func TestPostgresFUSEConnect(t *testing.T) {
if testing.Short() {
t.Skip("skipping Postgres integration tests")
}
tmpDir, cleanup := createTempDir(t)
defer cleanup()

host := proxy.UnixAddress(tmpDir, *postgresConnName)
dsn := fmt.Sprintf(
"host=%s user=%s password=%s database=%s sslmode=disable",
host, *postgresUser, *postgresPass, *postgresDB,
)
testFUSE(t, tmpDir, host, dsn)
}

func testFUSE(t *testing.T, tmpDir, host string, dsn string) {
tmpDir2, cleanup2 := createTempDir(t)
defer cleanup2()

waitForFUSE := func() error {
var err error
for i := 0; i < 10; i++ {
_, err = os.Stat(host)
if err == nil {
return nil
}
time.Sleep(500 * time.Millisecond)
}
return fmt.Errorf("failed to find FUSE mounted Unix socket: %v", err)
}

tcs := []struct {
desc string
dbUser string
args []string
}{
{
desc: "using default fuse",
args: []string{fmt.Sprintf("--fuse=%s", tmpDir), fmt.Sprintf("--fuse-tmp-dir=%s", tmpDir2)},
},
{
desc: "using fuse with auto-iam-authn",
args: []string{fmt.Sprintf("--fuse=%s", tmpDir), "--auto-iam-authn"},
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTestWithReady(t, tc.args, "pgx", dsn, waitForFUSE)
// given the kernel some time to unmount the fuse
time.Sleep(100 * time.Millisecond)
})
}

}

0 comments on commit fc11523

Please sign in to comment.