From d71b48e373505c7bb2a4197cd1de22d9d67b2a51 Mon Sep 17 00:00:00 2001 From: Nancy Hong Date: Thu, 15 Aug 2024 09:03:53 -0700 Subject: [PATCH] feat: exit when FUSE errors on accept --- internal/proxy/fuse_test.go | 83 ++++++++++++++++++++++++++++++----- internal/proxy/proxy_other.go | 22 ++++++++-- 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index 0842dd85..51386961 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -43,7 +43,7 @@ func randTmpDir(t interface { // newTestClient is a convenience function for testing that creates a // proxy.Client and starts it. The returned cleanup function is also a // convenience. Callers may choose to ignore it and manually close the client. -func newTestClient(t *testing.T, d alloydb.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) { +func newTestClient(t *testing.T, d alloydb.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, chan error, func()) { conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} c, err := proxy.NewClient(context.Background(), d, testLogger, conf) if err != nil { @@ -51,16 +51,26 @@ func newTestClient(t *testing.T, d alloydb.Dialer, fuseDir, fuseTempDir string) } ready := make(chan struct{}) - go c.Serve(context.Background(), func() { close(ready) }) + servErrCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + servErr := c.Serve(ctx, func() { close(ready) }) + select { + case servErrCh <- servErr: + case <-ctx.Done(): + } + }() select { case <-ready: case <-time.Tick(5 * time.Second): t.Fatal("failed to Serve") } - return c, func() { + return c, servErrCh, func() { if cErr := c.Close(); cErr != nil { t.Logf("failed to close client: %v", cErr) } + cancel() } } @@ -70,7 +80,7 @@ func TestFUSEREADME(t *testing.T) { } dir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) fi, err := os.Stat(dir) if err != nil { @@ -159,7 +169,7 @@ func TestFUSEDialInstance(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) + _, _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) defer cleanup() conn := tryDialUnix(t, tc.socketPath) @@ -184,12 +194,65 @@ func TestFUSEDialInstance(t *testing.T) { } } +func TestFUSEAcceptErrorReturnedFromServe(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + + fuseDir := randTmpDir(t) + fuseTempDir := randTmpDir(t) + socketPath := postgresSocketPath(fuseDir, "proj.region.cluster.instance") + + // Create a new client + d := &fakeDialer{} + c, servErrCh, cleanup := newTestClient(t, d, fuseDir, fuseTempDir) + defer cleanup() + + // Attempt a successful connection to the client + conn := tryDialUnix(t, socketPath) + defer conn.Close() + + // Ensure that the client actually fully connected. + // This solves a race condition in the test that is only present on + // the Ubuntu-Latest platform. + var got []string + for i := 0; i < 10; i++ { + got = d.dialedInstances() + if len(got) == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if len(got) != 1 { + t.Fatalf("dialed instances len: want = 1, got = %v", got) + } + + // Explicitly close the dialer. This will close all the unix sockets, forcing + // the unix socket accept goroutine to exit with an error + c.Close() + + // Check that Client.Serve() returned a non-nil error + for i := 0; i < 10; i++ { + select { + case servErr := <-servErrCh: + if servErr == nil { + t.Fatal("got nil, want non-nil error returned by Client.Serve()") + } + return + default: + time.Sleep(100 * time.Millisecond) + continue + } + } + t.Fatal("No error thrown by Client.Serve()") +} + func TestFUSEReadDir(t *testing.T) { if testing.Short() { t.Skip("skipping fuse tests in short mode.") } fuseDir := randTmpDir(t) - _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) defer cleanup() // Initiate a connection so the FUSE server will list it in the dir entries. @@ -219,7 +282,7 @@ func TestFUSEWithBadInstanceName(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() _, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid")) @@ -238,7 +301,7 @@ func TestFUSECheckConnections(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() // first establish a connection to "register" it with the proxy @@ -273,7 +336,7 @@ func TestFUSEClose(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) // first establish a connection to "register" it with the proxy conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance")) @@ -306,7 +369,7 @@ func TestLookupIgnoresContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() d := &fakeDialer{} - c, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) + c, _, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) // invoke Lookup with cancelled context, should ignore context and succeed _, err := c.Lookup(ctx, "proj.region.cluster.instance", nil) diff --git a/internal/proxy/proxy_other.go b/internal/proxy/proxy_other.go index 1851efcc..5a1c7176 100644 --- a/internal/proxy/proxy_other.go +++ b/internal/proxy/proxy_other.go @@ -68,6 +68,7 @@ type fuseMount struct { fuseServerMu *sync.Mutex fuseServer *fuse.Server fuseWg *sync.WaitGroup + fuseExitCh chan error // Inode adds support for FUSE operations. fs.Inode @@ -131,8 +132,18 @@ func (c *Client) Lookup(_ context.Context, instance string, _ *fuse.EntryOut) (* sErr := c.serveSocketMount(ctx, s) if sErr != nil { c.fuseMu.Lock() + defer c.fuseMu.Unlock() delete(c.fuseSockets, instance) - c.fuseMu.Unlock() + select { + // Best effort attempt to send error. + // If this send fails, it means the reading goroutine has + // already pulled a value out of the channel and is no longer + // reading any more values. In other words, we report only the + // first error. + case c.fuseExitCh <- sErr: + default: + return + } } }() @@ -163,10 +174,15 @@ func (c *Client) serveFuse(ctx context.Context, notify func()) error { } c.fuseServerMu.Lock() c.fuseServer = srv + c.fuseExitCh = make(chan error) c.fuseServerMu.Unlock() notify() - <-ctx.Done() - return ctx.Err() + select { + case err = <-c.fuseExitCh: + return err + case <-ctx.Done(): + return ctx.Err() + } } func (c *Client) fuseMounts() []*socketMount {