diff --git a/djherbis-stream.iml b/djherbis-stream.iml new file mode 100644 index 0000000..49df094 --- /dev/null +++ b/djherbis-stream.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/reader.go b/reader.go index 348030b..632f3e5 100644 --- a/reader.go +++ b/reader.go @@ -61,10 +61,12 @@ func (r *Reader) read(p []byte, off *int64) (n int, err error) { } func (r *Reader) checkErr(err error) error { - switch err { - case ErrCanceled: + switch { + case err == ErrCanceled, + err != nil && err == r.s.b.err: r.Close() } + return err } diff --git a/stream.go b/stream.go index 224b93c..2758286 100644 --- a/stream.go +++ b/stream.go @@ -125,6 +125,13 @@ func (s *Stream) Cancel() error { return s.Close() // all writes are stopped } +// CancelWithErr works like Stream.Cancel, but permits a custom error +// to be returned. +func (s *Stream) CancelWithErr(err error) error { + s.b.CancelWithErr(err) // all existing reads are canceled, no new reads will occur, all readers closed + return s.Close() // all writes are stopped +} + // NextReader will return a concurrent-safe Reader for this stream. Each Reader will // see a complete and independent view of the stream, and can Read while the stream // is written to. diff --git a/stream_test.go b/stream_test.go index aeb1306..b81a5a4 100644 --- a/stream_test.go +++ b/stream_test.go @@ -321,6 +321,67 @@ func testCancelBeforeClose(t *testing.T, fs FileSystem) { cleanup(f, t) } +func TestCancelWithErrBeforeClose(t *testing.T) { + for _, fs := range GetFilesystems() { + testCancelWithErrBeforeClose(t, fs) + } +} + +func testCancelWithErrBeforeClose(t *testing.T, fs FileSystem) { + wantErr := errors.New("oh dear") + f, err := NewStream(t.Name()+".txt", fs) + if err != nil { + t.Error(err) + t.FailNow() + } + f.Write([]byte("Hello")) + r, err := f.NextReader() // blocking reader + if err != nil { + t.Error("error creating new reader: ", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, err := ioutil.ReadAll(r) + if err != wantErr { + t.Error("Read after cancel should return an error") + } + wg.Done() + }() + <-time.After(50 * time.Millisecond) // give Reader time to block, this tests it unblocks + + // When canceling writer, reader is closed, so writer unblocks and test passes + f.CancelWithErr(wantErr) + // Double cancel should not affect the outcome + f.CancelWithErr(wantErr) + // Close after cancel should not affect the outcome + f.Close() + + // ReadAt after cancel + _, err = ioutil.ReadAll(io.NewSectionReader(r, 0, 1)) + if err != wantErr { + t.Error("ReadAt after cancel should return an error") + } + + // NextReader should fail as well + _, err = f.NextReader() + if err != wantErr { + t.Error("NextReader should be canceled, but got: ", err) + } + + n, err := f.Write([]byte("world")) + // Writer is closed as well + if err == nil { + t.Error("expected write after canceling to fail") + } + if n != 0 { + t.Error("expected write after canceling to not write anything") + } + wg.Wait() + cleanup(f, t) +} + func TestCancelAfterClose(t *testing.T) { for _, fs := range GetFilesystems() { testCancelAfterClose(t, fs) @@ -369,6 +430,55 @@ func testCancelAfterClose(t *testing.T, fs FileSystem) { cleanup(f, t) } +func TestCancelWithErrAfterClose(t *testing.T) { + for _, fs := range GetFilesystems() { + testCancelWithErrAfterClose(t, fs) + } +} + +func testCancelWithErrAfterClose(t *testing.T, fs FileSystem) { + wantErr := errors.New("oh dear") + f, err := NewStream(t.Name()+".txt", fs) + if err != nil { + t.Error(err) + t.FailNow() + } + + r, _ := f.NextReader() + + wg := sync.WaitGroup{} + + wg.Add(2) + + f.Write([]byte("Hello")) + f.Close() + + // This unblocks and cancels any future reads + f.CancelWithErr(wantErr) + + go func() { + time.Sleep(50 * time.Millisecond) + _, err := f.NextReader() + if err != wantErr { + t.Error("Opening new reader after canceling should fail") + } + wg.Done() + }() + + go func() { + time.Sleep(50 * time.Millisecond) + _, err := ioutil.ReadAll(r) + if err != wantErr { + t.Error("If canceling after closing, already opened readers should finish") + } + wg.Done() + }() + + wg.Wait() + + cleanup(f, t) +} + func TestShutdownAfterClose(t *testing.T) { for _, fs := range GetFilesystems() { testShutdownAfterClose(t, fs) diff --git a/sync.go b/sync.go index ef0c8d4..f480e90 100644 --- a/sync.go +++ b/sync.go @@ -22,13 +22,13 @@ const ( ) type broadcaster struct { - mu sync.RWMutex - cond *sync.Cond - state streamState - size int64 - newHandleErr error - rs *readerSet - fileInUse sync.WaitGroup + mu sync.RWMutex + cond *sync.Cond + state streamState + size int64 + err error + rs *readerSet + fileInUse sync.WaitGroup } func newBroadcaster() *broadcaster { @@ -50,6 +50,9 @@ func (b *broadcaster) Wait(r *Reader, off int64) error { switch b.state { case canceledState: + if b.err != nil { + return b.err + } return ErrCanceled case closedState: @@ -97,6 +100,20 @@ func (b *broadcaster) Cancel() (err error) { return nil } +func (b *broadcaster) CancelWithErr(cancelErr error) (err error) { + b.mu.Lock() + b.setState(canceledState) + b.preventNewHandles(cancelErr) + readersToClose := b.rs.dropAll() + b.mu.Unlock() + + for _, r := range readersToClose { + r.Close() + } + + return nil +} + func (b *broadcaster) PreventNewHandles(err error) { b.mu.Lock() b.preventNewHandles(err) @@ -104,8 +121,8 @@ func (b *broadcaster) PreventNewHandles(err error) { } func (b *broadcaster) preventNewHandles(err error) { - if b.newHandleErr == nil { - b.newHandleErr = err + if b.err == nil { + b.err = err } } @@ -118,7 +135,11 @@ func (b *broadcaster) UseHandle(do func() (int, error)) (int, error) { switch b.state { case canceledState: b.mu.RUnlock() - return 0, ErrCanceled + err := b.err + if err == nil { + err = ErrCanceled + } + return 0, err } b.mu.RUnlock() @@ -149,8 +170,8 @@ func (b *broadcaster) Size() (size int64, isClosed bool) { func (b *broadcaster) addHandle() error { b.mu.RLock() defer b.mu.RUnlock() - if b.newHandleErr != nil { - return b.newHandleErr + if b.err != nil { + return b.err } b.fileInUse.Add(1)