diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 5f9a9fe40f..17801f12c8 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -87,8 +87,6 @@ type IDService struct { ctx context.Context ctxCancel context.CancelFunc - // ensure we shutdown ONLY once - closeSync sync.Once // track resources that need to be shut down before we shut down refCount sync.WaitGroup @@ -126,25 +124,23 @@ func NewIDService(h host.Host, opts ...Option) (*IDService, error) { userAgent = cfg.userAgent } - hostCtx, cancel := context.WithCancel(context.Background()) s := &IDService{ Host: h, UserAgent: userAgent, - ctx: hostCtx, - ctxCancel: cancel, - conns: make(map[network.Conn]chan struct{}), + conns: make(map[network.Conn]chan struct{}), disableSignedPeerRecord: cfg.disableSignedPeerRecord, addPeerHandlerCh: make(chan addPeerHandlerReq), rmPeerHandlerCh: make(chan rmPeerHandlerReq), } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) // handle local protocol handler updates, and push deltas to peers. var err error - observedAddrs, err := NewObservedAddrManager(hostCtx, h) + observedAddrs, err := NewObservedAddrManager(h) if err != nil { return nil, fmt.Errorf("failed to create observed address manager: %s", err) } @@ -276,10 +272,8 @@ func (ids *IDService) loop() { // Close shuts down the IDService func (ids *IDService) Close() error { - ids.closeSync.Do(func() { - ids.ctxCancel() - ids.refCount.Wait() - }) + ids.ctxCancel() + ids.refCount.Wait() return nil } diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index ef6b118d23..5c241a68e8 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -98,13 +98,19 @@ type newObservation struct { type ObservedAddrManager struct { host host.Host + closeOnce sync.Once + refCount sync.WaitGroup + ctx context.Context // the context is canceled when Close is called + ctxCancel context.CancelFunc + // latest observation from active connections // we'll "re-observe" these when we gc activeConnsMu sync.Mutex // active connection -> most recent observation activeConns map[network.Conn]ma.Multiaddr - mu sync.RWMutex + mu sync.RWMutex + closed bool // local(internal) address -> list of observed(external) addresses addrs map[string][]*observedAddr ttl time.Duration @@ -123,7 +129,7 @@ type ObservedAddrManager struct { // NewObservedAddrManager returns a new address manager using // peerstore.OwnObservedAddressTTL as the TTL. -func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrManager, error) { +func NewObservedAddrManager(host host.Host) (*ObservedAddrManager, error) { oas := &ObservedAddrManager{ addrs: make(map[string][]*observedAddr), ttl: peerstore.OwnObservedAddrTTL, @@ -133,6 +139,7 @@ func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrM // refresh every ttl/2 so we don't forget observations from connected peers refreshTimer: time.NewTimer(peerstore.OwnObservedAddrTTL / 2), } + oas.ctx, oas.ctxCancel = context.WithCancel(context.Background()) reachabilitySub, err := host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged)) if err != nil { @@ -147,7 +154,8 @@ func NewObservedAddrManager(ctx context.Context, host host.Host) (*ObservedAddrM oas.emitNATDeviceTypeChanged = emitter oas.host.Network().Notify((*obsAddrNotifiee)(oas)) - go oas.worker(ctx) + oas.refCount.Add(1) + go oas.worker() return oas, nil } @@ -239,22 +247,12 @@ func (oas *ObservedAddrManager) Record(conn network.Conn, observed ma.Multiaddr) } } -func (oas *ObservedAddrManager) teardown() { - oas.host.Network().StopNotify((*obsAddrNotifiee)(oas)) - oas.reachabilitySub.Close() - - oas.mu.Lock() - oas.refreshTimer.Stop() - oas.mu.Unlock() -} - -func (oas *ObservedAddrManager) worker(ctx context.Context) { - defer oas.teardown() +func (oas *ObservedAddrManager) worker() { + defer oas.refCount.Done() ticker := time.NewTicker(GCInterval) defer ticker.Stop() - hostClosing := oas.host.Network().Process().Closing() subChan := oas.reachabilitySub.Out() for { select { @@ -265,17 +263,13 @@ func (oas *ObservedAddrManager) worker(ctx context.Context) { } ev := evt.(event.EvtLocalReachabilityChanged) oas.reachability = ev.Reachability - case obs := <-oas.wch: oas.maybeRecordObservation(obs.conn, obs.observed) - case <-ticker.C: oas.gc() case <-oas.refreshTimer.C: oas.refresh() - case <-hostClosing: - return - case <-ctx.Done(): + case <-oas.ctx.Done(): return } } @@ -534,6 +528,22 @@ func (oas *ObservedAddrManager) emitSpecificNATType(addrs []*observedAddr, proto return false, 0 } +func (oas *ObservedAddrManager) Close() error { + oas.closeOnce.Do(func() { + oas.ctxCancel() + + oas.mu.Lock() + oas.closed = true + oas.refreshTimer.Stop() + oas.mu.Unlock() + + oas.refCount.Wait() + oas.reachabilitySub.Close() + oas.host.Network().StopNotify((*obsAddrNotifiee)(oas)) + }) + return nil +} + // observerGroup is a function that determines what part of // a multiaddr counts as a different observer. for example, // two ipfs nodes at the same IP/TCP transport would get @@ -554,6 +564,9 @@ func observerGroup(m ma.Multiaddr) string { func (oas *ObservedAddrManager) SetTTL(ttl time.Duration) { oas.mu.Lock() defer oas.mu.Unlock() + if oas.closed { + return + } oas.ttl = ttl // refresh every ttl/2 so we don't forget observations from connected peers oas.refreshTimer.Reset(ttl / 2) diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index 4b3d0aca7f..2ce0400923 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -85,18 +85,11 @@ func (h *harness) observeInbound(observed ma.Multiaddr, observer peer.ID) networ func newHarness(ctx context.Context, t *testing.T) harness { mn := mocknet.New(ctx) sk, err := p2putil.RandTestBogusPrivateKey() - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) h, err := mn.AddPeer(sk, ma.StringCast("/ip4/127.0.0.1/tcp/10086")) - if err != nil { - t.Fatal(err) - } - - oas, err := identify.NewObservedAddrManager(ctx, h) require.NoError(t, err) - + oas, err := identify.NewObservedAddrManager(h) + require.NoError(t, err) return harness{ oas: oas, mocknet: mn, @@ -142,6 +135,7 @@ func TestObsAddrSet(t *testing.T) { defer cancel() harness := newHarness(ctx, t) + defer harness.oas.Close() if !addrsMatch(harness.oas.Addrs(), nil) { t.Error("addrs should be empty") @@ -243,6 +237,7 @@ func TestObservedAddrFiltering(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() harness := newHarness(ctx, t) + defer harness.oas.Close() require.Empty(t, harness.oas.Addrs()) // IP4/TCP @@ -344,6 +339,7 @@ func TestEmitNATDeviceTypeSymmetric(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() harness := newHarness(ctx, t) + defer harness.oas.Close() require.Empty(t, harness.oas.Addrs()) emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) require.NoError(t, err) @@ -390,6 +386,7 @@ func TestEmitNATDeviceTypeCone(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() harness := newHarness(ctx, t) + defer harness.oas.Close() require.Empty(t, harness.oas.Addrs()) emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) require.NoError(t, err)