Skip to content

Commit

Permalink
honour context cancellation on acquire.
Browse files Browse the repository at this point in the history
  • Loading branch information
raulk committed Aug 5, 2021
1 parent 2ef26aa commit d41a6e0
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 25 deletions.
9 changes: 5 additions & 4 deletions dagstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ import (
"os"
"sync"

"github.com/filecoin-project/dagstore/index"
"github.com/filecoin-project/dagstore/mount"
"github.com/filecoin-project/dagstore/shard"
"github.com/filecoin-project/dagstore/throttle"
ds "github.com/ipfs/go-datastore"
"github.com/ipfs/go-datastore/namespace"
"github.com/ipfs/go-datastore/query"
dssync "github.com/ipfs/go-datastore/sync"
logging "github.com/ipfs/go-log/v2"

"github.com/filecoin-project/dagstore/index"
"github.com/filecoin-project/dagstore/mount"
"github.com/filecoin-project/dagstore/shard"
"github.com/filecoin-project/dagstore/throttle"
)

var (
Expand Down
45 changes: 42 additions & 3 deletions dagstore_async.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package dagstore
import (
"context"

"github.com/filecoin-project/dagstore/mount"
"github.com/ipld/go-car/v2"
"github.com/ipld/go-car/v2/index"

"github.com/filecoin-project/dagstore/mount"
)

//
Expand All @@ -19,6 +20,19 @@ func (d *DAGStore) acquireAsync(ctx context.Context, w *waiter, s *Shard, mnt mo
k := s.key

reader, err := mnt.Fetch(ctx)

if err := ctx.Err(); err != nil {
log.Warnw("context cancelled while fetching shard; releasing", "shard", s.key, "error", err)

// release the shard to decrement the refcount that's incremented before `acquireAsync` is called.
_ = d.queueTask(&task{op: OpShardRelease, shard: s}, d.completionCh)

// send the shard error to the caller for correctness
// since the context is cancelled, the result will be discarded.
d.dispatchResult(&ShardResult{Key: k, Error: err}, w)
return
}

if err != nil {
log.Warnw("acquire: failed to fetch from mount upgrader", "shard", s.key, "error", err)

Expand All @@ -35,7 +49,21 @@ func (d *DAGStore) acquireAsync(ctx context.Context, w *waiter, s *Shard, mnt mo

log.Debugw("acquire: successfully fetched from mount upgrader", "shard", s.key)

// acquire the index.
idx, err := d.indices.GetFullIndex(k)

if err := ctx.Err(); err != nil {
log.Warnw("context cancelled while indexing shard; releasing", "shard", s.key, "error", err)

// release the shard to decrement the refcount that's incremented before `acquireAsync` is called.
_ = d.queueTask(&task{op: OpShardRelease, shard: s}, d.completionCh)

// send the shard error to the caller for correctness
// since the context is cancelled, the result will be discarded.
d.dispatchResult(&ShardResult{Key: k, Error: err}, w)
return
}

if err != nil {
log.Warnw("acquire: failed to get index for shard", "shard", s.key, "error", err)
if err := reader.Close(); err != nil {
Expand All @@ -54,9 +82,20 @@ func (d *DAGStore) acquireAsync(ctx context.Context, w *waiter, s *Shard, mnt mo
}

log.Debugw("acquire: successful; returning accessor", "shard", s.key)

// build the accessor.
sa, err := NewShardAccessor(reader, idx, s)

// send the shard accessor to the caller.
// send the shard accessor to the caller, adding a notifyDead function that
// will be called to release the shard if we were unable to deliver
// the accessor.
w.notifyDead = func() {
log.Warnw("context cancelled while delivering accessor; releasing", "shard", s.key)

// release the shard to decrement the refcount that's incremented before `acquireAsync` is called.
_ = d.queueTask(&task{op: OpShardRelease, shard: s}, d.completionCh)
}

d.dispatchResult(&ShardResult{Key: k, Accessor: sa, Error: err}, w)
}

Expand All @@ -70,9 +109,9 @@ func (d *DAGStore) initializeShard(ctx context.Context, s *Shard, mnt mount.Moun
_ = d.failShard(s, d.completionCh, "failed to acquire reader of mount on initialization: %w", err)
return
}
defer reader.Close()

log.Debugw("initialize: successfully fetched from mount upgrader", "shard", s.key)
defer reader.Close()

// works for both CARv1 and CARv2.
var idx index.Index
Expand Down
80 changes: 76 additions & 4 deletions dagstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ import (
"testing"
"time"

"github.com/filecoin-project/dagstore/index"
"github.com/filecoin-project/dagstore/mount"
"github.com/filecoin-project/dagstore/shard"
"github.com/filecoin-project/dagstore/testdata"
"github.com/ipfs/go-datastore"
dsq "github.com/ipfs/go-datastore/query"
dssync "github.com/ipfs/go-datastore/sync"
logging "github.com/ipfs/go-log/v2"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/filecoin-project/dagstore/index"
"github.com/filecoin-project/dagstore/mount"
"github.com/filecoin-project/dagstore/shard"
"github.com/filecoin-project/dagstore/testdata"
)

var (
Expand Down Expand Up @@ -1251,6 +1252,77 @@ func TestBlockCallback(t *testing.T) {
t.Skip("TODO")
}

func TestWaiterContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan ShardResult)
notifyCh := make(chan struct{})
w := &waiter{ctx: ctx, outCh: ch, notifyDead: func() { close(notifyCh) }}
cancel()
w.deliver(&ShardResult{})
_, open := <-notifyCh
require.False(t, open)
}

func TestAcquireContextCancelled(t *testing.T) {
r := testRegistry(t)
err := r.Register("block", newBlockingMount(&mount.FSMount{FS: testdata.FS}))
require.NoError(t, err)

dagst, err := NewDAGStore(Config{
MountRegistry: r,
TransientsDir: t.TempDir(),
})
require.NoError(t, err)

err = dagst.Start(context.Background())
require.NoError(t, err)

ch := make(chan ShardResult)
k := shard.KeyFromString("foo")
block := newBlockingMount(carv2mnt)
err = dagst.RegisterShard(context.Background(), k, block, ch, RegisterOpts{LazyInitialization: true})
require.NoError(t, err)
res := <-ch
require.NoError(t, res.Error)

ctx, cancel := context.WithCancel(context.Background())
cancel() // start with a cancelled context
err = dagst.AcquireShard(ctx, k, ch, AcquireOpts{})
require.NoError(t, err)

time.Sleep(1 * time.Second)

select {
case res := <-ch:
t.Fatalf("expected no ShardResult, got: %+v", res)
case <-time.After(1 * time.Second):
}

ctx, cancel = context.WithCancel(context.Background())
err = dagst.AcquireShard(ctx, k, ch, AcquireOpts{})
require.NoError(t, err)
block.UnblockNext(1)
cancel() // cancel immediately after unblocking.

time.Sleep(1 * time.Second)

select {
case res := <-ch:
t.Fatalf("expected no ShardResult, got: %+v", res)
case <-time.After(1 * time.Second):
}

// event loop continues to operate.
err = dagst.AcquireShard(context.Background(), k, ch, AcquireOpts{})
require.NoError(t, err)
res = <-ch
require.NoError(t, res.Error)
require.NotNil(t, res.Accessor)
err = res.Accessor.Close()
require.NoError(t, err)

}

// registerShards registers n shards concurrently, using the CARv2 mount.
func registerShards(t *testing.T, dagst *DAGStore, n int, mnt mount.Mount, opts RegisterOpts) (ret []shard.Key) {
grp, _ := errgroup.WithContext(context.Background())
Expand Down
20 changes: 10 additions & 10 deletions mount/upgrader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,18 @@ func TestUpgraderDeduplicatesRemote(t *testing.T) {
}

func TestUpgraderFetchAndCopyThrottle(t *testing.T) {
nFixedThrottle := 3
nFixedThrottle := 3

tcs := map[string]struct{
ready bool
tcs := map[string]struct {
ready bool
expectedThrottledReads int
}{
"no throttling when mount is not ready":{
ready: false,
"no throttling when mount is not ready": {
ready: false,
expectedThrottledReads: 100,
},
"throttle when mount is ready":{
ready: true,
"throttle when mount is ready": {
ready: true,
expectedThrottledReads: nFixedThrottle,
},
}
Expand All @@ -247,7 +247,7 @@ func TestUpgraderFetchAndCopyThrottle(t *testing.T) {

underlyings := make([]*blockingReaderMount, 100)
for i := range upgraders {
underlyings[i] = &blockingReaderMount{isReady:tc.ready, br: &blockingReader{r: io.LimitReader(rand2.Reader, 1)}}
underlyings[i] = &blockingReaderMount{isReady: tc.ready, br: &blockingReader{r: io.LimitReader(rand2.Reader, 1)}}
u, err := Upgrade(underlyings[i], thrt, t.TempDir(), "foo", "")
require.NoError(t, err)
upgraders[i] = u
Expand Down Expand Up @@ -326,7 +326,7 @@ func (br *blockingReader) Read(b []byte) (n int, err error) {

type blockingReaderMount struct {
isReady bool
br *blockingReader
br *blockingReader
}

var _ Mount = (*blockingReaderMount)(nil)
Expand All @@ -350,7 +350,7 @@ func (b *blockingReaderMount) Stat(ctx context.Context) (Stat, error) {
return Stat{
Exists: true,
Size: 1024,
Ready: b.isReady,
Ready: b.isReady,
}, nil
}

Expand Down
8 changes: 6 additions & 2 deletions shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (
// waiter encapsulates a context passed by the user, and the channel they want
// the result returned to.
type waiter struct {
ctx context.Context // governs the op if it's external
outCh chan<- ShardResult // to send back the result
ctx context.Context // governs the op if it's external
outCh chan<- ShardResult // to send back the result
notifyDead func() // called when the context expired and we weren't able to deliver the result
}

func (w waiter) deliver(res *ShardResult) {
Expand All @@ -22,6 +23,9 @@ func (w waiter) deliver(res *ShardResult) {
select {
case w.outCh <- *res:
case <-w.ctx.Done():
if w.notifyDead != nil {
w.notifyDead()
}
}
}

Expand Down
10 changes: 8 additions & 2 deletions shard_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ const (
)

func (ss ShardState) String() string {
return [...]string{
strs := [...]string{
ShardStateNew: "ShardStateNew",
ShardStateInitializing: "ShardStateInitializing",
ShardStateAvailable: "ShardStateAvailable",
ShardStateServing: "ShardStateServing",
ShardStateRecovering: "ShardStateRecovering",
ShardStateErrored: "ShardStateErrored",
ShardStateUnknown: "ShardStateUnknown"}[ss]
ShardStateUnknown: "ShardStateUnknown",
}
if ss < 0 || int(ss) >= len(strs) {
// safety comes first.
return "__undefined__"
}
return strs[ss]
}

0 comments on commit d41a6e0

Please sign in to comment.