Skip to content

Commit

Permalink
fix: Make the process exit if there as an error accepting a fuse conn…
Browse files Browse the repository at this point in the history
…ection.
  • Loading branch information
hessjcg committed Jul 8, 2024
1 parent a74fbf7 commit 6edb8d4
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 14 deletions.
81 changes: 71 additions & 10 deletions internal/proxy/fuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,29 @@ func randTmpDir(t interface {
// newTestClient is a convenience function for testing that creates a
// proxy.Client and starts it. The returned cleanup function is also a
// convenience. Callers may choose to ignore it and manually close the client.
func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) {
func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, chan error, func()) {
conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir}
c, err := proxy.NewClient(context.Background(), d, testLogger, conf)
if err != nil {
t.Fatalf("want error = nil, got = %v", err)
}

ready := make(chan struct{})
go c.Serve(context.Background(), func() { close(ready) })
servErrCh := make(chan error)
go func() {
servErr := c.Serve(context.Background(), func() { close(ready) })
select {
case servErrCh <- servErr:
default:
// exit background thread
}
}()
select {
case <-ready:
case <-time.Tick(5 * time.Second):
t.Fatal("failed to Serve")
}
return c, func() {
return c, servErrCh, func() {
if cErr := c.Close(); cErr != nil {
t.Logf("failed to close client: %v", cErr)
}
Expand All @@ -70,7 +78,7 @@ func TestFUSEREADME(t *testing.T) {
}
dir := randTmpDir(t)
d := &fakeDialer{}
_, cleanup := newTestClient(t, d, dir, randTmpDir(t))
_, _, cleanup := newTestClient(t, d, dir, randTmpDir(t))

fi, err := os.Stat(dir)
if err != nil {
Expand Down Expand Up @@ -161,7 +169,7 @@ func TestFUSEDialInstance(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
d := &fakeDialer{}
_, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir)
_, _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir)
defer cleanup()

conn := tryDialUnix(t, tc.socketPath)
Expand All @@ -185,13 +193,66 @@ func TestFUSEDialInstance(t *testing.T) {
})
}
}
func TestFUSEAcceptErrorReturnedFromServe(t *testing.T) {
if testing.Short() {
t.Skip("skipping fuse tests in short mode.")
}

fuseDir := randTmpDir(t)
fuseTempDir := randTmpDir(t)
socketPath := filepath.Join(fuseDir, "proj:region:mysql")

// Create a new client
d := &fakeDialer{}
c, servErrCh, cleanup := newTestClient(t, d, fuseDir, fuseTempDir)
defer cleanup()

// Attempt a successful connection to the client
conn := tryDialUnix(t, socketPath)
defer conn.Close()

// Ensure that the client actually fully connected.
// This solves a race condition in the test that is only present on
// the Ubuntu-Latest platform.
var got []string
for i := 0; i < 10; i++ {
got = d.dialedInstances()
if len(got) == 1 {
break
}
time.Sleep(100 * time.Millisecond)
}
if len(got) != 1 {
t.Fatalf("dialed instances len: want = 1, got = %v", got)
}

// Explicitly close the dialer. This will close all the unix sockets, forcing
// the unix socket accept goroutine to exit with an error
c.Close()

// Check that Client.Serve() returned a non-nil error
for i := 0; i < 10; i++ {
select {
case servErr := <-servErrCh:
if servErr == nil {
t.Fatal("got nil, want non-nil error returned by Client.Serve()")
}
return
default:
time.Sleep(100 * time.Millisecond)
continue
}
}
t.Fatal("No error thrown by Client.Serve()")

}

func TestFUSEReadDir(t *testing.T) {
if testing.Short() {
t.Skip("skipping fuse tests in short mode.")
}
fuseDir := randTmpDir(t)
_, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t))
_, _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t))
defer cleanup()

// Initiate a connection so the FUSE server will list it in the dir entries.
Expand Down Expand Up @@ -221,7 +282,7 @@ func TestFUSEErrors(t *testing.T) {
}
ctx := context.Background()
d := &fakeDialer{}
c, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t))
c, _, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t))

// Simulate FUSE file access by invoking Lookup directly to control
// how the socket cache is populated.
Expand Down Expand Up @@ -261,7 +322,7 @@ func TestFUSEWithBadInstanceName(t *testing.T) {
}
fuseDir := randTmpDir(t)
d := &fakeDialer{}
_, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t))
_, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t))
defer cleanup()

_, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid"))
Expand All @@ -280,7 +341,7 @@ func TestFUSECheckConnections(t *testing.T) {
}
fuseDir := randTmpDir(t)
d := &fakeDialer{}
c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t))
c, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t))
defer cleanup()

// first establish a connection to "register" it with the proxy
Expand Down Expand Up @@ -315,7 +376,7 @@ func TestFUSEClose(t *testing.T) {
}
fuseDir := randTmpDir(t)
d := &fakeDialer{}
c, _ := newTestClient(t, d, fuseDir, randTmpDir(t))
c, _, _ := newTestClient(t, d, fuseDir, randTmpDir(t))

// first establish a connection to "register" it with the proxy
conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql"))
Expand Down
41 changes: 37 additions & 4 deletions internal/proxy/proxy_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package proxy

import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -68,6 +69,7 @@ type fuseMount struct {
fuseServerMu *sync.Mutex
fuseServer *fuse.Server
fuseWg *sync.WaitGroup
fuseExitCh chan error

// Inode adds support for FUSE operations.
fs.Inode
Expand Down Expand Up @@ -131,10 +133,20 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut)
defer c.fuseWg.Done()
sErr := c.serveSocketMount(ctx, s)
if sErr != nil {
c.fuseMu.Lock()
c.logger.Debugf("could not serve socket for instance %q: %v", instance, sErr)
c.fuseMu.Lock()
defer c.fuseMu.Unlock()
delete(c.fuseSockets, instance)
c.fuseMu.Unlock()
select {
// Best effort attempt to send error.
// If this send fails, it means the reading goroutine has
// already pulled a value out of the channel and is no longer
// reading any more values. In other words, we report only the
// first error.
case c.fuseExitCh <- sErr:
default:
return
}
}
}()

Expand Down Expand Up @@ -165,10 +177,31 @@ func (c *Client) serveFuse(ctx context.Context, notify func()) error {
}
c.fuseServerMu.Lock()
c.fuseServer = srv
c.fuseExitCh = make(chan error)

// When the context is canceled, put the context cancel error into the
// exit chanel
go func() {
<-ctx.Done()
cancelErr := ctx.Err()
if cancelErr == nil {
cancelErr = errors.New("Canceled")
}
select {
case c.fuseExitCh <- cancelErr:
default:
}
}()

c.fuseServerMu.Unlock()
notify()
<-ctx.Done()
return ctx.Err()
select {
case err = <-c.fuseExitCh:
return err
default:
}
return nil

}

func (c *Client) fuseMounts() []*socketMount {
Expand Down

0 comments on commit 6edb8d4

Please sign in to comment.