Skip to content

Commit

Permalink
fix: thread safety for Manual functions (#1)
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Zantow <[email protected]>
  • Loading branch information
kzantow authored Mar 1, 2023
1 parent 4b1c25a commit 21920a4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 32 deletions.
68 changes: 52 additions & 16 deletions manual.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,71 @@
package progress

import (
"sync"
"sync/atomic"
)

type Manual struct {
N int64
Total int64
Err error
n int64
total int64
err error
errMutex sync.Mutex
}

func NewManual(size int64) *Manual {
return &Manual{
total: size,
}
}

func (p Manual) Current() int64 {
return int64(p.N)
func (p *Manual) Current() int64 {
return atomic.LoadInt64(&p.n)
}

func (p Manual) Size() int64 {
return int64(p.Total)
func (p *Manual) Size() int64 {
return atomic.LoadInt64(&p.total)
}

func (p Manual) Error() error {
return p.Err
func (p *Manual) Error() error {
p.errMutex.Lock()
defer p.errMutex.Unlock()
return p.err
}

func (p Manual) Progress() Progress {
func (p *Manual) SetError(err error) {
p.errMutex.Lock()
defer p.errMutex.Unlock()
p.err = err
}

func (p *Manual) Progress() Progress {
return Progress{
current: p.N,
size: p.Total,
err: p.Err,
current: p.Current(),
size: p.Size(),
err: p.Error(),
}
}

func (p *Manual) Add(n int64) {
atomic.AddInt64(&p.n, n)
}

func (p *Manual) Increment() {
atomic.AddInt64(&p.n, 1)
}

func (p *Manual) Set(n int64) {
atomic.StoreInt64(&p.n, n)
}

func (p *Manual) SetTotal(total int64) {
atomic.StoreInt64(&p.total, total)
}

func (p *Manual) SetCompleted() {
p.Err = ErrCompleted
if p.N > 0 && p.Total <= 0 {
p.Total = p.N
p.SetError(ErrCompleted)
if p.Current() > 0 && p.Size() <= 0 {
p.SetTotal(p.Current())
return
}
}
28 changes: 12 additions & 16 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,27 @@ import (

// Reader should wrap another reader (acts as a bytes pass through)
type Reader struct {
reader io.Reader
reader io.Reader
monitor *Manual
}

func NewSizedReader(reader io.Reader, size int64) *Reader {
return &Reader{
reader: reader,
monitor: &Manual{
Total: size,
},
reader: reader,
monitor: NewManual(size),
}
}

func NewReader(reader io.Reader) *Reader {
return &Reader{
reader: reader,
monitor: &Manual{
Total: -1,
},
reader: reader,
monitor: NewManual(-1),
}
}

func NewProxyReader(reader io.Reader, monitor *Manual) *Reader {
return &Reader{
reader: reader,
reader: reader,
monitor: monitor,
}
}
Expand All @@ -42,26 +38,26 @@ func (r *Reader) SetReader(reader io.Reader) {
}

func (r *Reader) SetCompleted() {
r.monitor.Err = multierror.Append(r.monitor.Err, ErrCompleted)
r.monitor.SetError(multierror.Append(r.monitor.Error(), ErrCompleted))
}

func (r *Reader) Read(p []byte) (n int, err error) {
bytes, err := r.reader.Read(p)
r.monitor.N += int64(bytes)
r.monitor.Add(int64(bytes))
if err != nil {
r.monitor.Err = multierror.Append(r.monitor.Err, err)
r.monitor.SetError(multierror.Append(r.monitor.Error(), err))
}
return bytes, err
}

func (r *Reader) Current() int64 {
return r.monitor.N
return r.monitor.Current()
}

func (r *Reader) Size() int64 {
return r.monitor.Total
return r.monitor.Size()
}

func (r *Reader) Error() error {
return r.monitor.Err
return r.monitor.Error()
}

0 comments on commit 21920a4

Please sign in to comment.