Skip to content

Commit

Permalink
Merge pull request #64 from ankushagarwal/master
Browse files Browse the repository at this point in the history
Add support for Unix Socket connections to Postgres databases
  • Loading branch information
Carrotman42 authored Mar 7, 2017
2 parents 264edf7 + 80fe850 commit 85f824e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func main() {
}
instList = append(instList, ins...)

cfgs, err := CreateInstanceConfigs(*dir, *useFuse, instList, *instanceSrc)
cfgs, err := CreateInstanceConfigs(*dir, *useFuse, instList, *instanceSrc, client)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -393,7 +393,7 @@ func main() {
}()
}

c, err := WatchInstances(*dir, cfgs, updates)
c, err := WatchInstances(*dir, cfgs, updates, client)
if err != nil {
log.Fatal(err)
}
Expand Down
41 changes: 31 additions & 10 deletions cmd/cloud_sql_proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"log"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
Expand All @@ -35,7 +36,7 @@ import (
// local connections. Values received from the updates channel are
// interpretted as a comma-separated list of instances. The set of sockets in
// 'dir' is the union of 'instances' and the most recent list from 'updates'.
func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string) (<-chan proxy.Conn, error) {
func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string, cl *http.Client) (<-chan proxy.Conn, error) {
ch := make(chan proxy.Conn, 1)

// Instances specified statically (e.g. as flags to the binary) will always
Expand All @@ -51,15 +52,15 @@ func WatchInstances(dir string, cfgs []instanceConfig, updates <-chan string) (<
}

if updates != nil {
go watchInstancesLoop(dir, ch, updates, staticInstances)
go watchInstancesLoop(dir, ch, updates, staticInstances, cl)
}
return ch, nil
}

func watchInstancesLoop(dir string, dst chan<- proxy.Conn, updates <-chan string, static map[string]net.Listener) {
func watchInstancesLoop(dir string, dst chan<- proxy.Conn, updates <-chan string, static map[string]net.Listener, cl *http.Client) {
dynamicInstances := make(map[string]net.Listener)
for instances := range updates {
list, err := parseInstanceConfigs(dir, strings.Split(instances, ","))
list, err := parseInstanceConfigs(dir, strings.Split(instances, ","), cl)
if err != nil {
log.Print(err)
}
Expand Down Expand Up @@ -200,7 +201,7 @@ var validNets = func() map[string]bool {
return m
}()

func parseInstanceConfig(dir, instance string) (instanceConfig, error) {
func parseInstanceConfig(dir, instance string, cl *http.Client) (instanceConfig, error) {
var ret instanceConfig
eq := strings.Index(instance, "=")
if eq != -1 {
Expand Down Expand Up @@ -228,7 +229,27 @@ func parseInstanceConfig(dir, instance string) (instanceConfig, error) {
ret.Instance = instance
// Default to unix socket.
ret.Network = "unix"
ret.Address = filepath.Join(dir, instance)
spl := strings.SplitN(instance, ":", 3)
sql, err := sqladmin.New(cl)
if err != nil {
return instanceConfig{}, err
}
in, err := sql.Instances.Get(spl[0], spl[2]).Do()
if err != nil {
return instanceConfig{}, err
}
if strings.HasPrefix(strings.ToLower(in.DatabaseVersion), "postgres") {
path := filepath.Join(dir, instance)
if _, err := os.Stat(path); !os.IsNotExist(err) {
return instanceConfig{}, err
}
if err := os.MkdirAll(path, 0755); err != nil {
return instanceConfig{}, err
}
ret.Address = filepath.Join(path, ".s.PGSQL.5432")
} else {
ret.Address = filepath.Join(dir, instance)
}
}

if !validNets[ret.Network] {
Expand All @@ -240,14 +261,14 @@ func parseInstanceConfig(dir, instance string) (instanceConfig, error) {
// parseInstanceConfigs calls parseInstanceConfig for each instance in the
// provided slice, collecting errors along the way. There may be valid
// instanceConfigs returned even if there's an error.
func parseInstanceConfigs(dir string, instances []string) ([]instanceConfig, error) {
func parseInstanceConfigs(dir string, instances []string, cl *http.Client) ([]instanceConfig, error) {
errs := new(bytes.Buffer)
var cfg []instanceConfig
for _, v := range instances {
if v == "" {
continue
}
if c, err := parseInstanceConfig(dir, v); err != nil {
if c, err := parseInstanceConfig(dir, v, cl); err != nil {
fmt.Fprintf(errs, "\n\t%v", err)
} else {
cfg = append(cfg, c)
Expand All @@ -264,12 +285,12 @@ func parseInstanceConfigs(dir string, instances []string) ([]instanceConfig, err
// CreateInstanceConfigs verifies that the parameters passed to it are valid
// for the proxy for the platform and system and then returns a slice of valid
// instanceConfig.
func CreateInstanceConfigs(dir string, useFuse bool, instances []string, instancesSrc string) ([]instanceConfig, error) {
func CreateInstanceConfigs(dir string, useFuse bool, instances []string, instancesSrc string, cl *http.Client) ([]instanceConfig, error) {
if useFuse && !fuse.Supported() {
return nil, errors.New("FUSE not supported on this system")
}

cfgs, err := parseInstanceConfigs(dir, instances)
cfgs, err := parseInstanceConfigs(dir, instances, cl)
if err != nil {
return nil, err
}
Expand Down
16 changes: 14 additions & 2 deletions cmd/cloud_sql_proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,21 @@
package main

import (
"bytes"
"io/ioutil"
"net/http"
"testing"
)

type mockTripper struct {
}

func (m *mockTripper) RoundTrip(r *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewReader([]byte("{}")))}, nil
}

var mockClient = &http.Client{Transport: &mockTripper{}}

func TestCreateInstanceConfigs(t *testing.T) {
for _, v := range []struct {
desc string
Expand Down Expand Up @@ -69,7 +81,7 @@ func TestCreateInstanceConfigs(t *testing.T) {
"", false, nil, "md", true,
},
} {
_, err := CreateInstanceConfigs(v.dir, v.useFuse, v.instances, v.instancesSrc)
_, err := CreateInstanceConfigs(v.dir, v.useFuse, v.instances, v.instancesSrc, mockClient)
if v.wantErr {
if err == nil {
t.Errorf("CreateInstanceConfigs passed when %s, wanted error", v.desc)
Expand Down Expand Up @@ -120,7 +132,7 @@ func TestParseInstanceConfig(t *testing.T) {
true,
},
} {
got, err := parseInstanceConfig(v.dir, v.instance)
got, err := parseInstanceConfig(v.dir, v.instance, mockClient)
if v.wantErr {
if err == nil {
t.Errorf("parseInstanceConfig(%s, %s) = %+v, wanted error", got)
Expand Down

0 comments on commit 85f824e

Please sign in to comment.