From 03a48ebe1c79ed0437dcdbc8bf3cd58e901a999d Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Sun, 8 Feb 2015 18:11:35 +0300 Subject: [PATCH] sync: simplify WaitGroup A comment in waitgroup.go describes the following scenario as the reason to have dynamically created semaphores: // G1: Add(1) // G1: go G2() // G1: Wait() // Context switch after Unlock() and before Semacquire(). // G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet. // G3: Wait() // Finds counter == 0, waiters == 0, doesn't block. // G3: Add(1) // Makes counter == 1, waiters == 0. // G3: go G4() // G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug. However, the scenario is incorrect: G3: Add(1) happens concurrently with G1: Wait(), and so there is no reasonable behavior of the program (G1: Wait() may or may not wait for G3: Add(1) which can't be the intended behavior). With this conclusion we can: 1. Remove dynamic allocation of semaphores. 2. Remove the mutex entirely and instead pack counter and waiters into single uint64. This makes the logic significantly simpler, both Add and Wait do only a single atomic RMW to update the state. benchmark old ns/op new ns/op delta BenchmarkWaitGroupUncontended 30.6 32.7 +6.86% BenchmarkWaitGroupActuallyWait 722 595 -17.59% BenchmarkWaitGroupActuallyWait-2 396 319 -19.44% BenchmarkWaitGroupActuallyWait-4 224 183 -18.30% BenchmarkWaitGroupActuallyWait-8 134 106 -20.90% benchmark old allocs new allocs delta BenchmarkWaitGroupActuallyWait 2 1 -50.00% benchmark old bytes new bytes delta BenchmarkWaitGroupActuallyWait 48 16 -66.67% Change-Id: I28911f3243aa16544e99ac8f1f5af31944c7ea3a Reviewed-on: https://go-review.googlesource.com/4117 Run-TryBot: Dmitry Vyukov TryBot-Result: Gobot Gobot Reviewed-by: Ian Lance Taylor Reviewed-by: Russ Cox --- src/sync/waitgroup.go | 134 ++++++++++++++++++------------------- src/sync/waitgroup_test.go | 112 +++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 67 deletions(-) diff --git a/src/sync/waitgroup.go b/src/sync/waitgroup.go index 92cc57d2cc87ea..de399e64ebbfc8 100644 --- a/src/sync/waitgroup.go +++ b/src/sync/waitgroup.go @@ -15,23 +15,21 @@ import ( // runs and calls Done when finished. At the same time, // Wait can be used to block until all goroutines have finished. type WaitGroup struct { - m Mutex - counter int32 - waiters int32 - sema *uint32 + // 64-bit value: high 32 bits are counter, low 32 bits are waiter count. + // 64-bit atomic operations require 64-bit alignment, but 32-bit + // compilers do not ensure it. So we allocate 12 bytes and then use + // the aligned 8 bytes in them as state. + state1 [12]byte + sema uint32 } -// WaitGroup creates a new semaphore each time the old semaphore -// is released. This is to avoid the following race: -// -// G1: Add(1) -// G1: go G2() -// G1: Wait() // Context switch after Unlock() and before Semacquire(). -// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet. -// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block. -// G3: Add(1) // Makes counter == 1, waiters == 0. -// G3: go G4() -// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug. +func (wg *WaitGroup) state() *uint64 { + if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { + return (*uint64)(unsafe.Pointer(&wg.state1)) + } else { + return (*uint64)(unsafe.Pointer(&wg.state1[4])) + } +} // Add adds delta, which may be negative, to the WaitGroup counter. // If the counter becomes zero, all goroutines blocked on Wait are released. @@ -43,10 +41,13 @@ type WaitGroup struct { // at any time. // Typically this means the calls to Add should execute before the statement // creating the goroutine or other event to be waited for. +// If a WaitGroup is reused to wait for several independent sets of events, +// new Add calls must happen after all previous Wait calls have returned. // See the WaitGroup example. func (wg *WaitGroup) Add(delta int) { + statep := wg.state() if raceenabled { - _ = wg.m.state // trigger nil deref early + _ = *statep // trigger nil deref early if delta < 0 { // Synchronize decrements with Wait. raceReleaseMerge(unsafe.Pointer(wg)) @@ -54,7 +55,9 @@ func (wg *WaitGroup) Add(delta int) { raceDisable() defer raceEnable() } - v := atomic.AddInt32(&wg.counter, int32(delta)) + state := atomic.AddUint64(statep, uint64(delta)<<32) + v := int32(state >> 32) + w := uint32(state) if raceenabled { if delta > 0 && v == int32(delta) { // The first increment must be synchronized with Wait. @@ -66,18 +69,25 @@ func (wg *WaitGroup) Add(delta int) { if v < 0 { panic("sync: negative WaitGroup counter") } - if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 { + if w != 0 && delta > 0 && v == int32(delta) { + panic("sync: WaitGroup misuse: Add called concurrently with Wait") + } + if v > 0 || w == 0 { return } - wg.m.Lock() - if atomic.LoadInt32(&wg.counter) == 0 { - for i := int32(0); i < wg.waiters; i++ { - runtime_Semrelease(wg.sema) - } - wg.waiters = 0 - wg.sema = nil + // This goroutine has set counter to 0 when waiters > 0. + // Now there can't be concurrent mutations of state: + // - Adds must not happen concurrently with Wait, + // - Wait does not increment waiters if it sees counter == 0. + // Still do a cheap sanity check to detect WaitGroup misuse. + if *statep != state { + panic("sync: WaitGroup misuse: Add called concurrently with Wait") + } + // Reset waiters count to 0. + *statep = 0 + for ; w != 0; w-- { + runtime_Semrelease(&wg.sema) } - wg.m.Unlock() } // Done decrements the WaitGroup counter. @@ -87,51 +97,41 @@ func (wg *WaitGroup) Done() { // Wait blocks until the WaitGroup counter is zero. func (wg *WaitGroup) Wait() { + statep := wg.state() if raceenabled { - _ = wg.m.state // trigger nil deref early + _ = *statep // trigger nil deref early raceDisable() } - if atomic.LoadInt32(&wg.counter) == 0 { - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) - } - return - } - wg.m.Lock() - w := atomic.AddInt32(&wg.waiters, 1) - // This code is racing with the unlocked path in Add above. - // The code above modifies counter and then reads waiters. - // We must modify waiters and then read counter (the opposite order) - // to avoid missing an Add. - if atomic.LoadInt32(&wg.counter) == 0 { - atomic.AddInt32(&wg.waiters, -1) - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) - raceDisable() + for { + state := atomic.LoadUint64(statep) + v := int32(state >> 32) + w := uint32(state) + if v == 0 { + // Counter is 0, no need to wait. + if raceenabled { + raceEnable() + raceAcquire(unsafe.Pointer(wg)) + } + return } - wg.m.Unlock() - if raceenabled { - raceEnable() + // Increment waiters count. + if atomic.CompareAndSwapUint64(statep, state, state+1) { + if raceenabled && w == 0 { + // Wait must be synchronized with the first Add. + // Need to model this is as a write to race with the read in Add. + // As a consequence, can do the write only for the first waiter, + // otherwise concurrent Waits will race with each other. + raceWrite(unsafe.Pointer(&wg.sema)) + } + runtime_Semacquire(&wg.sema) + if *statep != 0 { + panic("sync: WaitGroup is reused before previous Wait has returned") + } + if raceenabled { + raceEnable() + raceAcquire(unsafe.Pointer(wg)) + } + return } - return - } - if raceenabled && w == 1 { - // Wait must be synchronized with the first Add. - // Need to model this is as a write to race with the read in Add. - // As a consequence, can do the write only for the first waiter, - // otherwise concurrent Waits will race with each other. - raceWrite(unsafe.Pointer(&wg.sema)) - } - if wg.sema == nil { - wg.sema = new(uint32) - } - s := wg.sema - wg.m.Unlock() - runtime_Semacquire(s) - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) } } diff --git a/src/sync/waitgroup_test.go b/src/sync/waitgroup_test.go index 4c0a043c01ee7d..06a77798d088d3 100644 --- a/src/sync/waitgroup_test.go +++ b/src/sync/waitgroup_test.go @@ -5,6 +5,7 @@ package sync_test import ( + "runtime" . "sync" "sync/atomic" "testing" @@ -60,6 +61,90 @@ func TestWaitGroupMisuse(t *testing.T) { t.Fatal("Should panic") } +func TestWaitGroupMisuse2(t *testing.T) { + if runtime.NumCPU() <= 2 { + t.Skip("NumCPU<=2, skipping: this test requires parallelism") + } + defer func() { + err := recover() + if err != "sync: negative WaitGroup counter" && + err != "sync: WaitGroup misuse: Add called concurrently with Wait" && + err != "sync: WaitGroup is reused before previous Wait has returned" { + t.Fatalf("Unexpected panic: %#v", err) + } + }() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + done := make(chan interface{}, 2) + // The detection is opportunistically, so we want it to panic + // at least in one run out of a million. + for i := 0; i < 1e6; i++ { + var wg WaitGroup + wg.Add(1) + go func() { + defer func() { + done <- recover() + }() + wg.Wait() + }() + go func() { + defer func() { + done <- recover() + }() + wg.Add(1) // This is the bad guy. + wg.Done() + }() + wg.Done() + for j := 0; j < 2; j++ { + if err := <-done; err != nil { + panic(err) + } + } + } + t.Fatal("Should panic") +} + +func TestWaitGroupMisuse3(t *testing.T) { + if runtime.NumCPU() <= 1 { + t.Skip("NumCPU==1, skipping: this test requires parallelism") + } + defer func() { + err := recover() + if err != "sync: negative WaitGroup counter" && + err != "sync: WaitGroup misuse: Add called concurrently with Wait" && + err != "sync: WaitGroup is reused before previous Wait has returned" { + t.Fatalf("Unexpected panic: %#v", err) + } + }() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + done := make(chan interface{}, 1) + // The detection is opportunistically, so we want it to panic + // at least in one run out of a million. + for i := 0; i < 1e6; i++ { + var wg WaitGroup + wg.Add(1) + go func() { + wg.Done() + }() + go func() { + defer func() { + done <- recover() + }() + wg.Wait() + // Start reusing the wg before waiting for the Wait below to return. + wg.Add(1) + go func() { + wg.Done() + }() + wg.Wait() + }() + wg.Wait() + if err := <-done; err != nil { + panic(err) + } + } + t.Fatal("Should panic") +} + func TestWaitGroupRace(t *testing.T) { // Run this test for about 1ms. for i := 0; i < 1000; i++ { @@ -85,6 +170,19 @@ func TestWaitGroupRace(t *testing.T) { } } +func TestWaitGroupAlign(t *testing.T) { + type X struct { + x byte + wg WaitGroup + } + var x X + x.wg.Add(1) + go func(x *X) { + x.wg.Done() + }(&x) + x.wg.Wait() +} + func BenchmarkWaitGroupUncontended(b *testing.B) { type PaddedWaitGroup struct { WaitGroup @@ -146,3 +244,17 @@ func BenchmarkWaitGroupWait(b *testing.B) { func BenchmarkWaitGroupWaitWork(b *testing.B) { benchmarkWaitGroupWait(b, 100) } + +func BenchmarkWaitGroupActuallyWait(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var wg WaitGroup + wg.Add(1) + go func() { + wg.Done() + }() + wg.Wait() + } + }) +}