Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

fix: send cancel when GetBlocks() is cancelled #383

Merged
merged 5 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitswap.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
pm := bspm.New(ctx, peerQueueFactory, network.Self())
pqm := bspqm.New(ctx, network)

sessionFactory := func(sessctx context.Context, onShutdown bssession.OnShutdown, id uint64, spm bssession.SessionPeerManager,
sessionFactory := func(
sessctx context.Context,
sessmgr bssession.SessionManager,
id uint64,
spm bssession.SessionPeerManager,
sim *bssim.SessionInterestManager,
pm bssession.PeerManager,
bpm *bsbpm.BlockPresenceManager,
notif notifications.PubSub,
provSearchDelay time.Duration,
rebroadcastDelay delay.D,
self peer.ID) bssm.Session {
return bssession.New(ctx, sessctx, onShutdown, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self)
return bssession.New(sessctx, sessmgr, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.SessionPeerManager {
return bsspm.New(id, network.ConnectionManager())
Expand Down
10 changes: 10 additions & 0 deletions internal/blockpresencemanager/blockpresencemanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ func (bpm *BlockPresenceManager) RemoveKeys(ks []cid.Cid) {
delete(bpm.presence, c)
}
}

// HasKey indicates whether the BlockPresenceManager is tracking the given key
// (used by the tests)
func (bpm *BlockPresenceManager) HasKey(c cid.Cid) bool {
bpm.Lock()
defer bpm.Unlock()

_, ok := bpm.presence[c]
return ok
}
44 changes: 16 additions & 28 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ type PeerManager interface {
SendCancels(context.Context, []cid.Cid)
}

// SessionManager manages all the sessions
type SessionManager interface {
// Remove a session (called when the session shuts down)
RemoveSession(sesid uint64)
// Cancel wants (called when a call to GetBlocks() is cancelled)
CancelSessionWants(sid uint64, wants []cid.Cid)
}

// SessionPeerManager keeps track of peers in the session
type SessionPeerManager interface {
// PeersDiscovered indicates if any peers have been discovered yet
Expand Down Expand Up @@ -86,19 +94,15 @@ type op struct {
keys []cid.Cid
}

type OnShutdown func(uint64)

// Session holds state for an individual bitswap transfer operation.
// This allows bitswap to make smarter decisions about who to send wantlist
// info to, and who to request blocks from.
type Session struct {
// dependencies
bsctx context.Context // context for bitswap
ctx context.Context // context for session
ctx context.Context
shutdown func()
onShutdown OnShutdown
sm SessionManager
pm PeerManager
bpm *bsbpm.BlockPresenceManager
sprm SessionPeerManager
providerFinder ProviderFinder
sim *bssim.SessionInterestManager
Expand Down Expand Up @@ -130,9 +134,8 @@ type Session struct {
// New creates a new bitswap session whose lifetime is bounded by the
// given context.
func New(
bsctx context.Context, // context for bitswap
ctx context.Context, // context for this session
onShutdown OnShutdown,
ctx context.Context,
sm SessionManager,
id uint64,
sprm SessionPeerManager,
providerFinder ProviderFinder,
Expand All @@ -148,12 +151,10 @@ func New(
s := &Session{
sw: newSessionWants(broadcastLiveWantsLimit),
tickDelayReqs: make(chan time.Duration),
bsctx: bsctx,
ctx: ctx,
shutdown: cancel,
onShutdown: onShutdown,
sm: sm,
pm: pm,
bpm: bpm,
sprm: sprm,
providerFinder: providerFinder,
sim: sim,
Expand All @@ -167,7 +168,7 @@ func New(
periodicSearchDelay: periodicSearchDelay,
self: self,
}
s.sws = newSessionWantSender(id, pm, sprm, bpm, s.onWantsSent, s.onPeersExhausted)
s.sws = newSessionWantSender(id, pm, sprm, sm, bpm, s.onWantsSent, s.onPeersExhausted)

go s.run(ctx)

Expand Down Expand Up @@ -308,6 +309,7 @@ func (s *Session) run(ctx context.Context) {
case opCancel:
// Wants were cancelled
s.sw.CancelPending(oper.keys)
s.sws.Cancel(oper.keys)
case opWantsSent:
// Wants were sent to a peer
s.sw.WantsSent(oper.keys)
Expand Down Expand Up @@ -402,23 +404,9 @@ func (s *Session) handleShutdown() {
// Shut down the sessionWantSender (blocks until sessionWantSender stops
// sending)
s.sws.Shutdown()

// Remove session's interest in the given blocks.
cancelKs := s.sim.RemoveSessionInterest(s.id)

// Free up block presence tracking for keys that no session is interested
// in anymore
s.bpm.RemoveKeys(cancelKs)

// Send CANCEL to all peers for blocks that no session is interested in
// anymore.
// Note: use bitswap context because session context has already been
// cancelled.
s.pm.SendCancels(s.bsctx, cancelKs)

// Signal to the SessionManager that the session has been shutdown
// and can be cleaned up
s.onShutdown(s.id)
s.sm.RemoveSession(s.id)
}

// handleReceive is called when the session receives blocks from a peer
Expand Down
112 changes: 57 additions & 55 deletions internal/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,40 @@ import (
peer "github.com/libp2p/go-libp2p-core/peer"
)

type mockSessionMgr struct {
lk sync.Mutex
removeSession bool
cancels []cid.Cid
}

func newMockSessionMgr() *mockSessionMgr {
return &mockSessionMgr{}
}

func (msm *mockSessionMgr) removeSessionCalled() bool {
msm.lk.Lock()
defer msm.lk.Unlock()
return msm.removeSession
}

func (msm *mockSessionMgr) cancelled() []cid.Cid {
msm.lk.Lock()
defer msm.lk.Unlock()
return msm.cancels
}

func (msm *mockSessionMgr) RemoveSession(sesid uint64) {
msm.lk.Lock()
defer msm.lk.Unlock()
msm.removeSession = true
}

func (msm *mockSessionMgr) CancelSessionWants(sid uint64, wants []cid.Cid) {
msm.lk.Lock()
defer msm.lk.Unlock()
msm.cancels = append(msm.cancels, wants...)
}

func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
return bsspm.New(1, newFakePeerTagger())
}
Expand Down Expand Up @@ -61,8 +95,6 @@ type wantReq struct {

type fakePeerManager struct {
wantReqs chan wantReq
lk sync.Mutex
cancels []cid.Cid
}

func newFakePeerManager() *fakePeerManager {
Expand All @@ -82,35 +114,7 @@ func (pm *fakePeerManager) BroadcastWantHaves(ctx context.Context, cids []cid.Ci
case <-ctx.Done():
}
}
func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) {
pm.lk.Lock()
defer pm.lk.Unlock()
pm.cancels = append(pm.cancels, cancels...)
}
func (pm *fakePeerManager) allCancels() []cid.Cid {
pm.lk.Lock()
defer pm.lk.Unlock()
return append([]cid.Cid{}, pm.cancels...)
}

type onShutdownMonitor struct {
lk sync.Mutex
shutdown bool
}

func (sm *onShutdownMonitor) onShutdown(uint64) {
sm.lk.Lock()
defer sm.lk.Unlock()

sm.shutdown = true
}

func (sm *onShutdownMonitor) shutdownCalled() bool {
sm.lk.Lock()
defer sm.lk.Unlock()

return sm.shutdown
}
func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) {}

func TestSessionGetBlocks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
Expand All @@ -122,8 +126,8 @@ func TestSessionGetBlocks(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
var cids []cid.Cid
Expand Down Expand Up @@ -201,9 +205,9 @@ func TestSessionGetBlocks(t *testing.T) {

time.Sleep(10 * time.Millisecond)

// Verify wants were cancelled
if len(fpm.allCancels()) != len(blks) {
t.Fatal("expected cancels to be sent for all wants")
// Verify session was removed
if !sm.removeSessionCalled() {
t.Fatal("expected session to be removed")
}
}

Expand All @@ -218,8 +222,8 @@ func TestSessionFindMorePeers(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session.SetBaseTickDelay(200 * time.Microsecond)
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
Expand Down Expand Up @@ -293,8 +297,8 @@ func TestSessionOnPeersExhausted(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit + 5)
var cids []cid.Cid
Expand Down Expand Up @@ -338,8 +342,8 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(4)
var cids []cid.Cid
Expand Down Expand Up @@ -451,12 +455,11 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()

osm := &onShutdownMonitor{}
sm := newMockSessionMgr()

// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")

timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer timerCancel()
Expand Down Expand Up @@ -487,8 +490,8 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {

time.Sleep(10 * time.Millisecond)

// Expect onShutdown to be called
if !osm.shutdownCalled() {
// Expect RemoveSession to be called
if !sm.removeSessionCalled() {
t.Fatal("expected onShutdown to be called")
}
}
Expand All @@ -502,27 +505,26 @@ func TestSessionOnShutdownCalled(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()

osm := &onShutdownMonitor{}
sm := newMockSessionMgr()

// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer sesscancel()
session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")

// Shutdown the session
session.Shutdown()

time.Sleep(10 * time.Millisecond)

// Expect onShutdown to be called
if !osm.shutdownCalled() {
// Expect RemoveSession to be called
if !sm.removeSessionCalled() {
t.Fatal("expected onShutdown to be called")
}
}

func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), 10*time.Millisecond)
ctx, cancelCtx := context.WithTimeout(context.Background(), 20*time.Millisecond)
fpm := newFakePeerManager()
fspm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
Expand All @@ -532,8 +534,8 @@ func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(2)
cids := []cid.Cid{blks[0].Cid(), blks[1].Cid()}
Expand Down
Loading