Skip to content

Commit

Permalink
util/stop: add DelayedStopper to enable RunDelayedAsyncTask
Browse files Browse the repository at this point in the history
Release note: None
  • Loading branch information
ajwerner committed Feb 15, 2019
1 parent 673deaf commit 4098122
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
185 changes: 185 additions & 0 deletions pkg/util/stop/delayed_stopper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package stop

import (
"container/heap"
"context"
"time"

"github.com/cockroachdb/cockroach/pkg/util/timeutil"
)

// DelayedStopper extends Stopper to manage queuing and cancellation of
// delayed tasks.
type DelayedStopper struct {
*Stopper
tasks taskQueue
taskChan chan *DelayedTask
cancelChan chan *DelayedTask
}

// NewDelayedStopper creates a new DelayedStopper from a Stopper.
func NewDelayedStopper(stopper *Stopper) *DelayedStopper {
ds := &DelayedStopper{
Stopper: stopper,

taskChan: make(chan *DelayedTask),
cancelChan: make(chan *DelayedTask),
}
_ = ds.RunAsyncTask(context.Background(), "delayed stopper", ds.run)
return ds
}

// DelayedTask is returned from RunDelayedAsyncTask.
// It exposes a Cancel method which can be used to cancel the task before it
// has been started.
type DelayedTask struct {
ds *DelayedStopper
ctx context.Context
taskName string
f func(context.Context)
startTime time.Time

idx int // idx is the index into a taskQueue
}

// Cancel attempts to prevent execution of the DelayedTask before it has been
// started. Calls to Cancel after the task has been run are a no-op.
func (t *DelayedTask) Cancel() {
if t.ds == nil {
return
}
select {
case t.ds.cancelChan <- t:
case <-t.ds.ShouldQuiesce():
}
}

// RunDelayedAsyncTask queues a task for executaion after the provided delay.
// An error is returned if the context is cancelled or the DelayedStopper is
// shutting down before the task can be queued. Otherwise the returned object
// can be used to cancel the task while its queued. The method task an optional
// DelayedTask pointer to allow clients to avoid an allocation by storing the
// DelayedTask struct as a member of another struct.
func (ds *DelayedStopper) RunDelayedAsyncTask(
ctx context.Context,
taskName string,
f func(context.Context),
delay time.Duration,
t *DelayedTask,
) (*DelayedTask, error) {
if t == nil {
t = &DelayedTask{}
}
initTask(ctx, t, ds, taskName, f, delay)
select {
case <-ds.ShouldQuiesce():
return nil, ErrUnavailable
case ds.taskChan <- t:
return t, nil
}
}

func initTask(
ctx context.Context,
t *DelayedTask,
ds *DelayedStopper,
taskName string,
f func(context.Context),
delay time.Duration,
) {
*t = DelayedTask{
ds: ds,
idx: -1,
ctx: ctx,
taskName: taskName,
f: f,
startTime: timeutil.Now().Add(delay),
}
}

func (ds *DelayedStopper) run(ctx context.Context) {
var (
startTime time.Time
timer = timeutil.NewTimer()
maybeSetTimer = func() {
var nextStartTime time.Time
if next := ds.tasks.peekFront(); next != nil {
nextStartTime = next.startTime
}
if !startTime.Equal(nextStartTime) {
startTime = nextStartTime
if !startTime.IsZero() {
timer.Reset(time.Until(startTime))
} else {
// Clear the current timer due to a sole batch already sent before
// the timer fired.
timer.Stop()
timer = timeutil.NewTimer()
}
}
}
)
for {
select {
case t := <-ds.taskChan:
heap.Push(&ds.tasks, t)
case t := <-ds.cancelChan:
if t.idx != -1 {
ds.tasks.remove(t)
}
case <-timer.C:
timer.Read = true
t := ds.tasks.popFront()
_ = ds.RunAsyncTask(t.ctx, t.taskName, t.f)
case <-ds.ShouldQuiesce():
return
}
maybeSetTimer()
}
}

type taskQueue []*DelayedTask

func (q *taskQueue) remove(t *DelayedTask) {
heap.Remove(q, t.idx)
}

func (q *taskQueue) peekFront() *DelayedTask {
if len(*q) == 0 {
return nil
}
return (*q)[0]
}

func (q *taskQueue) popFront() *DelayedTask {
if len(*q) == 0 {
return nil
}
return heap.Pop(q).(*DelayedTask)
}

func (q *taskQueue) Len() int {
return len(*q)
}

func (q *taskQueue) Less(i, j int) bool {
return (*q)[i].startTime.Before((*q)[j].startTime)
}

func (q *taskQueue) Swap(i, j int) {
(*q)[i], (*q)[j] = (*q)[j], (*q)[i]
(*q)[i].idx, (*q)[j].idx = i, j
}

func (q *taskQueue) Push(v interface{}) {
t := v.(*DelayedTask)
t.idx = len(*q)
*q = append(*q, t)
}

func (q *taskQueue) Pop() interface{} {
t := (*q)[len(*q)-1]
t.idx = -1
(*q) = (*q)[:len(*q)-1]
return t
}
53 changes: 53 additions & 0 deletions pkg/util/stop/delayed_stopper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2019 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.

package stop_test

import (
"context"
"testing"
"time"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/stretchr/testify/assert"
)

func TestDelayedTask(t *testing.T) {
defer leaktest.AfterTest(t)()
ds := stop.NewDelayedStopper(stop.NewStopper())
ctx := context.Background()
dt1, err := ds.RunDelayedAsyncTask(ctx, "foo", func(context.Context) {}, time.Minute, nil)
assert.Nil(t, err)
var dt stop.DelayedTask
dt.Cancel() // ensure that Cancel on a zero value is safe.
dt2, err := ds.RunDelayedAsyncTask(ctx, "foo", func(context.Context) {}, 2*time.Minute, &dt)
assert.Nil(t, err)
assert.Equal(t, dt2, &dt)
dt3, err := ds.RunDelayedAsyncTask(ctx, "foo", func(context.Context) {}, 2*time.Minute, nil)
assert.Nil(t, err)
c := make(chan struct{})
_, err = ds.RunDelayedAsyncTask(ctx, "foo", func(context.Context) {
close(c)
}, time.Microsecond, nil)
assert.Nil(t, err)
<-c
dt1.Cancel()
dt2.Cancel()
dt3.Cancel()
ds.Stop(ctx)
dt4, err := ds.RunDelayedAsyncTask(ctx, "foo", func(context.Context) {}, 2*time.Minute, nil)
assert.Equal(t, err, stop.ErrUnavailable)
assert.Nil(t, dt4)
}

0 comments on commit 4098122

Please sign in to comment.