diff --git a/main.go b/main.go index 710be69..2eba9e8 100644 --- a/main.go +++ b/main.go @@ -75,28 +75,13 @@ func _main(databases map[string]database, databasesArgs []string, query string, targetDatabases = append(targetDatabases, k) } - sqlTypes := map[sqlType]exists{} - var sqlType sqlType - for _, db := range targetDatabases { - database := databases[db] - typ, ok := validSQLTypes[database.SQLType] - if !ok { - usage("Unknown sql type %v for %v", database.SQLType, db) - } - sqlType = typ - sqlTypes[sqlType] = exists{} - if len(sqlTypes) > 1 { - usage("More than one sql types specified in target databases.") - } - } - quitContext, cancel := context.WithCancel(context.Background()) go awaitSignal(cancel) var wg sync.WaitGroup wg.Add(len(targetDatabases)) - sqlRunner := mustNewSQLRunner(quitContext, sqlType, println, query, len(targetDatabases) > 1) + sqlRunner := mustNewSQLRunner(quitContext, println, query, len(targetDatabases) > 1) returnCode := 0 for _, k := range targetDatabases { diff --git a/sql_runner.go b/sql_runner.go index ddeef73..c97cb38 100644 --- a/sql_runner.go +++ b/sql_runner.go @@ -17,6 +17,10 @@ const ( postgreSQL ) +func (t sqlType) String() string { + return [...]string{"MySQL", "PostgreSQL"}[t] +} + type exists struct{} type sqlOptions struct { @@ -54,16 +58,14 @@ var sqlTypeToOptions = map[sqlType]sqlOptions{ } type sqlRunner struct { - typ sqlType printer func(string) query string quitContext context.Context multi bool } -func mustNewSQLRunner(quitContext context.Context, typ sqlType, printer func(string), query string, multi bool) *sqlRunner { +func mustNewSQLRunner(quitContext context.Context, printer func(string), query string, multi bool) *sqlRunner { return &sqlRunner{ - typ, printer, query, quitContext, @@ -72,7 +74,12 @@ func mustNewSQLRunner(quitContext context.Context, typ sqlType, printer func(str } func (sr *sqlRunner) runSQL(db database, key string) bool { - sqlOptions := sqlTypeToOptions[sr.typ] + typ, ok := validSQLTypes[db.SQLType] + if !ok { + usage("Unknown sql type %v for %v", db.SQLType, db) + } + + sqlOptions := sqlTypeToOptions[typ] userOption := "" if db.User != "" { @@ -100,7 +107,7 @@ func (sr *sqlRunner) runSQL(db database, key string) bool { } options := "" - if sr.typ == postgreSQL { + if typ == postgreSQL { options = fmt.Sprintf("%v %v %v %v %v", sqlOptions.cmd, userOption, hostOption, dbOption, sqlOptions.flags) } else { options = fmt.Sprintf("%v %v %v %v %v %v", sqlOptions.cmd, dbOption, userOption, passOption, hostOption, sqlOptions.flags) @@ -109,7 +116,7 @@ func (sr *sqlRunner) runSQL(db database, key string) bool { var cmd *exec.Cmd if db.AppServer != "" { escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(sr.query, `'`, `'"'"'`, -1)) - if sr.typ == postgreSQL { + if typ == postgreSQL { escapedQuery += fmt.Sprintf("-F%s", "\t") } @@ -117,13 +124,13 @@ func (sr *sqlRunner) runSQL(db database, key string) bool { } else { args := append(trimEmpty(strings.Split(options, " ")), sr.query) - if sr.typ == postgreSQL { + if typ == postgreSQL { args = append(args, fmt.Sprintf("-F%s", "\t")) } cmd = exec.CommandContext(sr.quitContext, args[0], args[1:]...) } - if sr.typ == postgreSQL { + if typ == postgreSQL { cmd.Env = os.Environ() cmd.Env = append(cmd.Env, passOption) }