diff --git a/api/next/51972.txt b/api/next/51972.txt new file mode 100644 index 0000000000000..cab7b3a8a908c --- /dev/null +++ b/api/next/51972.txt @@ -0,0 +1,3 @@ +pkg sync, method (*Map) Swap(interface{}, interface{}) (interface{}, bool) #51972 +pkg sync, method (*Map) CompareAndSwap(interface{}, interface{}, interface{}) bool #51972 +pkg sync, method (*Map) CompareAndDelete(interface{}, interface{}) bool #51972 diff --git a/src/sync/map.go b/src/sync/map.go index fa1cf7cee2bfb..658cef65cfcbc 100644 --- a/src/sync/map.go +++ b/src/sync/map.go @@ -150,47 +150,33 @@ func (e *entry) load() (value any, ok bool) { // Store sets the value for a key. func (m *Map) Store(key, value any) { - read := m.loadReadOnly() - if e, ok := read.m[key]; ok && e.tryStore(&value) { - return - } - - m.mu.Lock() - read = m.loadReadOnly() - if e, ok := read.m[key]; ok { - if e.unexpungeLocked() { - // The entry was previously expunged, which implies that there is a - // non-nil dirty map and this entry is not in it. - m.dirty[key] = e - } - e.storeLocked(&value) - } else if e, ok := m.dirty[key]; ok { - e.storeLocked(&value) - } else { - if !read.amended { - // We're adding the first new key to the dirty map. - // Make sure it is allocated and mark the read-only map as incomplete. - m.dirtyLocked() - m.read.Store(&readOnly{m: read.m, amended: true}) - } - m.dirty[key] = newEntry(value) - } - m.mu.Unlock() + _, _ = m.Swap(key, value) } -// tryStore stores a value if the entry has not been expunged. +// tryCompareAndSwap compare the entry with the given old value and swaps +// it with a new value if the entry is equal to the old value, and the entry +// has not been expunged. // -// If the entry is expunged, tryStore returns false and leaves the entry -// unchanged. -func (e *entry) tryStore(i *any) bool { +// If the entry is expunged, tryCompareAndSwap returns false and leaves +// the entry unchanged. +func (e *entry) tryCompareAndSwap(old, new any) bool { + p := e.p.Load() + if p == nil || p == expunged || *p != old { + return false + } + + // Copy the interface after the first load to make this method more amenable + // to escape analysis: if the comparison fails from the start, we shouldn't + // bother heap-allocating an interface value to store. + nc := new for { - p := e.p.Load() - if p == expunged { - return false - } - if e.p.CompareAndSwap(p, i) { + if e.p.CompareAndSwap(p, &nc) { return true } + p = e.p.Load() + if p == nil || p == expunged || *p != old { + return false + } } } @@ -202,11 +188,11 @@ func (e *entry) unexpungeLocked() (wasExpunged bool) { return e.p.CompareAndSwap(expunged, nil) } -// storeLocked unconditionally stores a value to the entry. +// swapLocked unconditionally swaps a value into the entry. // // The entry must be known not to be expunged. -func (e *entry) storeLocked(i *any) { - e.p.Store(i) +func (e *entry) swapLocked(i *any) *any { + return e.p.Swap(i) } // LoadOrStore returns the existing value for the key if present. @@ -321,6 +307,132 @@ func (e *entry) delete() (value any, ok bool) { } } +// trySwap swaps a value if the entry has not been expunged. +// +// If the entry is expunged, trySwap returns false and leaves the entry +// unchanged. +func (e *entry) trySwap(i *any) (*any, bool) { + for { + p := e.p.Load() + if p == expunged { + return nil, false + } + if e.p.CompareAndSwap(p, i) { + return p, true + } + } +} + +// Swap swaps the value for a key and returns the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map) Swap(key, value any) (previous any, loaded bool) { + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + if v, ok := e.trySwap(&value); ok { + if v == nil { + return nil, false + } + return *v, true + } + } + + m.mu.Lock() + read = m.loadReadOnly() + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + // The entry was previously expunged, which implies that there is a + // non-nil dirty map and this entry is not in it. + m.dirty[key] = e + } + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else if e, ok := m.dirty[key]; ok { + if v := e.swapLocked(&value); v != nil { + loaded = true + previous = *v + } + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(&readOnly{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + } + m.mu.Unlock() + return previous, loaded +} + +// CompareAndSwap swaps the old and new values for key +// if the value stored in the map is equal to old. +// The old value must be of a comparable type. +func (m *Map) CompareAndSwap(key, old, new any) bool { + read := m.loadReadOnly() + if e, ok := read.m[key]; ok { + return e.tryCompareAndSwap(old, new) + } else if !read.amended { + return false // No existing value for key. + } + + m.mu.Lock() + defer m.mu.Unlock() + read = m.loadReadOnly() + swapped := false + if e, ok := read.m[key]; ok { + swapped = e.tryCompareAndSwap(old, new) + } else if e, ok := m.dirty[key]; ok { + swapped = e.tryCompareAndSwap(old, new) + // We needed to lock mu in order to load the entry for key, + // and the operation didn't change the set of keys in the map + // (so it would be made more efficient by promoting the dirty + // map to read-only). + // Count it as a miss so that we will eventually switch to the + // more efficient steady state. + m.missLocked() + } + return swapped +} + +// CompareAndDelete deletes the entry for key if its value is equal to old. +// The old value must be of a comparable type. +// +// If there is no current value for key in the map, CompareAndDelete +// returns false (even if the old value is the nil interface value). +func (m *Map) CompareAndDelete(key, old any) (deleted bool) { + read := m.loadReadOnly() + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read = m.loadReadOnly() + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Don't delete key from m.dirty: we still need to do the “compare” part + // of the operation. The entry will eventually be expunged when the + // dirty map is promoted to the read map. + // + // Regardless of whether the entry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + for ok { + p := e.p.Load() + if p == nil || p == expunged || *p != old { + return false + } + if e.p.CompareAndSwap(p, nil) { + return true + } + } + return false +} + // Range calls f sequentially for each key and value present in the map. // If f returns false, range stops the iteration. // diff --git a/src/sync/map_bench_test.go b/src/sync/map_bench_test.go index e7b0e6039c44d..4815f57349ec1 100644 --- a/src/sync/map_bench_test.go +++ b/src/sync/map_bench_test.go @@ -198,7 +198,9 @@ func BenchmarkLoadAndDeleteCollision(b *testing.B) { perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { for ; pb.Next(); i++ { - m.LoadAndDelete(0) + if _, loaded := m.LoadAndDelete(0); loaded { + m.Store(0, 0) + } } }, }) @@ -287,3 +289,248 @@ func BenchmarkDeleteCollision(b *testing.B) { }, }) } + +func BenchmarkSwapCollision(b *testing.B) { + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + m.LoadOrStore(0, 0) + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.Swap(0, 0) + } + }, + }) +} + +func BenchmarkSwapMostlyHits(b *testing.B) { + const hits, misses = 1023, 1 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + if i%(hits+misses) < hits { + v := i % (hits + misses) + m.Swap(v, v) + } else { + m.Swap(i, i) + m.Delete(i) + } + } + }, + }) +} + +func BenchmarkSwapMostlyMisses(b *testing.B) { + const hits, misses = 1, 1023 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + if i%(hits+misses) < hits { + v := i % (hits + misses) + m.Swap(v, v) + } else { + m.Swap(i, i) + m.Delete(i) + } + } + }, + }) +} + +func BenchmarkCompareAndSwapCollision(b *testing.B) { + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + m.LoadOrStore(0, 0) + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for pb.Next() { + if m.CompareAndSwap(0, 0, 42) { + m.CompareAndSwap(0, 42, 0) + } + } + }, + }) +} + +func BenchmarkCompareAndSwapNoExistingKey(b *testing.B) { + benchMap(b, bench{ + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + if m.CompareAndSwap(i, 0, 0) { + m.Delete(i) + } + } + }, + }) +} + +func BenchmarkCompareAndSwapValueNotEqual(b *testing.B) { + const n = 1 << 10 + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + m.Store(0, 0) + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.CompareAndSwap(0, 1, 2) + } + }, + }) +} + +func BenchmarkCompareAndSwapMostlyHits(b *testing.B) { + const hits, misses = 1023, 1 + + benchMap(b, bench{ + setup: func(b *testing.B, m mapInterface) { + if _, ok := m.(*DeepCopyMap); ok { + b.Skip("DeepCopyMap has quadratic running time.") + } + + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + v := i + if i%(hits+misses) < hits { + v = i % (hits + misses) + } + m.CompareAndSwap(v, v, v) + } + }, + }) +} + +func BenchmarkCompareAndSwapMostlyMisses(b *testing.B) { + const hits, misses = 1, 1023 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + v := i + if i%(hits+misses) < hits { + v = i % (hits + misses) + } + m.CompareAndSwap(v, v, v) + } + }, + }) +} + +func BenchmarkCompareAndDeleteCollision(b *testing.B) { + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + m.LoadOrStore(0, 0) + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + if m.CompareAndDelete(0, 0) { + m.Store(0, 0) + } + } + }, + }) +} + +func BenchmarkCompareAndDeleteMostlyHits(b *testing.B) { + const hits, misses = 1023, 1 + + benchMap(b, bench{ + setup: func(b *testing.B, m mapInterface) { + if _, ok := m.(*DeepCopyMap); ok { + b.Skip("DeepCopyMap has quadratic running time.") + } + + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + v := i + if i%(hits+misses) < hits { + v = i % (hits + misses) + } + if m.CompareAndDelete(v, v) { + m.Store(v, v) + } + } + }, + }) +} + +func BenchmarkCompareAndDeleteMostlyMisses(b *testing.B) { + const hits, misses = 1, 1023 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + v := i + if i%(hits+misses) < hits { + v = i % (hits + misses) + } + if m.CompareAndDelete(v, v) { + m.Store(v, v) + } + } + }, + }) +} diff --git a/src/sync/map_reference_test.go b/src/sync/map_reference_test.go index 1122b40b9b835..aa5ebf352f98f 100644 --- a/src/sync/map_reference_test.go +++ b/src/sync/map_reference_test.go @@ -18,9 +18,17 @@ type mapInterface interface { LoadOrStore(key, value any) (actual any, loaded bool) LoadAndDelete(key any) (value any, loaded bool) Delete(any) + Swap(key, value any) (previous any, loaded bool) + CompareAndSwap(key, old, new any) (swapped bool) + CompareAndDelete(key, old any) (deleted bool) Range(func(key, value any) (shouldContinue bool)) } +var ( + _ mapInterface = &RWMutexMap{} + _ mapInterface = &DeepCopyMap{} +) + // RWMutexMap is an implementation of mapInterface using a sync.RWMutex. type RWMutexMap struct { mu sync.RWMutex @@ -57,6 +65,18 @@ func (m *RWMutexMap) LoadOrStore(key, value any) (actual any, loaded bool) { return actual, loaded } +func (m *RWMutexMap) Swap(key, value any) (previous any, loaded bool) { + m.mu.Lock() + if m.dirty == nil { + m.dirty = make(map[any]any) + } + + previous, loaded = m.dirty[key] + m.dirty[key] = value + m.mu.Unlock() + return +} + func (m *RWMutexMap) LoadAndDelete(key any) (value any, loaded bool) { m.mu.Lock() value, loaded = m.dirty[key] @@ -75,6 +95,36 @@ func (m *RWMutexMap) Delete(key any) { m.mu.Unlock() } +func (m *RWMutexMap) CompareAndSwap(key, old, new any) (swapped bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.dirty == nil { + return false + } + + value, loaded := m.dirty[key] + if loaded && value == old { + m.dirty[key] = new + return true + } + return false +} + +func (m *RWMutexMap) CompareAndDelete(key, old any) (deleted bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.dirty == nil { + return false + } + + value, loaded := m.dirty[key] + if loaded && value == old { + delete(m.dirty, key) + return true + } + return false +} + func (m *RWMutexMap) Range(f func(key, value any) (shouldContinue bool)) { m.mu.RLock() keys := make([]any, 0, len(m.dirty)) @@ -137,6 +187,16 @@ func (m *DeepCopyMap) LoadOrStore(key, value any) (actual any, loaded bool) { return actual, loaded } +func (m *DeepCopyMap) Swap(key, value any) (previous any, loaded bool) { + m.mu.Lock() + dirty := m.dirty() + previous, loaded = dirty[key] + dirty[key] = value + m.clean.Store(dirty) + m.mu.Unlock() + return +} + func (m *DeepCopyMap) LoadAndDelete(key any) (value any, loaded bool) { m.mu.Lock() dirty := m.dirty() @@ -155,6 +215,43 @@ func (m *DeepCopyMap) Delete(key any) { m.mu.Unlock() } +func (m *DeepCopyMap) CompareAndSwap(key, old, new any) (swapped bool) { + clean, _ := m.clean.Load().(map[any]any) + if previous, ok := clean[key]; !ok || previous != old { + return false + } + + m.mu.Lock() + defer m.mu.Unlock() + dirty := m.dirty() + value, loaded := dirty[key] + if loaded && value == old { + dirty[key] = new + m.clean.Store(dirty) + return true + } + return false +} + +func (m *DeepCopyMap) CompareAndDelete(key, old any) (deleted bool) { + clean, _ := m.clean.Load().(map[any]any) + if previous, ok := clean[key]; !ok || previous != old { + return false + } + + m.mu.Lock() + defer m.mu.Unlock() + + dirty := m.dirty() + value, loaded := dirty[key] + if loaded && value == old { + delete(dirty, key) + m.clean.Store(dirty) + return true + } + return false +} + func (m *DeepCopyMap) Range(f func(key, value any) (shouldContinue bool)) { clean, _ := m.clean.Load().(map[any]any) for k, v := range clean { diff --git a/src/sync/map_test.go b/src/sync/map_test.go index 8352471104381..1eb3fc68a5330 100644 --- a/src/sync/map_test.go +++ b/src/sync/map_test.go @@ -17,14 +17,26 @@ import ( type mapOp string const ( - opLoad = mapOp("Load") - opStore = mapOp("Store") - opLoadOrStore = mapOp("LoadOrStore") - opLoadAndDelete = mapOp("LoadAndDelete") - opDelete = mapOp("Delete") + opLoad = mapOp("Load") + opStore = mapOp("Store") + opLoadOrStore = mapOp("LoadOrStore") + opLoadAndDelete = mapOp("LoadAndDelete") + opDelete = mapOp("Delete") + opSwap = mapOp("Swap") + opCompareAndSwap = mapOp("CompareAndSwap") + opCompareAndDelete = mapOp("CompareAndDelete") ) -var mapOps = [...]mapOp{opLoad, opStore, opLoadOrStore, opLoadAndDelete, opDelete} +var mapOps = [...]mapOp{ + opLoad, + opStore, + opLoadOrStore, + opLoadAndDelete, + opDelete, + opSwap, + opCompareAndSwap, + opCompareAndDelete, +} // mapCall is a quick.Generator for calls on mapInterface. type mapCall struct { @@ -46,6 +58,21 @@ func (c mapCall) apply(m mapInterface) (any, bool) { case opDelete: m.Delete(c.k) return nil, false + case opSwap: + return m.Swap(c.k, c.v) + case opCompareAndSwap: + if m.CompareAndSwap(c.k, c.v, rand.Int()) { + m.Delete(c.k) + return c.v, true + } + return nil, false + case opCompareAndDelete: + if m.CompareAndDelete(c.k, c.v) { + if _, ok := m.Load(c.k); !ok { + return nil, true + } + } + return nil, false default: panic("invalid mapOp") } @@ -245,3 +272,11 @@ func TestMapRangeNestedCall(t *testing.T) { // Issue 46399 t.Fatalf("Unexpected sync.Map size, got %v want %v", length, 0) } } + +func TestCompareAndSwap_NonExistingKey(t *testing.T) { + m := &sync.Map{} + if m.CompareAndSwap(m, nil, 42) { + // See https://go.dev/issue/51972#issuecomment-1126408637. + t.Fatalf("CompareAndSwap on an non-existing key succeeded") + } +}