diff --git a/inflight/inflight.go b/inflight/inflight.go index 48988642..7f8b74e1 100644 --- a/inflight/inflight.go +++ b/inflight/inflight.go @@ -16,10 +16,13 @@ package inflight import ( "sync" - - "github.com/scylladb/go-set/u64set" ) +// We track inflights in the map, maps in golang are not shrinking +// Therefore we track how many inflights were deleted and when it reaches the limit +// we forcefully recreate the map to shrink it +const shrinkInflightsLimit = 1000000 + type InFlight interface { AddIfNotPresent(uint64) bool Delete(uint64) @@ -28,13 +31,15 @@ type InFlight interface { // New creates a instance of a simple InFlight set. // It's internal data is protected by a simple sync.RWMutex. func New() InFlight { - return newSyncU64set() + return newSyncU64set(shrinkInflightsLimit) } -func newSyncU64set() *syncU64set { +func newSyncU64set(limit uint64) *syncU64set { return &syncU64set{ - pks: u64set.New(), - mu: &sync.RWMutex{}, + values: make(map[uint64]bool), + limit: limit, + deleted: 0, + lock: sync.RWMutex{}, } } @@ -48,7 +53,7 @@ func NewConcurrent() InFlight { func newShardedSyncU64set() *shardedSyncU64set { s := &shardedSyncU64set{} for i := range s.shards { - s.shards[i] = newSyncU64set() + s.shards[i] = newSyncU64set(shrinkInflightsLimit) } return s } @@ -69,35 +74,71 @@ func (s *shardedSyncU64set) AddIfNotPresent(v uint64) bool { return ss.AddIfNotPresent(v) } -// syncU64set is an InFlight implementation protected by a sync.RWLock type syncU64set struct { - pks *u64set.Set - mu *sync.RWMutex + values map[uint64]bool + deleted uint64 + limit uint64 + lock sync.RWMutex } -func (s *syncU64set) Delete(v uint64) { - s.mu.Lock() - defer s.mu.Unlock() - s.pks.Remove(v) +func (s *syncU64set) AddIfNotPresent(u uint64) bool { + s.lock.RLock() + _, ok := s.values[u] + if ok { + s.lock.RUnlock() + return false + } + s.lock.RUnlock() + s.lock.Lock() + defer s.lock.Unlock() + _, ok = s.values[u] + if ok { + return false + } + s.values[u] = true + return true } -func (s *syncU64set) AddIfNotPresent(v uint64) bool { - s.mu.RLock() - if s.pks.Has(v) { - s.mu.RUnlock() - return false +func (s *syncU64set) Has(u uint64) bool { + s.lock.RLock() + defer s.lock.RUnlock() + _, ok := s.values[u] + return ok +} + +func (s *syncU64set) Delete(u uint64) { + s.lock.Lock() + defer s.lock.Unlock() + _, ok := s.values[u] + if !ok { + return } - s.mu.RUnlock() - return s.addIfNotPresent(v) + delete(s.values, u) + s.addDeleted(1) } -func (s *syncU64set) addIfNotPresent(v uint64) bool { - s.mu.Lock() - defer s.mu.Unlock() - if s.pks.Has(v) { - // double check - return false +func (s *syncU64set) addDeleted(n uint64) { + s.deleted += n + if s.limit != 0 && s.deleted > s.limit { + go s.shrink() } - s.pks.Add(v) - return true +} + +func (s *syncU64set) shrink() { + s.lock.Lock() + defer s.lock.Unlock() + var newValues map[uint64]bool + if uint64(len(s.values)) >= s.deleted { + newValues = make(map[uint64]bool, uint64(len(s.values))-s.deleted) + } else { + newValues = make(map[uint64]bool, 0) + } + + for key, val := range s.values { + if val == true { + newValues[key] = val + } + } + s.values = newValues + s.deleted = 0 } diff --git a/inflight/inflight_test.go b/inflight/inflight_test.go index f73415b4..ce117414 100644 --- a/inflight/inflight_test.go +++ b/inflight/inflight_test.go @@ -23,7 +23,7 @@ import ( func TestAddIfNotPresent(t *testing.T) { t.Parallel() - flight := newSyncU64set() + flight := newSyncU64set(shrinkInflightsLimit) if !flight.AddIfNotPresent(10) { t.Error("could not add the first value") } @@ -34,11 +34,11 @@ func TestAddIfNotPresent(t *testing.T) { func TestDelete(t *testing.T) { t.Parallel() - flight := newSyncU64set() + flight := newSyncU64set(shrinkInflightsLimit) flight.AddIfNotPresent(10) flight.Delete(10) - if flight.pks.Has(10) { + if flight.Has(10) { t.Error("did not delete the value") } } @@ -60,20 +60,20 @@ func TestDeleteSharded(t *testing.T) { flight.AddIfNotPresent(10) flight.Delete(10) - if flight.shards[10%256].pks.Has(10) { + if flight.shards[10%256].Has(10) { t.Error("did not delete the value") } } func TestInflight(t *testing.T) { t.Parallel() - flight := newSyncU64set() + flight := newSyncU64set(shrinkInflightsLimit) f := func(v uint64) interface{} { return flight.AddIfNotPresent(v) } g := func(v uint64) interface{} { flight.Delete(v) - return !flight.pks.Has(v) + return !flight.Has(v) } cfg := createQuickConfig() @@ -90,7 +90,7 @@ func TestInflightSharded(t *testing.T) { } g := func(v uint64) interface{} { flight.Delete(v) - return !flight.shards[v%256].pks.Has(v) + return !flight.shards[v%256].Has(v) } cfg := createQuickConfig()