From 628353661b1292a682ddd898da1594f9795e7670 Mon Sep 17 00:00:00 2001 From: Simon Zhu Date: Mon, 20 Sep 2021 20:17:33 -0700 Subject: [PATCH] Create peer filter option --- floodsub_test.go | 9 +++++++++ gossipsub.go | 15 ++++++++++++++- gossipsub_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ pubsub.go | 27 +++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/floodsub_test.go b/floodsub_test.go index b1bf3f37..a250044d 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -126,6 +126,15 @@ func assertReceive(t *testing.T, ch *Subscription, exp []byte) { } } +func assertNeverReceives(t *testing.T, ch *Subscription, timeout time.Duration) { + select { + case msg := <-ch.ch: + t.Logf("%#v\n", ch) + t.Fatal("got unexpected message: ", string(msg.GetData())) + case <-time.After(timeout): + } +} + func TestBasicFloodsub(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/gossipsub.go b/gossipsub.go index 98203b34..7c2da306 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -640,6 +640,10 @@ func (gs *GossipSubRouter) handleIHave(p peer.ID, ctl *pb.ControlMessage) []*pb. continue } + if !gs.p.peerFilter(p, topic) { + continue + } + for _, mid := range ihave.GetMessageIDs() { if gs.p.seenMessage(mid) { continue @@ -692,6 +696,10 @@ func (gs *GossipSubRouter) handleIWant(p peer.ID, ctl *pb.ControlMessage) []*pb. continue } + if !gs.p.peerFilter(p, msg.GetTopic()) { + continue + } + if count > gs.params.GossipRetransmission { log.Debugf("IWANT: Peer %s has asked for message %s too many times; ignoring request", p, mid) continue @@ -724,6 +732,11 @@ func (gs *GossipSubRouter) handleGraft(p peer.ID, ctl *pb.ControlMessage) []*pb. for _, graft := range ctl.GetGraft() { topic := graft.GetTopicID() + + if !gs.p.peerFilter(p, topic) { + continue + } + peers, ok := gs.mesh[topic] if !ok { // don't do PX when there is an unknown topic to avoid leaking our peers @@ -1857,7 +1870,7 @@ func (gs *GossipSubRouter) getPeers(topic string, count int, filter func(peer.ID peers := make([]peer.ID, 0, len(tmap)) for p := range tmap { - if gs.feature(GossipSubFeatureMesh, gs.peers[p]) && filter(p) { + if gs.feature(GossipSubFeatureMesh, gs.peers[p]) && filter(p) && gs.p.peerFilter(p, topic) { peers = append(peers, p) } } diff --git a/gossipsub_test.go b/gossipsub_test.go index 49830b02..0dfd6ded 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -1183,6 +1183,47 @@ func TestGossipsubDirectPeers(t *testing.T) { } } +func TestGossipSubPeerFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := getNetHosts(t, ctx, 3) + psubs := []*PubSub{ + getGossipsub(ctx, h[0], WithPeerFilter(func(pid peer.ID, topic string) bool { + return pid == h[1].ID() + })), + getGossipsub(ctx, h[1], WithPeerFilter(func(pid peer.ID, topic string) bool { + return pid == h[0].ID() + })), + getGossipsub(ctx, h[2]), + } + + connect(t, h[0], h[1]) + connect(t, h[0], h[2]) + + // Join all peers + var subs []*Subscription + for _, ps := range psubs { + sub, err := ps.Subscribe("test") + if err != nil { + t.Fatal(err) + } + subs = append(subs, sub) + } + + time.Sleep(time.Second) + + msg := []byte("message") + + psubs[0].Publish("test", msg) + assertReceive(t, subs[1], msg) + assertNeverReceives(t, subs[2], time.Second) + + psubs[1].Publish("test", msg) + assertReceive(t, subs[0], msg) + assertNeverReceives(t, subs[2], time.Second) +} + func TestGossipsubDirectPeersFanout(t *testing.T) { // regression test for #371 ctx, cancel := context.WithCancel(context.Background()) diff --git a/pubsub.go b/pubsub.go index f0296262..05ead71d 100644 --- a/pubsub.go +++ b/pubsub.go @@ -57,6 +57,8 @@ type PubSub struct { tracer *pubsubTracer + peerFilter PeerFilter + // maxMessageSize is the maximum message size; it applies globally to all // topics. maxMessageSize int @@ -235,6 +237,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option ctx: ctx, rt: rt, val: newValidation(), + peerFilter: DefaultPeerFilter, disc: &discover{}, maxMessageSize: DefaultMaxMessageSize, peerOutboundQueueSize: 32, @@ -332,6 +335,21 @@ func WithMessageIdFn(fn MsgIdFunction) Option { } } +// PeerFilter is used to filter pubsub peers. It should return true for peers that are accepted for +// a given topic. PubSub can be customized to use any implementation of this function by configuring +// it with the Option from WithPeerFilter. +type PeerFilter func(pid peer.ID, topic string) bool + +// WithPeerFilter is an option to set a filter for pubsub peers. +// The default peer filter is DefaultPeerFilter (which always returns true), but it can be customized +// to any custom implementation. +func WithPeerFilter(filter PeerFilter) Option { + return func(p *PubSub) error { + p.peerFilter = filter + return nil + } +} + // WithPeerOutboundQueueSize is an option to set the buffer size for outbound messages to a peer // We start dropping messages to a peer if the outbound queue if full func WithPeerOutboundQueueSize(size int) Option { @@ -983,6 +1001,10 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { for _, subopt := range subs { t := subopt.GetTopicid() + if !p.peerFilter(rpc.from, t) { + continue + } + if subopt.GetSubscribe() { tmap, ok := p.topics[t] if !ok { @@ -1042,6 +1064,11 @@ func DefaultMsgIdFn(pmsg *pb.Message) string { return string(pmsg.GetFrom()) + string(pmsg.GetSeqno()) } +// DefaultPeerFilter accepts all peers on all topics +func DefaultPeerFilter(pid peer.ID, topic string) bool { + return true +} + // pushMsg pushes a message performing validation as necessary func (p *PubSub) pushMsg(msg *Message) { src := msg.ReceivedFrom