From 9e969a03ce753707f2c6045113297e01447ad033 Mon Sep 17 00:00:00 2001 From: testtest959 <68299987+testtest959@users.noreply.github.com> Date: Fri, 19 Mar 2021 00:03:57 -0400 Subject: [PATCH] Postgres: Add a check to determine if table already exists to elide CREATE query (#526) * Squash commits * Format * Minor refactoring * Address PR feedback; Add mustRun * Fix a test assert --- database/pgx/pgx.go | 19 +++- database/pgx/pgx_test.go | 148 +++++++++++++++++++++++++++- database/postgres/postgres.go | 19 +++- database/postgres/postgres_test.go | 151 +++++++++++++++++++++++++++-- 4 files changed, 325 insertions(+), 12 deletions(-) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 8e49638ed..dda70a77c 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -429,7 +429,24 @@ func (p *Postgres) ensureVersionTable() (err error) { } }() - query := `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` + // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres + // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the + // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. + // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 + var count int + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) + + err = row.Scan(&count) + if err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + if count == 1 { + return nil + } + + query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index ec58c4aae..a79f9611e 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" sqldriver "database/sql/driver" + "errors" "fmt" "log" @@ -76,6 +77,14 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return true } +func mustRun(t *testing.T, d database.Driver, statements []string) { + for _, statement := range statements { + if err := d.Run(strings.NewReader(statement)); err != nil { + t.Fatal(err) + } + } +} + func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -309,6 +318,141 @@ func TestWithSchema(t *testing.T) { }) } +func TestFailToCreateTableWithoutPermissions(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + + // Check that opening the postgres connection returns NilVersion + p := &Postgres{} + + d, err := p.Open(addr) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine + // since this is a test environment and we're not expecting to the pgPassword to be malicious + mustRun(t, d, []string{ + "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", + "CREATE SCHEMA barfoo AUTHORIZATION postgres", + "GRANT USAGE ON SCHEMA barfoo TO not_owner", + "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", + "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", + }) + + // re-connect using that schema + d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + defer func() { + if d2 == nil { + return + } + if err := d2.Close(); err != nil { + t.Fatal(err) + } + }() + + var e *database.Error + if !errors.As(err, &e) || err == nil { + t.Fatal("Unexpected error, want permission denied error. Got: ", err) + } + + if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { + t.Fatal(e) + } + }) +} + +func TestCheckBeforeCreateTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + + // Check that opening the postgres connection returns NilVersion + p := &Postgres{} + + d, err := p.Open(addr) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine + // since this is a test environment and we're not expecting to the pgPassword to be malicious + mustRun(t, d, []string{ + "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", + "CREATE SCHEMA barfoo AUTHORIZATION postgres", + "GRANT USAGE ON SCHEMA barfoo TO not_owner", + "GRANT CREATE ON SCHEMA barfoo TO not_owner", + }) + + // re-connect using that schema + d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + if err != nil { + t.Fatal(err) + } + + if err := d2.Close(); err != nil { + t.Fatal(err) + } + + // revoke privileges + mustRun(t, d, []string{ + "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", + "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", + }) + + // re-connect using that schema + d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + if err != nil { + t.Fatal(err) + } + + version, _, err := d3.Version() + + if err != nil { + t.Fatal(err) + } + + if version != database.NilVersion { + t.Fatal("Unexpected version, want database.NilVersion. Got: ", version) + } + + defer func() { + if err := d3.Close(); err != nil { + t.Fatal(err) + } + }() + }) +} + func TestParallelSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -375,10 +519,6 @@ func TestParallelSchema(t *testing.T) { }) } -func TestWithInstance(t *testing.T) { - -} - func TestPostgres_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 0432db223..4926e1d84 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -426,7 +426,24 @@ func (p *Postgres) ensureVersionTable() (err error) { } }() - query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` + // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres + // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the + // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. + // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 + var count int + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) + + err = row.Scan(&count) + if err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + if count == 1 { + return nil + } + + query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 3d2661f26..c6cc9e7f8 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -6,11 +6,11 @@ import ( "context" "database/sql" sqldriver "database/sql/driver" + "errors" "fmt" - "log" - "github.com/golang-migrate/migrate/v4" "io" + "log" "strconv" "strings" "sync" @@ -75,6 +75,14 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return true } +func mustRun(t *testing.T, d database.Driver, statements []string) { + for _, statement := range statements { + if err := d.Run(strings.NewReader(statement)); err != nil { + t.Fatal(err) + } + } +} + func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -310,6 +318,141 @@ func TestWithSchema(t *testing.T) { }) } +func TestFailToCreateTableWithoutPermissions(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + + // Check that opening the postgres connection returns NilVersion + p := &Postgres{} + + d, err := p.Open(addr) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine + // since this is a test environment and we're not expecting to the pgPassword to be malicious + mustRun(t, d, []string{ + "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", + "CREATE SCHEMA barfoo AUTHORIZATION postgres", + "GRANT USAGE ON SCHEMA barfoo TO not_owner", + "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", + "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", + }) + + // re-connect using that schema + d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + defer func() { + if d2 == nil { + return + } + if err := d2.Close(); err != nil { + t.Fatal(err) + } + }() + + var e *database.Error + if !errors.As(err, &e) || err == nil { + t.Fatal("Unexpected error, want permission denied error. Got: ", err) + } + + if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { + t.Fatal(e) + } + }) +} + +func TestCheckBeforeCreateTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + + // Check that opening the postgres connection returns NilVersion + p := &Postgres{} + + d, err := p.Open(addr) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine + // since this is a test environment and we're not expecting to the pgPassword to be malicious + mustRun(t, d, []string{ + "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", + "CREATE SCHEMA barfoo AUTHORIZATION postgres", + "GRANT USAGE ON SCHEMA barfoo TO not_owner", + "GRANT CREATE ON SCHEMA barfoo TO not_owner", + }) + + // re-connect using that schema + d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + if err != nil { + t.Fatal(err) + } + + if err := d2.Close(); err != nil { + t.Fatal(err) + } + + // revoke privileges + mustRun(t, d, []string{ + "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", + "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", + }) + + // re-connect using that schema + d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + pgPassword, ip, port)) + + if err != nil { + t.Fatal(err) + } + + version, _, err := d3.Version() + + if err != nil { + t.Fatal(err) + } + + if version != database.NilVersion { + t.Fatal("Unexpected version, want database.NilVersion. Got: ", version) + } + + defer func() { + if err := d3.Close(); err != nil { + t.Fatal(err) + } + }() + }) +} + func TestParallelSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -378,10 +521,6 @@ func TestParallelSchema(t *testing.T) { }) } -func TestWithInstance(t *testing.T) { - -} - func TestPostgres_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort()