From ad607653c13eb76ba3be5b17a5203a08b2ea11af Mon Sep 17 00:00:00 2001 From: Nikhil Saraf <1028334+nikhilsaraf@users.noreply.github.com> Date: Sat, 11 Jul 2020 00:31:57 +0530 Subject: [PATCH] Database Schema Test Infrastructure also tests indexes on tables --- cmd/trade_test.go | 15 ++++++++++- support/database/upgrade_test.go | 5 ++++ support/database/upgrade_test_helper.go | 33 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/cmd/trade_test.go b/cmd/trade_test.go index e6ada40a1..cad613c89 100644 --- a/cmd/trade_test.go +++ b/cmd/trade_test.go @@ -69,6 +69,10 @@ func TestTradeUpgradeScripts(t *testing.T) { DataType: "text", CharacterMaximumLength: nil, }, &columns[4]) + // check indexes of db_version table + indexes := database.GetTableIndexes(db, "db_version") + assert.Equal(t, 1, len(indexes)) + database.AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes) // check schema of markets table columns = database.GetTableSchema(db, "markets") @@ -105,6 +109,10 @@ func TestTradeUpgradeScripts(t *testing.T) { DataType: "text", CharacterMaximumLength: nil, }, &columns[3]) + // check indexes of markets table + indexes = database.GetTableIndexes(db, "markets") + assert.Equal(t, 1, len(indexes)) + database.AssertIndex(t, "markets", "markets_pkey", "CREATE UNIQUE INDEX markets_pkey ON public.markets USING btree (market_id)", indexes) // check schema of trades table columns = database.GetTableSchema(db, "trades") @@ -181,6 +189,11 @@ func TestTradeUpgradeScripts(t *testing.T) { DataType: "double precision", CharacterMaximumLength: nil, }, &columns[8]) + // check indexes of trades table + indexes = database.GetTableIndexes(db, "trades") + assert.Equal(t, 2, len(indexes)) + database.AssertIndex(t, "trades", "trades_pkey", "CREATE UNIQUE INDEX trades_pkey ON public.trades USING btree (market_id, txid)", indexes) + database.AssertIndex(t, "trades", "trades_mdd", "CREATE INDEX trades_mdd ON public.trades USING btree (market_id, date(date_utc), date_utc)", indexes) // check entries of db_version table var allRows [][]interface{} @@ -189,7 +202,7 @@ func TestTradeUpgradeScripts(t *testing.T) { // first three code_version_string is nil becuase the field was not supported at the time when the upgrade script was run, and only in version 4 of // the database do we add the field. See upgradeScripts and RunUpgradeScripts() for more details database.ValidateDBVersionRow(t, allRows[0], 1, time.Now(), 1, 10, nil) - database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 10, nil) + database.ValidateDBVersionRow(t, allRows[1], 2, time.Now(), 3, 15, nil) database.ValidateDBVersionRow(t, allRows[2], 3, time.Now(), 2, 10, nil) database.ValidateDBVersionRow(t, allRows[3], 4, time.Now(), 1, 10, &codeVersionString) diff --git a/support/database/upgrade_test.go b/support/database/upgrade_test.go index b028a688d..7b715ec03 100644 --- a/support/database/upgrade_test.go +++ b/support/database/upgrade_test.go @@ -81,6 +81,11 @@ func TestUpgradeScripts(t *testing.T) { CharacterMaximumLength: nil, }, &columns[4]) + // check indexes of db_version table + indexes := GetTableIndexes(db, "db_version") + assert.Equal(t, 1, len(indexes)) + AssertIndex(t, "db_version", "db_version_pkey", "CREATE UNIQUE INDEX db_version_pkey ON public.db_version USING btree (version)", indexes) + // check entries of db_version table allRows := QueryAllRows(db, "db_version") assert.Equal(t, 2, len(allRows)) diff --git a/support/database/upgrade_test_helper.go b/support/database/upgrade_test_helper.go index 2ddb9f321..852c3b592 100644 --- a/support/database/upgrade_test_helper.go +++ b/support/database/upgrade_test_helper.go @@ -133,6 +133,39 @@ func GetTableSchema(db *sql.DB, tableName string) []TableColumn { return items } +// IndexSearchResult captures the result from GetTableIndexes() and is used as input to AssertIndex() +type IndexSearchResult map[string]string + +// GetTableIndexes is well-named +func GetTableIndexes(db *sql.DB, tableName string) IndexSearchResult { + indexQueryResult, e := db.Query(fmt.Sprintf("SELECT indexname, indexdef from pg_indexes where schemaname = 'public' AND tablename = '%s'", tableName)) + if e != nil { + panic(e) + } + defer indexQueryResult.Close() // remembering to defer closing the query + + m := map[string]string{} + for indexQueryResult.Next() { // remembering to call Next() before Scan() + var name, def string + e = indexQueryResult.Scan(&name, &def) + if e != nil { + panic(e) + } + + m[name] = def + } + + return m +} + +// AssertIndex validates that the index exists +func AssertIndex(t *testing.T, tableName string, wantIndexName string, wantDefinition string, indexes IndexSearchResult) { + m := map[string]string(indexes) + if v, ok := m[wantIndexName]; assert.True(t, ok, fmt.Sprintf("index '%s' should exist in the table '%s'", wantIndexName, tableName)) { + assert.Equal(t, wantDefinition, v) + } +} + // QueryAllRows queries all the rows of a given table in a database func QueryAllRows(db *sql.DB, tableName string) [][]interface{} { queryResult, e := db.Query(fmt.Sprintf("SELECT * FROM %s", tableName))