Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

fix: send cancel when GetBlocks() is cancelled #383

Merged
merged 5 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitswap.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
pm := bspm.New(ctx, peerQueueFactory, network.Self())
pqm := bspqm.New(ctx, network)

sessionFactory := func(sessctx context.Context, onShutdown bssession.OnShutdown, id uint64, spm bssession.SessionPeerManager,
sessionFactory := func(
sessctx context.Context,
sessmgr bssession.SessionManager,
id uint64,
spm bssession.SessionPeerManager,
sim *bssim.SessionInterestManager,
pm bssession.PeerManager,
bpm *bsbpm.BlockPresenceManager,
notif notifications.PubSub,
provSearchDelay time.Duration,
rebroadcastDelay delay.D,
self peer.ID) bssm.Session {
return bssession.New(ctx, sessctx, onShutdown, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self)
return bssession.New(sessctx, sessmgr, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.SessionPeerManager {
return bsspm.New(id, network.ConnectionManager())
Expand Down
10 changes: 10 additions & 0 deletions internal/blockpresencemanager/blockpresencemanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ func (bpm *BlockPresenceManager) RemoveKeys(ks []cid.Cid) {
delete(bpm.presence, c)
}
}

// HasKey indicates whether the BlockPresenceManager is tracking the given key
// (used by the tests)
func (bpm *BlockPresenceManager) HasKey(c cid.Cid) bool {
bpm.Lock()
defer bpm.Unlock()

_, ok := bpm.presence[c]
return ok
}
44 changes: 16 additions & 28 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ type PeerManager interface {
SendCancels(context.Context, []cid.Cid)
}

// SessionManager manages all the sessions
type SessionManager interface {
// Remove a session (called when the session shuts down)
RemoveSession(sesid uint64)
// Cancel wants (called when a call to GetBlocks() is cancelled)
CancelSessionWants(sid uint64, wants []cid.Cid)
}

// SessionPeerManager keeps track of peers in the session
type SessionPeerManager interface {
// PeersDiscovered indicates if any peers have been discovered yet
Expand Down Expand Up @@ -86,19 +94,15 @@ type op struct {
keys []cid.Cid
}

type OnShutdown func(uint64)

// Session holds state for an individual bitswap transfer operation.
// This allows bitswap to make smarter decisions about who to send wantlist
// info to, and who to request blocks from.
type Session struct {
// dependencies
bsctx context.Context // context for bitswap
ctx context.Context // context for session
ctx context.Context
shutdown func()
onShutdown OnShutdown
sm SessionManager
pm PeerManager
bpm *bsbpm.BlockPresenceManager
sprm SessionPeerManager
providerFinder ProviderFinder
sim *bssim.SessionInterestManager
Expand Down Expand Up @@ -130,9 +134,8 @@ type Session struct {
// New creates a new bitswap session whose lifetime is bounded by the
// given context.
func New(
bsctx context.Context, // context for bitswap
ctx context.Context, // context for this session
onShutdown OnShutdown,
ctx context.Context,
sm SessionManager,
id uint64,
sprm SessionPeerManager,
providerFinder ProviderFinder,
Expand All @@ -148,12 +151,10 @@ func New(
s := &Session{
sw: newSessionWants(broadcastLiveWantsLimit),
tickDelayReqs: make(chan time.Duration),
bsctx: bsctx,
ctx: ctx,
shutdown: cancel,
onShutdown: onShutdown,
sm: sm,
pm: pm,
bpm: bpm,
sprm: sprm,
providerFinder: providerFinder,
sim: sim,
Expand All @@ -167,7 +168,7 @@ func New(
periodicSearchDelay: periodicSearchDelay,
self: self,
}
s.sws = newSessionWantSender(id, pm, sprm, bpm, s.onWantsSent, s.onPeersExhausted)
s.sws = newSessionWantSender(id, pm, sprm, sm, bpm, s.onWantsSent, s.onPeersExhausted)

go s.run(ctx)

Expand Down Expand Up @@ -308,6 +309,7 @@ func (s *Session) run(ctx context.Context) {
case opCancel:
// Wants were cancelled
s.sw.CancelPending(oper.keys)
s.sws.Cancel(oper.keys)
case opWantsSent:
// Wants were sent to a peer
s.sw.WantsSent(oper.keys)
Expand Down Expand Up @@ -402,23 +404,9 @@ func (s *Session) handleShutdown() {
// Shut down the sessionWantSender (blocks until sessionWantSender stops
// sending)
s.sws.Shutdown()

// Remove session's interest in the given blocks.
cancelKs := s.sim.RemoveSessionInterest(s.id)

// Free up block presence tracking for keys that no session is interested
// in anymore
s.bpm.RemoveKeys(cancelKs)

// Send CANCEL to all peers for blocks that no session is interested in
// anymore.
// Note: use bitswap context because session context has already been
// cancelled.
s.pm.SendCancels(s.bsctx, cancelKs)

// Signal to the SessionManager that the session has been shutdown
// and can be cleaned up
s.onShutdown(s.id)
s.sm.RemoveSession(s.id)
}

// handleReceive is called when the session receives blocks from a peer
Expand Down
112 changes: 57 additions & 55 deletions internal/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,40 @@ import (
peer "github.com/libp2p/go-libp2p-core/peer"
)

type mockSessionMgr struct {
lk sync.Mutex
removeSession bool
cancels []cid.Cid
}

func newMockSessionMgr() *mockSessionMgr {
return &mockSessionMgr{}
}

func (msm *mockSessionMgr) removeSessionCalled() bool {
msm.lk.Lock()
defer msm.lk.Unlock()
return msm.removeSession
}

func (msm *mockSessionMgr) cancelled() []cid.Cid {
msm.lk.Lock()
defer msm.lk.Unlock()
return msm.cancels
}

func (msm *mockSessionMgr) RemoveSession(sesid uint64) {
msm.lk.Lock()
defer msm.lk.Unlock()
msm.removeSession = true
}

func (msm *mockSessionMgr) CancelSessionWants(sid uint64, wants []cid.Cid) {
msm.lk.Lock()
defer msm.lk.Unlock()
msm.cancels = append(msm.cancels, wants...)
}

func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
return bsspm.New(1, newFakePeerTagger())
}
Expand Down Expand Up @@ -61,8 +95,6 @@ type wantReq struct {

type fakePeerManager struct {
wantReqs chan wantReq
lk sync.Mutex
cancels []cid.Cid
}

func newFakePeerManager() *fakePeerManager {
Expand All @@ -82,35 +114,7 @@ func (pm *fakePeerManager) BroadcastWantHaves(ctx context.Context, cids []cid.Ci
case <-ctx.Done():
}
}
func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) {
pm.lk.Lock()
defer pm.lk.Unlock()
pm.cancels = append(pm.cancels, cancels...)
}
func (pm *fakePeerManager) allCancels() []cid.Cid {
pm.lk.Lock()
defer pm.lk.Unlock()
return append([]cid.Cid{}, pm.cancels...)
}

type onShutdownMonitor struct {
lk sync.Mutex
shutdown bool
}

func (sm *onShutdownMonitor) onShutdown(uint64) {
sm.lk.Lock()
defer sm.lk.Unlock()

sm.shutdown = true
}

func (sm *onShutdownMonitor) shutdownCalled() bool {
sm.lk.Lock()
defer sm.lk.Unlock()

return sm.shutdown
}
func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) {}

func TestSessionGetBlocks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
Expand All @@ -122,8 +126,8 @@ func TestSessionGetBlocks(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
var cids []cid.Cid
Expand Down Expand Up @@ -201,9 +205,9 @@ func TestSessionGetBlocks(t *testing.T) {

time.Sleep(10 * time.Millisecond)

// Verify wants were cancelled
if len(fpm.allCancels()) != len(blks) {
t.Fatal("expected cancels to be sent for all wants")
// Verify session was removed
if !sm.removeSessionCalled() {
t.Fatal("expected session to be removed")
}
}

Expand All @@ -218,8 +222,8 @@ func TestSessionFindMorePeers(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session.SetBaseTickDelay(200 * time.Microsecond)
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
Expand Down Expand Up @@ -293,8 +297,8 @@ func TestSessionOnPeersExhausted(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit + 5)
var cids []cid.Cid
Expand Down Expand Up @@ -338,8 +342,8 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(4)
var cids []cid.Cid
Expand Down Expand Up @@ -451,12 +455,11 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()

osm := &onShutdownMonitor{}
sm := newMockSessionMgr()

// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")

timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer timerCancel()
Expand Down Expand Up @@ -487,8 +490,8 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {

time.Sleep(10 * time.Millisecond)

// Expect onShutdown to be called
if !osm.shutdownCalled() {
// Expect RemoveSession to be called
if !sm.removeSessionCalled() {
t.Fatal("expected onShutdown to be called")
}
}
Expand All @@ -502,27 +505,26 @@ func TestSessionOnShutdownCalled(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()

osm := &onShutdownMonitor{}
sm := newMockSessionMgr()

// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer sesscancel()
session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")

// Shutdown the session
session.Shutdown()

time.Sleep(10 * time.Millisecond)

// Expect onShutdown to be called
if !osm.shutdownCalled() {
// Expect RemoveSession to be called
if !sm.removeSessionCalled() {
t.Fatal("expected onShutdown to be called")
}
}

func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), 10*time.Millisecond)
ctx, cancelCtx := context.WithTimeout(context.Background(), 20*time.Millisecond)
fpm := newFakePeerManager()
fspm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
Expand All @@ -532,8 +534,8 @@ func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) {
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
onShutdown := func(uint64) {}
session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
sm := newMockSessionMgr()
session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(2)
cids := []cid.Cid{blks[0].Cid(), blks[1].Cid()}
Expand Down
Loading