diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 82f03dca50a3..3f789be6631e 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -189,6 +189,7 @@ ALL_TESTS = [ "//pkg/kv/kvserver/intentresolver:intentresolver_test", "//pkg/kv/kvserver/liveness:liveness_test", "//pkg/kv/kvserver/loqrecovery:loqrecovery_test", + "//pkg/kv/kvserver/multiqueue:multiqueue_test", "//pkg/kv/kvserver/protectedts/ptcache:ptcache_test", "//pkg/kv/kvserver/protectedts/ptreconcile:ptreconcile_test", "//pkg/kv/kvserver/protectedts/ptstorage:ptstorage_test", @@ -1111,6 +1112,8 @@ GO_TARGETS = [ "//pkg/kv/kvserver/loqrecovery/loqrecoverypb:loqrecoverypb", "//pkg/kv/kvserver/loqrecovery:loqrecovery", "//pkg/kv/kvserver/loqrecovery:loqrecovery_test", + "//pkg/kv/kvserver/multiqueue:multiqueue", + "//pkg/kv/kvserver/multiqueue:multiqueue_test", "//pkg/kv/kvserver/protectedts/ptcache:ptcache", "//pkg/kv/kvserver/protectedts/ptcache:ptcache_test", "//pkg/kv/kvserver/protectedts/ptpb:ptpb", @@ -2361,6 +2364,7 @@ GET_X_DATA_TARGETS = [ "//pkg/kv/kvserver/liveness/livenesspb:get_x_data", "//pkg/kv/kvserver/loqrecovery:get_x_data", "//pkg/kv/kvserver/loqrecovery/loqrecoverypb:get_x_data", + "//pkg/kv/kvserver/multiqueue:get_x_data", "//pkg/kv/kvserver/protectedts:get_x_data", "//pkg/kv/kvserver/protectedts/ptcache:get_x_data", "//pkg/kv/kvserver/protectedts/ptpb:get_x_data", diff --git a/pkg/kv/kvserver/multiqueue/BUILD.bazel b/pkg/kv/kvserver/multiqueue/BUILD.bazel new file mode 100644 index 000000000000..0df500a1f1bc --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/BUILD.bazel @@ -0,0 +1,30 @@ +load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "multiqueue", + srcs = ["multi_queue.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvserver/multiqueue", + visibility = ["//visibility:public"], + deps = [ + "//pkg/util/stop", + "//pkg/util/syncutil", + "@com_github_cockroachdb_redact//:redact", + ], +) + +go_test( + name = "multiqueue_test", + srcs = ["multi_queue_test.go"], + embed = [":multiqueue"], + deps = [ + "//pkg/testutils", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/stop", + "@com_github_cockroachdb_errors//:errors", + "@com_github_stretchr_testify//require", + ], +) + +get_x_data(name = "get_x_data") diff --git a/pkg/kv/kvserver/multiqueue/multi_queue.go b/pkg/kv/kvserver/multiqueue/multi_queue.go new file mode 100644 index 000000000000..fcdeb19d2d7b --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/multi_queue.go @@ -0,0 +1,246 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package multiqueue + +import ( + "container/heap" + "context" + "sync" + + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/redact" +) + +// Task represents a request for a Permit for a piece of work that needs to be +// done. It is created by a call to MultiQueue.Add. After creation, +// Task.GetWaitChan is called to get a permit, and after all work related to +// this task is done, MultiQueue.Release must be called so future tasks can run. +// Alternatively, if the user decides they no longer want to run their work, +// MultiQueue.Cancel can be called to release the permit without waiting for the +// permit. +type Task struct { + permitC chan Permit + priority float64 + queueName string + heapIdx int +} + +// GetWaitChan returns a permit channel which is used to wait for the permit to +// become available. +func (t *Task) GetWaitChan() chan Permit { + return t.permitC +} + +func (t *Task) String() string { + return redact.Sprintf("{Queue name : %s, Priority :%f}", t.queueName, t.priority).StripMarkers() +} + +// notifyHeap is a standard go heap over tasks. +type notifyHeap []*Task + +func (h notifyHeap) Len() int { + return len(h) +} + +func (h notifyHeap) Less(i, j int) bool { + return h[j].priority < h[i].priority +} + +func (h notifyHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].heapIdx = i + h[j].heapIdx = j +} + +func (h *notifyHeap) Push(x interface{}) { + t := x.(*Task) + // Set the index to the end, it will be moved later + t.heapIdx = h.Len() + *h = append(*h, t) +} + +func (h *notifyHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[0 : n-1] + // No longer in the heap so clear the index + x.heapIdx = -1 + + return x +} + +// tryRemove attempts to remove the task from this queue by iterating through +// the queue. Will returns true if the task was successfully removed. +func (h *notifyHeap) tryRemove(task *Task) bool { + if task.heapIdx < 0 { + return false + } + heap.Remove(h, task.heapIdx) + return true +} + +// MultiQueue is a type that round-robins through a set of named queues, each +// independently prioritized. A MultiQueue is constructed with a concurrencySem +// which is the number of concurrent jobs this queue will allow to run. Tasks +// are added to the queue using MultiQueue.Add. That will return a channel that +// should be received from. It will be notified when the waiting job is ready to +// be run. Once the job is completed, MultiQueue.TaskDone must be called to +// return the Permit to the queue so that the next Task can be started. +type MultiQueue struct { + mu syncutil.Mutex + wakeUp *sync.Cond + concurrencySem chan Permit + nameMapping map[string]int + lastQueueIndex int + outstanding []notifyHeap + stopping bool + name string +} + +// NewMultiQueue creates a new queue. The queue is not started, and start needs +// to be called on it first. +func NewMultiQueue(name string, maxConcurrency int) *MultiQueue { + queue := MultiQueue{ + concurrencySem: make(chan Permit, maxConcurrency), + nameMapping: make(map[string]int), + name: name, + } + queue.wakeUp = sync.NewCond(&queue.mu) + queue.lastQueueIndex = -1 + + // Fill all the permits in the queue. + for i := 0; i < maxConcurrency; i++ { + queue.concurrencySem <- Permit{} + } + + return &queue +} + +// Permit is a token which is returned from a Task.GetWaitChan call. +type Permit struct{} + +// Start begins the main loop of this MultiQueue which will continue until Stop +// is called. A MultiQueue.Start should not be started more than once, or after +// Stop has been called. +func (m *MultiQueue) Start(startCtx context.Context, stopper *stop.Stopper) { + _ = stopper.RunAsyncTask(startCtx, m.name+"-multi-queue-quiesce", func(ctx context.Context) { + // Wait for the quiesce signal. Once we are signaled we need to do three things: + // * Close the concurrencySem so no more permits are available. + // * Set the stopping flag in case we are waiting for new tasks. + // * Signal the + <-stopper.ShouldQuiesce() + m.mu.Lock() + close(m.concurrencySem) + m.stopping = true + m.wakeUp.Signal() + m.mu.Unlock() + }) + _ = stopper.RunAsyncTask(startCtx, m.name+"-multi-queue", func(ctx context.Context) { + // Run until the concurrencySem is closed after Stopper.ShouldQuiesce is called. + for p := range m.concurrencySem { + // Hold the lock once we get the permit until we are able to run, or we + // are waiting in wakeUp.Wait. + m.mu.Lock() + for { + // If stopping is set, then we are shutting down, so release lock and return. + if m.stopping { + break + } + // If we gave a permit, then we are done. + if m.tryRunNext(p) { + break + } + // If there are no tasks on any queues, wait until one gets added. + m.wakeUp.Wait() + } + m.mu.Unlock() + } + }) +} + +// tryRunNext will run the next task in order round-robin through the queues and in +// priority order within a queue. It will return true if it ran a task. The +// MultiQueue.mu lock must be held before calling this func. +func (m *MultiQueue) tryRunNext(permit Permit) bool { + for i := 0; i < len(m.outstanding); i++ { + // Start with the next queue in order and iterate through all empty queues. + // If all queues are empty then return false signaling that nothing was run. + index := (m.lastQueueIndex + i + 1) % len(m.outstanding) + if m.outstanding[index].Len() > 0 { + task := heap.Pop(&m.outstanding[index]).(*Task) + task.permitC <- permit + m.lastQueueIndex = index + return true + } + } + return false +} + +// Add returns a Task that must be closed (calling Task.Close) to +// release the Permit. The number of names is expected to +// be relatively small and not be changing over time. +func (m *MultiQueue) Add(name string, priority float64) *Task { + m.mu.Lock() + defer m.mu.Unlock() + + // The mutex starts locked, unlock it when we are ready to run. + pos, ok := m.nameMapping[name] + if !ok { + // Append a new entry to both nameMapping and outstanding each time there is + // a new queue name. + pos = len(m.outstanding) + m.nameMapping[name] = pos + m.outstanding = append(m.outstanding, notifyHeap{}) + } + newTask := Task{ + priority: priority, + permitC: make(chan Permit, 1), + heapIdx: -1, + queueName: name, + } + heap.Push(&m.outstanding[pos], &newTask) + + // Once we are done adding a task, signal the main loop in case it finished + // all its work and was waiting for more work. We are holding the mu lock when + // signaling, so we guarantee that it will not be able to respond to the + // signal until after we release the lock. + m.wakeUp.Signal() + + return &newTask +} + +// Cancel will cancel a Task that may not have started yet. This is useful if it +// is determined that it is no longer required to run this Task. +func (m *MultiQueue) Cancel(task *Task) { + m.mu.Lock() + defer m.mu.Unlock() + // Find the right queue and try to remove it. Queues monotonically grow, and a + // Task will track its position within the queue. + queueIdx := m.nameMapping[task.queueName] + ok := m.outstanding[queueIdx].tryRemove(task) + // If we are not able to remove it from the queue, then it is either running + // or completed. + if !ok { + if p, ok := <-task.permitC; ok { + m.Release(p) + } + } +} + +// Release needs to be called once the Task that was running has completed and +// is no longer using system resources. This allows the MultiQueue to call the +// next Task. +func (m *MultiQueue) Release(permit Permit) { + m.concurrencySem <- permit +} diff --git a/pkg/kv/kvserver/multiqueue/multi_queue_test.go b/pkg/kv/kvserver/multiqueue/multi_queue_test.go new file mode 100644 index 000000000000..4f337874ffc9 --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/multi_queue_test.go @@ -0,0 +1,212 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package multiqueue + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +// TestMultiQueueEmpty makes sure that an empty queue can be created, started +// and stopped. +func TestMultiQueueEmpty(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + stopper := stop.NewStopper() + + queue := NewMultiQueue("test", 1) + queue.Start(ctx, stopper) + stopper.Stop(ctx) +} + +// TestMultiQueueAddTwiceSameQueue makes sure that for a single queue the +// priority is respected. +func TestMultiQueueAddTwiceSameQueue(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + stopper := stop.NewStopper() + queue := NewMultiQueue("test", 1) + + chan1 := queue.Add("a", 1.0) + chan2 := queue.Add("a", 2.0) + + queue.Start(ctx, stopper) + + // Verify chan2 is higher priority so runs first. + verifyOrder(t, queue, chan2, chan1) + stopper.Stop(ctx) +} + +// TestMultiQueueTwoQueues checks that if requests are added to two queue names, +// they are called in a round-robin order. It also verifies that the priority is +// respected for each. +func TestMultiQueueTwoQueues(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + stopper := stop.NewStopper() + queue := NewMultiQueue("test", 1) + + a1 := queue.Add("a", 4.0) + a2 := queue.Add("a", 5.0) + + b1 := queue.Add("b", 1.0) + b2 := queue.Add("b", 2.0) + + // The queue starts with the "second" item added. + queue.Start(ctx, stopper) + verifyOrder(t, queue, a2, b2, a1, b1) + stopper.Stop(ctx) +} + +// TestMultiQueueComplex verifies that with multiple queues, some added before +// and some after we start running, that the final order is still as expected. +// The expectation is that it round robins through the queues (a, b, c, ...) and +// runs higher priority tasks before lower priority within a queue. +func TestMultiQueueComplex(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + stopper := stop.NewStopper() + + queue := NewMultiQueue("test", 1) + + a2 := queue.Add("a", 4.0) + b1 := queue.Add("b", 1.1) + b2 := queue.Add("b", 2.1) + c2 := queue.Add("c", 1.2) + c3 := queue.Add("c", 2.2) + a3 := queue.Add("a", 5.0) + b3 := queue.Add("b", 6.1) + + queue.Start(ctx, stopper) + + verifyOrder(t, queue, a3, b3, c3, a2, b2, c2, b1) + stopper.Stop(ctx) +} + +func TestMultiQueueRemove(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + stopper := stop.NewStopper() + + queue := NewMultiQueue("test", 1) + + a2 := queue.Add("a", 4.0) + b1 := queue.Add("b", 1.1) + b2 := queue.Add("b", 2.1) + c2 := queue.Add("c", 1.2) + c3 := queue.Add("c", 2.2) + a3 := queue.Add("a", 5.0) + b3 := queue.Add("b", 6.1) + + queue.Cancel(b2) + queue.Cancel(b1) + + queue.Start(ctx, stopper) + + verifyOrder(t, queue, a3, b3, c3, a2, c2) + stopper.Stop(ctx) +} + +// TestMultiQueueStress calls Add from multiple threads. It chooses different +// names and different priorities for the requests. The goal is simply to make +// sure that all the requests are serviced and nothing hangs or fails. +func TestMultiQueueStress(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testutils.RunTrueAndFalse(t, "cancel", func(t *testing.T, alsoCancel bool) { + testutils.RunTrueAndFalse(t, "sleep", func(t *testing.T, alsoSleep bool) { + + ctx := context.Background() + stopper := stop.NewStopper() + + queue := NewMultiQueue("test", 5) + queue.Start(ctx, stopper) + + numThreads := 10 + numRequests := 1000 + var wg sync.WaitGroup + wg.Add(numThreads) + var ops int64 + var timeCancels int64 + + for i := 0; i < numThreads; i++ { + go func(name string) { + for j := 0; j < numRequests; j++ { + curTask := queue.Add(name, float64(j)) + if alsoCancel && j%99 == 0 { + queue.Cancel(curTask) + } else { + select { + case <-time.After(400 * time.Microsecond): + queue.Cancel(curTask) + atomic.AddInt64(&timeCancels, 1) + case p := <-curTask.GetWaitChan(): + if alsoSleep && j%10 == 0 { + // Sleep on 10% of requests to simulate doing work. + time.Sleep(200 * time.Microsecond) + } + queue.Release(p) + } + } + atomic.AddInt64(&ops, 1) + } + wg.Done() + }("queue" + fmt.Sprint(i%4)) + } + wg.Wait() + fmt.Printf("Num time cancels %d / %d\n", timeCancels, ops) + require.Equal(t, int64(numThreads*numRequests), ops) + stopper.Stop(ctx) + }) + }) +} + +// verifyOrder makes sure that the chans are called in the specified order. +func verifyOrder(t *testing.T, queue *MultiQueue, tasks ...*Task) { + // each time, verify that the only available channel is the "next" one in order + for i, task := range tasks { + var found Permit + testutils.SucceedsWithin(t, func() error { + for j, t2 := range tasks[i+1:] { + select { + case <-t2.GetWaitChan(): + return errors.Newf("Queue active when should not be iter %d, chan %d, task %v", i, i+1+j, t2) + default: + } + } + select { + case p := <-task.GetWaitChan(): + found = p + default: + return errors.Newf("Queue not active when should be Queue %d : task %v", i, task) + } + return nil + }, 2*time.Second) + queue.Release(found) + } +}