Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pgxstore: Implement CtxStore interface #204

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions pgxstore/pgxstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgxstore

import (
"context"
"errors"
"log"
"time"

Expand Down Expand Up @@ -33,42 +34,42 @@ func NewWithCleanupInterval(pool *pgxpool.Pool, cleanupInterval time.Duration) *
return p
}

// Find returns the data for a given session token from the PostgresStore instance.
// FindCtx returns the data for a given session token from the PostgresStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (p *PostgresStore) Find(token string) (b []byte, exists bool, err error) {
row := p.pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
func (p *PostgresStore) FindCtx(ctx context.Context, token string) (b []byte, found bool, err error) {
row := p.pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
err = row.Scan(&b)
if err == pgx.ErrNoRows {
if errors.Is(err, pgx.ErrNoRows) {
return nil, false, nil
} else if err != nil {
return nil, false, err
}
return b, true, nil
}

// Commit adds a session token and data to the PostgresStore instance with the
// CommitCtx adds a session token and data to the PostgresStore instance with the
// given expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (p *PostgresStore) Commit(token string, b []byte, expiry time.Time) error {
_, err := p.pool.Exec(context.Background(), "INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
func (p *PostgresStore) CommitCtx(ctx context.Context, token string, b []byte, expiry time.Time) (err error) {
_, err = p.pool.Exec(ctx, "INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
if err != nil {
return err
}
return nil
}

// Delete removes a session token and corresponding data from the PostgresStore
// DeleteCtx removes a session token and corresponding data from the PostgresStore
// instance.
func (p *PostgresStore) Delete(token string) error {
_, err := p.pool.Exec(context.Background(), "DELETE FROM sessions WHERE token = $1", token)
func (p *PostgresStore) DeleteCtx(ctx context.Context, token string) (err error) {
_, err = p.pool.Exec(ctx, "DELETE FROM sessions WHERE token = $1", token)
return err
}

// All returns a map containing the token and data for all active (i.e.
// AllCtx returns a map containing the token and data for all active (i.e.
// not expired) sessions in the PostgresStore instance.
func (p *PostgresStore) All() (map[string][]byte, error) {
rows, err := p.pool.Query(context.Background(), "SELECT token, data FROM sessions WHERE current_timestamp < expiry")
func (p *PostgresStore) AllCtx(ctx context.Context) (map[string][]byte, error) {
rows, err := p.pool.Query(ctx, "SELECT token, data FROM sessions WHERE current_timestamp < expiry")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -135,3 +136,23 @@ func (p *PostgresStore) deleteExpired() error {
_, err := p.pool.Exec(context.Background(), "DELETE FROM sessions WHERE expiry < current_timestamp")
return err
}

// We have to add the plain Store methods here to be recognized a Store
// by the go compiler. Not using a separate type makes any errors caught
// only at runtime instead of compile time.

func (p *PostgresStore) Find(token string) (b []byte, exists bool, err error) {
panic("missing context arg")
}

func (p *PostgresStore) Commit(token string, b []byte, expiry time.Time) error {
panic("missing context arg")
}

func (p *PostgresStore) Delete(token string) error {
panic("missing context arg")
}

func (p *PostgresStore) All() (map[string][]byte, error) {
panic("missing context arg")
}
79 changes: 47 additions & 32 deletions pgxstore/pgxstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@ import (
)

func TestFind(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

b, found, err := p.Find("session_token")
b, found, err := p.FindCtx(ctx, "session_token")
if err != nil {
t.Fatal(err)
}
Expand All @@ -43,21 +45,23 @@ func TestFind(t *testing.T) {
}

func TestFindMissing(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

_, found, err := p.Find("missing_session_token")
_, found, err := p.FindCtx(ctx, "missing_session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
Expand All @@ -67,26 +71,28 @@ func TestFindMissing(t *testing.T) {
}

func TestSaveNew(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = 'session_token'")
var data []byte
err = row.Scan(&data)
if err != nil {
Expand All @@ -98,30 +104,32 @@ func TestSaveNew(t *testing.T) {
}

func TestSaveUpdated(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Commit("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
err = p.CommitCtx(ctx, "session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = 'session_token'")
var data []byte
err = row.Scan(&data)
if err != nil {
Expand All @@ -133,62 +141,67 @@ func TestSaveUpdated(t *testing.T) {
}

func TestExpiry(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)
p := NewWithCleanupInterval(pool, 10*time.Millisecond)
defer p.StopCleanup()

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}

_, found, _ := p.Find("session_token")
_, found, _ := p.FindCtx(ctx, "session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}

time.Sleep(100 * time.Millisecond)
_, found, _ = p.Find("session_token")
_, found, _ = p.FindCtx(ctx, "session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}

func TestDelete(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Delete("session_token")
err = p.DeleteCtx(ctx, "session_token")
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
var count int
err = row.Scan(&count)
if err != nil {
Expand All @@ -200,27 +213,29 @@ func TestDelete(t *testing.T) {
}

func TestCleanup(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 200*time.Millisecond)
defer p.StopCleanup()

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
var count int
err = row.Scan(&count)
if err != nil {
Expand All @@ -231,7 +246,7 @@ func TestCleanup(t *testing.T) {
}

time.Sleep(300 * time.Millisecond)
row = pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
err = row.Scan(&count)
if err != nil {
t.Fatal(err)
Expand Down