Skip to content

Commit

Permalink
util/stop: Introduce RunLimitedAsyncTask
Browse files Browse the repository at this point in the history
This method adds channel-based backpressure to RunAsyncTask, to avoid
spinning up an unbounded number of goroutines.
  • Loading branch information
bdarnell committed Mar 14, 2016
1 parent d1655b0 commit 441e602
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
26 changes: 26 additions & 0 deletions util/stop/stopper.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,32 @@ func (s *Stopper) RunAsyncTask(f func()) bool {
return true
}

// RunLimitedAsyncTask runs function f in a goroutine, using the given
// channel as a semaphore to limit the number of tasks that are run
// concurrently to the channel's capacity. Blocks until the semaphore
// is available in order to push back on callers that may be trying to
// create many tasks. Returns false if the Stopper is draining and the
// function is not executed.
func (s *Stopper) RunLimitedAsyncTask(sem chan struct{}, f func()) bool {
file, line, _ := caller.Lookup(1)
key := taskKey{file, line}
select {
case sem <- struct{}{}:
case <-s.ShouldDrain():
return false
}
if !s.runPrelude(key) {
<-sem
return false
}
go func() {
defer s.runPostlude(key)
defer func() { <-sem }()
f()
}()
return true
}

func (s *Stopper) runPrelude(key taskKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
42 changes: 42 additions & 0 deletions util/stop/stopper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package stop_test

import (
"fmt"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -374,6 +375,47 @@ func TestStopperShouldDrain(t *testing.T) {
<-cleanup
}

func TestStopperRunLimitedAsyncTask(t *testing.T) {
defer leaktest.AfterTest(t)()
s := stop.NewStopper()
defer s.Stop()

const maxConcurrency = 5
const duration = 10 * time.Millisecond
sem := make(chan struct{}, maxConcurrency)
var mu sync.Mutex
concurrency := 0
peakConcurrency := 0
var wg sync.WaitGroup

f := func() {
mu.Lock()
concurrency++
if concurrency > peakConcurrency {
peakConcurrency = concurrency
}
mu.Unlock()
time.Sleep(duration)
mu.Lock()
concurrency--
mu.Unlock()
wg.Done()
}

for i := 0; i < maxConcurrency*3; i++ {
wg.Add(1)
s.RunLimitedAsyncTask(sem, f)
}
wg.Wait()
if concurrency != 0 {
t.Fatalf("expected 0 concurrency at end of test but got %d", concurrency)
}
if peakConcurrency != maxConcurrency {
t.Fatalf("expected peak concurrency %d to equal max concurrency %d",
peakConcurrency, maxConcurrency)
}
}

func maybePrint() {
if testing.Verbose() { // This just needs to be complicated enough not to inline.
fmt.Println("blah")
Expand Down

0 comments on commit 441e602

Please sign in to comment.