From 4d8d8daf0741bf8b13746f7b3bb0cb22d807bdb3 Mon Sep 17 00:00:00 2001 From: James Harris Date: Tue, 19 Mar 2024 09:58:55 +1000 Subject: [PATCH] Pass store directly to journal tests. --- driver/aws/dynamojournal/store_test.go | 8 +- driver/memory/memoryjournal/store_test.go | 8 +- driver/sql/postgres/pgjournal/store_test.go | 12 +- journal/benchmark.go | 40 ++--- journal/telemetry_test.go | 14 +- journal/test.go | 173 +++++++++----------- 6 files changed, 107 insertions(+), 148 deletions(-) diff --git a/driver/aws/dynamojournal/store_test.go b/driver/aws/dynamojournal/store_test.go index 7718f46..50c515e 100644 --- a/driver/aws/dynamojournal/store_test.go +++ b/driver/aws/dynamojournal/store_test.go @@ -15,9 +15,7 @@ func TestStore(t *testing.T) { client, table := setup(t) journal.RunTests( t, - func(t *testing.T) journal.BinaryStore { - return NewBinaryStore(client, table) - }, + NewBinaryStore(client, table), ) } @@ -25,9 +23,7 @@ func BenchmarkStore(b *testing.B) { client, table := setup(b) journal.RunBenchmarks( b, - func(b *testing.B) journal.BinaryStore { - return NewBinaryStore(client, table) - }, + NewBinaryStore(client, table), ) } diff --git a/driver/memory/memoryjournal/store_test.go b/driver/memory/memoryjournal/store_test.go index 0e2ddb8..4ab5ba3 100644 --- a/driver/memory/memoryjournal/store_test.go +++ b/driver/memory/memoryjournal/store_test.go @@ -10,17 +10,13 @@ import ( func TestStore(t *testing.T) { journal.RunTests( t, - func(t *testing.T) journal.BinaryStore { - return &BinaryStore{} - }, + &BinaryStore{}, ) } func BenchmarkStore(b *testing.B) { journal.RunBenchmarks( b, - func(b *testing.B) journal.BinaryStore { - return &BinaryStore{} - }, + &BinaryStore{}, ) } diff --git a/driver/sql/postgres/pgjournal/store_test.go b/driver/sql/postgres/pgjournal/store_test.go index ef17e93..a5afe89 100644 --- a/driver/sql/postgres/pgjournal/store_test.go +++ b/driver/sql/postgres/pgjournal/store_test.go @@ -15,10 +15,8 @@ func TestStore(t *testing.T) { db := setup(t) journal.RunTests( t, - func(t *testing.T) journal.BinaryStore { - return &BinaryStore{ - DB: db, - } + &BinaryStore{ + DB: db, }, ) } @@ -27,10 +25,8 @@ func BenchmarkStore(b *testing.B) { db := setup(b) journal.RunBenchmarks( b, - func(b *testing.B) journal.BinaryStore { - return &BinaryStore{ - DB: db, - } + &BinaryStore{ + DB: db, }, ) } diff --git a/journal/benchmark.go b/journal/benchmark.go index be68f85..993d36a 100644 --- a/journal/benchmark.go +++ b/journal/benchmark.go @@ -12,7 +12,7 @@ import ( // RunBenchmarks runs benchmarks against a [BinaryStore] implementation. func RunBenchmarks( b *testing.B, - newStore func(b *testing.B) BinaryStore, + store BinaryStore, ) { b.Run("Store", func(b *testing.B) { b.Run("Open", func(b *testing.B) { @@ -21,7 +21,7 @@ func RunBenchmarks( benchmarkStore( b, - newStore, + store, // SETUP func(ctx context.Context, s BinaryStore) error { name = uniqueName() @@ -51,7 +51,7 @@ func RunBenchmarks( benchmarkStore( b, - newStore, + store, // SETUP nil, // BEFORE EACH @@ -77,7 +77,7 @@ func RunBenchmarks( b.Run("empty journal", func(b *testing.B) { benchmarkJournal( b, - newStore, + store, // SETUP nil, // BEFORE EACH @@ -95,7 +95,7 @@ func RunBenchmarks( b.Run("non-empty journal", func(b *testing.B) { benchmarkJournal( b, - newStore, + store, // SETUP func(ctx context.Context, s BinaryStore, j BinaryJournal) error { for pos := Position(0); pos < 10000; pos++ { @@ -121,7 +121,7 @@ func RunBenchmarks( b.Run("truncated journal", func(b *testing.B) { benchmarkJournal( b, - newStore, + store, // SETUP func(ctx context.Context, s BinaryStore, j BinaryJournal) error { for pos := Position(0); pos < 10000; pos++ { @@ -152,7 +152,7 @@ func RunBenchmarks( benchmarkJournal( b, - newStore, + store, // SETUP nil, // BEFORE EACH @@ -178,7 +178,7 @@ func RunBenchmarks( benchmarkJournal( b, - newStore, + store, // SETUP func(ctx context.Context, _ BinaryStore, j BinaryJournal) error { for pos := Position(0); pos < 10000; pos++ { @@ -213,7 +213,7 @@ func RunBenchmarks( benchmarkJournal( b, - newStore, + store, // SETUP nil, // BEFORE EACH @@ -233,7 +233,7 @@ func RunBenchmarks( b.Run("Range (3k records)", func(b *testing.B) { benchmarkJournal( b, - newStore, + store, // SETUP func(ctx context.Context, _ BinaryStore, j BinaryJournal) error { rec := []byte("") @@ -266,7 +266,7 @@ func RunBenchmarks( benchmarkJournal( b, - newStore, + store, // SETUP func(ctx context.Context, _ BinaryStore, j BinaryJournal) error { rec := []byte("") @@ -293,22 +293,17 @@ func RunBenchmarks( func benchmarkStore[T any]( b *testing.B, - newStore func(b *testing.B) BinaryStore, + store BinaryStore, setup func(context.Context, BinaryStore) error, before func(context.Context, BinaryStore) error, fn func(context.Context, BinaryStore) (T, error), after func(T) error, ) { - var ( - store BinaryStore - result T - ) + var result T benchmark.Run( b, func(ctx context.Context) error { - store = newStore(b) - if setup != nil { return setup(ctx, store) } @@ -337,22 +332,17 @@ func benchmarkStore[T any]( func benchmarkJournal( b *testing.B, - newStore func(b *testing.B) BinaryStore, + store BinaryStore, setup func(context.Context, BinaryStore, BinaryJournal) error, before func(context.Context, BinaryJournal) error, fn func(context.Context, BinaryJournal) error, after func() error, ) { - var ( - store BinaryStore - journ BinaryJournal - ) + var journ BinaryJournal benchmark.Run( b, func(ctx context.Context) error { - store = newStore(b) - var err error journ, err = store.Open(ctx, uniqueName()) if err != nil { diff --git a/journal/telemetry_test.go b/journal/telemetry_test.go index 1743082..a1e4ba4 100644 --- a/journal/telemetry_test.go +++ b/journal/telemetry_test.go @@ -13,13 +13,11 @@ import ( func TestWithTelemetry(t *testing.T) { RunTests( t, - func(t *testing.T) BinaryStore { - return WithTelemetry( - &memoryjournal.BinaryStore{}, - nooptrace.NewTracerProvider(), - noopmetric.NewMeterProvider(), - spruce.NewLogger(t), - ) - }, + WithTelemetry( + &memoryjournal.BinaryStore{}, + nooptrace.NewTracerProvider(), + noopmetric.NewMeterProvider(), + spruce.NewLogger(t), + ), ) } diff --git a/journal/test.go b/journal/test.go index 177970a..08b2a31 100644 --- a/journal/test.go +++ b/journal/test.go @@ -16,27 +16,14 @@ import ( // RunTests runs tests that confirm a journal implementation behaves correctly. func RunTests( t *testing.T, - newStore func(t *testing.T) BinaryStore, + store BinaryStore, ) { - type dependencies struct { - Store BinaryStore - JournalName string - Journal BinaryJournal - } - - setup := func( - t *testing.T, - newStore func(t *testing.T) BinaryStore, - ) (context.Context, *dependencies) { + setup := func(t *testing.T) (context.Context, BinaryJournal) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) t.Cleanup(cancel) - deps := &dependencies{ - Store: newStore(t), - JournalName: uniqueName(), - } - - j, err := deps.Store.Open(ctx, deps.JournalName) + name := uniqueName() + j, err := store.Open(ctx, name) if err != nil { t.Fatal(err) } @@ -47,9 +34,7 @@ func RunTests( } }) - deps.Journal = j - - return ctx, deps + return ctx, j } t.Run("Store", func(t *testing.T) { @@ -64,8 +49,6 @@ func RunTests( ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - store := newStore(t) - j1, err := store.Open(ctx, "") if err != nil { t.Fatal(err) @@ -145,11 +128,11 @@ func RunTests( t.Run(c.Name, func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - c.Setup(ctx, t, deps.Journal) + c.Setup(ctx, t, j) - begin, end, err := deps.Journal.Bounds(ctx) + begin, end, err := j.Bounds(ctx) if err != nil { t.Fatal(err) } @@ -172,9 +155,9 @@ func RunTests( t.Run("it returns ErrNotFound if there is no record at the given position", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - _, err := deps.Journal.Get(ctx, 1) + _, err := j.Get(ctx, 1) if !errors.Is(err, ErrNotFound) { t.Fatal(err) } @@ -183,14 +166,14 @@ func RunTests( t.Run("it returns the record if it exists", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) // Ensure we test with a position that becomes 2 digits long to // confirm that the implementation is not using a lexical sort. - records := appendRecords(ctx, t, deps.Journal, 15) + records := appendRecords(ctx, t, j, 15) for i, want := range records { - got, err := deps.Journal.Get(ctx, Position(i)) + got, err := j.Get(ctx, Position(i)) if err != nil { t.Fatal(err) } @@ -209,13 +192,13 @@ func RunTests( t.Run("it does not return truncated records", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) const recordCount = 5 const truncateBefore = 3 - records := appendRecords(ctx, t, deps.Journal, recordCount) + records := appendRecords(ctx, t, j, recordCount) - err := deps.Journal.Truncate(ctx, truncateBefore) + err := j.Truncate(ctx, truncateBefore) if err != nil { t.Fatal(err) } @@ -224,11 +207,11 @@ func RunTests( pos := Position(pos) if pos < truncateBefore { - if _, err := deps.Journal.Get(ctx, pos); err != ErrNotFound { + if _, err := j.Get(ctx, pos); err != ErrNotFound { t.Fatalf("unexpected error at position %d: got %q, want %q", pos, err, ErrNotFound) } } else { - got, err := deps.Journal.Get(ctx, pos) + got, err := j.Get(ctx, pos) if err != nil { t.Fatal(err) } @@ -248,18 +231,18 @@ func RunTests( t.Run("it does not return any records when all records are truncated", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - records := appendRecords(ctx, t, deps.Journal, 5) + records := appendRecords(ctx, t, j, 5) - err := deps.Journal.Truncate(ctx, 5) + err := j.Truncate(ctx, 5) if err != nil { t.Fatal(err) } for i := range records { pos := Position(i) - if _, err := deps.Journal.Get(ctx, pos); err != ErrNotFound { + if _, err := j.Get(ctx, pos); err != ErrNotFound { t.Fatalf("unexpected error at position %d: got %q, want %q", i, err, ErrNotFound) } } @@ -268,18 +251,18 @@ func RunTests( t.Run("it does not return its internal byte slice", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 1) + appendRecords(ctx, t, j, 1) - rec, err := deps.Journal.Get(ctx, 0) + rec, err := j.Get(ctx, 0) if err != nil { t.Fatal(err) } rec[0] = 'X' - got, err := deps.Journal.Get(ctx, 0) + got, err := j.Get(ctx, 0) if err != nil { t.Fatal(err) } @@ -297,9 +280,9 @@ func RunTests( t.Run("handles maximum position value", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - _, err := deps.Journal.Get(ctx, math.MaxUint64) + _, err := j.Get(ctx, math.MaxUint64) if !errors.Is(err, ErrNotFound) { t.Fatal(err) } @@ -312,15 +295,15 @@ func RunTests( t.Run("calls the function for each record in the journal", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - want := appendRecords(ctx, t, deps.Journal, 15) + want := appendRecords(ctx, t, j, 15) var got [][]byte wantPos := Position(10) want = want[wantPos:] - if err := deps.Journal.Range( + if err := j.Range( ctx, wantPos, func(ctx context.Context, gotPos Position, rec []byte) (bool, error) { @@ -345,12 +328,12 @@ func RunTests( t.Run("it stops iterating if the function returns false", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 2) + appendRecords(ctx, t, j, 2) called := false - if err := deps.Journal.Range( + if err := j.Range( ctx, 0, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -369,9 +352,9 @@ func RunTests( t.Run("it returns ErrNotFound if journal is empty", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - err := deps.Journal.Range( + err := j.Range( ctx, 0, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -387,13 +370,13 @@ func RunTests( t.Run("it does not range over truncated records", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) const recordCount = 5 const truncateBefore = 3 - records := appendRecords(ctx, t, deps.Journal, recordCount) + records := appendRecords(ctx, t, j, recordCount) - err := deps.Journal.Truncate(ctx, truncateBefore) + err := j.Truncate(ctx, truncateBefore) if err != nil { t.Fatal(err) } @@ -402,7 +385,7 @@ func RunTests( pos := Position(pos) if pos < truncateBefore { - if err := deps.Journal.Range( + if err := j.Range( ctx, pos, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -412,7 +395,7 @@ func RunTests( t.Fatalf("unexpected error: got %q, want %q", err, ErrNotFound) } } else { - if err := deps.Journal.Range( + if err := j.Range( ctx, pos, func(ctx context.Context, pos Position, got []byte) (bool, error) { @@ -436,11 +419,11 @@ func RunTests( t.Run("it does not range over truncated records when all records are truncated", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - records := appendRecords(ctx, t, deps.Journal, 5) + records := appendRecords(ctx, t, j, 5) - err := deps.Journal.Truncate(ctx, 5) + err := j.Truncate(ctx, 5) if err != nil { t.Fatal(err) } @@ -448,7 +431,7 @@ func RunTests( for pos := range records { pos := Position(pos) - if err := deps.Journal.Range( + if err := j.Range( ctx, pos, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -464,15 +447,15 @@ func RunTests( t.Skip() // TODO t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 5) + appendRecords(ctx, t, j, 5) - err := deps.Journal.Range( + err := j.Range( ctx, 0, func(ctx context.Context, pos Position, rec []byte) (bool, error) { - return true, deps.Journal.Truncate(ctx, 5) + return true, j.Truncate(ctx, 5) }, ) @@ -484,11 +467,11 @@ func RunTests( t.Run("it does not invoke the function with its internal byte slice", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 1) + appendRecords(ctx, t, j, 1) - if err := deps.Journal.Range( + if err := j.Range( ctx, 0, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -500,7 +483,7 @@ func RunTests( t.Fatal(err) } - got, err := deps.Journal.Get(ctx, 0) + got, err := j.Get(ctx, 0) if err != nil { t.Fatal(err) } @@ -518,9 +501,9 @@ func RunTests( t.Run("handles maximum position value", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - err := deps.Journal.Range( + err := j.Range( ctx, math.MaxUint64, func(ctx context.Context, pos Position, rec []byte) (bool, error) { @@ -540,9 +523,9 @@ func RunTests( t.Run("it does not return an error if there is no record at the given position", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - if err := deps.Journal.Append(ctx, 0, []byte("")); err != nil { + if err := j.Append(ctx, 0, []byte("")); err != nil { t.Fatal(err) } }) @@ -550,24 +533,24 @@ func RunTests( t.Run("it returns ErrConflict there is already a record at the given position", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - if err := deps.Journal.Append(ctx, 0, []byte("")); err != nil { + if err := j.Append(ctx, 0, []byte("")); err != nil { t.Fatal(err) } want := []byte("") - if err := deps.Journal.Append(ctx, 1, want); err != nil { + if err := j.Append(ctx, 1, want); err != nil { t.Fatal(err) } - err := deps.Journal.Append(ctx, 1, []byte("")) + err := j.Append(ctx, 1, []byte("")) if !errors.Is(err, ErrConflict) { t.Fatalf("unexpected error: got %q, want %q", err, ErrConflict) } - got, err := deps.Journal.Get(ctx, 1) + got, err := j.Get(ctx, 1) if err != nil { t.Fatal(err) } @@ -584,17 +567,17 @@ func RunTests( t.Run("it does not keep a reference to the record slice", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) rec := []byte("") - if err := deps.Journal.Append(ctx, 0, rec); err != nil { + if err := j.Append(ctx, 0, rec); err != nil { t.Fatal(err) } rec[0] = 'X' - got, err := deps.Journal.Get(ctx, 0) + got, err := j.Get(ctx, 0) if err != nil { t.Fatal(err) } @@ -612,15 +595,15 @@ func RunTests( t.Run("it truncates the journal", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 3) + appendRecords(ctx, t, j, 3) - if err := deps.Journal.Truncate(ctx, 1); err != nil { + if err := j.Truncate(ctx, 1); err != nil { t.Fatal(err) } - got, _, err := deps.Journal.Bounds(ctx) + got, _, err := j.Bounds(ctx) if err != nil { t.Fatal(err) } @@ -634,15 +617,15 @@ func RunTests( t.Run("it allows truncating all records", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 3) + appendRecords(ctx, t, j, 3) - if err := deps.Journal.Truncate(ctx, 3); err != nil { + if err := j.Truncate(ctx, 3); err != nil { t.Fatal(err) } - begin, end, err := deps.Journal.Bounds(ctx) + begin, end, err := j.Bounds(ctx) if err != nil { t.Fatal(err) } @@ -661,19 +644,19 @@ func RunTests( t.Run("it truncates the journal when it has already been truncated", func(t *testing.T) { t.Parallel() - ctx, deps := setup(t, newStore) + ctx, j := setup(t) - appendRecords(ctx, t, deps.Journal, 3) + appendRecords(ctx, t, j, 3) - if err := deps.Journal.Truncate(ctx, 1); err != nil { + if err := j.Truncate(ctx, 1); err != nil { t.Fatal(err) } - if err := deps.Journal.Truncate(ctx, 2); err != nil { + if err := j.Truncate(ctx, 2); err != nil { t.Fatal(err) } - got, _, err := deps.Journal.Bounds(ctx) + got, _, err := j.Bounds(ctx) if err != nil { t.Fatal(err) }