Skip to content

Commit

Permalink
Gossip Test structs (#2514)
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Kim <[email protected]>
  • Loading branch information
joshua-kim authored Dec 19, 2023
1 parent fc3ffb3 commit 618f02c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 64 deletions.
31 changes: 31 additions & 0 deletions network/p2p/gossip/gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ var (
_ Gossiper = (*ValidatorGossiper)(nil)
_ Gossiper = (*PullGossiper[testTx, *testTx])(nil)
_ Gossiper = (*NoOpGossiper)(nil)
_ Gossiper = (*TestGossiper)(nil)

_ Accumulator[*testTx] = (*PushGossiper[*testTx])(nil)
_ Accumulator[*testTx] = (*NoOpAccumulator[*testTx])(nil)
_ Accumulator[*testTx] = (*TestAccumulator[*testTx])(nil)

metricLabels = []string{typeLabel}
)
Expand Down Expand Up @@ -359,3 +361,32 @@ func (NoOpAccumulator[_]) Gossip(context.Context) error {
}

func (NoOpAccumulator[T]) Add(...T) {}

type TestGossiper struct {
GossipF func(ctx context.Context) error
}

func (t *TestGossiper) Gossip(ctx context.Context) error {
return t.GossipF(ctx)
}

type TestAccumulator[T Gossipable] struct {
GossipF func(ctx context.Context) error
AddF func(...T)
}

func (t TestAccumulator[T]) Gossip(ctx context.Context) error {
if t.GossipF == nil {
return nil
}

return t.GossipF(ctx)
}

func (t TestAccumulator[T]) Add(gossipables ...T) {
if t.AddF == nil {
return
}

t.AddF(gossipables...)
}
21 changes: 4 additions & 17 deletions network/p2p/gossip/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ import (
"github.com/ava-labs/avalanchego/utils/units"
)

var (
_ p2p.ValidatorSet = (*testValidatorSet)(nil)
_ Gossiper = (*testGossiper)(nil)
)

func TestGossiperShutdown(*testing.T) {
gossiper := NewPullGossiper[testTx](
logging.NoLog{},
Expand Down Expand Up @@ -190,8 +185,8 @@ func TestGossiperGossip(t *testing.T) {
func TestEvery(*testing.T) {
ctx, cancel := context.WithCancel(context.Background())
calls := 0
gossiper := &testGossiper{
gossipF: func(context.Context) error {
gossiper := &TestGossiper{
GossipF: func(context.Context) error {
if calls >= 10 {
cancel()
return nil
Expand All @@ -217,8 +212,8 @@ func TestValidatorGossiper(t *testing.T) {

calls := 0
gossiper := ValidatorGossiper{
Gossiper: &testGossiper{
gossipF: func(context.Context) error {
Gossiper: &TestGossiper{
GossipF: func(context.Context) error {
calls++
return nil
},
Expand Down Expand Up @@ -439,14 +434,6 @@ func TestPushGossipE2E(t *testing.T) {
require.Equal(want, gotForwarded)
}

type testGossiper struct {
gossipF func(ctx context.Context) error
}

func (t *testGossiper) Gossip(ctx context.Context) error {
return t.gossipF(ctx)
}

type testValidatorSet struct {
validators set.Set[ids.NodeID]
}
Expand Down
31 changes: 31 additions & 0 deletions network/p2p/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
ErrNotValidator = errors.New("not a validator")

_ Handler = (*NoOpHandler)(nil)
_ Handler = (*TestHandler)(nil)
_ Handler = (*ValidatorHandler)(nil)
)

Expand Down Expand Up @@ -151,3 +152,33 @@ func (r *responder) CrossChainAppRequest(ctx context.Context, chainID ids.ID, re

return r.sender.SendCrossChainAppResponse(ctx, chainID, requestID, appResponse)
}

type TestHandler struct {
AppGossipF func(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte)
AppRequestF func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error)
CrossChainAppRequestF func(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error)
}

func (t TestHandler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) {
if t.AppGossipF == nil {
return
}

t.AppGossipF(ctx, nodeID, gossipBytes)
}

func (t TestHandler) AppRequest(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) {
if t.AppRequestF == nil {
return nil, nil
}

return t.AppRequestF(ctx, nodeID, deadline, requestBytes)
}

func (t TestHandler) CrossChainAppRequest(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) {
if t.CrossChainAppRequestF == nil {
return nil, nil
}

return t.CrossChainAppRequestF(ctx, chainID, deadline, requestBytes)
}
4 changes: 2 additions & 2 deletions network/p2p/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ func TestValidatorHandlerAppGossip(t *testing.T) {

called := false
handler := NewValidatorHandler(
&testHandler{
appGossipF: func(context.Context, ids.NodeID, []byte) {
&TestHandler{
AppGossipF: func(context.Context, ids.NodeID, []byte) {
called = true
},
},
Expand Down
24 changes: 12 additions & 12 deletions network/p2p/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ func TestMessageRouting(t *testing.T) {
wantMsg := []byte("message")

var appGossipCalled, appRequestCalled, crossChainAppRequestCalled bool
testHandler := &testHandler{
appGossipF: func(_ context.Context, nodeID ids.NodeID, msg []byte) {
testHandler := &TestHandler{
AppGossipF: func(_ context.Context, nodeID ids.NodeID, msg []byte) {
appGossipCalled = true
require.Equal(wantNodeID, nodeID)
require.Equal(wantMsg, msg)
},
appRequestF: func(_ context.Context, nodeID ids.NodeID, _ time.Time, msg []byte) ([]byte, error) {
AppRequestF: func(_ context.Context, nodeID ids.NodeID, _ time.Time, msg []byte) ([]byte, error) {
appRequestCalled = true
require.Equal(wantNodeID, nodeID)
require.Equal(wantMsg, msg)
return nil, nil
},
crossChainAppRequestF: func(_ context.Context, chainID ids.ID, _ time.Time, msg []byte) ([]byte, error) {
CrossChainAppRequestF: func(_ context.Context, chainID ids.ID, _ time.Time, msg []byte) ([]byte, error) {
crossChainAppRequestCalled = true
require.Equal(wantChainID, chainID)
require.Equal(wantMsg, msg)
Expand Down Expand Up @@ -290,15 +290,15 @@ func TestMessageForUnregisteredHandler(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
ctx := context.Background()
handler := &testHandler{
appGossipF: func(context.Context, ids.NodeID, []byte) {
handler := &TestHandler{
AppGossipF: func(context.Context, ids.NodeID, []byte) {
require.Fail("should not be called")
},
appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
require.Fail("should not be called")
return nil, nil
},
crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
require.Fail("should not be called")
return nil, nil
},
Expand Down Expand Up @@ -338,15 +338,15 @@ func TestResponseForUnrequestedRequest(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
ctx := context.Background()
handler := &testHandler{
appGossipF: func(context.Context, ids.NodeID, []byte) {
handler := &TestHandler{
AppGossipF: func(context.Context, ids.NodeID, []byte) {
require.Fail("should not be called")
},
appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
require.Fail("should not be called")
return nil, nil
},
crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
require.Fail("should not be called")
return nil, nil
},
Expand Down
36 changes: 3 additions & 33 deletions network/p2p/throttler_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/ava-labs/avalanchego/utils/logging"
)

var _ Handler = (*testHandler)(nil)
var _ Handler = (*TestHandler)(nil)

func TestThrottlerHandlerAppGossip(t *testing.T) {
tests := []struct {
Expand All @@ -38,8 +38,8 @@ func TestThrottlerHandlerAppGossip(t *testing.T) {

called := false
handler := NewThrottlerHandler(
testHandler{
appGossipF: func(context.Context, ids.NodeID, []byte) {
TestHandler{
AppGossipF: func(context.Context, ids.NodeID, []byte) {
called = true
},
},
Expand Down Expand Up @@ -83,33 +83,3 @@ func TestThrottlerHandlerAppRequest(t *testing.T) {
})
}
}

type testHandler struct {
appGossipF func(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte)
appRequestF func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error)
crossChainAppRequestF func(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error)
}

func (t testHandler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) {
if t.appGossipF == nil {
return
}

t.appGossipF(ctx, nodeID, gossipBytes)
}

func (t testHandler) AppRequest(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) {
if t.appRequestF == nil {
return nil, nil
}

return t.appRequestF(ctx, nodeID, deadline, requestBytes)
}

func (t testHandler) CrossChainAppRequest(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) {
if t.crossChainAppRequestF == nil {
return nil, nil
}

return t.crossChainAppRequestF(ctx, chainID, deadline, requestBytes)
}

0 comments on commit 618f02c

Please sign in to comment.