diff --git a/cmd/cloud_sql_proxy/cloud_sql_proxy.go b/cmd/cloud_sql_proxy/cloud_sql_proxy.go index ef35cc9c2..e82d736dc 100644 --- a/cmd/cloud_sql_proxy/cloud_sql_proxy.go +++ b/cmd/cloud_sql_proxy/cloud_sql_proxy.go @@ -357,7 +357,7 @@ func main() { } instList = append(instList, ins...) - cfgs, err := CreateInstanceConfigs(*dir, *useFuse, instList, *instanceSrc) + cfgs, err := CreateInstanceConfigs(*dir, *useFuse, instList, *instanceSrc, client) if err != nil { log.Fatal(err) } @@ -393,7 +393,7 @@ func main() { }() } - c, err := WatchInstances(*dir, cfgs, updates) + c, err := WatchInstances(*dir, cfgs, updates, client) if err != nil { log.Fatal(err) } diff --git a/cmd/cloud_sql_proxy/proxy.go b/cmd/cloud_sql_proxy/proxy.go index ba94361c9..afe79cc32 100644 --- a/cmd/cloud_sql_proxy/proxy.go +++ b/cmd/cloud_sql_proxy/proxy.go @@ -22,6 +22,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path/filepath" "runtime" @@ -35,7 +36,7 @@ import ( // local connections. Values received from the updates channel are // interpretted as a comma-separated list of instances. The set of sockets in // 'dir' is the union of 'instances' and the most recent list from 'updates'. -func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string) (<-chan proxy.Conn, error) { +func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string, cl *http.Client) (<-chan proxy.Conn, error) { ch := make(chan proxy.Conn, 1) // Instances specified statically (e.g. as flags to the binary) will always @@ -51,15 +52,15 @@ func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string) (< } if updates != nil { - go watchInstancesLoop(dir, ch, updates, staticInstances) + go watchInstancesLoop(dir, ch, updates, staticInstances, cl) } return ch, nil } -func watchInstancesLoop(dir string, dst chan<- proxy.Conn, updates <-chan string, static map[string]net.Listener) { +func watchInstancesLoop(dir string, dst chan<- proxy.Conn, updates <-chan string, static map[string]net.Listener, cl *http.Client) { dynamicInstances := make(map[string]net.Listener) for instances := range updates { - list, err := parseInstanceConfigs(dir, strings.Split(instances, ",")) + list, err := parseInstanceConfigs(dir, strings.Split(instances, ","), cl) if err != nil { log.Print(err) } @@ -200,7 +201,7 @@ var validNets = func() map[string]bool { return m }() -func parseInstanceConfig(dir, instance string) (instanceConfig, error) { +func parseInstanceConfig(dir, instance string, cl *http.Client) (instanceConfig, error) { var ret instanceConfig eq := strings.Index(instance, "=") if eq != -1 { @@ -228,7 +229,27 @@ func parseInstanceConfig(dir, instance string) (instanceConfig, error) { ret.Instance = instance // Default to unix socket. ret.Network = "unix" - ret.Address = filepath.Join(dir, instance) + spl := strings.SplitN(instance, ":", 3) + sql, err := sqladmin.New(cl) + if err != nil { + return instanceConfig{}, err + } + in, err := sql.Instances.Get(spl[0], spl[2]).Do() + if err != nil { + return instanceConfig{}, err + } + if strings.HasPrefix(strings.ToLower(in.DatabaseVersion), "postgres") { + path := filepath.Join(dir, instance) + if _, err := os.Stat(path); !os.IsNotExist(err) { + return instanceConfig{}, err + } + if err := os.MkdirAll(path, 0755); err != nil { + return instanceConfig{}, err + } + ret.Address = filepath.Join(path, ".s.PGSQL.5432") + } else { + ret.Address = filepath.Join(dir, instance) + } } if !validNets[ret.Network] { @@ -240,14 +261,14 @@ func parseInstanceConfig(dir, instance string) (instanceConfig, error) { // parseInstanceConfigs calls parseInstanceConfig for each instance in the // provided slice, collecting errors along the way. There may be valid // instanceConfigs returned even if there's an error. -func parseInstanceConfigs(dir string, instances []string) ([]instanceConfig, error) { +func parseInstanceConfigs(dir string, instances []string, cl *http.Client) ([]instanceConfig, error) { errs := new(bytes.Buffer) var cfg []instanceConfig for _, v := range instances { if v == "" { continue } - if c, err := parseInstanceConfig(dir, v); err != nil { + if c, err := parseInstanceConfig(dir, v, cl); err != nil { fmt.Fprintf(errs, "\n\t%v", err) } else { cfg = append(cfg, c) @@ -264,12 +285,12 @@ func parseInstanceConfigs(dir string, instances []string) ([]instanceConfig, err // CreateInstanceConfigs verifies that the parameters passed to it are valid // for the proxy for the platform and system and then returns a slice of valid // instanceConfig. -func CreateInstanceConfigs(dir string, useFuse bool, instances []string, instancesSrc string) ([]instanceConfig, error) { +func CreateInstanceConfigs(dir string, useFuse bool, instances []string, instancesSrc string, cl *http.Client) ([]instanceConfig, error) { if useFuse && !fuse.Supported() { return nil, errors.New("FUSE not supported on this system") } - cfgs, err := parseInstanceConfigs(dir, instances) + cfgs, err := parseInstanceConfigs(dir, instances, cl) if err != nil { return nil, err } diff --git a/cmd/cloud_sql_proxy/proxy_test.go b/cmd/cloud_sql_proxy/proxy_test.go index e6c2c2e65..3153daf73 100644 --- a/cmd/cloud_sql_proxy/proxy_test.go +++ b/cmd/cloud_sql_proxy/proxy_test.go @@ -15,9 +15,21 @@ package main import ( + "bytes" + "io/ioutil" + "net/http" "testing" ) +type mockTripper struct { +} + +func (m *mockTripper) RoundTrip(r *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewReader([]byte("{}")))}, nil +} + +var mockClient = &http.Client{Transport: &mockTripper{}} + func TestCreateInstanceConfigs(t *testing.T) { for _, v := range []struct { desc string @@ -69,7 +81,7 @@ func TestCreateInstanceConfigs(t *testing.T) { "", false, nil, "md", true, }, } { - _, err := CreateInstanceConfigs(v.dir, v.useFuse, v.instances, v.instancesSrc) + _, err := CreateInstanceConfigs(v.dir, v.useFuse, v.instances, v.instancesSrc, mockClient) if v.wantErr { if err == nil { t.Errorf("CreateInstanceConfigs passed when %s, wanted error", v.desc) @@ -120,7 +132,7 @@ func TestParseInstanceConfig(t *testing.T) { true, }, } { - got, err := parseInstanceConfig(v.dir, v.instance) + got, err := parseInstanceConfig(v.dir, v.instance, mockClient) if v.wantErr { if err == nil { t.Errorf("parseInstanceConfig(%s, %s) = %+v, wanted error", got)