Skip to content
This repository has been archived by the owner on Feb 1, 2024. It is now read-only.

Commit

Permalink
Database Schema Test Infrastructure also tests indexes on tables
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilsaraf committed Jul 10, 2020
1 parent fea5021 commit ad60765
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
15 changes: 14 additions & 1 deletion cmd/trade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func TestTradeUpgradeScripts(t *testing.T) {
DataType: "text",
CharacterMaximumLength: nil,
}, &columns[4])
// check indexes of db_version table
indexes := database.GetTableIndexes(db, "db_version")
assert.Equal(t, 1, len(indexes))
database.AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes)

// check schema of markets table
columns = database.GetTableSchema(db, "markets")
Expand Down Expand Up @@ -105,6 +109,10 @@ func TestTradeUpgradeScripts(t *testing.T) {
DataType: "text",
CharacterMaximumLength: nil,
}, &columns[3])
// check indexes of markets table
indexes = database.GetTableIndexes(db, "markets")
assert.Equal(t, 1, len(indexes))
database.AssertIndex(t, "markets", "markets_pkey", "CREATE UNIQUE INDEX markets_pkey ON public.markets USING btree (market_id)", indexes)

// check schema of trades table
columns = database.GetTableSchema(db, "trades")
Expand Down Expand Up @@ -181,6 +189,11 @@ func TestTradeUpgradeScripts(t *testing.T) {
DataType: "double precision",
CharacterMaximumLength: nil,
}, &columns[8])
// check indexes of trades table
indexes = database.GetTableIndexes(db, "trades")
assert.Equal(t, 2, len(indexes))
database.AssertIndex(t, "trades", "trades_pkey", "CREATE UNIQUE INDEX trades_pkey ON public.trades USING btree (market_id, txid)", indexes)
database.AssertIndex(t, "trades", "trades_mdd", "CREATE INDEX trades_mdd ON public.trades USING btree (market_id, date(date_utc), date_utc)", indexes)

// check entries of db_version table
var allRows [][]interface{}
Expand All @@ -189,7 +202,7 @@ func TestTradeUpgradeScripts(t *testing.T) {
// first three code_version_string is nil becuase the field was not supported at the time when the upgrade script was run, and only in version 4 of
// the database do we add the field. See upgradeScripts and RunUpgradeScripts() for more details
database.ValidateDBVersionRow(t, allRows[0], 1, time.Now(), 1, 10, nil)
database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 10, nil)
database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 15, nil)
database.ValidateDBVersionRow(t, allRows[2], 3, time.Now(), 2, 10, nil)
database.ValidateDBVersionRow(t, allRows[3], 4, time.Now(), 1, 10, &codeVersionString)

Expand Down
5 changes: 5 additions & 0 deletions support/database/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func TestUpgradeScripts(t *testing.T) {
CharacterMaximumLength: nil,
}, &columns[4])

// check indexes of db_version table
indexes := GetTableIndexes(db, "db_version")
assert.Equal(t, 1, len(indexes))
AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes)

// check entries of db_version table
allRows := QueryAllRows(db, "db_version")
assert.Equal(t, 2, len(allRows))
Expand Down
33 changes: 33 additions & 0 deletions support/database/upgrade_test_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,39 @@ func GetTableSchema(db *sql.DB, tableName string) []TableColumn {
return items
}

// IndexSearchResult captures the result from GetTableIndexes() and is used as input to AssertIndex()
type IndexSearchResult map[string]string

// GetTableIndexes is well-named
func GetTableIndexes(db *sql.DB, tableName string) IndexSearchResult {
indexQueryResult, e := db.Query(fmt.Sprintf("SELECT indexname, indexdef from pg_indexes where schemaname = 'public' AND tablename = '%s'", tableName))
if e != nil {
panic(e)
}
defer indexQueryResult.Close() // remembering to defer closing the query

m := map[string]string{}
for indexQueryResult.Next() { // remembering to call Next() before Scan()
var name, def string
e = indexQueryResult.Scan(&name, &def)
if e != nil {
panic(e)
}

m[name] = def
}

return m
}

// AssertIndex validates that the index exists
func AssertIndex(t *testing.T, tableName string, wantIndexName string, wantDefinition string, indexes IndexSearchResult) {
m := map[string]string(indexes)
if v, ok := m[wantIndexName]; assert.True(t, ok, fmt.Sprintf("index '%s' should exist in the table '%s'", wantIndexName, tableName)) {
assert.Equal(t, wantDefinition, v)
}
}

// QueryAllRows queries all the rows of a given table in a database
func QueryAllRows(db *sql.DB, tableName string) [][]interface{} {
queryResult, e := db.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
Expand Down

0 comments on commit ad60765

Please sign in to comment.