diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 82f03dca50a3..3f789be6631e 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -189,6 +189,7 @@ ALL_TESTS = [ "//pkg/kv/kvserver/intentresolver:intentresolver_test", "//pkg/kv/kvserver/liveness:liveness_test", "//pkg/kv/kvserver/loqrecovery:loqrecovery_test", + "//pkg/kv/kvserver/multiqueue:multiqueue_test", "//pkg/kv/kvserver/protectedts/ptcache:ptcache_test", "//pkg/kv/kvserver/protectedts/ptreconcile:ptreconcile_test", "//pkg/kv/kvserver/protectedts/ptstorage:ptstorage_test", @@ -1111,6 +1112,8 @@ GO_TARGETS = [ "//pkg/kv/kvserver/loqrecovery/loqrecoverypb:loqrecoverypb", "//pkg/kv/kvserver/loqrecovery:loqrecovery", "//pkg/kv/kvserver/loqrecovery:loqrecovery_test", + "//pkg/kv/kvserver/multiqueue:multiqueue", + "//pkg/kv/kvserver/multiqueue:multiqueue_test", "//pkg/kv/kvserver/protectedts/ptcache:ptcache", "//pkg/kv/kvserver/protectedts/ptcache:ptcache_test", "//pkg/kv/kvserver/protectedts/ptpb:ptpb", @@ -2361,6 +2364,7 @@ GET_X_DATA_TARGETS = [ "//pkg/kv/kvserver/liveness/livenesspb:get_x_data", "//pkg/kv/kvserver/loqrecovery:get_x_data", "//pkg/kv/kvserver/loqrecovery/loqrecoverypb:get_x_data", + "//pkg/kv/kvserver/multiqueue:get_x_data", "//pkg/kv/kvserver/protectedts:get_x_data", "//pkg/kv/kvserver/protectedts/ptcache:get_x_data", "//pkg/kv/kvserver/protectedts/ptpb:get_x_data", diff --git a/pkg/kv/kvserver/multiqueue/BUILD.bazel b/pkg/kv/kvserver/multiqueue/BUILD.bazel new file mode 100644 index 000000000000..874f6b68c0f8 --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/BUILD.bazel @@ -0,0 +1,27 @@ +load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "multiqueue", + srcs = ["multi_queue.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvserver/multiqueue", + visibility = ["//visibility:public"], + deps = [ + "//pkg/util/syncutil", + "@com_github_cockroachdb_redact//:redact", + ], +) + +go_test( + name = "multiqueue_test", + srcs = ["multi_queue_test.go"], + embed = [":multiqueue"], + deps = [ + "//pkg/testutils", + "//pkg/util/leaktest", + "//pkg/util/log", + "@com_github_stretchr_testify//require", + ], +) + +get_x_data(name = "get_x_data") diff --git a/pkg/kv/kvserver/multiqueue/multi_queue.go b/pkg/kv/kvserver/multiqueue/multi_queue.go new file mode 100644 index 000000000000..78cf6ef1c953 --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/multi_queue.go @@ -0,0 +1,222 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package multiqueue + +import ( + "container/heap" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/redact" +) + +// Task represents a request for a Permit for a piece of work that needs to be +// done. It is created by a call to MultiQueue.Add. After creation, +// Task.GetWaitChan is called to get a permit, and after all work related to +// this task is done, MultiQueue.Release must be called so future tasks can run. +// Alternatively, if the user decides they no longer want to run their work, +// MultiQueue.Cancel can be called to release the permit without waiting for the +// permit. +type Task struct { + priority float64 + queueName int + heapIdx int + permitC chan Permit +} + +// GetWaitChan returns a permit channel which is used to wait for the permit to +// become available. +func (t *Task) GetWaitChan() <-chan Permit { + return t.permitC +} + +func (t *Task) String() string { + return redact.Sprintf("{Queue name : %d, Priority :%f}", t.queueName, t.priority).StripMarkers() +} + +// notifyHeap is a standard go heap over tasks. +type notifyHeap []*Task + +func (h notifyHeap) Len() int { + return len(h) +} + +func (h notifyHeap) Less(i, j int) bool { + return h[j].priority < h[i].priority +} + +func (h notifyHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].heapIdx = i + h[j].heapIdx = j +} + +func (h *notifyHeap) Push(x interface{}) { + t := x.(*Task) + // Set the index to the end, it will be moved later + t.heapIdx = h.Len() + *h = append(*h, t) +} + +func (h *notifyHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[0 : n-1] + // No longer in the heap so clear the index + x.heapIdx = -1 + + return x +} + +// tryRemove attempts to remove the task from this queue by iterating through +// the queue. Will returns true if the task was successfully removed. +func (h *notifyHeap) tryRemove(task *Task) bool { + if task.heapIdx < 0 { + return false + } + heap.Remove(h, task.heapIdx) + return true +} + +// MultiQueue is a type that round-robins through a set of named queues, each +// independently prioritized. A MultiQueue is constructed with a concurrencySem +// which is the number of concurrent jobs this queue will allow to run. Tasks +// are added to the queue using MultiQueue.Add. That will return a channel that +// should be received from. It will be notified when the waiting job is ready to +// be run. Once the job is completed, MultiQueue.TaskDone must be called to +// return the Permit to the queue so that the next Task can be started. +type MultiQueue struct { + mu syncutil.Mutex + remainingRuns int + mapping map[int]int + lastQueueIndex int + outstanding []notifyHeap +} + +// NewMultiQueue creates a new queue. The queue is not started, and start needs +// to be called on it first. +func NewMultiQueue(maxConcurrency int) *MultiQueue { + queue := MultiQueue{ + remainingRuns: maxConcurrency, + mapping: make(map[int]int), + } + queue.lastQueueIndex = -1 + + return &queue +} + +// Permit is a token which is returned from a Task.GetWaitChan call. +type Permit struct { + valid bool +} + +// tryRunNext will run the next task in order round-robin through the queues and in +// priority order within a queue. It will return true if it ran a task. The +// MultiQueue.mu lock must be held before calling this func. +func (m *MultiQueue) tryRunNext() { + // If no permits are left, then we can't run anything. + if m.remainingRuns <= 0 { + return + } + + for i := 0; i < len(m.outstanding); i++ { + // Start with the next queue in order and iterate through all empty queues. + // If all queues are empty then return false signaling that nothing was run. + index := (m.lastQueueIndex + i + 1) % len(m.outstanding) + if m.outstanding[index].Len() > 0 { + task := heap.Pop(&m.outstanding[index]).(*Task) + task.permitC <- Permit{valid: true} + m.remainingRuns-- + m.lastQueueIndex = index + return + } + } +} + +// Add returns a Task that must be closed (calling Task.Close) to +// release the Permit. The number of names is expected to +// be relatively small and not be changing over time. +func (m *MultiQueue) Add(name int, priority float64) *Task { + m.mu.Lock() + defer m.mu.Unlock() + + // The mutex starts locked, unlock it when we are ready to run. + pos, ok := m.mapping[name] + if !ok { + // Append a new entry to both mapping and outstanding each time there is + // a new queue name. + pos = len(m.outstanding) + m.mapping[name] = pos + m.outstanding = append(m.outstanding, notifyHeap{}) + } + newTask := Task{ + priority: priority, + permitC: make(chan Permit, 1), + heapIdx: -1, + queueName: name, + } + heap.Push(&m.outstanding[pos], &newTask) + + // Once we are done adding a task, signal the main loop in case it finished + // all its work and was waiting for more work. We are holding the mu lock when + // signaling, so we guarantee that it will not be able to respond to the + // signal until after we release the lock. + m.tryRunNext() + + return &newTask +} + +// Cancel will cancel a Task that may not have started yet. This is useful if it +// is determined that it is no longer required to run this Task. +func (m *MultiQueue) Cancel(task *Task) { + m.mu.Lock() + defer m.mu.Unlock() + // Find the right queue and try to remove it. Queues monotonically grow, and a + // Task will track its position within the queue. + queueIdx := m.mapping[task.queueName] + ok := m.outstanding[queueIdx].tryRemove(task) + // If we get here, we are racing with the task being started. The concern is + // that the caller may also call MultiQueue.Release since the task was + // started. Either we get the permit or the caller, so we guarantee only one + // release will be called. + if !ok { + select { + case p, ok := <-task.permitC: + // Only release if the channel is open, and we can get the permit. + if ok { + p.valid = false + m.remainingRuns++ + } + default: + // If we are not able to get the permit, this means the permit has already + // been given to the caller, and they must call Release on it. + } + } + m.tryRunNext() +} + +// Release needs to be called once the Task that was running has completed and +// is no longer using system resources. This allows the MultiQueue to call the +// next Task. +func (m *MultiQueue) Release(permit *Permit) { + if !permit.valid { + panic("double release of permit") + } + permit.valid = false + + m.mu.Lock() + defer m.mu.Unlock() + + // We released one, so we can run one more now. + m.remainingRuns++ + m.tryRunNext() +} diff --git a/pkg/kv/kvserver/multiqueue/multi_queue_test.go b/pkg/kv/kvserver/multiqueue/multi_queue_test.go new file mode 100644 index 000000000000..375b1841f3e6 --- /dev/null +++ b/pkg/kv/kvserver/multiqueue/multi_queue_test.go @@ -0,0 +1,218 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package multiqueue + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +// TestMultiQueueEmpty makes sure that an empty queue can be created, started +// and stopped. +func TestMultiQueueEmpty(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + NewMultiQueue(1) +} + +// TestMultiQueueAddTwiceSameQueue makes sure that for a single queue the +// priority is respected. +func TestMultiQueueAddTwiceSameQueue(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + queue := NewMultiQueue(1) + blocker := queue.Add(0, 0) + + chan1 := queue.Add(7, 1.0) + chan2 := queue.Add(7, 2.0) + + permit := <-blocker.GetWaitChan() + queue.Release(&permit) + + // Verify chan2 is higher priority so runs first. + verifyOrder(t, queue, chan2, chan1) +} + +// TestMultiQueueTwoQueues checks that if requests are added to two queue names, +// they are called in a round-robin order. It also verifies that the priority is +// respected for each. +func TestMultiQueueTwoQueues(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + queue := NewMultiQueue(1) + blocker := queue.Add(0, 0) + + a1 := queue.Add(5, 4.0) + a2 := queue.Add(5, 5.0) + + b1 := queue.Add(6, 1.0) + b2 := queue.Add(6, 2.0) + + permit := <-blocker.GetWaitChan() + queue.Release(&permit) + // The queue starts with the "second" item added. + verifyOrder(t, queue, a2, b2, a1, b1) +} + +// TestMultiQueueComplex verifies that with multiple queues, some added before +// and some after we start running, that the final order is still as expected. +// The expectation is that it round robins through the queues (a, b, c, ...) and +// runs higher priority tasks before lower priority within a queue. +func TestMultiQueueComplex(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + queue := NewMultiQueue(1) + blocker := queue.Add(0, 0) + + a2 := queue.Add(1, 4.0) + b1 := queue.Add(2, 1.1) + b2 := queue.Add(2, 2.1) + c2 := queue.Add(3, 1.2) + c3 := queue.Add(3, 2.2) + a3 := queue.Add(1, 5.0) + b3 := queue.Add(2, 6.1) + + permit := <-blocker.GetWaitChan() + queue.Release(&permit) + verifyOrder(t, queue, a3, b3, c3, a2, b2, c2, b1) +} + +func TestMultiQueueRemove(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + queue := NewMultiQueue(1) + blocker := queue.Add(0, 0) + + a2 := queue.Add(1, 4.0) + b1 := queue.Add(2, 1.1) + b2 := queue.Add(2, 2.1) + c2 := queue.Add(3, 1.2) + c3 := queue.Add(3, 2.2) + a3 := queue.Add(1, 5.0) + b3 := queue.Add(2, 6.1) + + fmt.Println("Beginning cancel") + + queue.Cancel(b2) + queue.Cancel(b1) + + fmt.Println("Finished cancel") + + permit := <-blocker.GetWaitChan() + queue.Release(&permit) + verifyOrder(t, queue, a3, b3, c3, a2, c2) +} + +// TestMultiQueueStress calls Add from multiple threads. It chooses different +// names and different priorities for the requests. The goal is simply to make +// sure that all the requests are serviced and nothing hangs or fails. +func TestMultiQueueStress(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testutils.RunTrueAndFalse(t, "cancel", func(t *testing.T, alsoCancel bool) { + testutils.RunTrueAndFalse(t, "sleep", func(t *testing.T, alsoSleep bool) { + + queue := NewMultiQueue(5) + + numThreads := 10 + numRequests := 500 + var wg sync.WaitGroup + wg.Add(numThreads) + var ops int64 + var timeCancels int64 + + for i := 0; i < numThreads; i++ { + go func(name int) { + for j := 0; j < numRequests; j++ { + curTask := queue.Add(name, float64(j)) + if alsoCancel && j%99 == 0 { + queue.Cancel(curTask) + } else { + select { + case <-time.After(400 * time.Microsecond): + queue.Cancel(curTask) + atomic.AddInt64(&timeCancels, 1) + case p := <-curTask.GetWaitChan(): + if alsoSleep && j%10 == 0 { + // Sleep on 10% of requests to simulate doing work. + time.Sleep(200 * time.Microsecond) + } + queue.Release(&p) + } + } + atomic.AddInt64(&ops, 1) + } + wg.Done() + }(i % 4) + } + wg.Wait() + fmt.Printf("Num time cancels %d / %d\n", timeCancels, ops) + require.Equal(t, int64(numThreads*numRequests), ops) + }) + }) +} + +func TestMultiQueueReleaseTwice(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + queue := NewMultiQueue(1) + + task := queue.Add(1, 1) + p := <-task.GetWaitChan() + queue.Release(&p) + require.Panics(t, func() { queue.Release(&p) }) +} + +func TestMultiQueueReleaseAfterCancel(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + queue := NewMultiQueue(1) + + task := queue.Add(1, 1) + p := <-task.GetWaitChan() + queue.Cancel(task) + queue.Release(&p) +} + +// verifyOrder makes sure that the chans are called in the specified order. +func verifyOrder(t *testing.T, queue *MultiQueue, tasks ...*Task) { + // each time, verify that the only available channel is the "next" one in order + for i, task := range tasks { + var found Permit + for j, t2 := range tasks[i+1:] { + select { + case <-t2.GetWaitChan(): + require.Fail(t, "Queue active when should not be ", "iter %d, chan %d, task %v", i, i+1+j, t2) + default: + } + } + select { + case p := <-task.GetWaitChan(): + found = p + default: + require.Fail(t, "Queue not active when should be ", "Queue %d : task %v", i, task) + } + queue.Release(&found) + } +}