From 70ba445b93e51df894613c4fc160010b651d893c Mon Sep 17 00:00:00 2001 From: Jacky Zhen Date: Tue, 2 Oct 2018 21:53:20 +1300 Subject: [PATCH] Support postgres via database config --- Dockerfile | 2 + config.go | 1 + main.go | 95 ++---------- main_test.go | 151 +++++++++++++++---- test_schemas.sql => mysql_test_schemas.sql | 0 postgres_test_schemas.sql | 47 ++++++ sql_runner.go | 166 +++++++++++++++++++++ test-docker-compose.yml | 10 +- 8 files changed, 367 insertions(+), 105 deletions(-) rename test_schemas.sql => mysql_test_schemas.sql (100%) create mode 100644 postgres_test_schemas.sql create mode 100644 sql_runner.go diff --git a/Dockerfile b/Dockerfile index e3e33d1..8d1c23f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,4 +2,6 @@ FROM golang:1.11 RUN apt-get update && apt-get install -y --no-install-recommends mysql-client && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y postgresql-client + ENTRYPOINT [ "go", "test", "-v", "." ] diff --git a/config.go b/config.go index a546af8..7613082 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ type database struct { DbName string User string Pass string + SQLType sqlType } func mustReadDatabasesConfigFile() map[string]database { diff --git a/main.go b/main.go index d79373d..53ff2ac 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,11 @@ package main import ( - "bufio" "context" "flag" "fmt" "log" "os" - "os/exec" - "strings" "sync" ) @@ -78,17 +75,32 @@ func _main(databases map[string]database, databasesArgs []string, query string, targetDatabases = append(targetDatabases, k) } + sqlTypes := map[sqlType]struct{}{} + var sqlType sqlType + for _, db := range targetDatabases { + sqlType = databases[db].SQLType + if _, ok := validSQLTypes[sqlType]; !ok { + usage("Unknown sql type %v", sqlType) + } + sqlTypes[sqlType] = struct{}{} + if len(sqlTypes) > 1 { + usage("More than one sql types specified in target databases.") + } + } + quitContext, cancel := context.WithCancel(context.Background()) go awaitSignal(cancel) var wg sync.WaitGroup wg.Add(len(targetDatabases)) + sqlRunner := mustNewSQLRunner(quitContext, sqlType, println, query, len(targetDatabases) > 1) + returnCode := 0 for _, k := range targetDatabases { go func(db database, k string) { defer wg.Done() - if r := runSQL(quitContext, db, query, k, len(targetDatabases) > 1, println); !r { + if r := sqlRunner.runSQL(db, k); !r { returnCode = 1 } }(databases[k], k) @@ -97,78 +109,3 @@ func _main(databases map[string]database, databasesArgs []string, query string, wg.Wait() return returnCode } - -func runSQL(quitContext context.Context, db database, query string, key string, prependKey bool, println func(string)) bool { - userOption := "" - if db.User != "" { - userOption = fmt.Sprintf("-u %v ", db.User) - } - - passOption := "" - if db.Pass != "" { - passOption = fmt.Sprintf("-p%v ", db.Pass) - } - - hostOption := "" - if db.DbServer != "" { - hostOption = fmt.Sprintf("-h %v ", db.DbServer) - } - - prepend := "" - if prependKey { - prepend = key + "\t" - } - - mysql := "mysql" - options := fmt.Sprintf(" -Nsr %v%v%v%v -e ", userOption, passOption, hostOption, db.DbName) - - var cmd *exec.Cmd - if db.AppServer != "" { - escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(query, `'`, `'"'"'`, -1)) - cmd = exec.CommandContext(quitContext, "ssh", db.AppServer, mysql+options+escapedQuery) - } else { - args := append(trimEmpty(strings.Split(options, " ")), query) - cmd = exec.CommandContext(quitContext, mysql, args...) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - log.Printf("Cannot create pipe for STDOUT of running command on %v; not running. err=%v\n", key, err) - return false - } - - stderr, err := cmd.StderrPipe() - if err != nil { - log.Printf("Cannot create pipe for STDERR of running command on %v; not running. err=%v\n", key, err) - return false - } - - if err := cmd.Start(); err != nil { - log.Printf("Cannot start command on %v; not running. err=%v\n", key, err) - return false - } - - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - println(prepend + scanner.Text()) - } - - stderrLines := []string{} - scanner = bufio.NewScanner(stderr) - for scanner.Scan() { - stderrLines = append(stderrLines, scanner.Text()) - } - - cmd.Wait() - - result := true - if len(stderrLines) > 0 { - result = false - log.Println(key + " had errors:") - for _, v := range stderrLines { - log.Println(key + " [ERROR] " + v) - } - } - - return result -} diff --git a/main_test.go b/main_test.go index 0a25732..3dc89fc 100644 --- a/main_test.go +++ b/main_test.go @@ -11,7 +11,7 @@ import ( "time" ) -func TestSQL(t *testing.T) { +func Test_MySQL(t *testing.T) { var err error for i := 1; i <= 30; i++ { // Try up to 30 times, because MySQL takes a while to become online var c = exec.Command("mysql", "-h", "test-mysql", "-u", "root", "-e", "SELECT * FROM db1.table1") @@ -55,12 +55,12 @@ func TestSQL(t *testing.T) { query: "SELECT id FROM table1", expected: []string{ "", - "db1 1", - "db1 2", - "db1 3", - "db2 1", - "db2 2", - "db2 3", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", }, }, { @@ -69,15 +69,15 @@ func TestSQL(t *testing.T) { query: "SELECT id FROM table1", expected: []string{ "", - "db1 1", - "db1 2", - "db1 3", - "db2 1", - "db2 2", - "db2 3", - "db3 1", - "db3 2", - "db3 3", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + "db3\t1", + "db3\t2", + "db3\t3", }, }, { @@ -86,15 +86,116 @@ func TestSQL(t *testing.T) { query: "SELECT id, name FROM table1", expected: []string{ "", - "db1 1 John", - "db1 2 George", - "db1 3 Richard", - "db2 1 Rob", - "db2 2 Ken", - "db2 3 Robert", - "db3 1 Athos", - "db3 2 Porthos", - "db3 3 Aramis", + "db1\t1\tJohn", + "db1\t2\tGeorge", + "db1\t3\tRichard", + "db2\t1\tRob", + "db2\t2\tKen", + "db2\t3\tRobert", + "db3\t1\tAthos", + "db3\t2\tPorthos", + "db3\t3\tAramis", + }, + }, + } + ) + for _, tc := range ts { + t.Run(tc.name, func(t *testing.T) { + var buf = bytes.Buffer{} + _main(testConfig, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println) + var actual = strings.Split(buf.String(), "\n") + sort.Strings(actual) + if !reflect.DeepEqual(tc.expected, actual) { + t.Errorf("Expected %v but got %v", tc.expected, actual) + } + }) + } +} + +func Test_PostgreSQL(t *testing.T) { + var err error + for i := 1; i <= 30; i++ { + var c = exec.Command("psql", "-h", "test-postgres", "-U", "root", "-d", "db1", "-c", "SELECT * FROM table1") + if err = c.Run(); err == nil { + break + } + log.Printf("Retrying (%v/30) in 1 sec because PostgreSQL is not yet ready", i) + time.Sleep(1 * time.Second) + } + for err != nil { + t.Errorf("bailing because couldn't connect to PostgreSQL after 30 tries: %v", err) + t.FailNow() + } + + var ( + testConfig = map[string]database{ + "db1": database{DbServer: "test-postgres", DbName: "db1", User: "root", Pass: "", SQLType: postgreSQL}, + "db2": database{DbServer: "test-postgres", DbName: "db2", User: "root", Pass: "", SQLType: postgreSQL}, + "db3": database{DbServer: "test-postgres", DbName: "db3", User: "root", Pass: "", SQLType: postgreSQL}, + } + ts = []struct { + name string + targetDBs []string + query string + expected []string + }{ + { + name: "reads from one database", + targetDBs: []string{"db1"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "1", + "2", + "3", + }, + }, + { + name: "reads from two databases", + targetDBs: []string{"db1", "db2"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + }, + }, + { + name: "reads from all databases with the all keyword", + targetDBs: []string{"all"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + "db3\t1", + "db3\t2", + "db3\t3", + }, + }, + { + name: "reads two fields from all databases", + targetDBs: []string{"all"}, + query: "SELECT id, name FROM table1", + expected: []string{ + "", + "db1\t1 | John", + "db1\t2 | George", + "db1\t3 | Richard", + "db2\t1 | Rob", + "db2\t2 | Ken", + "db2\t3 | Robert", + "db3\t1 | Athos", + "db3\t2 | Porthos", + "db3\t3 | Aramis", }, }, } diff --git a/test_schemas.sql b/mysql_test_schemas.sql similarity index 100% rename from test_schemas.sql rename to mysql_test_schemas.sql diff --git a/postgres_test_schemas.sql b/postgres_test_schemas.sql new file mode 100644 index 0000000..f6ba1d6 --- /dev/null +++ b/postgres_test_schemas.sql @@ -0,0 +1,47 @@ +DROP DATABASE IF EXISTS db1; +CREATE DATABASE db1; + +DROP DATABASE IF EXISTS db2; +CREATE DATABASE db2; + +DROP DATABASE IF EXISTS db3; +CREATE DATABASE db3; + +\c db1 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'John'), +(2, 'George'), +(3, 'Richard'); + +\c db2 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'Rob'), +(2, 'Ken'), +(3, 'Robert'); + +\c db3 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'Athos'), +(2, 'Porthos'), +(3, 'Aramis'); diff --git a/sql_runner.go b/sql_runner.go new file mode 100644 index 0000000..3946d76 --- /dev/null +++ b/sql_runner.go @@ -0,0 +1,166 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "log" + "os" + "os/exec" + "strings" +) + +type sqlType int + +const ( + mySQL sqlType = iota + postgreSQL +) + +type exists struct{} + +type sqlOptions struct { + cmd string + user string + host string + pass string + db string + flags string +} + +var validSQLTypes = map[sqlType]exists{ + mySQL: exists{}, + postgreSQL: exists{}, +} + +var sqlTypeToOptions = map[sqlType]sqlOptions{ + mySQL: { + "mysql", + "-u%v", + "-h%v", + "-p%v", + "%v", + "-Nsre", + }, + postgreSQL: { + "psql", + "-U %v", + "-h%v", + "PGPASSWORD=%v", + "-d %v", + "-tc", + }, +} + +type sqlRunner struct { + typ sqlType + printer func(string) + query string + quitContext context.Context + multi bool +} + +func mustNewSQLRunner(quitContext context.Context, typ sqlType, printer func(string), query string, multi bool) *sqlRunner { + return &sqlRunner{ + typ, + printer, + query, + quitContext, + multi, + } +} + +func (sr *sqlRunner) runSQL(db database, key string) bool { + sqlOptions := sqlTypeToOptions[sr.typ] + + userOption := "" + if db.User != "" { + userOption = fmt.Sprintf(sqlOptions.user, db.User) + } + + passOption := "" + if db.Pass != "" { + passOption = fmt.Sprintf(sqlOptions.pass, db.Pass) + } + + hostOption := "" + if db.DbServer != "" { + hostOption = fmt.Sprintf(sqlOptions.host, db.DbServer) + } + + dbOption := "" + if db.DbName != "" { + dbOption = fmt.Sprintf(sqlOptions.db, db.DbName) + } + + prepend := "" + if sr.multi { + prepend = key + "\t" + } + + options := "" + if sr.typ == postgreSQL { + options = fmt.Sprintf("%v %v %v %v %v", sqlOptions.cmd, userOption, hostOption, dbOption, sqlOptions.flags) + } else { + options = fmt.Sprintf("%v %v %v %v %v %v", sqlOptions.cmd, dbOption, userOption, passOption, hostOption, sqlOptions.flags) + } + + var cmd *exec.Cmd + if db.AppServer != "" { + escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(sr.query, `'`, `'"'"'`, -1)) + cmd = exec.CommandContext(sr.quitContext, "ssh", db.AppServer, options+escapedQuery) + + } else { + args := append(trimEmpty(strings.Split(options, " ")), sr.query) + cmd = exec.CommandContext(sr.quitContext, args[0], args[1:]...) + } + + if sr.typ == postgreSQL { + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, passOption) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Printf("Cannot create pipe for STDOUT of running command on %v; not running. err=%v\n", key, err) + return false + } + + stderr, err := cmd.StderrPipe() + if err != nil { + log.Printf("Cannot create pipe for STDERR of running command on %v; not running. err=%v\n", key, err) + return false + } + + if err := cmd.Start(); err != nil { + log.Printf("Cannot start command on %v; not running. err=%v\n", key, err) + return false + } + + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + sr.printer(prepend + strings.TrimSpace(line)) + } + } + + stderrLines := []string{} + scanner = bufio.NewScanner(stderr) + for scanner.Scan() { + stderrLines = append(stderrLines, scanner.Text()) + } + + cmd.Wait() + + result := true + if len(stderrLines) > 0 { + result = false + log.Println(key + " had errors:") + for _, v := range stderrLines { + log.Println(key + " [ERROR] " + v) + } + } + + return result +} diff --git a/test-docker-compose.yml b/test-docker-compose.yml index 1677829..f6b871a 100644 --- a/test-docker-compose.yml +++ b/test-docker-compose.yml @@ -7,10 +7,18 @@ services: working_dir: /go/src/sql depends_on: - test-mysql + - test-postgres test-mysql: image: mysql:5 volumes: - - ./test_schemas.sql:/docker-entrypoint-initdb.d/test_schemas.sql + - ./mysql_test_schemas.sql:/docker-entrypoint-initdb.d/mysql_test_schemas.sql environment: MYSQL_ROOT_PASSWORD: MYSQL_ALLOW_EMPTY_PASSWORD: "yes" + test-postgres: + image: postgres:9-alpine + volumes: + - ./postgres_test_schemas.sql:/docker-entrypoint-initdb.d/postgres_test_schemas.sql + environment: + POSTGRES_USER: root + POSTGRES_PASSWORD: