From 618f02c27e2e95b9aec88a54095d6fddda7c07bb Mon Sep 17 00:00:00 2001 From: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:59:33 -0500 Subject: [PATCH] Gossip Test structs (#2514) Signed-off-by: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> --- network/p2p/gossip/gossip.go | 31 +++++++++++++++++++++++ network/p2p/gossip/gossip_test.go | 21 +++------------- network/p2p/handler.go | 31 +++++++++++++++++++++++ network/p2p/handler_test.go | 4 +-- network/p2p/network_test.go | 24 +++++++++--------- network/p2p/throttler_handler_test.go | 36 +++------------------------ 6 files changed, 83 insertions(+), 64 deletions(-) diff --git a/network/p2p/gossip/gossip.go b/network/p2p/gossip/gossip.go index f8c39e258fcc..4f5f5e5ebfc3 100644 --- a/network/p2p/gossip/gossip.go +++ b/network/p2p/gossip/gossip.go @@ -33,9 +33,11 @@ var ( _ Gossiper = (*ValidatorGossiper)(nil) _ Gossiper = (*PullGossiper[testTx, *testTx])(nil) _ Gossiper = (*NoOpGossiper)(nil) + _ Gossiper = (*TestGossiper)(nil) _ Accumulator[*testTx] = (*PushGossiper[*testTx])(nil) _ Accumulator[*testTx] = (*NoOpAccumulator[*testTx])(nil) + _ Accumulator[*testTx] = (*TestAccumulator[*testTx])(nil) metricLabels = []string{typeLabel} ) @@ -359,3 +361,32 @@ func (NoOpAccumulator[_]) Gossip(context.Context) error { } func (NoOpAccumulator[T]) Add(...T) {} + +type TestGossiper struct { + GossipF func(ctx context.Context) error +} + +func (t *TestGossiper) Gossip(ctx context.Context) error { + return t.GossipF(ctx) +} + +type TestAccumulator[T Gossipable] struct { + GossipF func(ctx context.Context) error + AddF func(...T) +} + +func (t TestAccumulator[T]) Gossip(ctx context.Context) error { + if t.GossipF == nil { + return nil + } + + return t.GossipF(ctx) +} + +func (t TestAccumulator[T]) Add(gossipables ...T) { + if t.AddF == nil { + return + } + + t.AddF(gossipables...) +} diff --git a/network/p2p/gossip/gossip_test.go b/network/p2p/gossip/gossip_test.go index 5ccebba9756c..a25a6dce07aa 100644 --- a/network/p2p/gossip/gossip_test.go +++ b/network/p2p/gossip/gossip_test.go @@ -26,11 +26,6 @@ import ( "github.com/ava-labs/avalanchego/utils/units" ) -var ( - _ p2p.ValidatorSet = (*testValidatorSet)(nil) - _ Gossiper = (*testGossiper)(nil) -) - func TestGossiperShutdown(*testing.T) { gossiper := NewPullGossiper[testTx]( logging.NoLog{}, @@ -190,8 +185,8 @@ func TestGossiperGossip(t *testing.T) { func TestEvery(*testing.T) { ctx, cancel := context.WithCancel(context.Background()) calls := 0 - gossiper := &testGossiper{ - gossipF: func(context.Context) error { + gossiper := &TestGossiper{ + GossipF: func(context.Context) error { if calls >= 10 { cancel() return nil @@ -217,8 +212,8 @@ func TestValidatorGossiper(t *testing.T) { calls := 0 gossiper := ValidatorGossiper{ - Gossiper: &testGossiper{ - gossipF: func(context.Context) error { + Gossiper: &TestGossiper{ + GossipF: func(context.Context) error { calls++ return nil }, @@ -439,14 +434,6 @@ func TestPushGossipE2E(t *testing.T) { require.Equal(want, gotForwarded) } -type testGossiper struct { - gossipF func(ctx context.Context) error -} - -func (t *testGossiper) Gossip(ctx context.Context) error { - return t.gossipF(ctx) -} - type testValidatorSet struct { validators set.Set[ids.NodeID] } diff --git a/network/p2p/handler.go b/network/p2p/handler.go index 27c220565a04..6829f0b94213 100644 --- a/network/p2p/handler.go +++ b/network/p2p/handler.go @@ -20,6 +20,7 @@ var ( ErrNotValidator = errors.New("not a validator") _ Handler = (*NoOpHandler)(nil) + _ Handler = (*TestHandler)(nil) _ Handler = (*ValidatorHandler)(nil) ) @@ -151,3 +152,33 @@ func (r *responder) CrossChainAppRequest(ctx context.Context, chainID ids.ID, re return r.sender.SendCrossChainAppResponse(ctx, chainID, requestID, appResponse) } + +type TestHandler struct { + AppGossipF func(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) + AppRequestF func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) + CrossChainAppRequestF func(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) +} + +func (t TestHandler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) { + if t.AppGossipF == nil { + return + } + + t.AppGossipF(ctx, nodeID, gossipBytes) +} + +func (t TestHandler) AppRequest(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) { + if t.AppRequestF == nil { + return nil, nil + } + + return t.AppRequestF(ctx, nodeID, deadline, requestBytes) +} + +func (t TestHandler) CrossChainAppRequest(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) { + if t.CrossChainAppRequestF == nil { + return nil, nil + } + + return t.CrossChainAppRequestF(ctx, chainID, deadline, requestBytes) +} diff --git a/network/p2p/handler_test.go b/network/p2p/handler_test.go index 3bbb6bc46711..3ed82cb06cbd 100644 --- a/network/p2p/handler_test.go +++ b/network/p2p/handler_test.go @@ -56,8 +56,8 @@ func TestValidatorHandlerAppGossip(t *testing.T) { called := false handler := NewValidatorHandler( - &testHandler{ - appGossipF: func(context.Context, ids.NodeID, []byte) { + &TestHandler{ + AppGossipF: func(context.Context, ids.NodeID, []byte) { called = true }, }, diff --git a/network/p2p/network_test.go b/network/p2p/network_test.go index ef42830db6ac..f86399809885 100644 --- a/network/p2p/network_test.go +++ b/network/p2p/network_test.go @@ -38,19 +38,19 @@ func TestMessageRouting(t *testing.T) { wantMsg := []byte("message") var appGossipCalled, appRequestCalled, crossChainAppRequestCalled bool - testHandler := &testHandler{ - appGossipF: func(_ context.Context, nodeID ids.NodeID, msg []byte) { + testHandler := &TestHandler{ + AppGossipF: func(_ context.Context, nodeID ids.NodeID, msg []byte) { appGossipCalled = true require.Equal(wantNodeID, nodeID) require.Equal(wantMsg, msg) }, - appRequestF: func(_ context.Context, nodeID ids.NodeID, _ time.Time, msg []byte) ([]byte, error) { + AppRequestF: func(_ context.Context, nodeID ids.NodeID, _ time.Time, msg []byte) ([]byte, error) { appRequestCalled = true require.Equal(wantNodeID, nodeID) require.Equal(wantMsg, msg) return nil, nil }, - crossChainAppRequestF: func(_ context.Context, chainID ids.ID, _ time.Time, msg []byte) ([]byte, error) { + CrossChainAppRequestF: func(_ context.Context, chainID ids.ID, _ time.Time, msg []byte) ([]byte, error) { crossChainAppRequestCalled = true require.Equal(wantChainID, chainID) require.Equal(wantMsg, msg) @@ -290,15 +290,15 @@ func TestMessageForUnregisteredHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) ctx := context.Background() - handler := &testHandler{ - appGossipF: func(context.Context, ids.NodeID, []byte) { + handler := &TestHandler{ + AppGossipF: func(context.Context, ids.NodeID, []byte) { require.Fail("should not be called") }, - appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { require.Fail("should not be called") return nil, nil }, - crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { require.Fail("should not be called") return nil, nil }, @@ -338,15 +338,15 @@ func TestResponseForUnrequestedRequest(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) ctx := context.Background() - handler := &testHandler{ - appGossipF: func(context.Context, ids.NodeID, []byte) { + handler := &TestHandler{ + AppGossipF: func(context.Context, ids.NodeID, []byte) { require.Fail("should not be called") }, - appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { require.Fail("should not be called") return nil, nil }, - crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { require.Fail("should not be called") return nil, nil }, diff --git a/network/p2p/throttler_handler_test.go b/network/p2p/throttler_handler_test.go index 79b1dc88665d..8c18a10dc308 100644 --- a/network/p2p/throttler_handler_test.go +++ b/network/p2p/throttler_handler_test.go @@ -14,7 +14,7 @@ import ( "github.com/ava-labs/avalanchego/utils/logging" ) -var _ Handler = (*testHandler)(nil) +var _ Handler = (*TestHandler)(nil) func TestThrottlerHandlerAppGossip(t *testing.T) { tests := []struct { @@ -38,8 +38,8 @@ func TestThrottlerHandlerAppGossip(t *testing.T) { called := false handler := NewThrottlerHandler( - testHandler{ - appGossipF: func(context.Context, ids.NodeID, []byte) { + TestHandler{ + AppGossipF: func(context.Context, ids.NodeID, []byte) { called = true }, }, @@ -83,33 +83,3 @@ func TestThrottlerHandlerAppRequest(t *testing.T) { }) } } - -type testHandler struct { - appGossipF func(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) - appRequestF func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) - crossChainAppRequestF func(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) -} - -func (t testHandler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) { - if t.appGossipF == nil { - return - } - - t.appGossipF(ctx, nodeID, gossipBytes) -} - -func (t testHandler) AppRequest(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) { - if t.appRequestF == nil { - return nil, nil - } - - return t.appRequestF(ctx, nodeID, deadline, requestBytes) -} - -func (t testHandler) CrossChainAppRequest(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) { - if t.crossChainAppRequestF == nil { - return nil, nil - } - - return t.crossChainAppRequestF(ctx, chainID, deadline, requestBytes) -}