diff --git a/.golangci.yml b/.golangci.yml index b1de5a69163f..f37d1fc53f78 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -71,6 +71,7 @@ linters-settings: - 'require\.ErrorContains$(# ErrorIs should be used instead)?' - 'require\.EqualValues$(# Equal should be used instead)?' - 'require\.NotEqualValues$(# NotEqual should be used instead)?' + - '^(t|b|tb|f)\.(Fatal|Fatalf|Error|Errorf)$(# the require library should be used instead)?' exclude_godoc_examples: false # https://golangci-lint.run/usage/linters#gosec gosec: diff --git a/cache/lru_cache_benchmark_test.go b/cache/lru_cache_benchmark_test.go index 73acf90b1056..d8e4f4185933 100644 --- a/cache/lru_cache_benchmark_test.go +++ b/cache/lru_cache_benchmark_test.go @@ -7,6 +7,8 @@ import ( "crypto/rand" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" ) @@ -16,9 +18,8 @@ func BenchmarkLRUCachePutSmall(b *testing.B) { for n := 0; n < b.N; n++ { for i := 0; i < smallLen; i++ { var id ids.ID - if _, err := rand.Read(id[:]); err != nil { - b.Fatal(err) - } + _, err := rand.Read(id[:]) + require.NoError(b, err) cache.Put(id, n) } b.StopTimer() @@ -33,9 +34,8 @@ func BenchmarkLRUCachePutMedium(b *testing.B) { for n := 0; n < b.N; n++ { for i := 0; i < mediumLen; i++ { var id ids.ID - if _, err := rand.Read(id[:]); err != nil { - b.Fatal(err) - } + _, err := rand.Read(id[:]) + require.NoError(b, err) cache.Put(id, n) } b.StopTimer() @@ -50,9 +50,8 @@ func BenchmarkLRUCachePutLarge(b *testing.B) { for n := 0; n < b.N; n++ { for i := 0; i < largeLen; i++ { var id ids.ID - if _, err := rand.Read(id[:]); err != nil { - b.Fatal(err) - } + _, err := rand.Read(id[:]) + require.NoError(b, err) cache.Put(id, n) } b.StopTimer() diff --git a/cache/unique_cache_test.go b/cache/unique_cache_test.go index 0094a47a0a42..3f0d40f8dc0d 100644 --- a/cache/unique_cache_test.go +++ b/cache/unique_cache_test.go @@ -6,6 +6,8 @@ package cache import ( "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" ) @@ -23,50 +25,32 @@ func (e *evictable[_]) Evict() { } func TestEvictableLRU(t *testing.T) { + require := require.New(t) + cache := EvictableLRU[ids.ID, *evictable[ids.ID]]{} expectedValue1 := &evictable[ids.ID]{id: ids.ID{1}} - if returnedValue := cache.Deduplicate(expectedValue1); returnedValue != expectedValue1 { - t.Fatalf("Returned unknown value") - } else if expectedValue1.evicted != 0 { - t.Fatalf("Value was evicted unexpectedly") - } else if returnedValue := cache.Deduplicate(expectedValue1); returnedValue != expectedValue1 { - t.Fatalf("Returned unknown value") - } else if expectedValue1.evicted != 0 { - t.Fatalf("Value was evicted unexpectedly") - } + require.Equal(expectedValue1, cache.Deduplicate(expectedValue1)) + require.Zero(expectedValue1.evicted) + require.Equal(expectedValue1, cache.Deduplicate(expectedValue1)) + require.Zero(expectedValue1.evicted) expectedValue2 := &evictable[ids.ID]{id: ids.ID{2}} returnedValue := cache.Deduplicate(expectedValue2) - switch { - case returnedValue != expectedValue2: - t.Fatalf("Returned unknown value") - case expectedValue1.evicted != 1: - t.Fatalf("Value should have been evicted") - case expectedValue2.evicted != 0: - t.Fatalf("Value was evicted unexpectedly") - } + require.Equal(expectedValue2, returnedValue) + require.Equal(1, expectedValue1.evicted) + require.Zero(expectedValue2.evicted) cache.Size = 2 expectedValue3 := &evictable[ids.ID]{id: ids.ID{2}} returnedValue = cache.Deduplicate(expectedValue3) - switch { - case returnedValue != expectedValue2: - t.Fatalf("Returned unknown value") - case expectedValue1.evicted != 1: - t.Fatalf("Value should have been evicted") - case expectedValue2.evicted != 0: - t.Fatalf("Value was evicted unexpectedly") - } + require.Equal(expectedValue2, returnedValue) + require.Equal(1, expectedValue1.evicted) + require.Zero(expectedValue2.evicted) cache.Flush() - switch { - case expectedValue1.evicted != 1: - t.Fatalf("Value should have been evicted") - case expectedValue2.evicted != 1: - t.Fatalf("Value should have been evicted") - case expectedValue3.evicted != 0: - t.Fatalf("Value was evicted unexpectedly") - } + require.Equal(1, expectedValue1.evicted) + require.Equal(1, expectedValue2.evicted) + require.Zero(expectedValue3.evicted) } diff --git a/chains/atomic/gsharedmemory/shared_memory_test.go b/chains/atomic/gsharedmemory/shared_memory_test.go index f8ae05ad6202..0ce546c94f77 100644 --- a/chains/atomic/gsharedmemory/shared_memory_test.go +++ b/chains/atomic/gsharedmemory/shared_memory_test.go @@ -43,10 +43,10 @@ func TestInterface(t *testing.T) { } func wrapSharedMemory(t *testing.T, sm atomic.SharedMemory, db database.Database) (atomic.SharedMemory, io.Closer) { + require := require.New(t) + listener, err := grpcutils.NewListener() - if err != nil { - t.Fatalf("Failed to create listener: %s", err) - } + require.NoError(err) serverCloser := grpcutils.ServerCloser{} server := grpcutils.NewServer() @@ -56,9 +56,7 @@ func wrapSharedMemory(t *testing.T, sm atomic.SharedMemory, db database.Database go grpcutils.Serve(listener, server) conn, err := grpcutils.Dial(listener.Addr().String()) - if err != nil { - t.Fatalf("Failed to dial: %s", err) - } + require.NoError(err) rpcsm := NewClient(sharedmemorypb.NewSharedMemoryClient(conn)) return rpcsm, conn diff --git a/chains/atomic/memory_test.go b/chains/atomic/memory_test.go index faf461f8298b..5acdb5233af4 100644 --- a/chains/atomic/memory_test.go +++ b/chains/atomic/memory_test.go @@ -6,6 +6,8 @@ package atomic import ( "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" ) @@ -19,32 +21,26 @@ func TestSharedID(t *testing.T) { sharedID0 := sharedID(blockchainID0, blockchainID1) sharedID1 := sharedID(blockchainID1, blockchainID0) - if sharedID0 != sharedID1 { - t.Fatalf("SharedMemory.sharedID should be communitive") - } + require.Equal(t, sharedID0, sharedID1) } func TestMemoryMakeReleaseLock(t *testing.T) { + require := require.New(t) + m := NewMemory(memdb.New()) sharedID := sharedID(blockchainID0, blockchainID1) lock0 := m.makeLock(sharedID) - if lock1 := m.makeLock(sharedID); lock0 != lock1 { - t.Fatalf("Memory.makeLock should have returned the same lock") - } + require.Equal(lock0, m.makeLock(sharedID)) m.releaseLock(sharedID) - if lock2 := m.makeLock(sharedID); lock0 != lock2 { - t.Fatalf("Memory.makeLock should have returned the same lock") - } + require.Equal(lock0, m.makeLock(sharedID)) m.releaseLock(sharedID) m.releaseLock(sharedID) - if lock3 := m.makeLock(sharedID); lock0 == lock3 { - t.Fatalf("Memory.releaseLock should have returned freed the lock") - } + require.Equal(lock0, m.makeLock(sharedID)) m.releaseLock(sharedID) } @@ -54,9 +50,7 @@ func TestMemoryUnknownFree(t *testing.T) { sharedID := sharedID(blockchainID0, blockchainID1) defer func() { - if recover() == nil { - t.Fatalf("Should have panicked due to an unknown free") - } + require.NotNil(t, recover()) }() m.releaseLock(sharedID) diff --git a/config/config_test.go b/config/config_test.go index d68488deef3a..075f046b7122 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -451,9 +451,10 @@ func TestGetSubnetConfigsFromFile(t *testing.T) { v := setupViper(configFilePath) subnetConfigs, err := getSubnetConfigs(v, []ids.ID{subnetID}) require.ErrorIs(err, test.expectedErr) - if test.expectedErr == nil { - test.testF(require, subnetConfigs) + if test.expectedErr != nil { + return } + test.testF(require, subnetConfigs) }) } } @@ -544,9 +545,10 @@ func TestGetSubnetConfigsFromFlags(t *testing.T) { subnetConfigs, err := getSubnetConfigs(v, []ids.ID{subnetID}) require.ErrorIs(err, test.expectedErr) - if test.expectedErr == nil { - test.testF(require, subnetConfigs) + if test.expectedErr != nil { + return } + test.testF(require, subnetConfigs) }) } } @@ -560,9 +562,11 @@ func setupConfigJSON(t *testing.T, rootPath string, value string) string { // setups file creates necessary path and writes value to it. func setupFile(t *testing.T, path string, fileName string, value string) { - require.NoError(t, os.MkdirAll(path, 0o700)) + require := require.New(t) + + require.NoError(os.MkdirAll(path, 0o700)) filePath := filepath.Join(path, fileName) - require.NoError(t, os.WriteFile(filePath, []byte(value), 0o600)) + require.NoError(os.WriteFile(filePath, []byte(value), 0o600)) } func setupViperFlags() *viper.Viper { diff --git a/genesis/genesis_test.go b/genesis/genesis_test.go index e64d52558e65..7a0b1aa2b0ff 100644 --- a/genesis/genesis_test.go +++ b/genesis/genesis_test.go @@ -164,10 +164,8 @@ func TestValidateConfig(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - require := require.New(t) - err := validateConfig(test.networkID, test.config, genesisStakingCfg) - require.ErrorIs(err, test.expectedErr) + require.ErrorIs(t, err, test.expectedErr) }) } } @@ -230,9 +228,9 @@ func TestGenesisFromFile(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - // test loading of genesis from file - require := require.New(t) + + // test loading of genesis from file var customFile string if len(test.customConfig) > 0 { customFile = filepath.Join(t.TempDir(), "config.json") @@ -304,9 +302,9 @@ func TestGenesisFromFlag(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - // test loading of genesis content from flag/env-var - require := require.New(t) + + // test loading of genesis content from flag/env-var var genBytes []byte if len(test.customConfig) == 0 { // try loading a default config diff --git a/ids/aliases_test.go b/ids/aliases_test.go index 2e50b992aaf2..b25177242ff2 100644 --- a/ids/aliases_test.go +++ b/ids/aliases_test.go @@ -30,6 +30,5 @@ func TestPrimaryAliasOrDefaultTest(t *testing.T) { require.Equal(res, id1.String()) expected := "Batman" - res = aliaser.PrimaryAliasOrDefault(id2) - require.Equal(expected, res) + require.Equal(expected, aliaser.PrimaryAliasOrDefault(id2)) } diff --git a/ids/bits_test.go b/ids/bits_test.go index 429da10046dd..4b5783dded46 100644 --- a/ids/bits_test.go +++ b/ids/bits_test.go @@ -10,6 +10,8 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) func flip(b uint8) uint8 { @@ -38,46 +40,40 @@ func Check(start, stop int, id1, id2 ID) bool { } func TestEqualSubsetEarlyStop(t *testing.T) { + require := require.New(t) + id1 := ID{0xf0, 0x0f} id2 := ID{0xf0, 0x1f} - if !EqualSubset(0, 12, id1, id2) { - t.Fatalf("Should have passed: %08b %08b == %08b %08b", id1[0], id1[1], id2[0], id2[1]) - } else if EqualSubset(0, 13, id1, id2) { - t.Fatalf("Should not have passed: %08b %08b == %08b %08b", id1[0], id1[1], id2[0], id2[1]) - } + require.True(EqualSubset(0, 12, id1, id2)) + require.False(EqualSubset(0, 13, id1, id2)) } func TestEqualSubsetLateStart(t *testing.T) { id1 := ID{0x1f, 0xf8} id2 := ID{0x10, 0x08} - if !EqualSubset(4, 12, id1, id2) { - t.Fatalf("Should have passed: %08b %08b == %08b %08b", id1[0], id1[1], id2[0], id2[1]) - } + require.True(t, EqualSubset(4, 12, id1, id2)) } func TestEqualSubsetSameByte(t *testing.T) { id1 := ID{0x18} id2 := ID{0xfc} - if !EqualSubset(3, 5, id1, id2) { - t.Fatalf("Should have passed: %08b == %08b", id1[0], id2[0]) - } + require.True(t, EqualSubset(3, 5, id1, id2)) } func TestEqualSubsetBadMiddle(t *testing.T) { id1 := ID{0x18, 0xe8, 0x55} id2 := ID{0x18, 0x8e, 0x55} - if EqualSubset(0, 8*3, id1, id2) { - t.Fatalf("Should not have passed: %08b == %08b", id1[1], id2[1]) - } + require.False(t, EqualSubset(0, 8*3, id1, id2)) } func TestEqualSubsetAll3Bytes(t *testing.T) { rand.Seed(time.Now().UnixNano()) seed := uint64(rand.Int63()) // #nosec G404 + t.Logf("seed: %d", seed) id1 := ID{}.Prefix(seed) for i := 0; i < BitsPerByte; i++ { @@ -87,12 +83,7 @@ func TestEqualSubsetAll3Bytes(t *testing.T) { for start := 0; start < BitsPerByte*3; start++ { for end := start; end <= BitsPerByte*3; end++ { - if EqualSubset(start, end, id1, id2) != Check(start, end, id1, id2) { - t.Fatalf("Subset failed on seed %d:\ns = %d\ne = %d\n%08b %08b %08b == %08b %08b %08b", - seed, start, end, - id1[0], id1[1], id1[2], - id2[0], id2[1], id2[2]) - } + require.Equal(t, Check(start, end, id1, id2), EqualSubset(start, end, id1, id2)) } } } @@ -104,77 +95,77 @@ func TestEqualSubsetOutOfBounds(t *testing.T) { id1 := ID{0x18, 0xe8, 0x55} id2 := ID{0x18, 0x8e, 0x55} - if EqualSubset(0, math.MaxInt32, id1, id2) { - t.Fatalf("Should not have passed") - } + require.False(t, EqualSubset(0, math.MaxInt32, id1, id2)) } func TestFirstDifferenceSubsetEarlyStop(t *testing.T) { + require := require.New(t) + id1 := ID{0xf0, 0x0f} id2 := ID{0xf0, 0x1f} - if _, found := FirstDifferenceSubset(0, 12, id1, id2); found { - t.Fatalf("Shouldn't have found a difference: %08b %08b == %08b %08b", id1[0], id1[1], id2[0], id2[1]) - } else if index, found := FirstDifferenceSubset(0, 13, id1, id2); !found { - t.Fatalf("Should have found a difference: %08b %08b == %08b %08b", id1[0], id1[1], id2[0], id2[1]) - } else if index != 12 { - t.Fatalf("Found a difference at index %d expected %d: %08b %08b == %08b %08b", index, 12, id1[0], id1[1], id2[0], id2[1]) - } + _, found := FirstDifferenceSubset(0, 12, id1, id2) + require.False(found) + + index, found := FirstDifferenceSubset(0, 13, id1, id2) + require.True(found) + require.Equal(12, index) } func TestFirstDifferenceEqualByte4(t *testing.T) { + require := require.New(t) + id1 := ID{0x10} id2 := ID{0x00} - if _, found := FirstDifferenceSubset(0, 4, id1, id2); found { - t.Fatalf("Shouldn't have found a difference: %08b == %08b", id1[0], id2[0]) - } else if index, found := FirstDifferenceSubset(0, 5, id1, id2); !found { - t.Fatalf("Should have found a difference: %08b == %08b", id1[0], id2[0]) - } else if index != 4 { - t.Fatalf("Found a difference at index %d expected %d: %08b == %08b", index, 4, id1[0], id2[0]) - } + _, found := FirstDifferenceSubset(0, 4, id1, id2) + require.False(found) + + index, found := FirstDifferenceSubset(0, 5, id1, id2) + require.True(found) + require.Equal(4, index) } func TestFirstDifferenceEqualByte5(t *testing.T) { + require := require.New(t) + id1 := ID{0x20} id2 := ID{0x00} - if _, found := FirstDifferenceSubset(0, 5, id1, id2); found { - t.Fatalf("Shouldn't have found a difference: %08b == %08b", id1[0], id2[0]) - } else if index, found := FirstDifferenceSubset(0, 6, id1, id2); !found { - t.Fatalf("Should have found a difference: %08b == %08b", id1[0], id2[0]) - } else if index != 5 { - t.Fatalf("Found a difference at index %d expected %d: %08b == %08b", index, 5, id1[0], id2[0]) - } + _, found := FirstDifferenceSubset(0, 5, id1, id2) + require.False(found) + + index, found := FirstDifferenceSubset(0, 6, id1, id2) + require.True(found) + require.Equal(5, index) } func TestFirstDifferenceSubsetMiddle(t *testing.T) { + require := require.New(t) + id1 := ID{0xf0, 0x0f, 0x11} id2 := ID{0xf0, 0x1f, 0xff} - if index, found := FirstDifferenceSubset(0, 24, id1, id2); !found { - t.Fatalf("Should have found a difference: %08b %08b %08b == %08b %08b %08b", id1[0], id1[1], id1[2], id2[0], id2[1], id2[2]) - } else if index != 12 { - t.Fatalf("Found a difference at index %d expected %d: %08b %08b %08b == %08b %08b %08b", index, 12, id1[0], id1[1], id1[2], id2[0], id2[1], id2[2]) - } + index, found := FirstDifferenceSubset(0, 24, id1, id2) + require.True(found) + require.Equal(12, index) } func TestFirstDifferenceStartMiddle(t *testing.T) { + require := require.New(t) + id1 := ID{0x1f, 0x0f, 0x11} id2 := ID{0x0f, 0x1f, 0xff} - if index, found := FirstDifferenceSubset(0, 24, id1, id2); !found { - t.Fatalf("Should have found a difference: %08b %08b %08b == %08b %08b %08b", id1[0], id1[1], id1[2], id2[0], id2[1], id2[2]) - } else if index != 4 { - t.Fatalf("Found a difference at index %d expected %d: %08b %08b %08b == %08b %08b %08b", index, 4, id1[0], id1[1], id1[2], id2[0], id2[1], id2[2]) - } + index, found := FirstDifferenceSubset(0, 24, id1, id2) + require.True(found) + require.Equal(4, index) } func TestFirstDifferenceVacuous(t *testing.T) { id1 := ID{0xf0, 0x0f, 0x11} id2 := ID{0xf0, 0x1f, 0xff} - if _, found := FirstDifferenceSubset(0, 0, id1, id2); found { - t.Fatalf("Shouldn't have found a difference") - } + _, found := FirstDifferenceSubset(0, 0, id1, id2) + require.False(t, found) } diff --git a/ids/galiasreader/alias_reader_test.go b/ids/galiasreader/alias_reader_test.go index 87a462f43ceb..f268d10fc268 100644 --- a/ids/galiasreader/alias_reader_test.go +++ b/ids/galiasreader/alias_reader_test.go @@ -19,9 +19,7 @@ func TestInterface(t *testing.T) { for _, test := range ids.AliasTests { listener, err := grpcutils.NewListener() - if err != nil { - t.Fatalf("Failed to create listener: %s", err) - } + require.NoError(err) serverCloser := grpcutils.ServerCloser{} w := ids.NewAliaser() diff --git a/ids/id_test.go b/ids/id_test.go index 60aeefd1d11d..00250aed621c 100644 --- a/ids/id_test.go +++ b/ids/id_test.go @@ -4,30 +4,29 @@ package ids import ( - "bytes" "encoding/json" - "reflect" "testing" "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/cb58" ) func TestID(t *testing.T) { + require := require.New(t) + id := ID{24} idCopy := ID{24} prefixed := id.Prefix(0) - if id != idCopy { - t.Fatalf("ID.Prefix mutated the ID") - } - if nextPrefix := id.Prefix(0); prefixed != nextPrefix { - t.Fatalf("ID.Prefix not consistent") - } + require.Equal(idCopy, id) + require.Equal(prefixed, id.Prefix(0)) } func TestIDBit(t *testing.T) { + require := require.New(t) + id0 := ID{1 << 0} id1 := ID{1 << 1} id2 := ID{1 << 2} @@ -38,54 +37,49 @@ func TestIDBit(t *testing.T) { id7 := ID{1 << 7} id8 := ID{0, 1 << 0} - switch { - case id0.Bit(0) != 1: - t.Fatalf("Wrong bit") - case id1.Bit(1) != 1: - t.Fatalf("Wrong bit") - case id2.Bit(2) != 1: - t.Fatalf("Wrong bit") - case id3.Bit(3) != 1: - t.Fatalf("Wrong bit") - case id4.Bit(4) != 1: - t.Fatalf("Wrong bit") - case id5.Bit(5) != 1: - t.Fatalf("Wrong bit") - case id6.Bit(6) != 1: - t.Fatalf("Wrong bit") - case id7.Bit(7) != 1: - t.Fatalf("Wrong bit") - case id8.Bit(8) != 1: - t.Fatalf("Wrong bit") - } + require.Equal(1, id0.Bit(0)) + require.Equal(1, id1.Bit(1)) + require.Equal(1, id2.Bit(2)) + require.Equal(1, id3.Bit(3)) + require.Equal(1, id4.Bit(4)) + require.Equal(1, id5.Bit(5)) + require.Equal(1, id6.Bit(6)) + require.Equal(1, id7.Bit(7)) + require.Equal(1, id8.Bit(8)) } func TestFromString(t *testing.T) { + require := require.New(t) + id := ID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'} idStr := id.String() id2, err := FromString(idStr) - if err != nil { - t.Fatal(err) - } - if id != id2 { - t.Fatal("Expected FromString to be inverse of String but it wasn't") - } + require.NoError(err) + require.Equal(id, id2) } func TestIDFromStringError(t *testing.T) { tests := []struct { - in string + in string + expectedErr error }{ - {""}, - {"foo"}, - {"foobar"}, + { + in: "", + expectedErr: cb58.ErrBase58Decoding, + }, + { + in: "foo", + expectedErr: cb58.ErrMissingChecksum, + }, + { + in: "foobar", + expectedErr: cb58.ErrBadChecksum, + }, } for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { _, err := FromString(tt.in) - if err == nil { - t.Error("Unexpected success") - } + require.ErrorIs(t, err, tt.expectedErr) }) } } @@ -107,12 +101,11 @@ func TestIDMarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + out, err := tt.in.MarshalJSON() - if err != tt.err { - t.Errorf("Expected err %s, got error %v", tt.err, err) - } else if !bytes.Equal(out, tt.out) { - t.Errorf("got %q, expected %q", out, tt.out) - } + require.ErrorIs(err, tt.err) + require.Equal(tt.out, out) }) } } @@ -134,13 +127,12 @@ func TestIDUnmarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + foo := ID{} err := foo.UnmarshalJSON(tt.in) - if err != tt.err { - t.Errorf("Expected err %s, got error %v", tt.err, err) - } else if foo != tt.out { - t.Errorf("got %q, expected %q", foo, tt.out) - } + require.ErrorIs(err, tt.err) + require.Equal(tt.out, foo) }) } } @@ -148,10 +140,7 @@ func TestIDUnmarshalJSON(t *testing.T) { func TestIDHex(t *testing.T) { id := ID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'} expected := "617661206c616273000000000000000000000000000000000000000000000000" - actual := id.Hex() - if actual != expected { - t.Fatalf("got %s, expected %s", actual, expected) - } + require.Equal(t, expected, id.Hex()) } func TestIDString(t *testing.T) { @@ -165,10 +154,7 @@ func TestIDString(t *testing.T) { } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { - result := tt.id.String() - if result != tt.expected { - t.Errorf("got %q, expected %q", result, tt.expected) - } + require.Equal(t, tt.expected, tt.id.String()) }) } } @@ -185,35 +171,23 @@ func TestSortIDs(t *testing.T) { {'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}, {'e', 'v', 'a', ' ', 'l', 'a', 'b', 's'}, } - if !reflect.DeepEqual(ids, expected) { - t.Fatal("[]ID was not sorted lexographically") - } + require.Equal(t, expected, ids) } func TestIDMapMarshalling(t *testing.T) { + require := require.New(t) + originalMap := map[ID]int{ {'e', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 1, {'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 2, } mapJSON, err := json.Marshal(originalMap) - if err != nil { - t.Fatal(err) - } + require.NoError(err) var unmarshalledMap map[ID]int - err = json.Unmarshal(mapJSON, &unmarshalledMap) - if err != nil { - t.Fatal(err) - } + require.NoError(json.Unmarshal(mapJSON, &unmarshalledMap)) - if len(originalMap) != len(unmarshalledMap) { - t.Fatalf("wrong map lengths") - } - for originalID, num := range originalMap { - if unmarshalledMap[originalID] != num { - t.Fatalf("map was incorrectly Unmarshalled") - } - } + require.Equal(originalMap, unmarshalledMap) } func TestIDLess(t *testing.T) { diff --git a/ids/node_id.go b/ids/node_id.go index 833015774f13..3a9ffbc43679 100644 --- a/ids/node_id.go +++ b/ids/node_id.go @@ -6,6 +6,7 @@ package ids import ( "bytes" "crypto/x509" + "errors" "fmt" "github.com/ava-labs/avalanchego/utils" @@ -17,6 +18,8 @@ const NodeIDPrefix = "NodeID-" var ( EmptyNodeID = NodeID{} + errShortNodeID = errors.New("insufficient NodeID length") + _ utils.Sortable[NodeID] = NodeID{} ) @@ -43,7 +46,7 @@ func (id *NodeID) UnmarshalJSON(b []byte) error { if str == nullStr { // If "null", do nothing return nil } else if len(str) <= 2+len(NodeIDPrefix) { - return fmt.Errorf("expected NodeID length to be > %d", 2+len(NodeIDPrefix)) + return fmt.Errorf("%w: expected to be > %d", errShortNodeID, 2+len(NodeIDPrefix)) } lastIndex := len(str) - 1 diff --git a/ids/node_id_test.go b/ids/node_id_test.go index 52c90c8ec656..b92fb6e19053 100644 --- a/ids/node_id_test.go +++ b/ids/node_id_test.go @@ -4,55 +4,58 @@ package ids import ( - "bytes" "encoding/json" "testing" "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/utils/cb58" ) func TestNodeIDEquality(t *testing.T) { + require := require.New(t) + id := NodeID{24} idCopy := NodeID{24} - if id != idCopy { - t.Fatalf("ID.Prefix mutated the ID") - } + require.Equal(id, idCopy) id2 := NodeID{} - if id == id2 { - t.Fatal("expected Node IDs to be unequal") - } + require.NotEqual(id, id2) } func TestNodeIDFromString(t *testing.T) { + require := require.New(t) + id := NodeID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'} idStr := id.String() id2, err := NodeIDFromString(idStr) - if err != nil { - t.Fatal(err) - } - if id != id2 { - t.Fatal("Expected FromString to be inverse of String but it wasn't") - } + require.NoError(err) + require.Equal(id, id2) expected := "NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz" - if idStr != expected { - t.Fatalf("expected %s but got %s", expected, idStr) - } + require.Equal(expected, idStr) } func TestNodeIDFromStringError(t *testing.T) { tests := []struct { - in string + in string + expectedErr error }{ - {""}, - {"foo"}, - {"foobar"}, + { + in: "", + expectedErr: cb58.ErrBase58Decoding, + }, + { + in: "foo", + expectedErr: cb58.ErrMissingChecksum, + }, + { + in: "foobar", + expectedErr: cb58.ErrBadChecksum, + }, } for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { _, err := FromString(tt.in) - if err == nil { - t.Error("Unexpected success") - } + require.ErrorIs(t, err, tt.expectedErr) }) } } @@ -74,73 +77,68 @@ func TestNodeIDMarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + out, err := tt.in.MarshalJSON() - if err != tt.err { - t.Errorf("Expected err %s, got error %v", tt.err, err) - } else if !bytes.Equal(out, tt.out) { - t.Errorf("got %q, expected %q", out, tt.out) - } + require.ErrorIs(err, tt.err) + require.Equal(tt.out, out) }) } } func TestNodeIDUnmarshalJSON(t *testing.T) { tests := []struct { - label string - in []byte - out NodeID - shouldErr bool + label string + in []byte + out NodeID + expectedErr error }{ - {"NodeID{}", []byte("null"), NodeID{}, false}, + {"NodeID{}", []byte("null"), NodeID{}, nil}, { "NodeID(\"ava labs\")", []byte("\"NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz\""), NodeID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}, - false, + nil, }, { "missing start quote", []byte("NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz\""), NodeID{}, - true, + errMissingQuotes, }, { "missing end quote", []byte("\"NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz"), NodeID{}, - true, + errMissingQuotes, }, { "NodeID-", []byte("\"NodeID-\""), NodeID{}, - true, + errShortNodeID, }, { "NodeID-1", []byte("\"NodeID-1\""), NodeID{}, - true, + cb58.ErrMissingChecksum, }, { "NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz1", []byte("\"NodeID-1\""), NodeID{}, - true, + cb58.ErrMissingChecksum, }, } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + foo := NodeID{} err := foo.UnmarshalJSON(tt.in) - switch { - case err == nil && tt.shouldErr: - t.Errorf("Expected no error but got error %v", err) - case err != nil && !tt.shouldErr: - t.Errorf("unxpected error: %v", err) - case foo != tt.out: - t.Errorf("got %q, expected %q", foo, tt.out) - } + require.ErrorIs(err, tt.expectedErr) + require.Equal(tt.out, foo) }) } } @@ -156,38 +154,24 @@ func TestNodeIDString(t *testing.T) { } for _, tt := range tests { t.Run(tt.label, func(t *testing.T) { - result := tt.id.String() - if result != tt.expected { - t.Errorf("got %q, expected %q", result, tt.expected) - } + require.Equal(t, tt.expected, tt.id.String()) }) } } func TestNodeIDMapMarshalling(t *testing.T) { + require := require.New(t) + originalMap := map[NodeID]int{ {'e', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 1, {'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 2, } mapJSON, err := json.Marshal(originalMap) - if err != nil { - t.Fatal(err) - } + require.NoError(err) var unmarshalledMap map[NodeID]int - err = json.Unmarshal(mapJSON, &unmarshalledMap) - if err != nil { - t.Fatal(err) - } - - if len(originalMap) != len(unmarshalledMap) { - t.Fatalf("wrong map lengths") - } - for originalID, num := range originalMap { - if unmarshalledMap[originalID] != num { - t.Fatalf("map was incorrectly Unmarshalled") - } - } + require.NoError(json.Unmarshal(mapJSON, &unmarshalledMap)) + require.Equal(originalMap, unmarshalledMap) } func TestNodeIDLess(t *testing.T) { diff --git a/ids/test_aliases.go b/ids/test_aliases.go index 5a0299720015..7e2b4fb9790e 100644 --- a/ids/test_aliases.go +++ b/ids/test_aliases.go @@ -40,8 +40,8 @@ func AliaserAliasesEmptyTest(require *require.Assertions, r AliaserReader, _ Ali func AliaserAliasesTest(require *require.Assertions, r AliaserReader, w AliaserWriter) { id := ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} - require.NoError(w.Alias(id, "Batman")) + require.NoError(w.Alias(id, "Batman")) require.NoError(w.Alias(id, "Dark Knight")) aliases, err := r.Aliases(id) @@ -54,8 +54,8 @@ func AliaserAliasesTest(require *require.Assertions, r AliaserReader, w AliaserW func AliaserPrimaryAliasTest(require *require.Assertions, r AliaserReader, w AliaserWriter) { id1 := ID{'J', 'a', 'm', 'e', 's', ' ', 'G', 'o', 'r', 'd', 'o', 'n'} id2 := ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} - require.NoError(w.Alias(id2, "Batman")) + require.NoError(w.Alias(id2, "Batman")) require.NoError(w.Alias(id2, "Dark Knight")) _, err := r.PrimaryAlias(id1) @@ -71,6 +71,7 @@ func AliaserPrimaryAliasTest(require *require.Assertions, r AliaserReader, w Ali func AliaserAliasClashTest(require *require.Assertions, _ AliaserReader, w AliaserWriter) { id1 := ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} id2 := ID{'D', 'i', 'c', 'k', ' ', 'G', 'r', 'a', 'y', 's', 'o', 'n'} + require.NoError(w.Alias(id1, "Batman")) err := w.Alias(id2, "Batman") @@ -81,8 +82,8 @@ func AliaserAliasClashTest(require *require.Assertions, _ AliaserReader, w Alias func AliaserRemoveAliasTest(require *require.Assertions, r AliaserReader, w AliaserWriter) { id1 := ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} id2 := ID{'J', 'a', 'm', 'e', 's', ' ', 'G', 'o', 'r', 'd', 'o', 'n'} - require.NoError(w.Alias(id1, "Batman")) + require.NoError(w.Alias(id1, "Batman")) require.NoError(w.Alias(id1, "Dark Knight")) w.RemoveAliases(id1) @@ -92,8 +93,6 @@ func AliaserRemoveAliasTest(require *require.Assertions, r AliaserReader, w Alia require.Error(err) //nolint:forbidigo // currently returns grpc errors too require.NoError(w.Alias(id2, "Batman")) - require.NoError(w.Alias(id2, "Dark Knight")) - require.NoError(w.Alias(id1, "Dark Night Rises")) } diff --git a/indexer/index_test.go b/indexer/index_test.go index 760700ab05f3..f31117c77b71 100644 --- a/indexer/index_test.go +++ b/indexer/index_test.go @@ -104,7 +104,7 @@ func TestIndex(t *testing.T) { require.Contains(containers, container.ID) require.Equal(containers[container.ID], container.Bytes) // Timestamps should be non-decreasing - require.True(container.Timestamp >= lastTimestamp) + require.GreaterOrEqual(container.Timestamp, lastTimestamp) lastTimestamp = container.Timestamp sawContainers.Add(container.ID) } diff --git a/ipcs/socket/socket_test.go b/ipcs/socket/socket_test.go index 3489fef2cc9f..4204d032285a 100644 --- a/ipcs/socket/socket_test.go +++ b/ipcs/socket/socket_test.go @@ -6,9 +6,13 @@ package socket import ( "net" "testing" + + "github.com/stretchr/testify/require" ) func TestSocketSendAndReceive(t *testing.T) { + require := require.New(t) + var ( connCh chan net.Conn socketName = "/tmp/pipe-test.sock" @@ -19,14 +23,10 @@ func TestSocketSendAndReceive(t *testing.T) { // Create socket and client; wait for client to connect socket := NewSocket(socketName, nil) socket.accept, connCh = newTestAcceptFn(t) - if err := socket.Listen(); err != nil { - t.Fatal("Failed to listen on socket:", err.Error()) - } + require.NoError(socket.Listen()) client, err := Dial(socketName) - if err != nil { - t.Fatal("Failed to dial socket:", err.Error()) - } + require.NoError(err) <-connCh // Start sending in the background @@ -38,22 +38,17 @@ func TestSocketSendAndReceive(t *testing.T) { // Receive message and compare it to what was sent receivedMsg, err := client.Recv() - if err != nil { - t.Fatal("Failed to receive from socket:", err.Error()) - } - if string(receivedMsg) != string(msg) { - t.Fatal("Received incorrect message:", string(msg)) - } + require.NoError(err) + require.Equal(msg, receivedMsg) // Test max message size client.SetMaxMessageSize(msgLen) - if _, err := client.Recv(); err != nil { - t.Fatal("Failed to receive from socket:", err.Error()) - } + _, err = client.Recv() + require.NoError(err) + client.SetMaxMessageSize(msgLen - 1) - if _, err := client.Recv(); err != ErrMessageTooLarge { - t.Fatal("Should have received message too large error, got:", err) - } + _, err = client.Recv() + require.ErrorIs(err, ErrMessageTooLarge) } // newTestAcceptFn creates a new acceptFn and a channel that receives all new @@ -63,9 +58,7 @@ func newTestAcceptFn(t *testing.T) (acceptFn, chan net.Conn) { return func(s *Socket, l net.Listener) { conn, err := l.Accept() - if err != nil { - t.Error(err) - } + require.NoError(t, err) s.connLock.Lock() s.conns[conn] = struct{}{} diff --git a/message/messages_test.go b/message/messages_test.go index e2ed952aac3b..a7cf74c95c48 100644 --- a/message/messages_test.go +++ b/message/messages_test.go @@ -25,15 +25,13 @@ import ( func TestMessage(t *testing.T) { t.Parallel() - require := require.New(t) - mb, err := newMsgBuilder( logging.NoLog{}, "test", prometheus.NewRegistry(), 5*time.Second, ) - require.NoError(err) + require.NoError(t, err) testID := ids.GenerateTestID() compressibleContainers := [][]byte{ @@ -43,10 +41,10 @@ func TestMessage(t *testing.T) { } testCertRaw, testKeyRaw, err := staking.NewCertAndKeyBytes() - require.NoError(err) + require.NoError(t, err) testTLSCert, err := staking.LoadTLSCertFromBytes(testKeyRaw, testCertRaw) - require.NoError(err) + require.NoError(t, err) nowUnix := time.Now().Unix() @@ -829,7 +827,9 @@ func TestMessage(t *testing.T) { } for _, tv := range tests { - require.True(t.Run(tv.desc, func(t2 *testing.T) { + t.Run(tv.desc, func(t *testing.T) { + require := require.New(t) + encodedMsg, err := mb.createOutbound(tv.msg, tv.compressionType, tv.bypassThrottling) require.NoError(err) @@ -842,7 +842,7 @@ func TestMessage(t *testing.T) { parsedMsg, err := mb.parseInbound(encodedMsg.Bytes(), ids.EmptyNodeID, func() {}) require.NoError(err) require.Equal(tv.op, parsedMsg.Op()) - })) + }) } } diff --git a/network/certs_test.go b/network/certs_test.go index 8405107d52fc..8ae2fc06a921 100644 --- a/network/certs_test.go +++ b/network/certs_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/network/peer" "github.com/ava-labs/avalanchego/staking" @@ -25,9 +27,7 @@ func getTLS(t *testing.T, index int) (ids.NodeID, *tls.Certificate, *tls.Config) for len(tlsCerts) <= index { cert, err := staking.NewTLSCert() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) tlsConfig := peer.TLSConfig(*cert, nil) tlsCerts = append(tlsCerts, cert) diff --git a/network/network_test.go b/network/network_test.go index ac4c98e1768e..06fdf135e6e8 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -318,13 +318,13 @@ func TestSend(t *testing.T) { t, []router.InboundHandler{ router.InboundHandlerFunc(func(context.Context, message.InboundMessage) { - t.Fatal("unexpected message received") + require.FailNow("unexpected message received") }), router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) { received <- msg }), router.InboundHandlerFunc(func(context.Context, message.InboundMessage) { - t.Fatal("unexpected message received") + require.FailNow("unexpected message received") }), }, ) @@ -357,13 +357,13 @@ func TestSendAndGossipWithFilter(t *testing.T) { t, []router.InboundHandler{ router.InboundHandlerFunc(func(context.Context, message.InboundMessage) { - t.Fatal("unexpected message received") + require.FailNow("unexpected message received") }), router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) { received <- msg }), router.InboundHandlerFunc(func(context.Context, message.InboundMessage) { - t.Fatal("unexpected message received") + require.FailNow("unexpected message received") }), }, ) diff --git a/network/peer/peer_test.go b/network/peer/peer_test.go index 89ac57ef42d8..87e701a357d0 100644 --- a/network/peer/peer_test.go +++ b/network/peer/peer_test.go @@ -191,12 +191,10 @@ func makeReadyTestPeers(t *testing.T, trackedSubnets set.Set[ids.ID]) (*testPeer peer0, peer1 := makeTestPeers(t, trackedSubnets) require.NoError(peer0.AwaitReady(context.Background())) - isReady := peer0.Ready() - require.True(isReady) + require.True(peer0.Ready()) require.NoError(peer1.AwaitReady(context.Background())) - isReady = peer1.Ready() - require.True(isReady) + require.True(peer1.Ready()) return peer0, peer1 } @@ -218,8 +216,7 @@ func TestReady(t *testing.T) { ), ) - isReady := peer0.Ready() - require.False(isReady) + require.False(peer0.Ready()) peer1 := Start( rawPeer1.config, @@ -235,12 +232,10 @@ func TestReady(t *testing.T) { ) require.NoError(peer0.AwaitReady(context.Background())) - isReady = peer0.Ready() - require.True(isReady) + require.True(peer0.Ready()) require.NoError(peer1.AwaitReady(context.Background())) - isReady = peer1.Ready() - require.True(isReady) + require.True(peer1.Ready()) peer0.StartClose() require.NoError(peer0.AwaitClosed(context.Background())) @@ -256,8 +251,7 @@ func TestSend(t *testing.T) { outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty, p2p.EngineType_ENGINE_TYPE_SNOWMAN) require.NoError(err) - sent := peer0.Send(context.Background(), outboundGetMsg) - require.True(sent) + require.True(peer0.Send(context.Background(), outboundGetMsg)) inboundGetMsg := <-peer1.inboundMsgChan require.Equal(message.GetOp, inboundGetMsg.Op()) @@ -391,8 +385,7 @@ func sendAndFlush(t *testing.T, sender *testPeer, receiver *testPeer) { mc := newMessageCreator(t) outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty, p2p.EngineType_ENGINE_TYPE_SNOWMAN) require.NoError(t, err) - sent := sender.Send(context.Background(), outboundGetMsg) - require.True(t, sent) + require.True(t, sender.Send(context.Background(), outboundGetMsg)) inboundGetMsg := <-receiver.inboundMsgChan require.Equal(t, message.GetOp, inboundGetMsg.Op()) } diff --git a/network/throttling/dial_throttler_test.go b/network/throttling/dial_throttler_test.go index f8e33846e0a5..db1776e8e24b 100644 --- a/network/throttling/dial_throttler_test.go +++ b/network/throttling/dial_throttler_test.go @@ -13,6 +13,8 @@ import ( // Test that the DialThrottler returned by NewDialThrottler works func TestDialThrottler(t *testing.T) { + require := require.New(t) + startTime := time.Now() // Allows 5 per second throttler := NewDialThrottler(5) @@ -21,12 +23,12 @@ func TestDialThrottler(t *testing.T) { acquiredChan := make(chan struct{}, 1) // Should return immediately because < 5 taken this second go func() { - require.NoError(t, throttler.Acquire(context.Background())) + require.NoError(throttler.Acquire(context.Background())) acquiredChan <- struct{}{} }() select { case <-time.After(10 * time.Millisecond): - t.Fatal("should have acquired immediately") + require.FailNow("should have acquired immediately") case <-acquiredChan: } close(acquiredChan) @@ -35,14 +37,14 @@ func TestDialThrottler(t *testing.T) { acquiredChan := make(chan struct{}, 1) go func() { // Should block because 5 already taken within last second - require.NoError(t, throttler.Acquire(context.Background())) + require.NoError(throttler.Acquire(context.Background())) acquiredChan <- struct{}{} }() select { case <-time.After(25 * time.Millisecond): case <-acquiredChan: - t.Fatal("should not have been able to acquire immediately") + require.FailNow("should not have been able to acquire immediately") } // Wait until the 6th Acquire() has returned. The time at which @@ -52,13 +54,13 @@ func TestDialThrottler(t *testing.T) { close(acquiredChan) // Use 1.05 seconds instead of 1 second to give some "wiggle room" // so test doesn't flake - if time.Since(startTime) > 1050*time.Millisecond { - t.Fatal("should not have blocked for so long") - } + require.LessOrEqual(time.Since(startTime), 1050*time.Millisecond) } // Test that Acquire honors its specification about its context being canceled func TestDialThrottlerCancel(t *testing.T) { + require := require.New(t) + // Allows 5 per second throttler := NewDialThrottler(5) // Use all 5 @@ -66,12 +68,12 @@ func TestDialThrottlerCancel(t *testing.T) { acquiredChan := make(chan struct{}, 1) // Should return immediately because < 5 taken this second go func() { - require.NoError(t, throttler.Acquire(context.Background())) + require.NoError(throttler.Acquire(context.Background())) acquiredChan <- struct{}{} }() select { case <-time.After(10 * time.Millisecond): - t.Fatal("should have acquired immediately") + require.FailNow("should have acquired immediately") case <-acquiredChan: } close(acquiredChan) @@ -83,7 +85,7 @@ func TestDialThrottlerCancel(t *testing.T) { // Should block because 5 already taken within last second err := throttler.Acquire(ctx) // Should error because we call cancel() below - require.ErrorIs(t, err, context.Canceled) + require.ErrorIs(err, context.Canceled) acquiredChan <- struct{}{} }() @@ -92,17 +94,19 @@ func TestDialThrottlerCancel(t *testing.T) { select { case <-acquiredChan: case <-time.After(10 * time.Millisecond): - t.Fatal("Acquire should have returned immediately upon context cancellation") + require.FailNow("Acquire should have returned immediately upon context cancellation") } close(acquiredChan) } // Test that the Throttler return by NewNoThrottler never blocks on Acquire() func TestNoDialThrottler(t *testing.T) { + require := require.New(t) + throttler := NewNoDialThrottler() for i := 0; i < 250; i++ { startTime := time.Now() - require.NoError(t, throttler.Acquire(context.Background())) // Should always immediately return - require.WithinDuration(t, time.Now(), startTime, 25*time.Millisecond) + require.NoError(throttler.Acquire(context.Background())) // Should always immediately return + require.WithinDuration(time.Now(), startTime, 25*time.Millisecond) } } diff --git a/network/throttling/inbound_conn_throttler_test.go b/network/throttling/inbound_conn_throttler_test.go index 5c28a45da20c..0b5d1ccd7fb8 100644 --- a/network/throttling/inbound_conn_throttler_test.go +++ b/network/throttling/inbound_conn_throttler_test.go @@ -22,7 +22,7 @@ type MockListener struct { func (ml *MockListener) Accept() (net.Conn, error) { if ml.OnAcceptF == nil { - ml.t.Fatal("unexpectedly called Accept") + require.FailNow(ml.t, "unexpectedly called Accept") return nil, nil } return ml.OnAcceptF() @@ -30,7 +30,7 @@ func (ml *MockListener) Accept() (net.Conn, error) { func (ml *MockListener) Close() error { if ml.OnCloseF == nil { - ml.t.Fatal("unexpectedly called Close") + require.FailNow(ml.t, "unexpectedly called Close") return nil } return ml.OnCloseF() @@ -38,7 +38,7 @@ func (ml *MockListener) Close() error { func (ml *MockListener) Addr() net.Addr { if ml.OnAddrF == nil { - ml.t.Fatal("unexpectedly called Addr") + require.FailNow(ml.t, "unexpectedly called Addr") return nil } return ml.OnAddrF() @@ -62,7 +62,7 @@ func TestInboundConnThrottlerClose(t *testing.T) { select { case <-wrappedL.(*throttledListener).ctx.Done(): default: - t.Fatal("should have closed context") + require.FailNow("should have closed context") } // Accept() should return an error because the context is cancelled @@ -85,6 +85,8 @@ func TestInboundConnThrottlerAddr(t *testing.T) { } func TestInboundConnThrottlerAccept(t *testing.T) { + require := require.New(t) + acceptCalled := false l := &MockListener{ t: t, @@ -95,6 +97,6 @@ func TestInboundConnThrottlerAccept(t *testing.T) { } wrappedL := NewThrottledListener(l, 1) _, err := wrappedL.Accept() - require.NoError(t, err) - require.True(t, acceptCalled) + require.NoError(err) + require.True(acceptCalled) } diff --git a/network/throttling/inbound_conn_upgrade_throttler_test.go b/network/throttling/inbound_conn_upgrade_throttler_test.go index 03fec7a8f287..d0e1fe93c84a 100644 --- a/network/throttling/inbound_conn_upgrade_throttler_test.go +++ b/network/throttling/inbound_conn_upgrade_throttler_test.go @@ -23,6 +23,8 @@ var ( ) func TestNoInboundConnUpgradeThrottler(t *testing.T) { + require := require.New(t) + { throttler := NewInboundConnUpgradeThrottler( logging.NoLog{}, @@ -33,8 +35,7 @@ func TestNoInboundConnUpgradeThrottler(t *testing.T) { ) // throttler should allow all for i := 0; i < 10; i++ { - allow := throttler.ShouldUpgrade(host1) - require.True(t, allow) + require.True(throttler.ShouldUpgrade(host1)) } } { @@ -47,8 +48,7 @@ func TestNoInboundConnUpgradeThrottler(t *testing.T) { ) // throttler should allow all for i := 0; i < 10; i++ { - allow := throttler.ShouldUpgrade(host1) - require.True(t, allow) + require.True(throttler.ShouldUpgrade(host1)) } } } @@ -91,7 +91,7 @@ func TestInboundConnUpgradeThrottler(t *testing.T) { throttler := throttlerIntf.(*inboundConnUpgradeThrottler) select { case <-throttler.done: - t.Fatal("shouldn't be done") + require.FailNow("shouldn't be done") default: } @@ -102,6 +102,6 @@ func TestInboundConnUpgradeThrottler(t *testing.T) { case _, chanOpen := <-throttler.done: require.False(chanOpen) default: - t.Fatal("should be done") + require.FailNow("should be done") } } diff --git a/network/throttling/inbound_msg_buffer_throttler_test.go b/network/throttling/inbound_msg_buffer_throttler_test.go index f7cf6d790d9d..76f399b6e94e 100644 --- a/network/throttling/inbound_msg_buffer_throttler_test.go +++ b/network/throttling/inbound_msg_buffer_throttler_test.go @@ -45,7 +45,7 @@ func TestMsgBufferThrottler(t *testing.T) { }() select { case <-done: - t.Fatal("should block on acquiring") + require.FailNow("should block on acquiring") case <-time.After(50 * time.Millisecond): } @@ -90,7 +90,7 @@ func TestMsgBufferThrottlerContextCancelled(t *testing.T) { }() select { case <-done: - t.Fatal("should block on acquiring") + require.FailNow("should block on acquiring") case <-time.After(50 * time.Millisecond): } @@ -102,7 +102,7 @@ func TestMsgBufferThrottlerContextCancelled(t *testing.T) { }() select { case <-done2: - t.Fatal("should block on acquiring") + require.FailNow("should block on acquiring") case <-time.After(50 * time.Millisecond): } @@ -111,11 +111,11 @@ func TestMsgBufferThrottlerContextCancelled(t *testing.T) { select { case <-done2: case <-time.After(50 * time.Millisecond): - t.Fatal("cancelling context should unblock Acquire") + require.FailNow("cancelling context should unblock Acquire") } select { case <-done: case <-time.After(50 * time.Millisecond): - t.Fatal("should be blocked") + require.FailNow("should be blocked") } } diff --git a/network/throttling/inbound_msg_byte_throttler_test.go b/network/throttling/inbound_msg_byte_throttler_test.go index 60869dc41916..fa21f7baf387 100644 --- a/network/throttling/inbound_msg_byte_throttler_test.go +++ b/network/throttling/inbound_msg_byte_throttler_test.go @@ -78,7 +78,7 @@ func TestInboundMsgByteThrottlerCancelContext(t *testing.T) { }() select { case <-vdr2Done: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } @@ -97,7 +97,7 @@ func TestInboundMsgByteThrottlerCancelContext(t *testing.T) { select { case <-vdr2Done: case <-time.After(50 * time.Millisecond): - t.Fatal("channel should signal because ctx was cancelled") + require.FailNow("channel should signal because ctx was cancelled") } require.NotContains(throttler.nodeToWaitingMsgID, vdr2ID) @@ -189,7 +189,7 @@ func TestInboundMsgByteThrottler(t *testing.T) { }() select { case <-vdr1Done: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } throttler.lock.Lock() @@ -207,7 +207,7 @@ func TestInboundMsgByteThrottler(t *testing.T) { }() select { case <-vdr2Done: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } throttler.lock.Lock() @@ -227,7 +227,7 @@ func TestInboundMsgByteThrottler(t *testing.T) { }() select { case <-nonVdrDone: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } throttler.lock.Lock() @@ -273,7 +273,7 @@ func TestInboundMsgByteThrottler(t *testing.T) { }() select { case <-nonVdrDone: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } throttler.lock.Lock() @@ -350,7 +350,7 @@ func TestSybilMsgThrottlerMaxNonVdr(t *testing.T) { }() select { case <-nonVdrDone: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } @@ -404,7 +404,7 @@ func TestMsgThrottlerNextMsg(t *testing.T) { }() select { case <-doneVdr: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } @@ -416,7 +416,7 @@ func TestMsgThrottlerNextMsg(t *testing.T) { }() select { case <-done: - t.Fatal("should block on acquiring any more bytes") + require.FailNow("should block on acquiring any more bytes") case <-time.After(50 * time.Millisecond): } @@ -432,7 +432,7 @@ func TestMsgThrottlerNextMsg(t *testing.T) { select { case <-doneVdr: - t.Fatal("should still be blocking") + require.FailNow("should still be blocking") case <-time.After(50 * time.Millisecond): } diff --git a/pubsub/filter_test.go b/pubsub/filter_test.go index 088d0ecee2d0..edc88794fa34 100644 --- a/pubsub/filter_test.go +++ b/pubsub/filter_test.go @@ -38,56 +38,40 @@ func TestAddAddressesParseAddresses(t *testing.T) { } func TestFilterParamUpdateMulti(t *testing.T) { + require := require.New(t) + fp := NewFilterParam() addr1 := []byte("abc") addr2 := []byte("def") addr3 := []byte("xyz") - if err := fp.Add(addr1, addr2, addr3); err != nil { - t.Fatal(err) - } - if len(fp.set) != 3 { - t.Fatalf("update multi failed") - } - if _, exists := fp.set[string(addr1)]; !exists { - t.Fatalf("update multi failed") - } - if _, exists := fp.set[string(addr2)]; !exists { - t.Fatalf("update multi failed") - } - if _, exists := fp.set[string(addr3)]; !exists { - t.Fatalf("update multi failed") - } + require.NoError(fp.Add(addr1, addr2, addr3)) + require.Len(fp.set, 3) + require.Contains(fp.set, string(addr1)) + require.Contains(fp.set, string(addr2)) + require.Contains(fp.set, string(addr3)) } func TestFilterParam(t *testing.T) { + require := require.New(t) + mapFilter := bloom.NewMap() fp := NewFilterParam() fp.SetFilter(mapFilter) addr := ids.GenerateTestShortID() - if err := fp.Add(addr[:]); err != nil { - t.Fatal(err) - } - if !fp.Check(addr[:]) { - t.Fatalf("check address failed") - } + require.NoError(fp.Add(addr[:])) + require.True(fp.Check(addr[:])) delete(fp.set, string(addr[:])) mapFilter.Add(addr[:]) - if !fp.Check(addr[:]) { - t.Fatalf("check address failed") - } - if fp.Check([]byte("bye")) { - t.Fatalf("check address failed") - } + require.True(fp.Check(addr[:])) + require.False(fp.Check([]byte("bye"))) } func TestNewBloom(t *testing.T) { cm := &NewBloom{} - if cm.IsParamsValid() { - t.Fatalf("new filter check failed") - } + require.False(t, cm.IsParamsValid()) } diff --git a/snow/engine/snowman/transitive_test.go b/snow/engine/snowman/transitive_test.go index f0c24ae18d22..c80993ab2c96 100644 --- a/snow/engine/snowman/transitive_test.go +++ b/snow/engine/snowman/transitive_test.go @@ -669,7 +669,7 @@ func TestEngineBuildBlock(t *testing.T) { } sender.SendPullQueryF = func(_ context.Context, inVdrs set.Set[ids.NodeID], _ uint32, _ ids.ID) { - t.Fatalf("should not be sending pulls when we are the block producer") + require.FailNow("should not be sending pulls when we are the block producer") } pushSent := new(bool) diff --git a/utils/beacon/set_test.go b/utils/beacon/set_test.go index 2d4d3d7f240d..2dc240404988 100644 --- a/utils/beacon/set_test.go +++ b/utils/beacon/set_test.go @@ -39,65 +39,44 @@ func TestSet(t *testing.T) { s := NewSet() - idsArg := s.IDsArg() - require.Equal("", idsArg) - ipsArg := s.IPsArg() - require.Equal("", ipsArg) - len := s.Len() - require.Zero(len) + require.Equal("", s.IDsArg()) + require.Equal("", s.IPsArg()) + require.Zero(s.Len()) require.NoError(s.Add(b0)) - idsArg = s.IDsArg() - require.Equal("NodeID-111111111111111111116DBWJs", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:0", ipsArg) - len = s.Len() - require.Equal(1, len) + require.Equal("NodeID-111111111111111111116DBWJs", s.IDsArg()) + require.Equal("0.0.0.0:0", s.IPsArg()) + require.Equal(1, s.Len()) err := s.Add(b0) require.ErrorIs(err, errDuplicateID) - idsArg = s.IDsArg() - require.Equal("NodeID-111111111111111111116DBWJs", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:0", ipsArg) - len = s.Len() - require.Equal(1, len) + require.Equal("NodeID-111111111111111111116DBWJs", s.IDsArg()) + require.Equal("0.0.0.0:0", s.IPsArg()) + require.Equal(1, s.Len()) require.NoError(s.Add(b1)) - idsArg = s.IDsArg() - require.Equal("NodeID-111111111111111111116DBWJs,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:0,0.0.0.0:1", ipsArg) - len = s.Len() - require.Equal(2, len) + require.Equal("NodeID-111111111111111111116DBWJs,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt", s.IDsArg()) + require.Equal("0.0.0.0:0,0.0.0.0:1", s.IPsArg()) + require.Equal(2, s.Len()) require.NoError(s.Add(b2)) - idsArg = s.IDsArg() - require.Equal("NodeID-111111111111111111116DBWJs,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt,NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:0,0.0.0.0:1,0.0.0.0:2", ipsArg) - len = s.Len() - require.Equal(3, len) + require.Equal("NodeID-111111111111111111116DBWJs,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt,NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp", s.IDsArg()) + require.Equal("0.0.0.0:0,0.0.0.0:1,0.0.0.0:2", s.IPsArg()) + require.Equal(3, s.Len()) require.NoError(s.RemoveByID(b0.ID())) - idsArg = s.IDsArg() - require.Equal("NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:2,0.0.0.0:1", ipsArg) - len = s.Len() - require.Equal(2, len) + require.Equal("NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp,NodeID-6HgC8KRBEhXYbF4riJyJFLSHt37UNuRt", s.IDsArg()) + require.Equal("0.0.0.0:2,0.0.0.0:1", s.IPsArg()) + require.Equal(2, s.Len()) require.NoError(s.RemoveByIP(b1.IP())) - idsArg = s.IDsArg() - require.Equal("NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp", idsArg) - ipsArg = s.IPsArg() - require.Equal("0.0.0.0:2", ipsArg) - len = s.Len() - require.Equal(1, len) + require.Equal("NodeID-BaMPFdqMUQ46BV8iRcwbVfsam55kMqcp", s.IDsArg()) + require.Equal("0.0.0.0:2", s.IPsArg()) + require.Equal(1, s.Len()) } diff --git a/utils/cb58/cb58.go b/utils/cb58/cb58.go index 7fcc58cf249d..b4b7f3da7b17 100644 --- a/utils/cb58/cb58.go +++ b/utils/cb58/cb58.go @@ -19,8 +19,8 @@ const checksumLen = 4 var ( ErrBase58Decoding = errors.New("base58 decoding error") ErrMissingChecksum = errors.New("input string is smaller than the checksum size") + ErrBadChecksum = errors.New("invalid input checksum") errEncodingOverFlow = errors.New("encoding overflow") - errBadChecksum = errors.New("invalid input checksum") ) // Encode [bytes] to a string using cb58 format. @@ -50,7 +50,7 @@ func Decode(str string) ([]byte, error) { rawBytes := decodedBytes[:len(decodedBytes)-checksumLen] checksum := decodedBytes[len(decodedBytes)-checksumLen:] if !bytes.Equal(checksum, hashing.Checksum(rawBytes, checksumLen)) { - return nil, errBadChecksum + return nil, ErrBadChecksum } return rawBytes, nil } diff --git a/utils/cb58/cb58_test.go b/utils/cb58/cb58_test.go index 59710b6232b2..858c0b8783ba 100644 --- a/utils/cb58/cb58_test.go +++ b/utils/cb58/cb58_test.go @@ -12,6 +12,8 @@ import ( // Test encoding bytes to a string and decoding back to bytes func TestEncodeDecode(t *testing.T) { + require := require.New(t) + type test struct { bytes []byte str string @@ -44,20 +46,14 @@ func TestEncodeDecode(t *testing.T) { for _, test := range tests { // Encode the bytes strResult, err := Encode(test.bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Make sure the string repr. is what we expected - require.Equal(t, test.str, strResult) + require.Equal(test.str, strResult) // Decode the string bytesResult, err := Decode(strResult) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Make sure we got the same bytes back - if !bytes.Equal(test.bytes, bytesResult) { - t.Fatal("bytes not symmetric") - } + require.True(bytes.Equal(test.bytes, bytesResult)) } } diff --git a/utils/compression/compressor_test.go b/utils/compression/compressor_test.go index 0f343874705c..0467c2c1b234 100644 --- a/utils/compression/compressor_test.go +++ b/utils/compression/compressor_test.go @@ -48,8 +48,10 @@ func TestDecompressZipBombs(t *testing.T) { newCompressorFunc := newCompressorFuncs[compressionType] t.Run(compressionType.String(), func(t *testing.T) { + require := require.New(t) + compressor, err := newCompressorFunc(maxMessageSize) - require.NoError(t, err) + require.NoError(err) var ( beforeDecompressionStats runtime.MemStats @@ -59,12 +61,12 @@ func TestDecompressZipBombs(t *testing.T) { _, err = compressor.Decompress(zipBomb) runtime.ReadMemStats(&afterDecompressionStats) - require.ErrorIs(t, err, ErrDecompressedMsgTooLarge) + require.ErrorIs(err, ErrDecompressedMsgTooLarge) // Make sure that we didn't allocate significantly more memory than // the max message size. bytesAllocatedDuringDecompression := afterDecompressionStats.TotalAlloc - beforeDecompressionStats.TotalAlloc - require.Less(t, bytesAllocatedDuringDecompression, uint64(10*maxMessageSize)) + require.Less(bytesAllocatedDuringDecompression, uint64(10*maxMessageSize)) }) } } @@ -146,9 +148,8 @@ func TestNewCompressorWithInvalidLimit(t *testing.T) { continue } t.Run(compressionType.String(), func(t *testing.T) { - require := require.New(t) _, err := compressorFunc(math.MaxInt64) - require.ErrorIs(err, ErrInvalidMaxSizeCompressor) + require.ErrorIs(t, err, ErrInvalidMaxSizeCompressor) }) } } @@ -174,7 +175,7 @@ func fuzzHelper(f *testing.F, compressionType Type) { compressor, err = NewZstdCompressor(maxMessageSize) require.NoError(f, err) default: - f.Fatal("Unknown compression type") + require.FailNow(f, "Unknown compression type") } f.Fuzz(func(t *testing.T, data []byte) { @@ -209,12 +210,14 @@ func BenchmarkCompress(b *testing.B) { } for _, size := range sizes { b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) { + require := require.New(b) + bytes := utils.RandomBytes(size) compressor, err := newCompressorFunc(maxMessageSize) - require.NoError(b, err) + require.NoError(err) for n := 0; n < b.N; n++ { _, err := compressor.Compress(bytes) - require.NoError(b, err) + require.NoError(err) } }) } @@ -235,16 +238,18 @@ func BenchmarkDecompress(b *testing.B) { } for _, size := range sizes { b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) { + require := require.New(b) + bytes := utils.RandomBytes(size) compressor, err := newCompressorFunc(maxMessageSize) - require.NoError(b, err) + require.NoError(err) compressedBytes, err := compressor.Compress(bytes) - require.NoError(b, err) + require.NoError(err) for n := 0; n < b.N; n++ { _, err := compressor.Decompress(compressedBytes) - require.NoError(b, err) + require.NoError(err) } }) } diff --git a/utils/compression/type_test.go b/utils/compression/type_test.go index 20313d053824..13d6b313aa48 100644 --- a/utils/compression/type_test.go +++ b/utils/compression/type_test.go @@ -50,9 +50,11 @@ func TestTypeMarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.Type.String(), func(t *testing.T) { + require := require.New(t) + b, err := tt.Type.MarshalJSON() - require.NoError(t, err) - require.Equal(t, tt.expected, string(b)) + require.NoError(err) + require.Equal(tt.expected, string(b)) }) } } diff --git a/utils/constants/network_ids.go b/utils/constants/network_ids.go index b64b6793d6a0..4d35b63cafa8 100644 --- a/utils/constants/network_ids.go +++ b/utils/constants/network_ids.go @@ -4,6 +4,7 @@ package constants import ( + "errors" "fmt" "strconv" "strings" @@ -87,6 +88,8 @@ var ( } ValidNetworkPrefix = "network-" + + ErrParseNetworkName = errors.New("failed to parse network name") ) // GetHRP returns the Human-Readable-Part of bech32 addresses for a networkID @@ -119,7 +122,7 @@ func NetworkID(networkName string) (uint32, error) { } id, err := strconv.ParseUint(idStr, 10, 32) if err != nil { - return 0, fmt.Errorf("failed to parse %q as a network name", networkName) + return 0, fmt.Errorf("%w: %q", ErrParseNetworkName, networkName) } return uint32(id), nil } diff --git a/utils/constants/network_ids_test.go b/utils/constants/network_ids_test.go index d29e0ce92b75..69557096efd8 100644 --- a/utils/constants/network_ids_test.go +++ b/utils/constants/network_ids_test.go @@ -3,7 +3,11 @@ package constants -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestGetHRP(t *testing.T) { tests := []struct { @@ -33,10 +37,7 @@ func TestGetHRP(t *testing.T) { } for _, test := range tests { t.Run(test.hrp, func(t *testing.T) { - if hrp := GetHRP(test.id); hrp != test.hrp { - t.Fatalf("GetHRP(%d) returned %q but expected %q", - test.id, hrp, test.hrp) - } + require.Equal(t, test.hrp, GetHRP(test.id)) }) } } @@ -69,19 +70,16 @@ func TestNetworkName(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if name := NetworkName(test.id); name != test.name { - t.Fatalf("NetworkName(%d) returned %q but expected %q", - test.id, name, test.name) - } + require.Equal(t, test.name, NetworkName(test.id)) }) } } func TestNetworkID(t *testing.T) { tests := []struct { - name string - id uint32 - shouldErr bool + name string + id uint32 + expectedErr error }{ { name: MainnetName, @@ -112,30 +110,25 @@ func TestNetworkID(t *testing.T) { id: 4294967295, }, { - name: "networ-4294967295", - shouldErr: true, + name: "networ-4294967295", + expectedErr: ErrParseNetworkName, }, { - name: "network-4294967295123123", - shouldErr: true, + name: "network-4294967295123123", + expectedErr: ErrParseNetworkName, }, { - name: "4294967295123123", - shouldErr: true, + name: "4294967295123123", + expectedErr: ErrParseNetworkName, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + require := require.New(t) + id, err := NetworkID(test.name) - if err == nil && test.shouldErr { - t.Fatalf("NetworkID(%q) returned %d but should have errored", test.name, test.id) - } - if err != nil && !test.shouldErr { - t.Fatalf("NetworkID(%q) unexpectedly errored with: %s", test.name, err) - } - if id != test.id { - t.Fatalf("NetworkID(%q) returned %d but expected %d", test.name, id, test.id) - } + require.ErrorIs(err, test.expectedErr) + require.Equal(test.id, id) }) } } diff --git a/utils/crypto/bls/bls_benchmark_test.go b/utils/crypto/bls/bls_benchmark_test.go index a4503260a821..a84cdadd80a5 100644 --- a/utils/crypto/bls/bls_benchmark_test.go +++ b/utils/crypto/bls/bls_benchmark_test.go @@ -28,10 +28,8 @@ var sizes = []int{ } func BenchmarkSign(b *testing.B) { - require := require.New(b) - privateKey, err := NewSecretKey() - require.NoError(err) + require.NoError(b, err) for _, messageSize := range sizes { b.Run(fmt.Sprintf("%d", messageSize), func(b *testing.B) { message := utils.RandomBytes(messageSize) @@ -46,10 +44,8 @@ func BenchmarkSign(b *testing.B) { } func BenchmarkVerify(b *testing.B) { - require := require.New(b) - privateKey, err := NewSecretKey() - require.NoError(err) + require.NoError(b, err) publicKey := PublicFromSecretKey(privateKey) for _, messageSize := range sizes { @@ -60,7 +56,7 @@ func BenchmarkVerify(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - require.True(Verify(publicKey, signature, message)) + require.True(b, Verify(publicKey, signature, message)) } }) } @@ -77,11 +73,9 @@ func BenchmarkAggregatePublicKeys(b *testing.B) { for _, size := range sizes { b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { - require := require.New(b) - for n := 0; n < b.N; n++ { _, err := AggregatePublicKeys(keys[:size]) - require.NoError(err) + require.NoError(b, err) } }) } diff --git a/utils/filesystem/rename_test.go b/utils/filesystem/rename_test.go index 12de4a67cf80..305a65092727 100644 --- a/utils/filesystem/rename_test.go +++ b/utils/filesystem/rename_test.go @@ -11,29 +11,31 @@ import ( ) func TestRenameIfExists(t *testing.T) { + require := require.New(t) + t.Parallel() f, err := os.CreateTemp(os.TempDir(), "test-rename") - require.NoError(t, err) + require.NoError(err) a := f.Name() b := a + ".2" - require.NoError(t, f.Close()) + require.NoError(f.Close()) // rename "a" to "b" renamed, err := RenameIfExists(a, b) - require.True(t, renamed) - require.NoError(t, err) + require.NoError(err) + require.True(renamed) // rename "b" to "a" renamed, err = RenameIfExists(b, a) - require.True(t, renamed) - require.NoError(t, err) + require.NoError(err) + require.True(renamed) // remove "a", but rename "a"->"b" should NOT error - require.NoError(t, os.RemoveAll(a)) + require.NoError(os.RemoveAll(a)) renamed, err = RenameIfExists(a, b) - require.False(t, renamed) - require.NoError(t, err) + require.NoError(err) + require.False(renamed) } diff --git a/utils/formatting/encoding_benchmark_test.go b/utils/formatting/encoding_benchmark_test.go index 83d9d9c9d85d..598ed39310ca 100644 --- a/utils/formatting/encoding_benchmark_test.go +++ b/utils/formatting/encoding_benchmark_test.go @@ -8,6 +8,8 @@ import ( "math/rand" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/utils/units" ) @@ -58,9 +60,8 @@ func BenchmarkEncodings(b *testing.B) { _, _ = rand.Read(bytes) // #nosec G404 b.Run(fmt.Sprintf("%s-%d bytes", benchmark.encoding, benchmark.size), func(b *testing.B) { for n := 0; n < b.N; n++ { - if _, err := Encode(benchmark.encoding, bytes); err != nil { - b.Fatal(err) - } + _, err := Encode(benchmark.encoding, bytes) + require.NoError(b, err) } }) } diff --git a/utils/formatting/encoding_test.go b/utils/formatting/encoding_test.go index 72793477cfed..29f6c1d5df39 100644 --- a/utils/formatting/encoding_test.go +++ b/utils/formatting/encoding_test.go @@ -4,6 +4,7 @@ package formatting import ( + "encoding/hex" "encoding/json" "testing" @@ -11,35 +12,29 @@ import ( ) func TestEncodingMarshalJSON(t *testing.T) { + require := require.New(t) + enc := Hex jsonBytes, err := enc.MarshalJSON() - if err != nil { - t.Fatal(err) - } - if string(jsonBytes) != `"hex"` { - t.Fatal("should be 'hex'") - } + require.NoError(err) + require.Equal(`"hex"`, string(jsonBytes)) } func TestEncodingUnmarshalJSON(t *testing.T) { + require := require.New(t) + jsonBytes := []byte(`"hex"`) var enc Encoding - if err := json.Unmarshal(jsonBytes, &enc); err != nil { - t.Fatal(err) - } - if enc != Hex { - t.Fatal("should be hex") - } + require.NoError(json.Unmarshal(jsonBytes, &enc)) + require.Equal(Hex, enc) + var serr *json.SyntaxError jsonBytes = []byte("") - if err := json.Unmarshal(jsonBytes, &enc); err == nil { - t.Fatal("should have erred due to invalid encoding") - } + require.ErrorAs(json.Unmarshal(jsonBytes, &enc), &serr) jsonBytes = []byte(`""`) - if err := json.Unmarshal(jsonBytes, &enc); err == nil { - t.Fatal("should have erred due to invalid encoding") - } + err := json.Unmarshal(jsonBytes, &enc) + require.ErrorIs(err, errInvalidEncoding) } func TestEncodingString(t *testing.T) { @@ -49,6 +44,8 @@ func TestEncodingString(t *testing.T) { // Test encoding bytes to a string and decoding back to bytes func TestEncodeDecode(t *testing.T) { + require := require.New(t) + type test struct { encoding Encoding bytes []byte @@ -82,44 +79,63 @@ func TestEncodeDecode(t *testing.T) { for _, test := range tests { // Encode the bytes strResult, err := Encode(test.encoding, test.bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Make sure the string repr. is what we expected - require.Equal(t, test.str, strResult) + require.Equal(test.str, strResult) // Decode the string bytesResult, err := Decode(test.encoding, strResult) - if err != nil { - t.Fatal(err) - } + require.NoError(err) // Make sure we got the same bytes back - require.Equal(t, test.bytes, bytesResult) + require.Equal(test.bytes, bytesResult) } } // Test that encoding nil bytes works func TestEncodeNil(t *testing.T) { + require := require.New(t) + str, err := Encode(Hex, nil) - if err != nil { - t.Fatal(err) - } - require.Equal(t, "0x7852b855", str) + require.NoError(err) + require.Equal("0x7852b855", str) } func TestDecodeHexInvalid(t *testing.T) { - invalidHex := []string{"0", "x", "0xg", "0x0017afa0Zd", "0xafafafafaf"} - for _, str := range invalidHex { - _, err := Decode(Hex, str) - if err == nil { - t.Fatalf("should have failed to decode invalid hex '%s'", str) - } + tests := []struct { + inputStr string + expectedErr error + }{ + { + inputStr: "0", + expectedErr: errMissingHexPrefix, + }, + { + inputStr: "x", + expectedErr: errMissingHexPrefix, + }, + { + inputStr: "0xg", + expectedErr: hex.InvalidByteError('g'), + }, + { + inputStr: "0x0017afa0Zd", + expectedErr: hex.InvalidByteError('Z'), + }, + { + inputStr: "0xafafafafaf", + expectedErr: errBadChecksum, + }, + } + for _, test := range tests { + _, err := Decode(Hex, test.inputStr) + require.ErrorIs(t, err, test.expectedErr) } } func TestDecodeNil(t *testing.T) { - if result, err := Decode(Hex, ""); err != nil || len(result) != 0 { - t.Fatal("decoding the empty string should return an empty byte slice") - } + require := require.New(t) + result, err := Decode(Hex, "") + require.NoError(err) + require.Empty(result) } func FuzzEncodeDecode(f *testing.F) { diff --git a/utils/formatting/int_format_test.go b/utils/formatting/int_format_test.go index bb730839d8d6..febf23bca4a2 100644 --- a/utils/formatting/int_format_test.go +++ b/utils/formatting/int_format_test.go @@ -3,40 +3,25 @@ package formatting -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestIntFormat(t *testing.T) { - if format := IntFormat(0); format != "%01d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(9); format != "%01d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(10); format != "%02d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(99); format != "%02d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(100); format != "%03d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(999); format != "%03d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(1000); format != "%04d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(9999); format != "%04d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(10000); format != "%05d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(99999); format != "%05d" { - t.Fatalf("Wrong int format: %s", format) - } - if format := IntFormat(100000); format != "%06d" { - t.Fatalf("Wrong int format: %s", format) - } + require := require.New(t) + + require.Equal("%01d", IntFormat(0)) + require.Equal("%01d", IntFormat(9)) + require.Equal("%02d", IntFormat(10)) + require.Equal("%02d", IntFormat(99)) + require.Equal("%03d", IntFormat(100)) + require.Equal("%03d", IntFormat(999)) + require.Equal("%04d", IntFormat(1000)) + require.Equal("%04d", IntFormat(9999)) + require.Equal("%05d", IntFormat(10000)) + require.Equal("%05d", IntFormat(99999)) + require.Equal("%06d", IntFormat(100000)) + require.Equal("%06d", IntFormat(999999)) } diff --git a/utils/hashing/consistent/ring_test.go b/utils/hashing/consistent/ring_test.go index a72cbadea478..f4ba21b5d939 100644 --- a/utils/hashing/consistent/ring_test.go +++ b/utils/hashing/consistent/ring_test.go @@ -177,6 +177,7 @@ func TestGetMapsToClockwiseNode(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + require := require.New(t) ring, hasher, ctrl := setupTest(t, 1) defer ctrl.Finish() @@ -196,8 +197,8 @@ func TestGetMapsToClockwiseNode(t *testing.T) { } node, err := ring.Get(test.key) - require.Equal(t, test.expectedNode, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(test.expectedNode, node) }) } } @@ -256,6 +257,8 @@ func TestRemoveExistingKeyReturnsTrue(t *testing.T) { // Tests that if we have a collision, the node is replaced. func TestAddCollisionReplacement(t *testing.T) { + require := require.New(t) + ring, hasher, ctrl := setupTest(t, 1) defer ctrl.Finish() @@ -280,13 +283,14 @@ func TestAddCollisionReplacement(t *testing.T) { ring.Add(node2) ringMember, err := ring.Get(foo) - - require.Equal(t, node2, ringMember) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, ringMember) } // Tests that virtual nodes are replicated on Add. func TestAddVirtualNodes(t *testing.T) { + require := require.New(t) + ring, hasher, ctrl := setupTest(t, 3) defer ctrl.Finish() @@ -328,29 +332,31 @@ func TestAddVirtualNodes(t *testing.T) { // Gets that should route to node-1 node, err := ring.Get(testKey{key: "foo1"}) - require.Equal(t, node1, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node1, node) node, err = ring.Get(testKey{key: "foo3"}) - require.Equal(t, node1, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node1, node) node, err = ring.Get(testKey{key: "foo5"}) - require.Equal(t, node1, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node1, node) // Gets that should route to node-2 node, err = ring.Get(testKey{key: "foo0"}) - require.Equal(t, node2, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, node) node, err = ring.Get(testKey{key: "foo2"}) - require.Equal(t, node2, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, node) node, err = ring.Get(testKey{key: "foo4"}) - require.Equal(t, node2, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, node) } // Tests that the node routed to changes if an Add results in a key shuffle. func TestGetShuffleOnAdd(t *testing.T) { + require := require.New(t) + ring, hasher, ctrl := setupTest(t, 1) defer ctrl.Finish() @@ -378,9 +384,8 @@ func TestGetShuffleOnAdd(t *testing.T) { // Ring: // ... -> node-1 -> foo -> ... node, err := ring.Get(foo) - - require.Equal(t, node1, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node1, node) // Add node-2, which results in foo being shuffled from node-1 to node-2. // @@ -393,13 +398,14 @@ func TestGetShuffleOnAdd(t *testing.T) { // Ring: // ... -> node-1 -> foo -> node-2 -> ... node, err = ring.Get(foo) - - require.Equal(t, node2, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, node) } // Tests that we can iterate around the ring. func TestIteration(t *testing.T) { + require := require.New(t) + ring, hasher, ctrl := setupTest(t, 1) defer ctrl.Finish() @@ -433,13 +439,13 @@ func TestIteration(t *testing.T) { // Ring: // ... -> foo -> node-1 -> node-2 -> ... node, err := ring.Get(foo) - require.Equal(t, node1, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node1, node) // iterate by re-using node-1 to get node-2 node, err = ring.Get(node) - require.Equal(t, node2, node) - require.Nil(t, err) + require.NoError(err) + require.Equal(node2, node) } func setupTest(t *testing.T, virtualNodes int) (Ring, *hashing.MockHasher, *gomock.Controller) { diff --git a/utils/ips/ip_test.go b/utils/ips/ip_test.go index d454d65acd5e..c3c569a8ae0a 100644 --- a/utils/ips/ip_test.go +++ b/utils/ips/ip_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net" + "strconv" "testing" "github.com/stretchr/testify/require" @@ -88,37 +89,70 @@ func TestIPPortString(t *testing.T) { } for _, tt := range tests { t.Run(tt.result, func(t *testing.T) { - if result := tt.ipPort.String(); result != tt.result { - t.Errorf("Expected %q, got %q", tt.result, result) - } + require.Equal(t, tt.result, tt.ipPort.String()) }) } } func TestToIPPortError(t *testing.T) { tests := []struct { - in string - out IPPort + in string + out IPPort + expectedErr error }{ - {"", IPPort{}}, - {":", IPPort{}}, - {"abc:", IPPort{}}, - {":abc", IPPort{}}, - {"abc:abc", IPPort{}}, - {"127.0.0.1:", IPPort{}}, - {":1", IPPort{}}, - {"::1", IPPort{}}, - {"::1:42", IPPort{}}, + { + in: "", + out: IPPort{}, + expectedErr: errBadIP, + }, + { + in: ":", + out: IPPort{}, + expectedErr: strconv.ErrSyntax, + }, + { + in: "abc:", + out: IPPort{}, + expectedErr: strconv.ErrSyntax, + }, + { + in: ":abc", + out: IPPort{}, + expectedErr: strconv.ErrSyntax, + }, + { + in: "abc:abc", + out: IPPort{}, + expectedErr: strconv.ErrSyntax, + }, + { + in: "127.0.0.1:", + out: IPPort{}, + expectedErr: strconv.ErrSyntax, + }, + { + in: ":1", + out: IPPort{}, + expectedErr: errBadIP, + }, + { + in: "::1", + out: IPPort{}, + expectedErr: errBadIP, + }, + { + in: "::1:42", + out: IPPort{}, + expectedErr: errBadIP, + }, } for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { + require := require.New(t) + result, err := ToIPPort(tt.in) - if err == nil { - t.Errorf("Unexpected success") - } - if !tt.out.Equal(result) { - t.Errorf("Expected %v, got %v", tt.out, result) - } + require.ErrorIs(err, tt.expectedErr) + require.Equal(tt.out, result) }) } } @@ -133,13 +167,11 @@ func TestToIPPort(t *testing.T) { } for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { + require := require.New(t) + result, err := ToIPPort(tt.in) - if err != nil { - t.Errorf("Unexpected error %v", err) - } - if !tt.out.Equal(result) { - t.Errorf("Expected %#v, got %#v", tt.out, result) - } + require.NoError(err) + require.Equal(tt.out, result) }) } } diff --git a/utils/json/float32_test.go b/utils/json/float32_test.go index a3a4fdc47ca7..3d336927ced5 100644 --- a/utils/json/float32_test.go +++ b/utils/json/float32_test.go @@ -6,9 +6,13 @@ package json import ( "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestFloat32(t *testing.T) { + require := require.New(t) + type test struct { f Float32 expectedStr string @@ -45,17 +49,11 @@ func TestFloat32(t *testing.T) { for _, tt := range tests { jsonBytes, err := tt.f.MarshalJSON() - if err != nil { - t.Fatalf("couldn't marshal %f: %s", float32(tt.f), err) - } else if string(jsonBytes) != fmt.Sprintf("\"%s\"", tt.expectedStr) { - t.Fatalf("expected %f to marshal to %s but got %s", tt.f, tt.expectedStr, string(jsonBytes)) - } + require.NoError(err) + require.Equal(fmt.Sprintf(`"%s"`, tt.expectedStr), string(jsonBytes)) var f Float32 - if err := f.UnmarshalJSON(jsonBytes); err != nil { - t.Fatalf("couldn't unmarshal %s to Float32: %s", string(jsonBytes), err) - } else if float32(f) != tt.expectedUnmarshalled { - t.Fatalf("expected %s to unmarshal to %f but got %f", string(jsonBytes), tt.expectedUnmarshalled, f) - } + require.NoError(f.UnmarshalJSON(jsonBytes)) + require.Equal(tt.expectedUnmarshalled, float32(f)) } } diff --git a/utils/logging/log_test.go b/utils/logging/log_test.go index 4242ecabac31..c968747ba726 100644 --- a/utils/logging/log_test.go +++ b/utils/logging/log_test.go @@ -3,7 +3,11 @@ package logging -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestLog(t *testing.T) { log := NewLogger("", NewWrappedCore(Info, Discard, Plain.ConsoleEncoder())) @@ -17,7 +21,5 @@ func TestLog(t *testing.T) { } log.RecoverAndExit(panicFunc, exitFunc) - if !*recovered { - t.Fatalf("Exit function was never called") - } + require.True(t, *recovered) } diff --git a/utils/math/continuous_averager_test.go b/utils/math/continuous_averager_test.go index 7eb4f25825a5..16f0f6913b90 100644 --- a/utils/math/continuous_averager_test.go +++ b/utils/math/continuous_averager_test.go @@ -11,42 +11,44 @@ import ( ) func TestAverager(t *testing.T) { + require := require.New(t) + halflife := time.Second currentTime := time.Now() a := NewSyncAverager(NewAverager(0, halflife, currentTime)) - expectedValue := float64(0) - require.Equal(t, expectedValue, a.Read()) + require.Zero(a.Read()) currentTime = currentTime.Add(halflife) a.Observe(1, currentTime) - expectedValue = 1.0 / 1.5 - require.Equal(t, expectedValue, a.Read()) + require.Equal(1.0/1.5, a.Read()) } func TestAveragerTimeTravel(t *testing.T) { + require := require.New(t) + halflife := time.Second currentTime := time.Now() a := NewSyncAverager(NewAverager(1, halflife, currentTime)) - expectedValue := float64(1) - require.Equal(t, expectedValue, a.Read()) + require.Equal(float64(1), a.Read()) currentTime = currentTime.Add(-halflife) a.Observe(0, currentTime) - expectedValue = 1.0 / 1.5 - require.Equal(t, expectedValue, a.Read()) + require.Equal(1.0/1.5, a.Read()) } func TestUninitializedAverager(t *testing.T) { + require := require.New(t) + halfLife := time.Second currentTime := time.Now() firstObservation := float64(10) a := NewUninitializedAverager(halfLife) - require.Zero(t, a.Read()) + require.Zero(a.Read()) a.Observe(firstObservation, currentTime) - require.Equal(t, firstObservation, a.Read()) + require.Equal(firstObservation, a.Read()) } diff --git a/utils/math/meter/meter_test.go b/utils/math/meter/meter_test.go index 2bf29185d9e0..9bffb77d9d1e 100644 --- a/utils/math/meter/meter_test.go +++ b/utils/math/meter/meter_test.go @@ -5,7 +5,6 @@ package meter import ( "fmt" - "math" "testing" "time" @@ -54,116 +53,91 @@ func TestMeters(t *testing.T) { } func NewTest(t *testing.T, factory Factory) { - m := factory.New(halflife) - require.NotNil(t, m, "should have returned a valid interface") + require.NotNil(t, factory.New(halflife)) } func TimeTravelTest(t *testing.T, factory Factory) { + require := require.New(t) + m := factory.New(halflife) now := time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC) m.Inc(now, 1) now = now.Add(halflife - 1) - epsilon := 0.0001 - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + delta := 0.0001 + require.InDelta(m.Read(now), .5, delta) m.Dec(now, 1) now = now.Add(-halflife) - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + require.InDelta(m.Read(now), .5, delta) m.Inc(now, 1) now = now.Add(halflife / 2) - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + require.InDelta(m.Read(now), .5, delta) } func StandardUsageTest(t *testing.T, factory Factory) { + require := require.New(t) + m := factory.New(halflife) now := time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC) m.Inc(now, 1) now = now.Add(halflife - 1) - epsilon := 0.0001 - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + delta := 0.0001 + require.InDelta(m.Read(now), .5, delta) m.Inc(now, 1) - - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + require.InDelta(m.Read(now), .5, delta) m.Dec(now, 1) - - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + require.InDelta(m.Read(now), .5, delta) m.Dec(now, 1) - if uptime := m.Read(now); math.Abs(uptime-.5) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .5, uptime) - } + require.InDelta(m.Read(now), .5, delta) now = now.Add(halflife) - if uptime := m.Read(now); math.Abs(uptime-.25) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .25, uptime) - } + require.InDelta(m.Read(now), .25, delta) m.Inc(now, 1) now = now.Add(halflife) - if uptime := m.Read(now); math.Abs(uptime-.625) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .625, uptime) - } + require.InDelta(m.Read(now), .625, delta) now = now.Add(34 * halflife) - if uptime := m.Read(now); math.Abs(uptime-1) > epsilon { - t.Fatalf("Wrong uptime value. Expected %d got %f", 1, uptime) - } + require.InDelta(m.Read(now), 1, delta) m.Dec(now, 1) now = now.Add(34 * halflife) - if uptime := m.Read(now); math.Abs(uptime-0) > epsilon { - t.Fatalf("Wrong uptime value. Expected %d got %f", 0, uptime) - } + require.InDelta(m.Read(now), 0, delta) m.Inc(now, 1) now = now.Add(2 * halflife) - if uptime := m.Read(now); math.Abs(uptime-.75) > epsilon { - t.Fatalf("Wrong uptime value. Expected %f got %f", .75, uptime) - } + require.InDelta(m.Read(now), .75, delta) // Second start m.Inc(now, 1) now = now.Add(34 * halflife) - if uptime := m.Read(now); math.Abs(uptime-2) > epsilon { - t.Fatalf("Wrong uptime value. Expected %d got %f", 2, uptime) - } + require.InDelta(m.Read(now), 2, delta) // Stop the second CPU m.Dec(now, 1) now = now.Add(34 * halflife) - if uptime := m.Read(now); math.Abs(uptime-1) > epsilon { - t.Fatalf("Wrong uptime value. Expected %d got %f", 1, uptime) - } + require.InDelta(m.Read(now), 1, delta) } func TestTimeUntil(t *testing.T) { + require := require.New(t) + halflife := 5 * time.Second f := ContinuousFactory{} m := f.New(halflife) @@ -184,9 +158,9 @@ func TestTimeUntil(t *testing.T) { now = now.Add(timeUntilDesiredVal) actualVal := m.Read(now) // Make sure the actual/expected are close - require.InDelta(t, desiredVal, actualVal, .00001) + require.InDelta(desiredVal, actualVal, .00001) // Make sure TimeUntil returns the zero duration if // the value provided >= the current value - require.Zero(t, m.TimeUntil(now, actualVal)) - require.Zero(t, m.TimeUntil(now, actualVal+.1)) + require.Zero(m.TimeUntil(now, actualVal)) + require.Zero(m.TimeUntil(now, actualVal+.1)) } diff --git a/utils/sampler/uniform_test.go b/utils/sampler/uniform_test.go index 34502e313b99..e5b00af31c26 100644 --- a/utils/sampler/uniform_test.go +++ b/utils/sampler/uniform_test.go @@ -97,34 +97,35 @@ func UniformOutOfRangeTest(t *testing.T, s Uniform) { } func UniformEmptyTest(t *testing.T, s Uniform) { + require := require.New(t) + s.Initialize(1) val, err := s.Sample(0) - require.NoError(t, err) - require.Empty(t, val) + require.NoError(err) + require.Empty(val) } func UniformSingletonTest(t *testing.T, s Uniform) { + require := require.New(t) + s.Initialize(1) val, err := s.Sample(1) - require.NoError(t, err) - require.Equal(t, []uint64{0}, val, "should have selected the only element") + require.NoError(err) + require.Equal([]uint64{0}, val) } func UniformDistributionTest(t *testing.T, s Uniform) { + require := require.New(t) + s.Initialize(3) val, err := s.Sample(3) - require.NoError(t, err) + require.NoError(err) slices.Sort(val) - require.Equal( - t, - []uint64{0, 1, 2}, - val, - "should have selected the only element", - ) + require.Equal([]uint64{0, 1, 2}, val) } func UniformOverSampleTest(t *testing.T, s Uniform) { @@ -135,20 +136,22 @@ func UniformOverSampleTest(t *testing.T, s Uniform) { } func UniformLazilySample(t *testing.T, s Uniform) { + require := require.New(t) + s.Initialize(3) for j := 0; j < 2; j++ { sampled := map[uint64]bool{} for i := 0; i < 3; i++ { val, err := s.Next() - require.NoError(t, err) - require.False(t, sampled[val]) + require.NoError(err) + require.False(sampled[val]) sampled[val] = true } _, err := s.Next() - require.ErrorIs(t, err, ErrOutOfRange) + require.ErrorIs(err, ErrOutOfRange) s.Reset() } diff --git a/utils/sampler/weighted_test.go b/utils/sampler/weighted_test.go index aba782ea3f6b..f826e879f408 100644 --- a/utils/sampler/weighted_test.go +++ b/utils/sampler/weighted_test.go @@ -94,36 +94,44 @@ func WeightedInitializeOverflowTest(t *testing.T, s Weighted) { } func WeightedOutOfRangeTest(t *testing.T, s Weighted) { - require.NoError(t, s.Initialize([]uint64{1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1})) _, err := s.Sample(1) - require.ErrorIs(t, err, ErrOutOfRange) + require.ErrorIs(err, ErrOutOfRange) } func WeightedSingletonTest(t *testing.T, s Weighted) { - require.NoError(t, s.Initialize([]uint64{1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1})) index, err := s.Sample(0) - require.NoError(t, err) - require.Zero(t, index, "should have selected the first element") + require.NoError(err) + require.Zero(index) } func WeightedWithZeroTest(t *testing.T, s Weighted) { - require.NoError(t, s.Initialize([]uint64{0, 1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{0, 1})) index, err := s.Sample(0) - require.NoError(t, err) - require.Equal(t, 1, index, "should have selected the second element") + require.NoError(err) + require.Equal(1, index) } func WeightedDistributionTest(t *testing.T, s Weighted) { - require.NoError(t, s.Initialize([]uint64{1, 1, 2, 3, 4})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1, 1, 2, 3, 4})) counts := make([]int, 5) for i := uint64(0); i < 11; i++ { index, err := s.Sample(i) - require.NoError(t, err) + require.NoError(err) counts[index]++ } - require.Equal(t, []int{1, 1, 2, 3, 4}, counts, "wrong distribution returned") + require.Equal([]int{1, 1, 2, 3, 4}, counts) } diff --git a/utils/sampler/weighted_without_replacement_benchmark_test.go b/utils/sampler/weighted_without_replacement_benchmark_test.go index 3d9b0085ab27..03459a5e757b 100644 --- a/utils/sampler/weighted_without_replacement_benchmark_test.go +++ b/utils/sampler/weighted_without_replacement_benchmark_test.go @@ -6,6 +6,8 @@ package sampler import ( "fmt" "testing" + + "github.com/stretchr/testify/require" ) // BenchmarkAllWeightedWithoutReplacement @@ -40,13 +42,11 @@ func WeightedWithoutReplacementPowBenchmark( size int, count int, ) { + require := require.New(b) + _, weights, err := CalcWeightedPoW(exponent, size) - if err != nil { - b.Fatal(err) - } - if err := s.Initialize(weights); err != nil { - b.Fatal(err) - } + require.NoError(err) + require.NoError(s.Initialize(weights)) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/utils/sampler/weighted_without_replacement_test.go b/utils/sampler/weighted_without_replacement_test.go index 3bba30870940..b4a6d5e4d3cc 100644 --- a/utils/sampler/weighted_without_replacement_test.go +++ b/utils/sampler/weighted_without_replacement_test.go @@ -95,75 +95,77 @@ func WeightedWithoutReplacementOutOfRangeTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize([]uint64{1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1})) _, err := s.Sample(2) - require.ErrorIs(t, err, ErrOutOfRange) + require.ErrorIs(err, ErrOutOfRange) } func WeightedWithoutReplacementEmptyWithoutWeightTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize(nil)) + require := require.New(t) + + require.NoError(s.Initialize(nil)) indices, err := s.Sample(0) - require.NoError(t, err) - require.Empty(t, indices, "shouldn't have selected any elements") + require.NoError(err) + require.Empty(indices) } func WeightedWithoutReplacementEmptyTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize([]uint64{1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1})) indices, err := s.Sample(0) - require.NoError(t, err) - require.Empty(t, indices, "shouldn't have selected any elements") + require.NoError(err) + require.Empty(indices) } func WeightedWithoutReplacementSingletonTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize([]uint64{1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1})) indices, err := s.Sample(1) - require.NoError(t, err) - require.Equal(t, []int{0}, indices, "should have selected the first element") + require.NoError(err) + require.Equal([]int{0}, indices) } func WeightedWithoutReplacementWithZeroTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize([]uint64{0, 1})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{0, 1})) indices, err := s.Sample(1) - require.NoError(t, err) - require.Equal( - t, - []int{1}, - indices, - "should have selected the second element", - ) + require.NoError(err) + require.Equal([]int{1}, indices) } func WeightedWithoutReplacementDistributionTest( t *testing.T, s WeightedWithoutReplacement, ) { - require.NoError(t, s.Initialize([]uint64{1, 1, 2})) + require := require.New(t) + + require.NoError(s.Initialize([]uint64{1, 1, 2})) indices, err := s.Sample(4) - require.NoError(t, err) + require.NoError(err) slices.Sort(indices) - require.Equal( - t, - []int{0, 1, 2, 2}, - indices, - "should have selected all the elements", - ) + require.Equal([]int{0, 1, 2, 2}, indices) } diff --git a/utils/set/bits_64_test.go b/utils/set/bits_64_test.go index 3f62d1cca3bf..87374b4f297a 100644 --- a/utils/set/bits_64_test.go +++ b/utils/set/bits_64_test.go @@ -3,152 +3,109 @@ package set -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestBits64(t *testing.T) { - var bs1 Bits64 + require := require.New(t) - if bs1.Len() != 0 { - t.Fatalf("Empty set's len should be 0") - } + var bs1 Bits64 + require.Empty(bs1) bs1.Add(5) - if bs1.Len() != 1 { - t.Fatalf("Wrong set length") - } else if !bs1.Contains(5) { - t.Fatalf("Set should contain element") - } + require.Equal(1, bs1.Len()) + require.True(bs1.Contains(5)) bs1.Add(10) - switch { - case bs1.Len() != 2: - t.Fatalf("Wrong set length") - case !bs1.Contains(5): - t.Fatalf("Set should contain element") - case !bs1.Contains(10): - t.Fatalf("Set should contain element") - } + require.Equal(2, bs1.Len()) + require.True(bs1.Contains(5)) + require.True(bs1.Contains(10)) bs1.Add(10) - switch { - case bs1.Len() != 2: - t.Fatalf("Wrong set length") - case !bs1.Contains(5): - t.Fatalf("Set should contain element") - case !bs1.Contains(10): - t.Fatalf("Set should contain element") - } + require.Equal(2, bs1.Len()) + require.True(bs1.Contains(5)) + require.True(bs1.Contains(10)) var bs2 Bits64 + require.Empty(bs2) bs2.Add(0) - if bs2.Len() != 1 { - t.Fatalf("Wrong set length") - } else if !bs2.Contains(0) { - t.Fatalf("Set should contain element") - } + require.Equal(1, bs2.Len()) + require.True(bs2.Contains(0)) bs2.Union(bs1) - switch { - case bs1.Len() != 2: - t.Fatalf("Wrong set length") - case !bs1.Contains(5): - t.Fatalf("Set should contain element") - case !bs1.Contains(10): - t.Fatalf("Set should contain element") - case bs2.Len() != 3: - t.Fatalf("Wrong set length") - case !bs2.Contains(0): - t.Fatalf("Set should contain element") - case !bs2.Contains(5): - t.Fatalf("Set should contain element") - case !bs2.Contains(10): - t.Fatalf("Set should contain element") - } + require.Equal(2, bs1.Len()) + require.True(bs1.Contains(5)) + require.True(bs1.Contains(10)) + require.Equal(3, bs2.Len()) + require.True(bs2.Contains(0)) + require.True(bs2.Contains(5)) + require.True(bs2.Contains(10)) bs1.Clear() - switch { - case bs1.Len() != 0: - t.Fatalf("Wrong set length") - case bs2.Len() != 3: - t.Fatalf("Wrong set length") - case !bs2.Contains(0): - t.Fatalf("Set should contain element") - case !bs2.Contains(5): - t.Fatalf("Set should contain element") - case !bs2.Contains(10): - t.Fatalf("Set should contain element") - } + require.Empty(bs1) + require.Equal(3, bs2.Len()) + require.True(bs2.Contains(0)) + require.True(bs2.Contains(5)) + require.True(bs2.Contains(10)) bs1.Add(63) - if bs1.Len() != 1 { - t.Fatalf("Wrong set length") - } else if !bs1.Contains(63) { - t.Fatalf("Set should contain element") - } + require.Equal(1, bs1.Len()) + require.True(bs1.Contains(63)) bs1.Add(1) - switch { - case bs1.Len() != 2: - t.Fatalf("Wrong set length") - case !bs1.Contains(1): - t.Fatalf("Set should contain element") - case !bs1.Contains(63): - t.Fatalf("Set should contain element") - } + require.Equal(2, bs1.Len()) + require.True(bs1.Contains(1)) + require.True(bs1.Contains(63)) bs1.Remove(63) - if bs1.Len() != 1 { - t.Fatalf("Wrong set length") - } else if !bs1.Contains(1) { - t.Fatalf("Set should contain element") - } + require.Equal(1, bs1.Len()) + require.True(bs1.Contains(1)) var bs3 Bits64 + require.Empty(bs3) bs3.Add(0) bs3.Add(2) bs3.Add(5) var bs4 Bits64 + require.Empty(bs4) bs4.Add(2) bs4.Add(5) bs3.Intersection(bs4) - switch { - case bs3.Len() != 2: - t.Fatalf("Wrong set length") - case !bs3.Contains(2): - t.Fatalf("Set should contain element") - case !bs3.Contains(5): - t.Fatalf("Set should contain element") - case bs4.Len() != 2: - t.Fatalf("Wrong set length") - } + require.Equal(2, bs3.Len()) + require.True(bs3.Contains(2)) + require.True(bs3.Contains(5)) + require.Equal(2, bs4.Len()) + require.True(bs4.Contains(2)) + require.True(bs4.Contains(5)) var bs5 Bits64 + require.Empty(bs5) bs5.Add(7) bs5.Add(11) bs5.Add(9) var bs6 Bits64 + require.Empty(bs6) bs6.Add(9) bs6.Add(11) bs5.Difference(bs6) - - switch { - case bs5.Len() != 1: - t.Fatalf("Wrong set length") - case !bs5.Contains(7): - t.Fatalf("Set should contain element") - case bs6.Len() != 2: - t.Fatalf("Wrong set length") - } + require.Equal(1, bs5.Len()) + require.True(bs5.Contains(7)) + require.Equal(2, bs6.Len()) + require.True(bs6.Contains(9)) + require.True(bs6.Contains(11)) } func TestBits64String(t *testing.T) { @@ -156,9 +113,5 @@ func TestBits64String(t *testing.T) { bs.Add(17) - expected := "0000000000020000" - - if bsString := bs.String(); bsString != expected { - t.Fatalf("BitSet.String returned %s expected %s", bsString, expected) - } + require.Equal(t, "0000000000020000", bs.String()) } diff --git a/utils/set/set_test.go b/utils/set/set_test.go index 603c46b5a35e..bb9bac58f9d9 100644 --- a/utils/set/set_test.go +++ b/utils/set/set_test.go @@ -76,14 +76,16 @@ func TestSetCappedList(t *testing.T) { } func TestSetClear(t *testing.T) { + require := require.New(t) + set := Set[int]{} for i := 0; i < 25; i++ { set.Add(i) } set.Clear() - require.Empty(t, set) + require.Empty(set) set.Add(1337) - require.Len(t, set, 1) + require.Len(set, 1) } func TestSetPop(t *testing.T) { diff --git a/utils/timer/adaptive_timeout_manager.go b/utils/timer/adaptive_timeout_manager.go index 8bfab05734e3..0a0a299cd1da 100644 --- a/utils/timer/adaptive_timeout_manager.go +++ b/utils/timer/adaptive_timeout_manager.go @@ -19,7 +19,10 @@ import ( ) var ( - errNonPositiveHalflife = errors.New("timeout halflife must be positive") + errNonPositiveHalflife = errors.New("timeout halflife must be positive") + errInitialTimeoutAboveMaximum = errors.New("initial timeout cannot be greater than maximum timeout") + errInitialTimeoutBelowMinimum = errors.New("initial timeout cannot be less than minimum timeout") + errTooSmallTimeoutCoefficient = errors.New("timeout coefficient must be >= 1") _ heap.Interface = (*timeoutQueue)(nil) _ AdaptiveTimeoutManager = (*adaptiveTimeoutManager)(nil) @@ -129,11 +132,11 @@ func NewAdaptiveTimeoutManager( ) (AdaptiveTimeoutManager, error) { switch { case config.InitialTimeout > config.MaximumTimeout: - return nil, fmt.Errorf("initial timeout (%s) > maximum timeout (%s)", config.InitialTimeout, config.MaximumTimeout) + return nil, fmt.Errorf("%w: (%s) > (%s)", errInitialTimeoutAboveMaximum, config.InitialTimeout, config.MaximumTimeout) case config.InitialTimeout < config.MinimumTimeout: - return nil, fmt.Errorf("initial timeout (%s) < minimum timeout (%s)", config.InitialTimeout, config.MinimumTimeout) + return nil, fmt.Errorf("%w: (%s) < (%s)", errInitialTimeoutBelowMinimum, config.InitialTimeout, config.MinimumTimeout) case config.TimeoutCoefficient < 1: - return nil, fmt.Errorf("timeout coefficient must be >= 1 but got %f", config.TimeoutCoefficient) + return nil, fmt.Errorf("%w: %f", errTooSmallTimeoutCoefficient, config.TimeoutCoefficient) case config.TimeoutHalflife <= 0: return nil, errNonPositiveHalflife } diff --git a/utils/timer/adaptive_timeout_manager_test.go b/utils/timer/adaptive_timeout_manager_test.go index 72f1e0c207e7..660401eb357f 100644 --- a/utils/timer/adaptive_timeout_manager_test.go +++ b/utils/timer/adaptive_timeout_manager_test.go @@ -18,8 +18,8 @@ import ( // Test that Initialize works func TestAdaptiveTimeoutManagerInit(t *testing.T) { type test struct { - config AdaptiveTimeoutConfig - shouldErrWith string + config AdaptiveTimeoutConfig + expectedErr error } tests := []test{ @@ -31,7 +31,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { TimeoutCoefficient: 2, TimeoutHalflife: 5 * time.Minute, }, - shouldErrWith: "initial timeout < minimum timeout", + expectedErr: errInitialTimeoutBelowMinimum, }, { config: AdaptiveTimeoutConfig{ @@ -41,7 +41,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { TimeoutCoefficient: 2, TimeoutHalflife: 5 * time.Minute, }, - shouldErrWith: "initial timeout > maximum timeout", + expectedErr: errInitialTimeoutAboveMaximum, }, { config: AdaptiveTimeoutConfig{ @@ -51,7 +51,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { TimeoutCoefficient: 0.9, TimeoutHalflife: 5 * time.Minute, }, - shouldErrWith: "timeout coefficient < 1", + expectedErr: errTooSmallTimeoutCoefficient, }, { config: AdaptiveTimeoutConfig{ @@ -60,7 +60,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { MaximumTimeout: 3 * time.Second, TimeoutCoefficient: 1, }, - shouldErrWith: "timeout halflife is 0", + expectedErr: errNonPositiveHalflife, }, { config: AdaptiveTimeoutConfig{ @@ -70,7 +70,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { TimeoutCoefficient: 1, TimeoutHalflife: -1 * time.Second, }, - shouldErrWith: "timeout halflife is negative", + expectedErr: errNonPositiveHalflife, }, { config: AdaptiveTimeoutConfig{ @@ -85,11 +85,7 @@ func TestAdaptiveTimeoutManagerInit(t *testing.T) { for _, test := range tests { _, err := NewAdaptiveTimeoutManager(&test.config, "", prometheus.NewRegistry()) - if err != nil && test.shouldErrWith == "" { - require.FailNow(t, "error from valid config", err) - } else if err == nil && test.shouldErrWith != "" { - require.FailNowf(t, "should have errored", test.shouldErrWith) - } + require.ErrorIs(t, err, test.expectedErr) } } @@ -105,9 +101,7 @@ func TestAdaptiveTimeoutManager(t *testing.T) { "", prometheus.NewRegistry(), ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) go tm.Dispatch() var lock sync.Mutex diff --git a/utils/timer/mockable/clock_test.go b/utils/timer/mockable/clock_test.go index b43da19cf8a1..eee8922e9e58 100644 --- a/utils/timer/mockable/clock_test.go +++ b/utils/timer/mockable/clock_test.go @@ -11,38 +11,34 @@ import ( ) func TestClockSet(t *testing.T) { + require := require.New(t) + clock := Clock{} - clock.Set(time.Unix(1000000, 0)) - if clock.faked == false { - t.Error("Fake time was set, but .faked flag was not set") - } - if !clock.Time().Equal(time.Unix(1000000, 0)) { - t.Error("Fake time was set, but not returned") - } + time := time.Unix(1000000, 0) + clock.Set(time) + require.True(clock.faked) + require.Equal(time, clock.Time()) } func TestClockSync(t *testing.T) { + require := require.New(t) + clock := Clock{true, time.Unix(0, 0)} clock.Sync() - if clock.faked == true { - t.Error("Clock was synced, but .faked flag was set") - } - if clock.Time().Equal(time.Unix(0, 0)) { - t.Error("Clock was synced, but returned a fake time") - } + require.False(clock.faked) + require.NotEqual(time.Unix(0, 0), clock.Time()) } func TestClockUnixTime(t *testing.T) { + require := require.New(t) + clock := Clock{true, time.Unix(123, 123)} - require.Zero(t, clock.UnixTime().Nanosecond()) - require.Equal(t, 123, clock.Time().Nanosecond()) + require.Zero(clock.UnixTime().Nanosecond()) + require.Equal(123, clock.Time().Nanosecond()) } func TestClockUnix(t *testing.T) { clock := Clock{true, time.Unix(-14159040, 0)} actual := clock.Unix() - if actual != 0 { - // We are Unix of 1970s, Moon landings are irrelevant - t.Errorf("Expected time prior to Unix epoch to be clamped to 0, got %d", actual) - } + require.Zero(t, actual) // time prior to Unix epoch should be clamped to 0 } diff --git a/utils/window/window_test.go b/utils/window/window_test.go index 9e36658850b8..9257dfebdfa2 100644 --- a/utils/window/window_test.go +++ b/utils/window/window_test.go @@ -42,6 +42,8 @@ func TestAdd(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + require := require.New(t) + window := New[int]( Config{ Clock: &mockable.Clock{}, @@ -55,10 +57,10 @@ func TestAdd(t *testing.T) { window.Add(test.newlyAdded) - require.Equal(t, len(test.window)+1, window.Length()) + require.Equal(len(test.window)+1, window.Length()) oldest, ok := window.Oldest() - require.Equal(t, test.expectedOldest, oldest) - require.True(t, ok) + require.True(ok) + require.Equal(test.expectedOldest, oldest) }) } } @@ -66,6 +68,8 @@ func TestAdd(t *testing.T) { // TestTTLAdd tests the case where an element is stale in the window // and needs to be evicted on Add. func TestTTLAdd(t *testing.T) { + require := require.New(t) + clock := mockable.Clock{} window := New[int]( Config{ @@ -83,10 +87,10 @@ func TestTTLAdd(t *testing.T) { window.Add(2) window.Add(3) - require.Equal(t, 3, window.Length()) + require.Equal(3, window.Length()) oldest, ok := window.Oldest() - require.Equal(t, 1, oldest) - require.True(t, ok) + require.True(ok) + require.Equal(1, oldest) // Now we're one second past the ttl of 10 seconds as defined in testTTL, // so all existing elements need to be evicted. clock.Set(epochStart.Add(11 * time.Second)) @@ -95,25 +99,27 @@ func TestTTLAdd(t *testing.T) { // [4] window.Add(4) - require.Equal(t, 1, window.Length()) + require.Equal(1, window.Length()) oldest, ok = window.Oldest() - require.Equal(t, 4, oldest) - require.True(t, ok) + require.True(ok) + require.Equal(4, oldest) // Now we're one second past the ttl of 10 seconds of when [4] was added, // so all existing elements should be evicted. clock.Set(epochStart.Add(22 * time.Second)) // Now the window should look like this: // [] - require.Zero(t, window.Length()) + require.Zero(window.Length()) oldest, ok = window.Oldest() - require.Zero(t, oldest) - require.False(t, ok) + require.False(ok) + require.Zero(oldest) } // TestTTLReadOnly tests that stale elements are still evicted on Length func TestTTLLength(t *testing.T) { + require := require.New(t) + clock := mockable.Clock{} window := New[int]( Config{ @@ -131,18 +137,20 @@ func TestTTLLength(t *testing.T) { window.Add(2) window.Add(3) - require.Equal(t, 3, window.Length()) + require.Equal(3, window.Length()) // Now we're one second past the ttl of 10 seconds as defined in testTTL, // so all existing elements need to be evicted. clock.Set(epochStart.Add(11 * time.Second)) // No more elements should be present in the window. - require.Zero(t, window.Length()) + require.Zero(window.Length()) } // TestTTLReadOnly tests that stale elements are still evicted on calling Oldest func TestTTLOldest(t *testing.T) { + require := require.New(t) + clock := mockable.Clock{} windowIntf := New[int]( Config{ @@ -151,7 +159,7 @@ func TestTTLOldest(t *testing.T) { TTL: testTTL, }, ) - require.IsType(t, &window[int]{}, windowIntf) + require.IsType(&window[int]{}, windowIntf) window := windowIntf.(*window[int]) epochStart := time.Unix(0, 0) clock.Set(epochStart) @@ -163,9 +171,9 @@ func TestTTLOldest(t *testing.T) { window.Add(3) oldest, ok := window.Oldest() - require.Equal(t, 1, oldest) - require.True(t, ok) - require.Equal(t, 3, window.elements.Len()) + require.True(ok) + require.Equal(1, oldest) + require.Equal(3, window.elements.Len()) // Now we're one second past the ttl of 10 seconds as defined in testTTL, // so all existing elements need to be evicted. @@ -173,13 +181,15 @@ func TestTTLOldest(t *testing.T) { // Now there shouldn't be any elements in the window oldest, ok = window.Oldest() - require.Zero(t, oldest) - require.False(t, ok) - require.Zero(t, window.elements.Len()) + require.False(ok) + require.Zero(oldest) + require.Zero(window.elements.Len()) } // Tests that we bound the amount of elements in the window func TestMaxCapacity(t *testing.T) { + require := require.New(t) + window := New[int]( Config{ Clock: &mockable.Clock{}, @@ -207,8 +217,8 @@ func TestMaxCapacity(t *testing.T) { // [4, 5, 6] window.Add(6) - require.Equal(t, 3, window.Length()) + require.Equal(3, window.Length()) oldest, ok := window.Oldest() - require.Equal(t, 4, oldest) - require.True(t, ok) + require.True(ok) + require.Equal(4, oldest) } diff --git a/version/application_test.go b/version/application_test.go index 95757f302781..0423e91918e5 100644 --- a/version/application_test.go +++ b/version/application_test.go @@ -11,15 +11,17 @@ import ( ) func TestNewDefaultApplication(t *testing.T) { + require := require.New(t) + v := &Application{ Major: 1, Minor: 2, Patch: 3, } - require.Equal(t, "avalanche/1.2.3", v.String()) - require.NoError(t, v.Compatible(v)) - require.False(t, v.Before(v)) + require.Equal("avalanche/1.2.3", v.String()) + require.NoError(v.Compatible(v)) + require.False(v.Before(v)) } func TestComparingVersions(t *testing.T) { @@ -130,19 +132,14 @@ func TestComparingVersions(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("%s %s", test.myVersion, test.peerVersion), func(t *testing.T) { + require := require.New(t) err := test.myVersion.Compatible(test.peerVersion) - if test.compatible && err != nil { - t.Fatalf("Expected version to be compatible but returned: %s", - err) - } else if !test.compatible && err == nil { - t.Fatalf("Expected version to be incompatible but returned no error") - } - before := test.myVersion.Before(test.peerVersion) - if test.before && !before { - t.Fatalf("Expected version to be before the peer version but wasn't") - } else if !test.before && before { - t.Fatalf("Expected version not to be before the peer version but was") + if test.compatible { + require.NoError(err) + } else { + require.ErrorIs(err, errDifferentMajor) } + require.Equal(test.before, test.myVersion.Before(test.peerVersion)) }) } } diff --git a/version/compatibility_test.go b/version/compatibility_test.go index 014f28adf87c..a95a85af55dd 100644 --- a/version/compatibility_test.go +++ b/version/compatibility_test.go @@ -38,9 +38,9 @@ func TestCompatibility(t *testing.T) { require.Equal(t, v, compatibility.Version()) tests := []struct { - peer *Application - time time.Time - compatible bool + peer *Application + time time.Time + expectedErr error }{ { peer: &Application{ @@ -48,8 +48,7 @@ func TestCompatibility(t *testing.T) { Minor: 5, Patch: 0, }, - time: minCompatableTime, - compatible: true, + time: minCompatableTime, }, { peer: &Application{ @@ -57,8 +56,7 @@ func TestCompatibility(t *testing.T) { Minor: 3, Patch: 5, }, - time: time.Unix(8500, 0), - compatible: true, + time: time.Unix(8500, 0), }, { peer: &Application{ @@ -66,8 +64,8 @@ func TestCompatibility(t *testing.T) { Minor: 1, Patch: 0, }, - time: minCompatableTime, - compatible: false, + time: minCompatableTime, + expectedErr: errDifferentMajor, }, { peer: &Application{ @@ -75,8 +73,8 @@ func TestCompatibility(t *testing.T) { Minor: 3, Patch: 5, }, - time: minCompatableTime, - compatible: false, + time: minCompatableTime, + expectedErr: errIncompatible, }, { peer: &Application{ @@ -84,8 +82,8 @@ func TestCompatibility(t *testing.T) { Minor: 2, Patch: 5, }, - time: time.Unix(8500, 0), - compatible: false, + time: time.Unix(8500, 0), + expectedErr: errIncompatible, }, { peer: &Application{ @@ -93,19 +91,16 @@ func TestCompatibility(t *testing.T) { Minor: 1, Patch: 5, }, - time: time.Unix(7500, 0), - compatible: false, + time: time.Unix(7500, 0), + expectedErr: errIncompatible, }, } for _, test := range tests { peer := test.peer compatibility.clock.Set(test.time) t.Run(fmt.Sprintf("%s-%s", peer, test.time), func(t *testing.T) { - if err := compatibility.Compatible(peer); test.compatible && err != nil { - t.Fatalf("incorrectly marked %s as incompatible with %s", peer, err) - } else if !test.compatible && err == nil { - t.Fatalf("incorrectly marked %s as compatible", peer) - } + err := compatibility.Compatible(peer) + require.ErrorIs(t, err, test.expectedErr) }) } }