Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PGX for Database Conections to Lookout #915

Merged
merged 5 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ require (
github.com/grpc-ecosystem/grpc-gateway v1.16.0
github.com/hashicorp/golang-lru v0.5.3 // indirect
github.com/instrumenta/kubeval v0.0.0-20190918223246-8d013ec9fc56
github.com/jackc/pgx/v4 v4.15.0
github.com/jcmturner/gokrb5/v8 v8.4.2-0.20201112171129-78f56934d598
github.com/lib/pq v1.10.4
github.com/mitchellh/go-homedir v1.1.0
github.com/mitchellh/mapstructure v1.4.2
github.com/nats-io/jsm.go v0.0.26
Expand Down
89 changes: 89 additions & 0 deletions go.sum

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion internal/lookout/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/doug-martin/goqu/v9"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
_ "github.com/lib/pq"
"github.com/nats-io/jsm.go"
"github.com/nats-io/stan.go"
log "github.com/sirupsen/logrus"
Expand Down
4 changes: 2 additions & 2 deletions internal/lookout/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"database/sql"
"strings"

_ "github.com/lib/pq"
_ "github.com/jackc/pgx/v4/stdlib"

"github.com/G-Research/armada/internal/lookout/configuration"
)

func Open(config configuration.PostgresConfig) (*sql.DB, error) {
db, err := sql.Open("postgres", createConnectionString(config.Connection))
db, err := sql.Open("pgx", createConnectionString(config.Connection))
if err != nil {
return nil, err
}
Expand Down
33 changes: 16 additions & 17 deletions internal/lookout/repository/job_sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/gogo/protobuf/types"
"github.com/lib/pq"

"github.com/G-Research/armada/pkg/api/lookout"
)
Expand All @@ -20,21 +19,21 @@ type jobSetCountsRow struct {
Running sql.NullInt64 `db:"running"`
Succeeded sql.NullInt64 `db:"succeeded"`
Failed sql.NullInt64 `db:"failed"`
Submitted pq.NullTime `db:"submitted"`

RunningStatsMin pq.NullTime `db:"running_min"`
RunningStatsMax pq.NullTime `db:"running_max"`
RunningStatsAverage pq.NullTime `db:"running_average"`
RunningStatsMedian pq.NullTime `db:"running_median"`
RunningStatsQ1 pq.NullTime `db:"running_q1"`
RunningStatsQ3 pq.NullTime `db:"running_q3"`

QueuedStatsMin pq.NullTime `db:"queued_min"`
QueuedStatsMax pq.NullTime `db:"queued_max"`
QueuedStatsAverage pq.NullTime `db:"queued_average"`
QueuedStatsMedian pq.NullTime `db:"queued_median"`
QueuedStatsQ1 pq.NullTime `db:"queued_q1"`
QueuedStatsQ3 pq.NullTime `db:"queued_q3"`
Submitted sql.NullTime `db:"submitted"`

RunningStatsMin sql.NullTime `db:"running_min"`
RunningStatsMax sql.NullTime `db:"running_max"`
RunningStatsAverage sql.NullTime `db:"running_average"`
RunningStatsMedian sql.NullTime `db:"running_median"`
RunningStatsQ1 sql.NullTime `db:"running_q1"`
RunningStatsQ3 sql.NullTime `db:"running_q3"`

QueuedStatsMin sql.NullTime `db:"queued_min"`
QueuedStatsMax sql.NullTime `db:"queued_max"`
QueuedStatsAverage sql.NullTime `db:"queued_average"`
QueuedStatsMedian sql.NullTime `db:"queued_median"`
QueuedStatsQ1 sql.NullTime `db:"queued_q1"`
QueuedStatsQ3 sql.NullTime `db:"queued_q3"`
}

func (r *SQLJobRepository) GetJobSetInfos(ctx context.Context, opts *lookout.GetJobSetsRequest) ([]*lookout.JobSetInfo, error) {
Expand Down Expand Up @@ -216,7 +215,7 @@ func (r *SQLJobRepository) rowsToJobSets(rows []*jobSetCountsRow, queue string)
return jobSetInfos
}

func getProtoDuration(currentTime time.Time, maybeTime pq.NullTime) *types.Duration {
func getProtoDuration(currentTime time.Time, maybeTime sql.NullTime) *types.Duration {
var duration *types.Duration
if maybeTime.Valid {
duration = types.DurationProto(currentTime.Sub(maybeTime.Time))
Expand Down
92 changes: 58 additions & 34 deletions internal/lookout/repository/queues.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package repository
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"time"

"github.com/doug-martin/goqu/v9"
Expand All @@ -21,32 +19,68 @@ type countsRow struct {
Running uint32 `db:"running"`
}

type rowsSql struct {
Counts string
OldestQueued string
LongestRunning string
}

func (r *SQLJobRepository) GetQueueInfos(ctx context.Context) ([]*lookout.QueueInfo, error) {
queries, err := r.getQueuesSql()
if err != nil {
return nil, err
}

rows, err := r.goquDb.Db.QueryContext(ctx, queries)
if err != nil {
return nil, err
}
var countRows *sql.Rows
var longestRunningRows *sql.Rows
var oldestQueuedRows *sql.Rows

defer func() {
err := rows.Close()
if err != nil {
logrus.Fatalf("Failed to close SQL connection: %v", err)
if countRows != nil {
err := countRows.Close()
if err != nil {
logrus.Errorf("Failed to close SQL connection: %v", err)
}
}

if longestRunningRows != nil {
err := longestRunningRows.Close()
if err != nil {
logrus.Errorf("Failed to close SQL connection: %v", err)
}
}

if oldestQueuedRows != nil {
err := oldestQueuedRows.Close()
if err != nil {
logrus.Errorf("Failed to close SQL connection: %v", err)
}
}
}()
countRows, err = r.goquDb.Db.QueryContext(ctx, queries.Counts)
if err != nil {
return nil, err
}

result, err := r.rowsToQueues(rows)
longestRunningRows, err = r.goquDb.Db.QueryContext(ctx, queries.LongestRunning)
if err != nil {
return nil, err
}

oldestQueuedRows, err = r.goquDb.Db.QueryContext(ctx, queries.OldestQueued)
if err != nil {
return nil, err
}

result, err := r.rowsToQueues(countRows, oldestQueuedRows, longestRunningRows)
if err != nil {
return nil, err
}

return result, nil
}

func (r *SQLJobRepository) getQueuesSql() (string, error) {
func (r *SQLJobRepository) getQueuesSql() (rowsSql, error) {
countsDs := r.goquDb.
From(jobTable).
Select(
Expand Down Expand Up @@ -111,53 +145,43 @@ func (r *SQLJobRepository) getQueuesSql() (string, error) {
jobRun_created,
jobRun_started,
jobRun_finished).
Order(jobRun_runId.Asc(), jobRun_started.Asc()).
As("longest_running")

countsSql, _, err := countsDs.ToSQL()
if err != nil {
return "", err
return rowsSql{}, err
}
oldestQueuedSql, _, err := oldestQueuedDs.ToSQL()
if err != nil {
return "", err
return rowsSql{}, err
}
longestRunningSql, _, err := longestRunningDs.ToSQL()
if err != nil {
return "", err
return rowsSql{}, err
}

// Execute three unprepared statements sequentially.
// There are no parameters and we don't care if updates happen between queries.
return strings.Join([]string{countsSql, oldestQueuedSql, longestRunningSql}, " ; "), nil
return rowsSql{countsSql, oldestQueuedSql, longestRunningSql}, nil
}

func (r *SQLJobRepository) rowsToQueues(rows *sql.Rows) ([]*lookout.QueueInfo, error) {
func (r *SQLJobRepository) rowsToQueues(counts *sql.Rows, oldestQueued *sql.Rows, longestRunning *sql.Rows) ([]*lookout.QueueInfo, error) {
queueInfoMap := make(map[string]*lookout.QueueInfo)

// Job counts
err := setJobCounts(rows, queueInfoMap)
err := setJobCounts(counts, queueInfoMap)
if err != nil {
return nil, err
}

// Oldest queued
if rows.NextResultSet() {
err = r.setOldestQueuedJob(rows, queueInfoMap)
if err != nil {
return nil, err
}
} else if rows.Err() != nil {
return nil, fmt.Errorf("expected result set for oldest queued job: %v", rows.Err())
err = r.setOldestQueuedJob(oldestQueued, queueInfoMap)
if err != nil {
return nil, err
}

// Longest Running
if rows.NextResultSet() {
err = r.setLongestRunningJob(rows, queueInfoMap)
if err != nil {
return nil, err
}
} else if rows.Err() != nil {
return nil, fmt.Errorf("expected result set for longest Running job: %v", rows.Err())
err = r.setLongestRunningJob(longestRunning, queueInfoMap)
if err != nil {
return nil, err
}

result := getSortedQueueInfos(queueInfoMap)
Expand Down
15 changes: 8 additions & 7 deletions internal/lookout/repository/schema/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@ type migration struct {
func UpdateDatabase(db *sql.DB) error {
log.Info("Updating database...")
version, err := readVersion(db)
log.Infof("Current version %v", version)

if err != nil {
return err
}

log.Infof("Current version %v", version)
migrations, err := getMigrations()
if err != nil {
return err
}

for _, m := range migrations {
if m.id > version {
log.Infof("Migration %v", m.name)

_, err := db.Exec(m.sql)
if err != nil {
return err
Expand All @@ -54,9 +50,14 @@ func UpdateDatabase(db *sql.DB) error {
}

func readVersion(db *sql.DB) (int, error) {
_, err := db.Exec(
`CREATE SEQUENCE IF NOT EXISTS database_version START WITH 0 MINVALUE 0;`)
if err != nil {
return 0, err
}

result, err := db.Query(
`CREATE SEQUENCE IF NOT EXISTS database_version START WITH 0 MINVALUE 0;
SELECT last_value FROM database_version`)
`SELECT last_value FROM database_version`)
if err != nil {
return 0, err
}
Expand Down
11 changes: 5 additions & 6 deletions internal/lookout/repository/sql_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/doug-martin/goqu/v9"
_ "github.com/doug-martin/goqu/v9/dialect/postgres"
"github.com/lib/pq"

"github.com/G-Research/armada/pkg/api/lookout"
)
Expand Down Expand Up @@ -80,17 +79,17 @@ type JobRow struct {
Owner sql.NullString `db:"owner"`
JobSet sql.NullString `db:"jobset"`
Priority sql.NullFloat64 `db:"priority"`
Submitted pq.NullTime `db:"submitted"`
Cancelled pq.NullTime `db:"cancelled"`
Submitted sql.NullTime `db:"submitted"`
Cancelled sql.NullTime `db:"cancelled"`
JobJson sql.NullString `db:"job"`
State sql.NullInt64 `db:"state"`
RunId sql.NullString `db:"run_id"`
PodNumber sql.NullInt64 `db:"pod_number"`
Cluster sql.NullString `db:"cluster"`
Node sql.NullString `db:"node"`
Created pq.NullTime `db:"created"`
Started pq.NullTime `db:"started"`
Finished pq.NullTime `db:"finished"`
Created sql.NullTime `db:"created"`
Started sql.NullTime `db:"started"`
Finished sql.NullTime `db:"finished"`
Succeeded sql.NullBool `db:"succeeded"`
Error sql.NullString `db:"error"`
}
Expand Down
1 change: 0 additions & 1 deletion internal/lookout/repository/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
_ "github.com/lib/pq"

"github.com/G-Research/armada/internal/common/util"
"github.com/G-Research/armada/pkg/api"
Expand Down
5 changes: 2 additions & 3 deletions internal/lookout/repository/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/lib/pq"

"github.com/G-Research/armada/internal/common/util"
)
Expand Down Expand Up @@ -97,14 +96,14 @@ func ParseNullFloat(nullFloat sql.NullFloat64) float64 {
return nullFloat.Float64
}

func ParseNullTime(nullTime pq.NullTime) *time.Time {
func ParseNullTime(nullTime sql.NullTime) *time.Time {
if !nullTime.Valid {
return nil
}
return &nullTime.Time
}

func ParseNullTimeDefault(nullTime pq.NullTime) time.Time {
func ParseNullTimeDefault(nullTime sql.NullTime) time.Time {
if !nullTime.Valid {
return time.Time{}
}
Expand Down
5 changes: 3 additions & 2 deletions internal/lookout/testutil/db_testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"testing"

_ "github.com/jackc/pgx/v4/stdlib"
"github.com/stretchr/testify/assert"

"github.com/G-Research/armada/internal/common/util"
Expand All @@ -13,15 +14,15 @@ import (
func WithDatabase(t *testing.T, action func(db *sql.DB)) {
dbName := "test_" + util.NewULID()
connectionString := "host=localhost port=5432 user=postgres password=psw sslmode=disable"
db, err := sql.Open("postgres", connectionString)
db, err := sql.Open("pgx", connectionString)
defer db.Close()

assert.Nil(t, err)

_, err = db.Exec("CREATE DATABASE " + dbName)
assert.Nil(t, err)

testDb, err := sql.Open("postgres", connectionString+" dbname="+dbName)
testDb, err := sql.Open("pgx", connectionString+" dbname="+dbName)
assert.Nil(t, err)

defer func() {
Expand Down