diff --git a/src/sync/waitgroup.go b/src/sync/waitgroup.go index 92cc57d2cc87ea..de399e64ebbfc8 100644 --- a/src/sync/waitgroup.go +++ b/src/sync/waitgroup.go @@ -15,23 +15,21 @@ import ( // runs and calls Done when finished. At the same time, // Wait can be used to block until all goroutines have finished. type WaitGroup struct { - m Mutex - counter int32 - waiters int32 - sema *uint32 + // 64-bit value: high 32 bits are counter, low 32 bits are waiter count. + // 64-bit atomic operations require 64-bit alignment, but 32-bit + // compilers do not ensure it. So we allocate 12 bytes and then use + // the aligned 8 bytes in them as state. + state1 [12]byte + sema uint32 } -// WaitGroup creates a new semaphore each time the old semaphore -// is released. This is to avoid the following race: -// -// G1: Add(1) -// G1: go G2() -// G1: Wait() // Context switch after Unlock() and before Semacquire(). -// G2: Done() // Release semaphore: sema == 1, waiters == 0. G1 doesn't run yet. -// G3: Wait() // Finds counter == 0, waiters == 0, doesn't block. -// G3: Add(1) // Makes counter == 1, waiters == 0. -// G3: go G4() -// G3: Wait() // G1 still hasn't run, G3 finds sema == 1, unblocked! Bug. +func (wg *WaitGroup) state() *uint64 { + if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { + return (*uint64)(unsafe.Pointer(&wg.state1)) + } else { + return (*uint64)(unsafe.Pointer(&wg.state1[4])) + } +} // Add adds delta, which may be negative, to the WaitGroup counter. // If the counter becomes zero, all goroutines blocked on Wait are released. @@ -43,10 +41,13 @@ type WaitGroup struct { // at any time. // Typically this means the calls to Add should execute before the statement // creating the goroutine or other event to be waited for. +// If a WaitGroup is reused to wait for several independent sets of events, +// new Add calls must happen after all previous Wait calls have returned. // See the WaitGroup example. func (wg *WaitGroup) Add(delta int) { + statep := wg.state() if raceenabled { - _ = wg.m.state // trigger nil deref early + _ = *statep // trigger nil deref early if delta < 0 { // Synchronize decrements with Wait. raceReleaseMerge(unsafe.Pointer(wg)) @@ -54,7 +55,9 @@ func (wg *WaitGroup) Add(delta int) { raceDisable() defer raceEnable() } - v := atomic.AddInt32(&wg.counter, int32(delta)) + state := atomic.AddUint64(statep, uint64(delta)<<32) + v := int32(state >> 32) + w := uint32(state) if raceenabled { if delta > 0 && v == int32(delta) { // The first increment must be synchronized with Wait. @@ -66,18 +69,25 @@ func (wg *WaitGroup) Add(delta int) { if v < 0 { panic("sync: negative WaitGroup counter") } - if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 { + if w != 0 && delta > 0 && v == int32(delta) { + panic("sync: WaitGroup misuse: Add called concurrently with Wait") + } + if v > 0 || w == 0 { return } - wg.m.Lock() - if atomic.LoadInt32(&wg.counter) == 0 { - for i := int32(0); i < wg.waiters; i++ { - runtime_Semrelease(wg.sema) - } - wg.waiters = 0 - wg.sema = nil + // This goroutine has set counter to 0 when waiters > 0. + // Now there can't be concurrent mutations of state: + // - Adds must not happen concurrently with Wait, + // - Wait does not increment waiters if it sees counter == 0. + // Still do a cheap sanity check to detect WaitGroup misuse. + if *statep != state { + panic("sync: WaitGroup misuse: Add called concurrently with Wait") + } + // Reset waiters count to 0. + *statep = 0 + for ; w != 0; w-- { + runtime_Semrelease(&wg.sema) } - wg.m.Unlock() } // Done decrements the WaitGroup counter. @@ -87,51 +97,41 @@ func (wg *WaitGroup) Done() { // Wait blocks until the WaitGroup counter is zero. func (wg *WaitGroup) Wait() { + statep := wg.state() if raceenabled { - _ = wg.m.state // trigger nil deref early + _ = *statep // trigger nil deref early raceDisable() } - if atomic.LoadInt32(&wg.counter) == 0 { - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) - } - return - } - wg.m.Lock() - w := atomic.AddInt32(&wg.waiters, 1) - // This code is racing with the unlocked path in Add above. - // The code above modifies counter and then reads waiters. - // We must modify waiters and then read counter (the opposite order) - // to avoid missing an Add. - if atomic.LoadInt32(&wg.counter) == 0 { - atomic.AddInt32(&wg.waiters, -1) - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) - raceDisable() + for { + state := atomic.LoadUint64(statep) + v := int32(state >> 32) + w := uint32(state) + if v == 0 { + // Counter is 0, no need to wait. + if raceenabled { + raceEnable() + raceAcquire(unsafe.Pointer(wg)) + } + return } - wg.m.Unlock() - if raceenabled { - raceEnable() + // Increment waiters count. + if atomic.CompareAndSwapUint64(statep, state, state+1) { + if raceenabled && w == 0 { + // Wait must be synchronized with the first Add. + // Need to model this is as a write to race with the read in Add. + // As a consequence, can do the write only for the first waiter, + // otherwise concurrent Waits will race with each other. + raceWrite(unsafe.Pointer(&wg.sema)) + } + runtime_Semacquire(&wg.sema) + if *statep != 0 { + panic("sync: WaitGroup is reused before previous Wait has returned") + } + if raceenabled { + raceEnable() + raceAcquire(unsafe.Pointer(wg)) + } + return } - return - } - if raceenabled && w == 1 { - // Wait must be synchronized with the first Add. - // Need to model this is as a write to race with the read in Add. - // As a consequence, can do the write only for the first waiter, - // otherwise concurrent Waits will race with each other. - raceWrite(unsafe.Pointer(&wg.sema)) - } - if wg.sema == nil { - wg.sema = new(uint32) - } - s := wg.sema - wg.m.Unlock() - runtime_Semacquire(s) - if raceenabled { - raceEnable() - raceAcquire(unsafe.Pointer(wg)) } } diff --git a/src/sync/waitgroup_test.go b/src/sync/waitgroup_test.go index 4c0a043c01ee7d..06a77798d088d3 100644 --- a/src/sync/waitgroup_test.go +++ b/src/sync/waitgroup_test.go @@ -5,6 +5,7 @@ package sync_test import ( + "runtime" . "sync" "sync/atomic" "testing" @@ -60,6 +61,90 @@ func TestWaitGroupMisuse(t *testing.T) { t.Fatal("Should panic") } +func TestWaitGroupMisuse2(t *testing.T) { + if runtime.NumCPU() <= 2 { + t.Skip("NumCPU<=2, skipping: this test requires parallelism") + } + defer func() { + err := recover() + if err != "sync: negative WaitGroup counter" && + err != "sync: WaitGroup misuse: Add called concurrently with Wait" && + err != "sync: WaitGroup is reused before previous Wait has returned" { + t.Fatalf("Unexpected panic: %#v", err) + } + }() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + done := make(chan interface{}, 2) + // The detection is opportunistically, so we want it to panic + // at least in one run out of a million. + for i := 0; i < 1e6; i++ { + var wg WaitGroup + wg.Add(1) + go func() { + defer func() { + done <- recover() + }() + wg.Wait() + }() + go func() { + defer func() { + done <- recover() + }() + wg.Add(1) // This is the bad guy. + wg.Done() + }() + wg.Done() + for j := 0; j < 2; j++ { + if err := <-done; err != nil { + panic(err) + } + } + } + t.Fatal("Should panic") +} + +func TestWaitGroupMisuse3(t *testing.T) { + if runtime.NumCPU() <= 1 { + t.Skip("NumCPU==1, skipping: this test requires parallelism") + } + defer func() { + err := recover() + if err != "sync: negative WaitGroup counter" && + err != "sync: WaitGroup misuse: Add called concurrently with Wait" && + err != "sync: WaitGroup is reused before previous Wait has returned" { + t.Fatalf("Unexpected panic: %#v", err) + } + }() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + done := make(chan interface{}, 1) + // The detection is opportunistically, so we want it to panic + // at least in one run out of a million. + for i := 0; i < 1e6; i++ { + var wg WaitGroup + wg.Add(1) + go func() { + wg.Done() + }() + go func() { + defer func() { + done <- recover() + }() + wg.Wait() + // Start reusing the wg before waiting for the Wait below to return. + wg.Add(1) + go func() { + wg.Done() + }() + wg.Wait() + }() + wg.Wait() + if err := <-done; err != nil { + panic(err) + } + } + t.Fatal("Should panic") +} + func TestWaitGroupRace(t *testing.T) { // Run this test for about 1ms. for i := 0; i < 1000; i++ { @@ -85,6 +170,19 @@ func TestWaitGroupRace(t *testing.T) { } } +func TestWaitGroupAlign(t *testing.T) { + type X struct { + x byte + wg WaitGroup + } + var x X + x.wg.Add(1) + go func(x *X) { + x.wg.Done() + }(&x) + x.wg.Wait() +} + func BenchmarkWaitGroupUncontended(b *testing.B) { type PaddedWaitGroup struct { WaitGroup @@ -146,3 +244,17 @@ func BenchmarkWaitGroupWait(b *testing.B) { func BenchmarkWaitGroupWaitWork(b *testing.B) { benchmarkWaitGroupWait(b, 100) } + +func BenchmarkWaitGroupActuallyWait(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var wg WaitGroup + wg.Add(1) + go func() { + wg.Done() + }() + wg.Wait() + } + }) +}