From 3d33d73c009077c5bf30ae4b03802904bfb5d5b2 Mon Sep 17 00:00:00 2001 From: arekkas Date: Tue, 8 May 2018 17:34:59 +0200 Subject: [PATCH] cmd: Retries SQL connection on migrate commands This patch also introduces a fatal error if migrations fail --- authentication/oauth2_introspection_test.go | 4 +- cmd/server/migrate.go | 23 ++++--- cmd/server/sql.go | 72 +++++++++++++-------- 3 files changed, 60 insertions(+), 39 deletions(-) diff --git a/authentication/oauth2_introspection_test.go b/authentication/oauth2_introspection_test.go index 1a261ee79..e282023a1 100644 --- a/authentication/oauth2_introspection_test.go +++ b/authentication/oauth2_introspection_test.go @@ -110,7 +110,7 @@ func TestOAuth2Introspection(t *testing.T) { IssuedAt: now.Unix(), NotBefore: now.Unix(), Username: "username", - Audience: "audience", + Audience: []string{"audience"}, Issuer: "issuer", } }, @@ -126,7 +126,7 @@ func TestOAuth2Introspection(t *testing.T) { IssuedAt: now, NotBefore: now, Username: "username", - Audience: "audience", + Audience: []string{"audience"}, Issuer: "issuer", }, }, diff --git a/cmd/server/migrate.go b/cmd/server/migrate.go index 14227c4f8..5f66dd0b2 100644 --- a/cmd/server/migrate.go +++ b/cmd/server/migrate.go @@ -61,22 +61,25 @@ func getMigrationSql(cmd *cobra.Command, args []string, logger *logrus.Logger) ( func RunMigrateSQL(logger *logrus.Logger) func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { - dbUrl, u := getMigrationSql(cmd, args, logger) + db, dbu := getMigrationSql(cmd, args, logger) + if dbu.Scheme != "postgres" && dbu.Scheme != "mysql" { + logger.WithField("database_url", dbu.Scheme+"://*:*@"+dbu.Host+dbu.Path+"?"+dbu.RawQuery).Fatal("Migrations can only be run against PostgreSQL or MySQL databases") + } - db, err := connectToSql(dbUrl, u.Scheme) + managers, err := newManagers(db, logger) if err != nil { - logger.WithError(err).WithField("database_url", u.Scheme+"://*:*@"+u.Host+u.Path+"?"+u.RawQuery).Fatal("Unable to parse DATABASE_URL, make sure it has the right format") + logger.WithError(err).WithField("database_url", dbu.Scheme+"://*:*@"+dbu.Host+dbu.Path+"?"+dbu.RawQuery).Fatal("Unable to parse DATABASE_URL, make sure it has the right format") } logger.Info("Applying SQL migrations...") - if n, err := role.NewSQLManager(db).CreateSchemas(); err != nil { - logger.WithError(err).WithField("migrations", n).WithField("table", "policies").Print("An error occurred while trying to apply SQL migrations") + if n, err := managers.roleManager.(*role.SQLManager).CreateSchemas(); err != nil { + logger.WithError(err).WithField("migrations", n).WithField("table", "policies").Fatal("An error occurred while trying to apply SQL migrations") } else { logger.WithField("migrations", n).WithField("table", "role").Print("Successfully applied SQL migrations") } - if n, err := sql.NewSQLManager(db, nil).CreateSchemas("", "keto_policy_migrations"); err != nil { - logger.WithError(err).WithField("migrations", n).WithField("table", "policies").Print("An error occurred while trying to apply SQL migrations") + if n, err := managers.policyManager.(*sql.SQLManager).CreateSchemas("", "keto_policy_migrations"); err != nil { + logger.WithError(err).WithField("migrations", n).WithField("table", "policies").Fatal("An error occurred while trying to apply SQL migrations") } else { logger.WithField("migrations", n).WithField("table", "policies").Print("Successfully applied SQL migrations") } @@ -89,15 +92,15 @@ func RunMigrateHydra(logger *logrus.Logger) func(cmd *cobra.Command, args []stri return func(cmd *cobra.Command, args []string) { dbUrl, u := getMigrationSql(cmd, args, logger) - db, err := connectToSql(dbUrl, u.Scheme) + db, err := connectToSQL(dbUrl, logger) if err != nil { logger.WithError(err).WithField("database_url", u.Scheme+"://*:*@"+u.Host+u.Path+"?"+u.RawQuery).Fatal("Unable to parse DATABASE_URL, make sure it has the right format") } migrate.SetTable("keto_legacy_hydra_migrations") - n, err := migrate.Exec(db.DB, db.DriverName(), legacy.HydraLegacyMigrations[db.DriverName()], migrate.Up) + n, err := migrate.Exec(db.GetDatabase().DB, db.GetDatabase().DriverName(), legacy.HydraLegacyMigrations[db.GetDatabase().DriverName()], migrate.Up) if err != nil { - logger.WithError(err).WithField("migrations", n).Print("An error occurred while trying to apply SQL migrations") + logger.WithError(err).WithField("migrations", n).Fatal("An error occurred while trying to apply SQL migrations") } logger.WithField("migrations", n).Print("Successfully applied SQL migrations") logger.Info("Done applying SQL migrations") diff --git a/cmd/server/sql.go b/cmd/server/sql.go index be2ca6787..d0cb6a33e 100644 --- a/cmd/server/sql.go +++ b/cmd/server/sql.go @@ -23,11 +23,9 @@ package server import ( "net/url" - "runtime" "time" _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx" _ "github.com/lib/pq" "github.com/ory/keto/role" "github.com/ory/ladon" @@ -38,30 +36,6 @@ import ( "github.com/sirupsen/logrus" ) -func connectToSql(url string, dbt string) (*sqlx.DB, error) { - db, err := sqlx.Open(dbt, url) - if err != nil { - return nil, errors.WithStack(err) - } - - maxConns := maxParallelism() * 2 - maxConnLifetime := time.Duration(0) - maxIdleConns := maxParallelism() - db.SetMaxOpenConns(maxConns) - db.SetMaxIdleConns(maxIdleConns) - db.SetConnMaxLifetime(maxConnLifetime) - return db, nil -} - -func maxParallelism() int { - maxProcs := runtime.GOMAXPROCS(0) - numCPU := runtime.NumCPU() - if maxProcs < numCPU { - return maxProcs - } - return numCPU -} - type managers struct { roleManager role.Manager policyManager ladon.Manager @@ -86,7 +60,7 @@ func newManagers(db string, logger logrus.FieldLogger) (*managers, error) { case "postgres": fallthrough case "mysql": - sdb, err := sqlcon.NewSQLConnection(db, logger) + sdb, err := connectToSQL(db, logger) if err != nil { return nil, errors.WithStack(err) } @@ -99,3 +73,47 @@ func newManagers(db string, logger logrus.FieldLogger) (*managers, error) { return nil, errors.Errorf("The provided database URL %s can not be handled", db) } + +func retry(logger logrus.FieldLogger, maxWait time.Duration, failAfter time.Duration, f func() error) (err error) { + var lastStart time.Time + err = errors.New("Did not connect.") + loopWait := time.Millisecond * 100 + retryStart := time.Now().UTC() + for retryStart.Add(failAfter).After(time.Now().UTC()) { + lastStart = time.Now().UTC() + if err = f(); err == nil { + return nil + } + + if lastStart.Add(maxWait * 2).Before(time.Now().UTC()) { + retryStart = time.Now().UTC() + } + + logger.WithError(err).Infof("Retrying in %f seconds...", loopWait.Seconds()) + time.Sleep(loopWait) + loopWait = loopWait * time.Duration(int64(2)) + if loopWait > maxWait { + loopWait = maxWait + } + } + return err +} + +func connectToSQL(db string, logger logrus.FieldLogger) (sdb *sqlcon.SQLConnection, err error) { + if err := retry(logger, time.Minute, time.Minute*15, func() error { + var err error + sdb, err = sqlcon.NewSQLConnection(db, logger) + if err != nil { + return errors.WithStack(err) + } + + if err := sdb.GetDatabase().Ping(); err != nil { + return errors.WithStack(err) + } + return nil + }); err != nil { + return nil, errors.WithStack(err) + } + + return sdb, nil +}