diff --git a/cmd/root.go b/cmd/root.go index 8243a0eb..74c5b6a4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -24,6 +24,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "strconv" "strings" "syscall" @@ -174,6 +175,11 @@ down when the number of open connections reaches 0 or when the maximum time has passed. Defaults to 0s.`) cmd.PersistentFlags().StringVar(&c.conf.APIEndpointURL, "alloydbadmin-api-endpoint", "https://alloydb.googleapis.com/v1beta", "When set, the proxy uses this host as the base API path.") + cmd.PersistentFlags().StringVar(&c.conf.FUSEDir, "fuse", "", + "Mount a directory at the path using FUSE to access Cloud SQL instances.") + cmd.PersistentFlags().StringVar(&c.conf.FUSETempDir, "fuse-tmp-dir", + filepath.Join(os.TempDir(), "csql-tmp"), + "Temp dir for Unix sockets created with FUSE") cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "", "Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.") @@ -208,11 +214,24 @@ only. Uses the port specified by the http-port flag.`) } func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { - // If no instance connection names were provided, error. - if len(args) == 0 { + // If no instance connection names were provided AND FUSE isn't enabled, + // error. + if len(args) == 0 && conf.FUSEDir == "" { return newBadCommandError("missing instance uri (e.g., projects/$PROJECTS/locations/$LOCTION/clusters/$CLUSTER/instances/$INSTANCES)") } + if conf.FUSEDir != "" { + if err := proxy.SupportsFUSE(); err != nil { + return newBadCommandError( + fmt.Sprintf("--fuse is not supported: %v", err), + ) + } + } + + if len(args) == 0 && conf.FUSEDir == "" && conf.FUSETempDir != "" { + return newBadCommandError("cannot specify --fuse-tmp-dir without --fuse") + } + userHasSet := func(f string) bool { return cmd.PersistentFlags().Lookup(f).Changed } diff --git a/cmd/root_linux_test.go b/cmd/root_linux_test.go new file mode 100644 index 00000000..0a2e8e49 --- /dev/null +++ b/cmd/root_linux_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/cobra" +) + +func TestNewCommandArgumentsOnLinux(t *testing.T) { + defaultTmp := filepath.Join(os.TempDir(), "csql-tmp") + tcs := []struct { + desc string + args []string + wantDir string + wantTempDir string + }{ + { + desc: "using the fuse flag", + args: []string{"--fuse", "/cloudsql"}, + wantDir: "/cloudsql", + wantTempDir: defaultTmp, + }, + { + desc: "using the fuse temporary directory flag", + args: []string{"--fuse", "/cloudsql", "--fuse-tmp-dir", "/mycooldir"}, + wantDir: "/cloudsql", + wantTempDir: "/mycooldir", + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + c := NewCommand() + // Keep the test output quiet + c.SilenceUsage = true + c.SilenceErrors = true + // Disable execute behavior + c.RunE = func(*cobra.Command, []string) error { + return nil + } + c.SetArgs(tc.args) + + err := c.Execute() + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + if got, want := c.conf.FUSEDir, tc.wantDir; got != want { + t.Fatalf("FUSEDir: want = %v, got = %v", want, got) + } + + if got, want := c.conf.FUSETempDir, tc.wantTempDir; got != want { + t.Fatalf("FUSEDir: want = %v, got = %v", want, got) + } + }) + } +} diff --git a/cmd/root_test.go b/cmd/root_test.go index 3b0c35dc..310b3e37 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -20,6 +20,8 @@ import ( "fmt" "net" "net/http" + "os" + "path/filepath" "sync" "testing" "time" @@ -41,11 +43,16 @@ func TestNewCommandArguments(t *testing.T) { if c.Port == 0 { c.Port = 5432 } - if c.Instances == nil { - c.Instances = []proxy.InstanceConnConfig{{}} + if c.FUSEDir == "" { + if c.Instances == nil { + c.Instances = []proxy.InstanceConnConfig{{}} + } + if i := &c.Instances[0]; i.Name == "" { + i.Name = "projects/proj/locations/region/clusters/clust/instances/inst" + } } - if i := &c.Instances[0]; i.Name == "" { - i.Name = "projects/proj/locations/region/clusters/clust/instances/inst" + if c.FUSETempDir == "" { + c.FUSETempDir = filepath.Join(os.TempDir(), "csql-tmp") } if c.APIEndpointURL == "" { c.APIEndpointURL = "https://alloydb.googleapis.com/v1beta" @@ -354,6 +361,10 @@ func TestNewCommandWithErrors(t *testing.T) { desc: "using an invalid url for host flag", args: []string{"--host", "https://invalid:url[/]", "proj:region:inst"}, }, + { + desc: "using fuse-tmp-dir without fuse", + args: []string{"--fuse-tmp-dir", "/mydir"}, + }, } for _, tc := range tcs { diff --git a/cmd/root_windows_test.go b/cmd/root_windows_test.go new file mode 100644 index 00000000..78b17674 --- /dev/null +++ b/cmd/root_windows_test.go @@ -0,0 +1,36 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" +) + +func TestWindowsDoesNotSupportFUSE(t *testing.T) { + c := NewCommand() + // Keep the test output quiet + c.SilenceUsage = true + c.SilenceErrors = true + // Disable execute behavior + c.RunE = func(*cobra.Command, []string) error { return nil } + c.SetArgs([]string{"--fuse"}) + + err := c.Execute() + if err == nil { + t.Fatal("want error != nil, got = nil") + } +} diff --git a/go.mod b/go.mod index 6be5b863..8c562448 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ require ( contrib.go.opencensus.io/exporter/prometheus v0.4.2 contrib.go.opencensus.io/exporter/stackdriver v0.13.13 github.com/google/go-cmp v0.5.8 + github.com/hanwen/go-fuse v1.0.0 + github.com/hanwen/go-fuse/v2 v2.1.0 github.com/spf13/cobra v1.5.0 go.opencensus.io v0.23.0 go.uber.org/zap v1.23.0 diff --git a/go.sum b/go.sum index 25c94186..627f8b9f 100644 --- a/go.sum +++ b/go.sum @@ -633,6 +633,10 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= +github.com/hanwen/go-fuse v1.0.0 h1:GxS9Zrn6c35/BnfiVsZVWmsG803xwE7eVRDvcf/BEVc= +github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= +github.com/hanwen/go-fuse/v2 v2.1.0 h1:+32ffteETaLYClUj0a3aHjZ1hOPxxaNEHiZiujuDaek= +github.com/hanwen/go-fuse/v2 v2.1.0/go.mod h1:oRyA5eK+pvJyv5otpO/DgccS8y/RvYMaO00GgRLGryc= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= @@ -784,6 +788,8 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= diff --git a/internal/proxy/fuse.go b/internal/proxy/fuse.go new file mode 100644 index 00000000..36e41132 --- /dev/null +++ b/internal/proxy/fuse.go @@ -0,0 +1,85 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "syscall" + + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/hanwen/go-fuse/v2/fuse/nodefs" +) + +// symlink implements a symbolic link, returning the underlying path when +// Readlink is called. +type symlink struct { + fs.Inode + path string +} + +// Readlink implements fs.NodeReadlinker and returns the symlink's path. +func (s *symlink) Readlink(ctx context.Context) ([]byte, syscall.Errno) { + return []byte(s.path), fs.OK +} + +// readme represents a static read-only text file. +type readme struct { + fs.Inode +} + +const readmeText = ` +When applications attempt to open files in this directory, a remote connection +to the AlloyDB instance of the same name will be established. + +For example, when you run one of the following commands, the proxy will initiate +a connection to the corresponding Cloud SQL instance, given you have the correct +IAM permissions. + + psql "host=/somedir/project.region.cluster.instance dbname=mydb user=myuser" + +The proxy will create a directory with the instance short name, and create a +socket inside that directory with the special Postgres name: .s.PGSQL.5432. + +Listing the contents of this directory will show all instances with active +connections. +` + +// Getattr implements fs.NodeGetattrer and indicates that this file is a regular +// file. +func (*readme) Getattr(ctx context.Context, f fs.FileHandle, out *fuse.AttrOut) syscall.Errno { + *out = fuse.AttrOut{Attr: fuse.Attr{ + Mode: 0444 | syscall.S_IFREG, + Size: uint64(len(readmeText)), + }} + return fs.OK +} + +// Read implements fs.NodeReader and supports incremental reads. +func (*readme) Read(ctx context.Context, f fs.FileHandle, dest []byte, off int64) (fuse.ReadResult, syscall.Errno) { + end := int(off) + len(dest) + if end > len(readmeText) { + end = len(readmeText) + } + return fuse.ReadResultData([]byte(readmeText[off:end])), fs.OK +} + +// Open implements fs.NodeOpener and supports opening the README as a read-only +// file. +func (*readme) Open(ctx context.Context, mode uint32) (fs.FileHandle, uint32, syscall.Errno) { + df := nodefs.NewDataFile([]byte(readmeText)) + rf := nodefs.NewReadOnlyFile(df) + return rf, 0, fs.OK +} diff --git a/internal/proxy/fuse_darwin.go b/internal/proxy/fuse_darwin.go new file mode 100644 index 00000000..ceb5db26 --- /dev/null +++ b/internal/proxy/fuse_darwin.go @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" + "os" +) + +const ( + macfusePath = "/Library/Filesystems/macfuse.fs/Contents/Resources/mount_macfuse" + osxfusePath = "/Library/Filesystems/osxfuse.fs/Contents/Resources/mount_osxfuse" +) + +// SupportsFUSE checks if macfuse or osxfuse are installed on the host by +// looking for both in their known installation location. +func SupportsFUSE() error { + // This code follows the same strategy as hanwen/go-fuse. + // See https://github.com/hanwen/go-fuse/blob/0f728ba15b38579efefc3dc47821882ca18ffea7/fuse/mount_darwin.go#L121-L124. + + // check for macfuse first (newer version of osxfuse) + if _, err := os.Stat(macfusePath); err != nil { + // if that fails, check for osxfuse next + if _, err := os.Stat(osxfusePath); err != nil { + return errors.New("failed to find osxfuse or macfuse: verify FUSE installation and try again (see https://osxfuse.github.io).") + } + } + return nil +} diff --git a/internal/proxy/fuse_linux.go b/internal/proxy/fuse_linux.go new file mode 100644 index 00000000..264e2acf --- /dev/null +++ b/internal/proxy/fuse_linux.go @@ -0,0 +1,33 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" + "os/exec" +) + +// SupportsFUSE checks if the fusermount binary is present in the PATH or a well +// known location. +func SupportsFUSE() error { + // This code follows the same strategy found in hanwen/go-fuse. + // See https://github.com/hanwen/go-fuse/blob/0f728ba15b38579efefc3dc47821882ca18ffea7/fuse/mount_linux.go#L184-L198. + if _, err := exec.LookPath("fusermount"); err != nil { + if _, err := exec.LookPath("/bin/fusermount"); err != nil { + return errors.New("fusermount binary not found in PATH or /bin") + } + } + return nil +} diff --git a/internal/proxy/fuse_linux_test.go b/internal/proxy/fuse_linux_test.go new file mode 100644 index 00000000..6e3dd494 --- /dev/null +++ b/internal/proxy/fuse_linux_test.go @@ -0,0 +1,43 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy_test + +import ( + "os" + "testing" + + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" +) + +func TestFUSESupport(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + + removePath := func() func() { + original := os.Getenv("PATH") + os.Unsetenv("PATH") + return func() { os.Setenv("PATH", original) } + } + if err := proxy.SupportsFUSE(); err != nil { + t.Fatalf("expected FUSE to be support (PATH set): %v", err) + } + cleanup := removePath() + defer cleanup() + + if err := proxy.SupportsFUSE(); err != nil { + t.Fatalf("expected FUSE to be supported (PATH unset): %v", err) + } +} diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go new file mode 100644 index 00000000..840f2ca3 --- /dev/null +++ b/internal/proxy/fuse_test.go @@ -0,0 +1,272 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows && !darwin +// +build !windows,!darwin + +package proxy_test + +import ( + "context" + "io/ioutil" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" +) + +func randTmpDir(t interface { + Fatalf(format string, args ...interface{}) +}) string { + name, err := ioutil.TempDir("", "*") + if err != nil { + t.Fatalf("failed to create tmp dir: %v", err) + } + return name +} + +// newTestClient is a convenience function for testing that creates a +// proxy.Client and starts it. The returned cleanup function is also a +// convenience. Callers may choose to ignore it and manually close the client. +func newTestClient(t *testing.T, d alloydb.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) { + conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} + c, err := proxy.NewClient(context.Background(), d, testLogger, conf) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + + ready := make(chan struct{}) + go c.Serve(context.Background(), func() { close(ready) }) + select { + case <-ready: + case <-time.Tick(5 * time.Second): + t.Fatal("failed to Serve") + } + return c, func() { + if cErr := c.Close(); cErr != nil { + t.Logf("failed to close client: %v", cErr) + } + } +} + +func TestFUSEREADME(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + dir := randTmpDir(t) + d := &fakeDialer{} + _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("os.Stat: %v", err) + } + if !fi.IsDir() { + t.Fatalf("fuse mount mode: want = dir, got = %v", fi.Mode()) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("os.ReadDir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("dir entries: want = 1, got = %v", len(entries)) + } + e := entries[0] + if want, got := "README", e.Name(); want != got { + t.Fatalf("want = %v, got = %v", want, got) + } + + data, err := ioutil.ReadFile(filepath.Join(dir, "README")) + if err != nil { + t.Fatal(err) + } + if len(data) == 0 { + t.Fatalf("expected README data, got no data (dir = %v)", dir) + } + + cleanup() // close the client + + // verify that the FUSE server is no longer mounted + _, err = ioutil.ReadFile(filepath.Join(dir, "README")) + if err == nil { + t.Fatal("expected ioutil.Readfile to fail, but it succeeded") + } +} + +func tryDialUnix(t *testing.T, addr string) net.Conn { + var ( + conn net.Conn + dialErr error + ) + for i := 0; i < 10; i++ { + conn, dialErr = net.Dial("unix", addr) + if conn != nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if dialErr != nil { + t.Fatalf("net.Dial(): %v", dialErr) + } + return conn +} + +func postgresSocketPath(dir, inst string) string { + return filepath.Join(dir, inst, ".s.PGSQL.5432") +} + +func TestFUSEDialInstance(t *testing.T) { + fuseDir := randTmpDir(t) + fuseTempDir := randTmpDir(t) + tcs := []struct { + desc string + wantInstance string + socketPath string + fuseTempDir string + }{ + { + desc: "connections create a directory with a special file", + wantInstance: "projects/proj/locations/region/clusters/cluster/instances/instance", + socketPath: postgresSocketPath(fuseDir, "proj.region.cluster.instance"), + fuseTempDir: fuseTempDir, + }, + { + desc: "connecting creates intermediate temp directories", + wantInstance: "projects/proj/locations/region/clusters/cluster/instances/instance", + socketPath: postgresSocketPath(fuseDir, "proj.region.cluster.instance"), + fuseTempDir: filepath.Join(fuseTempDir, "doesntexist"), + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + d := &fakeDialer{} + _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) + defer cleanup() + + conn := tryDialUnix(t, tc.socketPath) + defer conn.Close() + + var got []string + for i := 0; i < 10; i++ { + got = d.dialedInstances() + if len(got) == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if len(got) != 1 { + t.Fatalf("dialed instances len: want = 1, got = %v", got) + } + if want, inst := tc.wantInstance, got[0]; want != inst { + t.Fatalf("instance: want = %v, got = %v", want, inst) + } + + }) + } +} + +func TestFUSEReadDir(t *testing.T) { + fuseDir := randTmpDir(t) + _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) + defer cleanup() + + // Initiate a connection so the FUSE server will list it in the dir entries. + conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance")) + defer conn.Close() + + entries, err := os.ReadDir(fuseDir) + if err != nil { + t.Fatalf("os.ReadDir(): %v", err) + } + // len should be README plus the proj:reg:mysql socket + if got, want := len(entries), 2; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + var names []string + for _, e := range entries { + names = append(names, e.Name()) + } + if names[0] != "README" || names[1] != "proj.region.cluster.instance" { + t.Fatalf("want = %v, got = %v", []string{"README", "proj.region.cluster.instance"}, names) + } +} + +func TestFUSEWithBadInstanceName(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + defer cleanup() + + _, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid")) + if dialErr == nil { + t.Fatalf("net.Dial() should fail") + } + + if got := d.dialAttempts(); got > 0 { + t.Fatalf("dial calls: want = 0, got = %v", got) + } +} + +func TestFUSECheckConnections(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + defer cleanup() + + // first establish a connection to "register" it with the proxy + conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance")) + defer conn.Close() + + if err := c.CheckConnections(context.Background()); err != nil { + t.Fatalf("c.CheckConnections(): %v", err) + } + + // verify the dialer was invoked twice, once for connect, once for check + // connection + var attempts int + wantAttempts := 2 + for i := 0; i < 10; i++ { + attempts = d.dialAttempts() + if attempts == wantAttempts { + return + } + time.Sleep(100 * time.Millisecond) + } + t.Fatalf("dial attempts: want = %v, got = %v", wantAttempts, attempts) +} + +func TestFUSEClose(t *testing.T) { + fuseDir := randTmpDir(t) + d := &fakeDialer{} + c, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) + + // first establish a connection to "register" it with the proxy + conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance")) + defer conn.Close() + + // Close the proxy which should close all listeners + if err := c.Close(); err != nil { + t.Fatalf("c.Close(): %v", err) + } + + _, err := net.Dial("unix", postgresSocketPath(fuseDir, "proj.region.cluster.instance")) + if err == nil { + t.Fatal("net.Dial() should fail") + } +} diff --git a/internal/proxy/fuse_windows.go b/internal/proxy/fuse_windows.go new file mode 100644 index 00000000..6e5289cf --- /dev/null +++ b/internal/proxy/fuse_windows.go @@ -0,0 +1,24 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "errors" +) + +// SupportsFUSE is false on Windows. +func SupportsFUSE() error { + return errors.New("fuse is not supported on Windows") +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8bc17c94..52ce61a4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -25,18 +25,21 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "cloud.google.com/go/alloydbconn" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/gcloud" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" "golang.org/x/oauth2" ) // InstanceConnConfig holds the configuration for an individual instance // connection. type InstanceConnConfig struct { - // Name is the instance connection name. + // Name is the instance URI. Name string // Addr is the address on which to bind a listener for the instance. Addr string @@ -75,6 +78,15 @@ type Config struct { // connected to any Instances. If set, takes precedence over Addr and Port. UnixSocket string + // FUSEDir enables a file system in user space at the provided path that + // connects to the requested instance only when a client requests it. + FUSEDir string + + // FUSETempDir sets the temporary directory where the FUSE mount will place + // Unix domain sockets connected to Cloud SQL instances. The temp directory + // is not accessed directly. + FUSETempDir string + // APIEndpointURL is the URL of the AlloyDB Admin API. APIEndpointURL string @@ -146,15 +158,23 @@ func (c *portConfig) nextPort() int { return p } +type socketSymlink struct { + socket *socketMount + symlink *symlink +} + var ( // Instance URI is in the format: // 'projects//locations//clusters//instances/' // Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT") instURIRegex = regexp.MustCompile("projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)") + // unixRegex is the expected format for a Unix socket + // e.g. project.region.cluster.instance + unixRegex = regexp.MustCompile(`([^:]+)\.([^:]+)\.([^:]+)\.([^:]+)`) ) -// UnixSocketDir returns a shorted instance connection name to prevent exceeding -// the Unix socket length. +// UnixSocketDir returns a shorted instance connection name to prevent +// exceeding the Unix socket length, e.g., project.region.cluster.instance func UnixSocketDir(dir, inst string) (string, error) { m := instURIRegex.FindSubmatch([]byte(inst)) if m == nil { @@ -168,6 +188,23 @@ func UnixSocketDir(dir, inst string) (string, error) { return filepath.Join(dir, shortName), nil } +// toFullURI converts a shortened Unix socket name (e.g., +// project.region.cluster.instance) into a full instance URI. +func toFullURI(short string) (string, error) { + m := unixRegex.FindSubmatch([]byte(short)) + if m == nil { + return "", fmt.Errorf("invalid short name: %v", short) + } + project := string(m[1]) + region := string(m[2]) + cluster := string(m[3]) + name := string(m[4]) + return fmt.Sprintf( + "projects/%v/locations/%v/clusters/%v/instances/%v", + project, region, cluster, name, + ), nil +} + // Client proxies connections from a local client to the remote server side // proxy for multiple AlloyDB instances. type Client struct { @@ -189,6 +226,23 @@ type Client struct { waitOnClose time.Duration logger alloydb.Logger + + // fuseDir specifies the directory where a FUSE server is mounted. The value + // is empty if FUSE is not enabled. The directory holds symlinks to Unix + // domain sockets in the fuseTmpDir. + fuseDir string + fuseTempDir string + // fuseMu protects access to fuseSockets. + fuseMu sync.Mutex + // fuseSockets is a map of instance connection name to socketMount and + // symlink. + fuseSockets map[string]socketSymlink + fuseServerMu sync.Mutex + fuseServer *fuse.Server + fuseWg sync.WaitGroup + + // Inode adds support for FUSE operations. + fs.Inode } // NewClient completes the initial setup required to get the proxy to a "steady" state. @@ -206,6 +260,23 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co } } + c := &Client{ + logger: l, + dialer: d, + maxConns: conf.MaxConnections, + waitOnClose: conf.WaitOnClose, + } + + if conf.FUSEDir != "" { + if err := os.MkdirAll(conf.FUSETempDir, 0777); err != nil { + return nil, err + } + c.fuseDir = conf.FUSEDir + c.fuseTempDir = conf.FUSETempDir + c.fuseSockets = map[string]socketSymlink{} + return c, nil + } + var mnts []*socketMount pc := newPortConfig(conf.Port) for _, inst := range conf.Instances { @@ -224,24 +295,104 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co mnts = append(mnts, m) } - c := &Client{ - mnts: mnts, - logger: l, - dialer: d, - maxConns: conf.MaxConnections, - waitOnClose: conf.WaitOnClose, - } + c.mnts = mnts + return c, nil } +// Readdir returns a list of all active Unix sockets in addition to the README. +func (c *Client) Readdir(ctx context.Context) (fs.DirStream, syscall.Errno) { + entries := []fuse.DirEntry{ + {Name: "README", Mode: 0555 | fuse.S_IFREG}, + } + var active []string + c.fuseMu.Lock() + for k := range c.fuseSockets { + active = append(active, k) + } + c.fuseMu.Unlock() + + for _, a := range active { + entries = append(entries, fuse.DirEntry{ + Name: a, + Mode: 0777 | syscall.S_IFSOCK, + }) + } + return fs.NewListDirStream(entries), fs.OK +} + +// Lookup implements the fs.NodeLookuper interface and returns an index node +// (inode) for a symlink that points to a Unix domain socket. The Unix domain +// socket is connected to the requested Cloud SQL instance. Lookup returns a +// symlink (instead of the socket itself) so that multiple callers all use the +// same Unix socket. +func (c *Client) Lookup(ctx context.Context, instance string, out *fuse.EntryOut) (*fs.Inode, syscall.Errno) { + if instance == "README" { + return c.NewInode(ctx, &readme{}, fs.StableAttr{}), fs.OK + } + + instanceURI, err := toFullURI(instance) + if err != nil { + return nil, syscall.ENOENT + } + + c.fuseMu.Lock() + defer c.fuseMu.Unlock() + if l, ok := c.fuseSockets[instance]; ok { + return l.symlink.EmbeddedInode(), fs.OK + } + + s, err := newSocketMount( + ctx, &Config{UnixSocket: c.fuseTempDir}, + nil, InstanceConnConfig{Name: instanceURI}, + ) + if err != nil { + c.logger.Errorf("could not create socket for %q: %v", instance, err) + return nil, syscall.ENOENT + } + + c.fuseWg.Add(1) + go func() { + defer c.fuseWg.Done() + sErr := c.serveSocketMount(ctx, s) + if sErr != nil { + c.fuseMu.Lock() + delete(c.fuseSockets, instance) + c.fuseMu.Unlock() + } + }() + + // Return a symlink that points to the actual Unix socket within the + // temporary directory. For Postgres, return a symlink that points to the + // directory which holds the ".s.PGSQL.5432" Unix socket. + sl := &symlink{path: filepath.Join(c.fuseTempDir, instance)} + c.fuseSockets[instance] = socketSymlink{ + socket: s, + symlink: sl, + } + return c.NewInode(ctx, sl, fs.StableAttr{ + Mode: 0777 | fuse.S_IFLNK}, + ), fs.OK +} + // CheckConnections dials each registered instance and reports any errors that // may have occurred. func (c *Client) CheckConnections(ctx context.Context) error { var ( wg sync.WaitGroup errCh = make(chan error, len(c.mnts)) + mnts = c.mnts ) - for _, m := range c.mnts { + + if c.fuseDir != "" { + mnts = []*socketMount{} + c.fuseMu.Lock() + for _, m := range c.fuseSockets { + mnts = append(mnts, m.socket) + } + c.fuseMu.Unlock() + } + for _, m := range mnts { wg.Add(1) go func(inst string) { defer wg.Done() @@ -284,6 +435,22 @@ func (c *Client) ConnCount() (uint64, uint64) { func (c *Client) Serve(ctx context.Context, notify func()) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + + if c.fuseDir != "" { + srv, err := fs.Mount(c.fuseDir, c, &fs.Options{ + MountOptions: fuse.MountOptions{AllowOther: true}, + }) + if err != nil { + return fmt.Errorf("FUSE mount failed: %q: %v", c.fuseDir, err) + } + c.fuseServerMu.Lock() + c.fuseServer = srv + c.fuseServerMu.Unlock() + notify() + <-ctx.Done() + return ctx.Err() + } + exitCh := make(chan error) for _, m := range c.mnts { go func(mnt *socketMount) { @@ -323,14 +490,35 @@ func (m MultiErr) Error() string { } func (c *Client) Close() error { + mnts := c.mnts + + c.fuseServerMu.Lock() + hasFuseServer := c.fuseServer != nil + c.fuseServerMu.Unlock() + var mErr MultiErr + if hasFuseServer { + if err := c.fuseServer.Unmount(); err != nil { + mErr = append(mErr, err) + } + mnts = []*socketMount{} + c.fuseMu.Lock() + for _, m := range c.fuseSockets { + mnts = append(mnts, m.socket) + } + c.fuseMu.Unlock() + } + // First, close all open socket listeners to prevent additional connections. - for _, m := range c.mnts { + for _, m := range mnts { err := m.Close() if err != nil { mErr = append(mErr, err) } } + if hasFuseServer { + c.fuseWg.Wait() + } // Next, close the dialer to prevent any additional refreshes. cErr := c.dialer.Close() if cErr != nil { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index f53d1acf..af5b75e9 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -33,6 +33,8 @@ import ( "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" ) +var testLogger = log.NewStdLogger(os.Stdout, os.Stdout) + type testCase struct { desc string in *proxy.Config @@ -43,12 +45,14 @@ type testCase struct { type fakeDialer struct { mu sync.Mutex dialCount int + instances []string } func (f *fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { f.mu.Lock() defer f.mu.Unlock() f.dialCount++ + f.instances = append(f.instances, inst) c1, _ := net.Pipe() return c1, nil } @@ -59,6 +63,12 @@ func (f *fakeDialer) dialAttempts() int { return f.dialCount } +func (f *fakeDialer) dialedInstances() []string { + f.mu.Lock() + defer f.mu.Unlock() + return append([]string{}, f.instances...) +} + func (*fakeDialer) Close() error { return nil } @@ -216,8 +226,7 @@ func TestClientInitialization(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, tc.in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -258,8 +267,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { }, MaxConnections: 1, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), d, logger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -318,7 +326,6 @@ func tryTCPDial(t *testing.T, addr string) net.Conn { } func TestClientCloseWaitsForActiveConnections(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -327,7 +334,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { }, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -342,7 +349,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { in.WaitOnClose = time.Second in.Port = 5001 - c, err = proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err = proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -372,8 +379,7 @@ func TestClientClosesCleanly(t *testing.T) { {Name: "proj:reg:inst"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -395,8 +401,7 @@ func TestClosesWithError(t *testing.T) { {Name: "proj:reg:inst"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), &errorDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -451,14 +456,13 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) { {Name: "projects/proj/locations/region/clusters/clust/instances/inst1"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } c.Close() - c, err = proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -497,8 +501,7 @@ func TestClientInitializationWithCustomHost(t *testing.T) { APIEndpointURL: s.URL, Port: 7000, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(context.Background(), nil, logger, in) + c, err := proxy.NewClient(context.Background(), nil, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -536,8 +539,7 @@ func TestClientNotifiesCallerOnServe(t *testing.T) { {Name: "proj:region:pg"}, }, } - logger := log.NewStdLogger(os.Stdout, os.Stdout) - c, err := proxy.NewClient(ctx, &fakeDialer{}, logger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -561,7 +563,6 @@ func TestClientNotifiesCallerOnServe(t *testing.T) { } func TestClientConnCount(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -571,7 +572,7 @@ func TestClientConnCount(t *testing.T) { MaxConnections: 10, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, logger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -604,7 +605,6 @@ func TestClientConnCount(t *testing.T) { } func TestCheckConnections(t *testing.T) { - logger := log.NewStdLogger(os.Stdout, os.Stdout) in := &proxy.Config{ Addr: "127.0.0.1", Port: 5000, @@ -613,7 +613,7 @@ func TestCheckConnections(t *testing.T) { }, } d := &fakeDialer{} - c, err := proxy.NewClient(context.Background(), d, logger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -637,7 +637,7 @@ func TestCheckConnections(t *testing.T) { }, } ed := &errorDialer{} - c, err = proxy.NewClient(context.Background(), ed, logger, in) + c, err = proxy.NewClient(context.Background(), ed, testLogger, in) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) }