diff --git a/storage/postgres/client.go b/storage/postgres/client.go index c84ced993..5c2362edc 100644 --- a/storage/postgres/client.go +++ b/storage/postgres/client.go @@ -237,21 +237,7 @@ func (c *Client) listNexusTables(ctx context.Context) ([]string, error) { return tables, nil } -// Wipe removes all contents of the database. -func (c *Client) Wipe(ctx context.Context) error { - tables, err := c.listNexusTables(ctx) - if err != nil { - return err - } - for _, table := range tables { - c.logger.Info("dropping table", "table", table) - if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s CASCADE;", table)); err != nil { - return err - } - } - - // List, then drop all custom types. - // Query from https://stackoverflow.com/questions/3660787/how-to-list-custom-types-using-postgres-information-schema +func (c *Client) listNexusTypes(ctx context.Context) ([]string, error) { rows, err := c.Query(ctx, ` SELECT n.nspname as schema, t.typname as type FROM pg_type t @@ -261,59 +247,109 @@ func (c *Client) Wipe(ctx context.Context) error { AND n.nspname != 'information_schema' AND n.nspname NOT LIKE 'pg_%'; `) if err != nil { - return fmt.Errorf("failed to list types: %w", err) + return nil, fmt.Errorf("list types: %w", err) } + + types := []string{} defer rows.Close() // Ensure rows is closed even if we return early. for rows.Next() { var schema, typ string if err = rows.Scan(&schema, &typ); err != nil { - return err - } - c.logger.Info("dropping type", "schema", schema, "type", typ) - if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP TYPE %s.%s CASCADE;", schema, typ)); err != nil { - return err + return nil, err } + types = append(types, fmt.Sprintf("%s.%s", schema, typ)) } + return types, nil +} - // List, then drop all custom functions. - rows, err = c.Query(ctx, ` +func (c *Client) listNexusFunctions(ctx context.Context) ([]string, error) { + rows, err := c.Query(ctx, ` SELECT n.nspname as schema, p.proname as function FROM pg_proc p LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace WHERE n.nspname NOT IN ('pg_catalog', 'information_schema'); `) if err != nil { - return fmt.Errorf("failed to list functions: %w", err) + return nil, fmt.Errorf("failed to list functions: %w", err) } + + functions := []string{} defer rows.Close() // Ensure rows is closed even if we return early. for rows.Next() { var schema, fn string if err = rows.Scan(&schema, &fn); err != nil { - return err - } - c.logger.Info("dropping function", "schema", schema, "function", fn) - if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP FUNCTION %s.%s CASCADE;", schema, fn)); err != nil { - return err + return nil, err } + functions = append(functions, fmt.Sprintf("%s.%s", schema, fn)) } + return functions, nil +} - // List, then drop all materialized views. - rows, err = c.Query(ctx, ` +func (c *Client) listNexusMaterializedViews(ctx context.Context) ([]string, error) { + rows, err := c.Query(ctx, ` SELECT schemaname, matviewname FROM pg_matviews WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' `) if err != nil { - return fmt.Errorf("failed to list materialized views: %w", err) + return nil, fmt.Errorf("failed to list materialized views: %w", err) } + + materializedViews := []string{} defer rows.Close() // Ensure rows is closed even if we return early. for rows.Next() { var schema, view string if err = rows.Scan(&schema, &view); err != nil { + return nil, err + } + materializedViews = append(materializedViews, fmt.Sprintf("%s.%s", schema, view)) + } + return materializedViews, nil +} + +// Wipe removes all contents of the database. +func (c *Client) Wipe(ctx context.Context) error { + tables, err := c.listNexusTables(ctx) + if err != nil { + return err + } + for _, table := range tables { + c.logger.Info("dropping table", "table", table) + if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s CASCADE;", table)); err != nil { return err } - c.logger.Info("dropping materialized view", "schema", schema, "view", view) - if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP MATERIALIZED VIEW %s.%s CASCADE;", schema, view)); err != nil { + } + + // List, then drop all custom types. + // Query from https://stackoverflow.com/questions/3660787/how-to-list-custom-types-using-postgres-information-schema + types, err := c.listNexusTypes(ctx) + for _, typ := range types { + c.logger.Info("dropping type", "type", typ) + if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP TYPE %s CASCADE;", typ)); err != nil { + return err + } + } + + // List, then drop all custom functions. + functions, err := c.listNexusFunctions(ctx) + if err != nil { + return err + } + for _, fn := range functions { + c.logger.Info("dropping function", "function", fn) + if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP FUNCTION %s CASCADE;", fn)); err != nil { + return err + } + } + + // List, then drop all materialized views. + materializedViews, err := c.listNexusMaterializedViews(ctx) + if err != nil { + return err + } + for _, view := range materializedViews { + c.logger.Info("dropping materialized view", "view", view) + if _, err = c.pool.Exec(ctx, fmt.Sprintf("DROP MATERIALIZED VIEW %s CASCADE;", view)); err != nil { return err } }