Skip to content

Commit

Permalink
Support postgres via database config
Browse files Browse the repository at this point in the history
  • Loading branch information
jackyzhen committed Oct 2, 2018
1 parent 99d526e commit 70ba445
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 105 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ FROM golang:1.11

RUN apt-get update && apt-get install -y --no-install-recommends mysql-client && rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y postgresql-client

ENTRYPOINT [ "go", "test", "-v", "." ]
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type database struct {
DbName string
User string
Pass string
SQLType sqlType
}

func mustReadDatabasesConfigFile() map[string]database {
Expand Down
95 changes: 16 additions & 79 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package main

import (
"bufio"
"context"
"flag"
"fmt"
"log"
"os"
"os/exec"
"strings"
"sync"
)

Expand Down Expand Up @@ -78,17 +75,32 @@ func _main(databases map[string]database, databasesArgs []string, query string,
targetDatabases = append(targetDatabases, k)
}

sqlTypes := map[sqlType]struct{}{}
var sqlType sqlType
for _, db := range targetDatabases {
sqlType = databases[db].SQLType
if _, ok := validSQLTypes[sqlType]; !ok {
usage("Unknown sql type %v", sqlType)
}
sqlTypes[sqlType] = struct{}{}
if len(sqlTypes) > 1 {
usage("More than one sql types specified in target databases.")
}
}

quitContext, cancel := context.WithCancel(context.Background())
go awaitSignal(cancel)

var wg sync.WaitGroup
wg.Add(len(targetDatabases))

sqlRunner := mustNewSQLRunner(quitContext, sqlType, println, query, len(targetDatabases) > 1)

returnCode := 0
for _, k := range targetDatabases {
go func(db database, k string) {
defer wg.Done()
if r := runSQL(quitContext, db, query, k, len(targetDatabases) > 1, println); !r {
if r := sqlRunner.runSQL(db, k); !r {
returnCode = 1
}
}(databases[k], k)
Expand All @@ -97,78 +109,3 @@ func _main(databases map[string]database, databasesArgs []string, query string,
wg.Wait()
return returnCode
}

func runSQL(quitContext context.Context, db database, query string, key string, prependKey bool, println func(string)) bool {
userOption := ""
if db.User != "" {
userOption = fmt.Sprintf("-u %v ", db.User)
}

passOption := ""
if db.Pass != "" {
passOption = fmt.Sprintf("-p%v ", db.Pass)
}

hostOption := ""
if db.DbServer != "" {
hostOption = fmt.Sprintf("-h %v ", db.DbServer)
}

prepend := ""
if prependKey {
prepend = key + "\t"
}

mysql := "mysql"
options := fmt.Sprintf(" -Nsr %v%v%v%v -e ", userOption, passOption, hostOption, db.DbName)

var cmd *exec.Cmd
if db.AppServer != "" {
escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(query, `'`, `'"'"'`, -1))
cmd = exec.CommandContext(quitContext, "ssh", db.AppServer, mysql+options+escapedQuery)
} else {
args := append(trimEmpty(strings.Split(options, " ")), query)
cmd = exec.CommandContext(quitContext, mysql, args...)
}

stdout, err := cmd.StdoutPipe()
if err != nil {
log.Printf("Cannot create pipe for STDOUT of running command on %v; not running. err=%v\n", key, err)
return false
}

stderr, err := cmd.StderrPipe()
if err != nil {
log.Printf("Cannot create pipe for STDERR of running command on %v; not running. err=%v\n", key, err)
return false
}

if err := cmd.Start(); err != nil {
log.Printf("Cannot start command on %v; not running. err=%v\n", key, err)
return false
}

scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
println(prepend + scanner.Text())
}

stderrLines := []string{}
scanner = bufio.NewScanner(stderr)
for scanner.Scan() {
stderrLines = append(stderrLines, scanner.Text())
}

cmd.Wait()

result := true
if len(stderrLines) > 0 {
result = false
log.Println(key + " had errors:")
for _, v := range stderrLines {
log.Println(key + " [ERROR] " + v)
}
}

return result
}
151 changes: 126 additions & 25 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"time"
)

func TestSQL(t *testing.T) {
func Test_MySQL(t *testing.T) {
var err error
for i := 1; i <= 30; i++ { // Try up to 30 times, because MySQL takes a while to become online
var c = exec.Command("mysql", "-h", "test-mysql", "-u", "root", "-e", "SELECT * FROM db1.table1")
Expand Down Expand Up @@ -55,12 +55,12 @@ func TestSQL(t *testing.T) {
query: "SELECT id FROM table1",
expected: []string{
"",
"db1 1",
"db1 2",
"db1 3",
"db2 1",
"db2 2",
"db2 3",
"db1\t1",
"db1\t2",
"db1\t3",
"db2\t1",
"db2\t2",
"db2\t3",
},
},
{
Expand All @@ -69,15 +69,15 @@ func TestSQL(t *testing.T) {
query: "SELECT id FROM table1",
expected: []string{
"",
"db1 1",
"db1 2",
"db1 3",
"db2 1",
"db2 2",
"db2 3",
"db3 1",
"db3 2",
"db3 3",
"db1\t1",
"db1\t2",
"db1\t3",
"db2\t1",
"db2\t2",
"db2\t3",
"db3\t1",
"db3\t2",
"db3\t3",
},
},
{
Expand All @@ -86,15 +86,116 @@ func TestSQL(t *testing.T) {
query: "SELECT id, name FROM table1",
expected: []string{
"",
"db1 1 John",
"db1 2 George",
"db1 3 Richard",
"db2 1 Rob",
"db2 2 Ken",
"db2 3 Robert",
"db3 1 Athos",
"db3 2 Porthos",
"db3 3 Aramis",
"db1\t1\tJohn",
"db1\t2\tGeorge",
"db1\t3\tRichard",
"db2\t1\tRob",
"db2\t2\tKen",
"db2\t3\tRobert",
"db3\t1\tAthos",
"db3\t2\tPorthos",
"db3\t3\tAramis",
},
},
}
)
for _, tc := range ts {
t.Run(tc.name, func(t *testing.T) {
var buf = bytes.Buffer{}
_main(testConfig, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println)
var actual = strings.Split(buf.String(), "\n")
sort.Strings(actual)
if !reflect.DeepEqual(tc.expected, actual) {
t.Errorf("Expected %v but got %v", tc.expected, actual)
}
})
}
}

func Test_PostgreSQL(t *testing.T) {
var err error
for i := 1; i <= 30; i++ {
var c = exec.Command("psql", "-h", "test-postgres", "-U", "root", "-d", "db1", "-c", "SELECT * FROM table1")
if err = c.Run(); err == nil {
break
}
log.Printf("Retrying (%v/30) in 1 sec because PostgreSQL is not yet ready", i)
time.Sleep(1 * time.Second)
}
for err != nil {
t.Errorf("bailing because couldn't connect to PostgreSQL after 30 tries: %v", err)
t.FailNow()
}

var (
testConfig = map[string]database{
"db1": database{DbServer: "test-postgres", DbName: "db1", User: "root", Pass: "", SQLType: postgreSQL},
"db2": database{DbServer: "test-postgres", DbName: "db2", User: "root", Pass: "", SQLType: postgreSQL},
"db3": database{DbServer: "test-postgres", DbName: "db3", User: "root", Pass: "", SQLType: postgreSQL},
}
ts = []struct {
name string
targetDBs []string
query string
expected []string
}{
{
name: "reads from one database",
targetDBs: []string{"db1"},
query: "SELECT id FROM table1",
expected: []string{
"",
"1",
"2",
"3",
},
},
{
name: "reads from two databases",
targetDBs: []string{"db1", "db2"},
query: "SELECT id FROM table1",
expected: []string{
"",
"db1\t1",
"db1\t2",
"db1\t3",
"db2\t1",
"db2\t2",
"db2\t3",
},
},
{
name: "reads from all databases with the all keyword",
targetDBs: []string{"all"},
query: "SELECT id FROM table1",
expected: []string{
"",
"db1\t1",
"db1\t2",
"db1\t3",
"db2\t1",
"db2\t2",
"db2\t3",
"db3\t1",
"db3\t2",
"db3\t3",
},
},
{
name: "reads two fields from all databases",
targetDBs: []string{"all"},
query: "SELECT id, name FROM table1",
expected: []string{
"",
"db1\t1 | John",
"db1\t2 | George",
"db1\t3 | Richard",
"db2\t1 | Rob",
"db2\t2 | Ken",
"db2\t3 | Robert",
"db3\t1 | Athos",
"db3\t2 | Porthos",
"db3\t3 | Aramis",
},
},
}
Expand Down
File renamed without changes.
47 changes: 47 additions & 0 deletions postgres_test_schemas.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
DROP DATABASE IF EXISTS db1;
CREATE DATABASE db1;

DROP DATABASE IF EXISTS db2;
CREATE DATABASE db2;

DROP DATABASE IF EXISTS db3;
CREATE DATABASE db3;

\c db1

CREATE TABLE table1(
id integer NOT NULL,
name varchar(255) NOT NULL
);

INSERT INTO table1 (id, name)
VALUES
(1, 'John'),
(2, 'George'),
(3, 'Richard');

\c db2

CREATE TABLE table1(
id integer NOT NULL,
name varchar(255) NOT NULL
);

INSERT INTO table1 (id, name)
VALUES
(1, 'Rob'),
(2, 'Ken'),
(3, 'Robert');

\c db3

CREATE TABLE table1(
id integer NOT NULL,
name varchar(255) NOT NULL
);

INSERT INTO table1 (id, name)
VALUES
(1, 'Athos'),
(2, 'Porthos'),
(3, 'Aramis');
Loading

0 comments on commit 70ba445

Please sign in to comment.