From f62164474106bb096b2f694bd1d89ce93df3e744 Mon Sep 17 00:00:00 2001 From: Andreas Kluth Date: Mon, 29 Nov 2021 15:50:04 +0100 Subject: [PATCH] Add WithConnection to Postgres similar to MySQL. --- database/postgres/postgres.go | 46 +++++++++++++++++++++--------- database/postgres/postgres_test.go | 38 ++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 82919171d..59cb569de 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "fmt" - "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -16,10 +15,12 @@ import ( "strings" "time" + "go.uber.org/atomic" + "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/database/multistmt" - multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-multierror" "github.com/lib/pq" ) @@ -65,19 +66,19 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -91,7 +92,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -119,15 +120,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ conn: conn, - db: instance, config: config, } @@ -138,6 +132,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err + } + + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err + } + + px, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + px.db = instance + return px, nil +} + func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -207,7 +221,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) { func (p *Postgres) Close() error { connErr := p.conn.Close() - dbErr := p.db.Close() + var dbErr error + if p.db != nil { + dbErr = p.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index c4ed3560f..65395cc7e 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -684,6 +684,44 @@ func TestWithInstance_Concurrent(t *testing.T) { } }) } + +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("postgres", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func Test_computeLineFromPos(t *testing.T) { testcases := []struct { pos int