diff --git a/dht.go b/dht.go index d304c3a1b02..a541cacc6d4 100644 --- a/dht.go +++ b/dht.go @@ -34,9 +34,6 @@ import ( var log = logging.Logger("dht") -var ProtocolDHT protocol.ID = "/ipfs/kad/1.0.0" -var ProtocolDHTOld protocol.ID = "/ipfs/dht" - // NumBootstrapQueries defines the number of random dht queries to do to // collect members of the routing table. const NumBootstrapQueries = 5 @@ -64,6 +61,8 @@ type IpfsDHT struct { smlk sync.Mutex plk sync.Mutex + + protocols []protocol.ID // DHT protocols } // New creates a new DHT with the specified host and options. @@ -72,7 +71,7 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er if err := cfg.Apply(append([]opts.Option{opts.Defaults}, options...)...); err != nil { return nil, err } - dht := makeDHT(ctx, h, cfg.Datastore) + dht := makeDHT(ctx, h, cfg.Datastore, cfg.Protocols) // register for network notifs. dht.host.Network().Notify((*netNotifiee)(dht)) @@ -87,8 +86,9 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er dht.Validator = cfg.Validator if !cfg.Client { - h.SetStreamHandler(ProtocolDHT, dht.handleNewStream) - h.SetStreamHandler(ProtocolDHTOld, dht.handleNewStream) + for _, p := range cfg.Protocols { + h.SetStreamHandler(p, dht.handleNewStream) + } } return dht, nil } @@ -116,7 +116,7 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT return dht } -func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT { +func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []protocol.ID) *IpfsDHT { rt := kb.NewRoutingTable(KValue, kb.ConvertPeerID(h.ID()), time.Minute, h.Peerstore()) cmgr := h.ConnManager() @@ -137,6 +137,7 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT { providers: providers.NewProviderManager(ctx, h.ID(), dstore), birth: time.Now(), routingTable: rt, + protocols: protocols, } } @@ -389,6 +390,15 @@ func (dht *IpfsDHT) Close() error { return dht.proc.Close() } +func (dht *IpfsDHT) protocolStrs() []string { + pstrs := make([]string, len(dht.protocols)) + for idx, proto := range dht.protocols { + pstrs[idx] = string(proto) + } + + return pstrs +} + func mkDsKey(s string) ds.Key { return ds.NewKey(base32.RawStdEncoding.EncodeToString([]byte(s))) } diff --git a/dht_net.go b/dht_net.go index 8513db3cfde..e596c548145 100644 --- a/dht_net.go +++ b/dht_net.go @@ -190,7 +190,7 @@ func (ms *messageSender) prep() error { return nil } - nstr, err := ms.dht.host.NewStream(ms.dht.ctx, ms.p, ProtocolDHT, ProtocolDHTOld) + nstr, err := ms.dht.host.NewStream(ms.dht.ctx, ms.p, ms.dht.protocols...) if err != nil { return err } diff --git a/dht_test.go b/dht_test.go index 3d22920fd0e..ac845494582 100644 --- a/dht_test.go +++ b/dht_test.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand" "sort" + "strings" "sync" "testing" "time" @@ -1075,3 +1076,83 @@ func TestFindClosestPeers(t *testing.T) { t.Fatalf("got wrong number of peers (got %d, expected %d)", len(out), KValue) } } + +func TestGetSetPluggedProtocol(t *testing.T) { + t.Run("PutValue/GetValue - same protocol", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + os := []opts.Option{ + opts.Protocols("/esh/dht"), + opts.Client(false), + opts.NamespacedValidator("v", blankValidator{}), + } + + dhtA, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), os...) + if err != nil { + t.Fatal(err) + } + + dhtB, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), os...) + if err != nil { + t.Fatal(err) + } + + connect(t, ctx, dhtA, dhtB) + + ctxT, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := dhtA.PutValue(ctxT, "/v/cat", []byte("meow")); err != nil { + t.Fatal(err) + } + + value, err := dhtB.GetValue(ctxT, "/v/cat") + if err != nil { + t.Fatal(err) + } + + if string(value) != "meow" { + t.Fatalf("Expected 'meow' got '%s'", string(value)) + } + }) + + t.Run("DHT routing table for peer A won't contain B if A and B don't use same protocol", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + dhtA, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), []opts.Option{ + opts.Protocols("/esh/dht"), + opts.Client(false), + opts.NamespacedValidator("v", blankValidator{}), + }...) + if err != nil { + t.Fatal(err) + } + + dhtB, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), []opts.Option{ + opts.Protocols("/lsr/dht"), + opts.Client(false), + opts.NamespacedValidator("v", blankValidator{}), + }...) + if err != nil { + t.Fatal(err) + } + + connectNoSync(t, ctx, dhtA, dhtB) + + // We don't expect connection notifications for A to reach B (or vice-versa), given + // that they've been configured with different protocols - but we'll give them a + // chance, anyhow. + time.Sleep(time.Second * 2) + + err = dhtA.PutValue(ctx, "/v/cat", []byte("meow")) + if err == nil || !strings.Contains(err.Error(), "failed to find any peer in table") { + t.Fatal("should not have been able to find any peers in routing table") + } + + _, err = dhtB.GetValue(ctx, "/v/cat") + if err == nil || !strings.Contains(err.Error(), "failed to find any peer in table") { + t.Fatal("should not have been able to find any peers in routing table") + } + }) +} diff --git a/ext_test.go b/ext_test.go index 90e48187a95..33f23d6afd9 100644 --- a/ext_test.go +++ b/ext_test.go @@ -36,7 +36,7 @@ func TestGetFailures(t *testing.T) { d.Update(ctx, hosts[1].ID()) // Reply with failures to every message - hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { + hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) { s.Close() }) @@ -58,7 +58,7 @@ func TestGetFailures(t *testing.T) { t.Log("Timeout test passed.") // Reply with failures to every message - hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { + hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) @@ -110,7 +110,7 @@ func TestGetFailures(t *testing.T) { Record: rec, } - s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), ProtocolDHT) + s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), d.protocols[0]) if err != nil { t.Fatal(err) } @@ -160,7 +160,7 @@ func TestNotFound(t *testing.T) { // Reply with random peers to every message for _, host := range hosts { host := host // shadow loop var - host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) { + host.SetStreamHandler(d.protocols[0], func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) @@ -239,7 +239,7 @@ func TestLessThanKResponses(t *testing.T) { // Reply with random peers to every message for _, host := range hosts { host := host // shadow loop var - host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) { + host.SetStreamHandler(d.protocols[0], func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) @@ -305,7 +305,7 @@ func TestMultipleQueries(t *testing.T) { // It would be nice to be able to just get a value and succeed but then // we'd need to deal with selectors and validators... - hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) { + hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) diff --git a/notif.go b/notif.go index f7fe5af1167..6dd1a82d10b 100644 --- a/notif.go +++ b/notif.go @@ -9,8 +9,6 @@ import ( // netNotifiee defines methods to be used with the IpfsDHT type netNotifiee IpfsDHT -var dhtProtocols = []string{string(ProtocolDHT), string(ProtocolDHTOld)} - func (nn *netNotifiee) DHT() *IpfsDHT { return (*IpfsDHT)(nn) } @@ -24,7 +22,7 @@ func (nn *netNotifiee) Connected(n inet.Network, v inet.Conn) { } p := v.RemotePeer() - protos, err := dht.peerstore.SupportsProtocols(p, dhtProtocols...) + protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) if err == nil && len(protos) != 0 { // We lock here for consistency with the lock in testConnection. // This probably isn't necessary because (dis)connect @@ -57,7 +55,7 @@ func (nn *netNotifiee) testConnection(v inet.Conn) { } defer s.Close() - selected, err := mstream.SelectOneOf(dhtProtocols, s) + selected, err := mstream.SelectOneOf(dht.protocolStrs(), s) if err != nil { // Doesn't support the protocol return diff --git a/opts/options.go b/opts/options.go index aec6a8c8213..164f3932995 100644 --- a/opts/options.go +++ b/opts/options.go @@ -5,14 +5,20 @@ import ( ds "github.com/ipfs/go-datastore" dssync "github.com/ipfs/go-datastore/sync" + "github.com/libp2p/go-libp2p-protocol" record "github.com/libp2p/go-libp2p-record" ) +var ProtocolDHT protocol.ID = "/ipfs/kad/1.0.0" +var ProtocolDHTOld protocol.ID = "/ipfs/dht" +var DefaultProtocols = []protocol.ID{ProtocolDHT, ProtocolDHTOld} + // Options is a structure containing all the options that can be used when constructing a DHT. type Options struct { Datastore ds.Batching Validator record.Validator Client bool + Protocols []protocol.ID } // Apply applies the given options to this Option @@ -35,6 +41,7 @@ var Defaults = func(o *Options) error { "pk": record.PublicKeyValidator{}, } o.Datastore = dssync.MutexWrap(ds.NewMapDatastore()) + o.Protocols = DefaultProtocols return nil } @@ -85,3 +92,13 @@ func NamespacedValidator(ns string, v record.Validator) Option { return nil } } + +// Protocols sets the protocols for the DHT +// +// Defaults to dht.DefaultProtocols +func Protocols(protocols ...protocol.ID) Option { + return func(o *Options) error { + o.Protocols = protocols + return nil + } +}