diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer.go b/pkg/ccl/sqlproxyccl/balancer/balancer.go index 6bb62888095b..2eaad4a51696 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer.go @@ -37,6 +37,28 @@ const ( // DRAINING state before the proxy starts moving connections away from it. minDrainPeriod = 1 * time.Minute + // rebalancePercentDeviation defines the percentage threshold that the + // current number of assignments can deviate away from the mean. Having a + // 15% "deadzone" reduces frequent transfers especially when load is + // fluctuating. + // + // For example, if the percent deviation is 0.15, and mean is 10, the + // number of assignments for every pod has to be between [8, 12] to be + // considered balanced. + // + // NOTE: This must be between 0 and 1 inclusive. + rebalancePercentDeviation = 0.15 + + // rebalanceRate defines the rate of rebalancing assignments across SQL + // pods. This rate applies to both RUNNING and DRAINING pods. For example, + // consider the case where the rate is 0.50; if we have decided that we need + // to move 15 assignments away from a particular pod, only 7 pods will be + // moved at a time. + // + // NOTE: This must be between 0 and 1 inclusive. 0 means no rebalancing + // will occur. + rebalanceRate = 0.50 + // defaultMaxConcurrentRebalances represents the maximum number of // concurrent rebalance requests that are being processed. This effectively // limits the number of concurrent transfers per proxy. @@ -296,9 +318,9 @@ func (b *Balancer) rebalance(ctx context.Context) { continue } - // Build a podMap so we could easily retrieve the pod by address. + // Construct a map so we could easily retrieve the pod by address. podMap := make(map[string]*tenant.Pod) - hasRunningPod := false + var hasRunningPod bool for _, pod := range tenantPods { podMap[pod.Addr] = pod @@ -316,45 +338,185 @@ func (b *Balancer) rebalance(ctx context.Context) { continue } - connMap := b.connTracker.GetConnsMap(tenantID) - for addr, podConns := range connMap { - pod, ok := podMap[addr] - if !ok { - // We have a connection to the pod, but the pod is not in the - // directory cache. This race case happens if the connection - // was transferred by a different goroutine to this new pod - // right after we fetch the list of pods from the directory - // cache above. Ignore here, and this connection will be handled - // on the next rebalance loop. - continue - } + activeList, idleList := b.connTracker.listAssignments(tenantID) + b.rebalancePartition(podMap, activeList) + b.rebalancePartition(podMap, idleList) + } +} - // Transfer all connections in DRAINING pods. - // - // TODO(jaylim-crl): Consider extracting this logic for the DRAINING - // case into a separate function once we add the rebalancing logic. - if pod.State != tenant.DRAINING { - continue - } +// rebalancePartition rebalances the given assignments partition. +func (b *Balancer) rebalancePartition( + pods map[string]*tenant.Pod, assignments []*ServerAssignment, +) { + // Nothing to do here. + if len(pods) == 0 || len(assignments) == 0 { + return + } - // Only move connections for pods which have been draining for - // at least 1 minute. When load is fluctuating, the pod may - // transition back and forth between the DRAINING and RUNNING - // states. This check prevents us from moving connections around - // when that happens. - drainingFor := b.timeSource.Now().Sub(pod.StateTimestamp) - if drainingFor < minDrainPeriod { - continue - } + // Transfer assignments away if the partition is in an imbalanced state. + toMove := collectRunningPodAssignments(pods, assignments, rebalancePercentDeviation) + b.enqueueRebalanceRequests(toMove) - for _, c := range podConns { - b.queue.enqueue(&rebalanceRequest{ - createdAt: b.timeSource.Now(), - conn: c, - }) - } + // Move all assignments away from DRAINING pods if and only if the pods have + // been draining for at least minDrainPeriod. + toMove = collectDrainingPodAssignments(pods, assignments, b.timeSource) + b.enqueueRebalanceRequests(toMove) +} + +// enqueueRebalanceRequests enqueues the first N server assignments for a +// transfer operation based on the defined rebalance rate. For example, if +// there are 10 server assignments in the input list, and rebalance rate is 0.4, +// only the first four server assignments will be enqueued for a transfer. +func (b *Balancer) enqueueRebalanceRequests(list []*ServerAssignment) { + toMoveCount := int(math.Ceil(float64(len(list)) * float64(rebalanceRate))) + for i := 0; i < toMoveCount; i++ { + b.queue.enqueue(&rebalanceRequest{ + createdAt: b.timeSource.Now(), + conn: list[i].Owner(), + }) + } +} + +// collectRunningPodAssignments returns a set of ServerAssignments that have to +// be moved because the partition is in an imbalanced state. Only assignments to +// RUNNING pods will be accounted for. +// +// NOTE: pods should not be nil, and percentDeviation must be between [0, 1]. +func collectRunningPodAssignments( + pods map[string]*tenant.Pod, partition []*ServerAssignment, percentDeviation float64, +) []*ServerAssignment { + // Construct a distribution map of server assignments. + numAssignments := 0 + distribution := make(map[string][]*ServerAssignment) + for _, a := range partition { + pod, ok := pods[a.Addr()] + if !ok || pod.State != tenant.RUNNING { + // We have a connection to the pod, but the pod is not in the + // directory cache. This race case happens if the connection was + // transferred by a different goroutine to this new pod right after + // we fetch the list of pods from the directory cache. Ignore here, + // and this connection will be handled on the next rebalance loop. + continue } + distribution[a.Addr()] = append(distribution[a.Addr()], a) + numAssignments++ + } + + // Ensure that all RUNNING pods have an entry in distribution. Doing that + // allows us to account for new or underutilized pods. + for _, pod := range pods { + if pod.State != tenant.RUNNING { + continue + } + if _, ok := distribution[pod.Addr]; !ok { + distribution[pod.Addr] = []*ServerAssignment{} + } + } + + // No pods or assignments to work with. + if len(distribution) == 0 || numAssignments == 0 { + return nil } + + // Calculate average number of assignments, and lower/upper bounds based + // on the rebalance percent deviation. We want to ensure that the number + // of assignments on each pod is within [lowerBound, upperBound]. If all + // of the pods are within that interval, the partition is considered to be + // balanced. + // + // Note that lowerBound cannot be 0, or else the addition of a new pod with + // no connections may still result in a balanced state. + avgAssignments := float64(numAssignments) / float64(len(distribution)) + lowerBound := int(math.Max(1, math.Floor(avgAssignments*(1-percentDeviation)))) + upperBound := int(math.Ceil(avgAssignments * (1 + percentDeviation))) + + // Construct a set of assignments that we want to move, and the algorithm to + // do so would be as follows: + // 1. Compute the number of assignments that we need to move. This would + // be X = MAX(n, m), where: + // n = total number of assignments that exceed the upper bound + // m = total number of assignments that fall short of lower bound + // + // 2. First pass on distribution: collect assignments that exceed the + // upper bound. Update distribution and X to reflect the remaining + // assignments accordingly. + // + // 3. Second pass on distribution: greedily collect as many assignments + // up to X without violating the average. We could theoretically + // minimize the deviation from the mean by collecting from pods + // starting with the ones with the largest number of assignments, + // but this would require a sort. + // + // The implementation below is an optimization of the algorithm described + // above, where steps 1 and 2 are combined. We will also start simple by + // omitting the sort in (3). + + // Steps 1 and 2. + missingCount := 0 + var toMove []*ServerAssignment + for addr, d := range distribution { + missingCount += int(math.Max(float64(lowerBound-len(d)), 0.0)) + + // Move everything that exceed the upper bound. + excess := len(d) - upperBound + if excess > 0 { + toMove = append(toMove, d[:excess]...) + distribution[addr] = d[excess:] + missingCount -= excess + } + } + + // Step 3. + for addr, d := range distribution { + if missingCount <= 0 { + break + } + extra := len(d) - int(avgAssignments) + if extra <= 0 || len(d) <= 1 { + // Check length in second condition here to ensure that we don't + // remove connections resulting in 0 assignments to that pod. + continue + } + excess := int(math.Min(float64(extra), float64(missingCount))) + missingCount -= excess + toMove = append(toMove, d[:excess]...) + distribution[addr] = d[excess:] + } + + return toMove +} + +// collectDrainingPodAssignments returns a set of ServerAssignments that have to +// be moved because the pods that they are in have been draining for at least +// minDrainPeriod. +// +// NOTE: pods and timeSource should not be nil. +func collectDrainingPodAssignments( + pods map[string]*tenant.Pod, partition []*ServerAssignment, timeSource timeutil.TimeSource, +) []*ServerAssignment { + var collected []*ServerAssignment + for _, a := range partition { + pod, ok := pods[a.Addr()] + if !ok || pod.State != tenant.DRAINING { + // We have a connection to the pod, but the pod is not in the + // directory cache. This race case happens if the connection was + // transferred by a different goroutine to this new pod right after + // we fetch the list of pods from the directory cache. Ignore here, + // and this connection will be handled on the next rebalance loop. + continue + } + + // Only move connections for pods which have been draining for at least + // 1 minute. When load is fluctuating, the pod may transition back and + // forth between the DRAINING and RUNNING states. This check prevents us + // from moving connections around when that happens. + drainingFor := timeSource.Now().Sub(pod.StateTimestamp) + if drainingFor < minDrainPeriod { + continue + } + collected = append(collected, a) + } + return collected } // rebalanceRequest corresponds to a rebalance request. @@ -397,6 +559,11 @@ func (q *rebalancerQueue) enqueue(req *rebalanceRequest) { q.mu.Lock() defer q.mu.Unlock() + // Test environments may create rebalanceRequests with nil owners. + if req.conn == nil { + return + } + e, ok := q.elements[req.conn] if ok { // Use the newer request of the two. diff --git a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go index d6ab2eb8114f..165e70570350 100644 --- a/pkg/ccl/sqlproxyccl/balancer/balancer_test.go +++ b/pkg/ccl/sqlproxyccl/balancer/balancer_test.go @@ -10,6 +10,7 @@ package balancer import ( "context" + "fmt" "reflect" "sync" "sync/atomic" @@ -20,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" @@ -262,42 +264,87 @@ func TestRebalancer_rebalanceLoop(t *testing.T) { ) require.NoError(t, err) - pods := []*tenant.Pod{ - {TenantID: 30, Addr: "127.0.0.30:80", State: tenant.DRAINING}, - {TenantID: 30, Addr: "127.0.0.30:81", State: tenant.RUNNING}, + tenantID := roachpb.MakeTenantID(30) + drainingPod := &tenant.Pod{TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:80", State: tenant.DRAINING} + require.True(t, directoryCache.upsertPod(drainingPod)) + runningPods := []*tenant.Pod{ + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:81", State: tenant.RUNNING}, + {TenantID: tenantID.ToUint64(), Addr: "127.0.0.30:82", State: tenant.RUNNING}, } - for _, pod := range pods { + for _, pod := range runningPods { require.True(t, directoryCache.upsertPod(pod)) } - // Manually assign a pod to the tracker in the balancer. - h := &testConnHandle{ - onTransferConnection: func() error { - return nil - }, + // Create new server assignments. + // - 1 to drainingPod + // - 9 to runningPods[1] + // + // We expect the connection to drainingPod to be moved away because the pod + // is draining. At the same time, 4 connections should be moved away from + // runningPods[1] with mean interval of [4, 5]. + var mu syncutil.Mutex + var assignments []*ServerAssignment + var makeTestConnHandle func(idx int) *testConnHandle + makeTestConnHandle = func(idx int) *testConnHandle { + var handle *testConnHandle + handle = &testConnHandle{ + onTransferConnection: func() error { + mu.Lock() + defer mu.Unlock() + + // Already moved earlier. + if assignments[idx].Owner() != handle { + return nil + } + assignments[idx].Close() + + pod := selectTenantPod(runningPods, b.connTracker.getEntry(tenantID, false)) + require.NotNil(t, pod) + assignments[idx] = NewServerAssignment( + tenantID, b.connTracker, makeTestConnHandle(idx), pod.Addr, + ) + return nil + }, + } + return handle + } + assignments = append( + assignments, + NewServerAssignment(tenantID, b.connTracker, makeTestConnHandle(0), drainingPod.Addr), + ) + for i := 1; i < 10; i++ { + assignments = append( + assignments, + NewServerAssignment(tenantID, b.connTracker, makeTestConnHandle(i), runningPods[1].Addr), + ) } - b.connTracker.registerAssignment(roachpb.MakeTenantID(30), &ServerAssignment{ - addr: pods[0].Addr, - owner: h, - }) - // Wait until rebalance queue gets processed. - runs := 0 + // Wait until the rebalance queue gets processed. testutils.SucceedsSoon(t, func() error { - runs++ - timeSource.Advance(rebalanceInterval) - count := h.transferConnectionCount() - if count >= 3 && runs >= count { - return nil + activeList, _ := b.connTracker.listAssignments(tenantID) + distribution := make(map[string]int) + total := 0 + for _, sa := range activeList { + distribution[sa.Addr()]++ + total++ + } + for _, val := range distribution { + if val > 6 || val < 4 { + return errors.Newf("expected count to be between [4, 6]") + } + } + if total != 10 { + return errors.Newf("should have 10 assignments, but got %d", total) } - return errors.Newf("insufficient runs, expected >= 3, but got %d", count) + return nil }) } func TestRebalancer_rebalance(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) ctx := context.Background() stopper := stop.NewStopper() @@ -326,6 +373,7 @@ func TestRebalancer_rebalance(t *testing.T) { // - tenant-30: two draining pods (one with < 1m), one running pod // - tenant-40: one draining pod, one running pod // - tenant-50: one running pod + // - tenant-60: three running pods recentlyDrainedPod := &tenant.Pod{ TenantID: 30, Addr: "127.0.0.30:81", @@ -339,6 +387,9 @@ func TestRebalancer_rebalance(t *testing.T) { {TenantID: 40, Addr: "127.0.0.40:80", State: tenant.DRAINING}, {TenantID: 40, Addr: "127.0.0.40:81", State: tenant.RUNNING}, {TenantID: 50, Addr: "127.0.0.50:80", State: tenant.RUNNING}, + {TenantID: 60, Addr: "127.0.0.60:80", State: tenant.RUNNING}, + {TenantID: 60, Addr: "127.0.0.60:81", State: tenant.RUNNING}, + {TenantID: 60, Addr: "127.0.0.60:82", State: tenant.RUNNING}, } // reset recreates the directory cache. @@ -356,20 +407,12 @@ func TestRebalancer_rebalance(t *testing.T) { } } - // makeHandle returns a handle that doesn't panic when TransferConnection - // is called. - makeHandle := func() *testConnHandle { - return &testConnHandle{ - onTransferConnection: func() error { - return nil - }, - } - } - for _, tc := range []struct { - name string - handlesFn func(t *testing.T) []ConnectionHandle - expectedCounts []int + name string + handlesFn func(t *testing.T) []ConnectionHandle + preRebalanceFn func(t *testing.T) + expectedCounts []int + expectedCountsMatcherFn func(handles []ConnectionHandle) error }{ { // This case should not occur unless there's a bug in the directory @@ -381,7 +424,7 @@ func TestRebalancer_rebalance(t *testing.T) { // Use a random IP since tenant-10 doesn't have a pod, and it // does not matter. - handle := makeHandle() + handle := makeTestHandle() sa := NewServerAssignment(tenant10, b.connTracker, handle, "foobarip") handle.onClose = sa.Close return []ConnectionHandle{handle} @@ -395,68 +438,70 @@ func TestRebalancer_rebalance(t *testing.T) { handlesFn: func(t *testing.T) []ConnectionHandle { tenant20 := roachpb.MakeTenantID(20) - handle := makeHandle() + handle := makeTestHandle() sa := NewServerAssignment(tenant20, b.connTracker, handle, pods[0].Addr) handle.onClose = sa.Close return []ConnectionHandle{handle} }, expectedCounts: []int{0}, }, - { - // If the connection has been closed, we shouldn't bother initiating - // a transfer. Use tenant-30's DRAINING pod here. - name: "connection closed", - handlesFn: func(t *testing.T) []ConnectionHandle { - tenant30 := roachpb.MakeTenantID(30) - cancelledCtx, cancel := context.WithCancel(context.Background()) - cancel() - - handle := makeHandle() - handle.ctx = cancelledCtx - sa := NewServerAssignment(tenant30, b.connTracker, handle, pods[1].Addr) - handle.onClose = sa.Close - return []ConnectionHandle{handle} - }, - expectedCounts: []int{0}, - }, { // Use tenant-30's recently drained pod. We shouldn't transfer // because minDrainPeriod hasn't elapsed. - name: "recently drained pod", + name: "draining/recently drained pod", handlesFn: func(t *testing.T) []ConnectionHandle { tenant30 := roachpb.MakeTenantID(30) - handle := makeHandle() - sa := NewServerAssignment(tenant30, b.connTracker, handle, recentlyDrainedPod.Addr) - handle.onClose = sa.Close - return []ConnectionHandle{handle} + activeHandle := makeTestHandle() + sa := NewServerAssignment(tenant30, b.connTracker, activeHandle, recentlyDrainedPod.Addr) + activeHandle.onClose = sa.Close + + idleHandle := makeTestHandle() + sa = NewServerAssignment(tenant30, b.connTracker, idleHandle, recentlyDrainedPod.Addr) + idleHandle.onClose = sa.Close + idleHandle.setIdle(true) + + // Refresh partitions, and validate idle connection. + e30 := b.connTracker.getEntry(tenant30, false) + e30.refreshPartitions() + _, idleList30 := e30.listAssignments() + require.Len(t, idleList30, 1) + + return []ConnectionHandle{activeHandle, idleHandle} }, - expectedCounts: []int{0}, + expectedCounts: []int{0, 0}, }, { - name: "multiple connections", + name: "draining/multiple connections", handlesFn: func(t *testing.T) []ConnectionHandle { conns := []*tenant.Pod{ + // Active connections + // ------------------ // Connection on tenant with single draining pod. Should // not transfer because nothing to transfer to. pods[0], - // Connections to draining pod (>= 1m). - pods[1], + // Connection to draining pod (>= 1m). pods[1], // Connections to recently drained pod. recentlyDrainedPod, recentlyDrainedPod, // Connection to running pod. Nothing happens. pods[3], - // Connections to draining pod (>= 1m). + // Connection to draining pod (>= 1m). pods[4], // Connections to running pods. Nothing happens. pods[5], pods[6], + // Idle connections + // ---------------- + // Connection to draining pod (>= 1m). + pods[1], + // Connection to draining pod (>= 1m). + pods[4], } var handles []ConnectionHandle for _, c := range conns { - handle := makeHandle() + handle := makeTestHandle() sa := NewServerAssignment( roachpb.MakeTenantID(c.TenantID), b.connTracker, @@ -466,9 +511,174 @@ func TestRebalancer_rebalance(t *testing.T) { handle.onClose = sa.Close handles = append(handles, handle) } + // The last two are idle connections. + handles[len(handles)-1].(*testConnHandle).setIdle(true) + handles[len(handles)-2].(*testConnHandle).setIdle(true) + + // Refresh partitions, and validate idle connections. + e30 := b.connTracker.getEntry(roachpb.MakeTenantID(30), false) + e40 := b.connTracker.getEntry(roachpb.MakeTenantID(40), false) + e30.refreshPartitions() + e40.refreshPartitions() + _, idleList30 := e30.listAssignments() + require.Len(t, idleList30, 1) + _, idleList40 := e40.listAssignments() + require.Len(t, idleList40, 1) + return handles }, - expectedCounts: []int{0, 1, 1, 0, 0, 0, 1, 0, 0}, + expectedCounts: []int{0, 1, 0, 0, 0, 1, 0, 0, 1, 1}, + }, + { + name: "running/multiple connections", + handlesFn: func(t *testing.T) []ConnectionHandle { + // Create 100 active connections: 53->pod[7], 0->pod[8], 47->pod[9]. + var conns []*tenant.Pod + for i := 0; i < 53; i++ { + conns = append(conns, pods[7]) + } + for i := 0; i < 47; i++ { + conns = append(conns, pods[9]) + } + // Add another 30 idle connections: 20->pod[7], 1->pod[8], 9->pod[9]. + for i := 0; i < 20; i++ { + conns = append(conns, pods[7]) + } + conns = append(conns, pods[8]) + for i := 0; i < 9; i++ { + conns = append(conns, pods[9]) + } + var handles []ConnectionHandle + for _, c := range conns { + handle := makeTestHandle() + sa := NewServerAssignment( + roachpb.MakeTenantID(c.TenantID), + b.connTracker, + handle, + c.Addr, + ) + handle.onClose = sa.Close + handles = append(handles, handle) + } + for i := 0; i < 30; i++ { + handles[len(handles)-i-1].(*testConnHandle).setIdle(true) + } + + // Refresh partitions, and validate idle connection. + e60 := b.connTracker.getEntry(roachpb.MakeTenantID(60), false) + e60.refreshPartitions() + _, idleList60 := e60.listAssignments() + require.Len(t, idleList60, 30) + + return handles + }, + expectedCountsMatcherFn: func(handles []ConnectionHandle) error { + // Active connections + // ------------------ + // Average = 33.33, mean interval with 15% deadzone is [28, 39]. + // Expect 53-39=14 from pod[7] and 47-39=8 from pod[9]. Since + // missingCount = 28 > (14+8), 28 connections should be moved. + // Taking a 50% rebalancing rate into account, we have 14 in + // total. Since we cannot guarantee ordering, so we will just + // count here. The actual logic is already unit tested in + // TestCollectRunningPodAssignments. + count := 0 + for i := 0; i < 100; i++ { + count += handles[i].(*testConnHandle).transferConnectionCount() + } + if count != 14 { + return errors.Newf("require 14, but got %v", count) + } + // Idle connections + // ---------------- + // Average = 10, mean interval with 15% deadzone is [8, 12]. + // Exceed 8 from pod[7] and short 7 from pod[8]. Taking the + // greater of the two, we will transfer 8 connections in total. + // Half that to account for rebalancing rate. + count = 0 + for i := 0; i < 30; i++ { + count += handles[i+100].(*testConnHandle).transferConnectionCount() + } + if count != 4 { + return errors.Newf("require 4, but got %v", count) + } + return nil + }, + }, + { + name: "both active and idle connections", + handlesFn: func(t *testing.T) []ConnectionHandle { + conns := []*tenant.Pod{ + // Active connections + // ------------------ + // Connection to draining pod (>= 1m). + pods[1], + // Connections to running pods. Move 2 away. With rebalance + // rate of 50%, move 1. + pods[7], + pods[7], + pods[7], + // Idle connections + // ---------------- + // Connection to draining pod (>= 1m). + pods[4], + // Connections to running pods. Move 1 away. Rebalance rate + // does not apply. + pods[8], + pods[8], + } + var handles []ConnectionHandle + for _, c := range conns { + handle := makeTestHandle() + sa := NewServerAssignment( + roachpb.MakeTenantID(c.TenantID), + b.connTracker, + handle, + c.Addr, + ) + handle.onClose = sa.Close + handles = append(handles, handle) + } + // The last three are idle connections. + handles[len(handles)-1].(*testConnHandle).setIdle(true) + handles[len(handles)-2].(*testConnHandle).setIdle(true) + handles[len(handles)-3].(*testConnHandle).setIdle(true) + + // Refresh partitions, and validate idle connections. + e40 := b.connTracker.getEntry(roachpb.MakeTenantID(40), false) + e60 := b.connTracker.getEntry(roachpb.MakeTenantID(60), false) + e40.refreshPartitions() + e60.refreshPartitions() + _, idleList40 := e40.listAssignments() + require.Len(t, idleList40, 1) + _, idleList60 := e60.listAssignments() + require.Len(t, idleList60, 2) + + return handles + }, + expectedCountsMatcherFn: func(handles []ConnectionHandle) error { + // Active connections + // ------------------ + count := 0 + for i := 0; i < 4; i++ { + count += handles[i].(*testConnHandle).transferConnectionCount() + } + // 1 from draining, 1 from running. + if count != 2 { + return errors.Newf("require 2, but got %v", count) + } + // Idle connections + // ---------------- + count = 0 + for i := 0; i < 3; i++ { + count += handles[i+4].(*testConnHandle).transferConnectionCount() + } + // 1 from draining, 1 from running. + if count != 2 { + return errors.Newf("require 2, but got %v", count) + } + return nil + }, }, } { t.Run(tc.name, func(t *testing.T) { @@ -480,12 +690,18 @@ func TestRebalancer_rebalance(t *testing.T) { // Wait until rebalance queue gets processed. testutils.SucceedsSoon(t, func() error { - var counts []int - for _, h := range handles { - counts = append(counts, h.(*testConnHandle).transferConnectionCount()) - } - if !reflect.DeepEqual(tc.expectedCounts, counts) { - return errors.Newf("require %v, but got %v", tc.expectedCounts, counts) + if tc.expectedCountsMatcherFn != nil { + if err := tc.expectedCountsMatcherFn(handles); err != nil { + return err + } + } else { + var counts []int + for _, h := range handles { + counts = append(counts, h.(*testConnHandle).transferConnectionCount()) + } + if !reflect.DeepEqual(tc.expectedCounts, counts) { + return errors.Newf("require %v, but got %v", tc.expectedCounts, counts) + } } return nil }) @@ -498,6 +714,248 @@ func TestRebalancer_rebalance(t *testing.T) { } } +func TestEnqueueRebalanceRequests(t *testing.T) { + defer leaktest.AfterTest(t)() + + baseCtx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(baseCtx) + + // Use a custom time source for testing. + t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) + timeSource := timeutil.NewManualTime(t0) + + b, err := NewBalancer( + baseCtx, + stopper, + NewMetrics(), + nil, /* directoryCache */ + NoRebalanceLoop(), + TimeSource(timeSource), + ) + require.NoError(t, err) + + var list []*ServerAssignment + for i := 0; i < 15; i++ { + list = append(list, &ServerAssignment{owner: makeTestHandle()}) + } + b.enqueueRebalanceRequests(list) + + // Since rebalanceRate is 0.5, only the first ceil(7.5) = 8 will be + // transferred. + testutils.SucceedsSoon(t, func() error { + for i := 0; i < 8; i++ { + count := list[i].Owner().(*testConnHandle).transferConnectionCount() + if count != 1 { + return errors.Newf("pending count 1, but got %d", count) + } + } + for i := 0; i < 7; i++ { + count := list[i+8].Owner().(*testConnHandle).transferConnectionCount() + if count != 0 { + return errors.Newf("pending count 0, but got %d", count) + } + } + return nil + }) +} + +func TestCollectRunningPodAssignments(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("no pods", func(t *testing.T) { + require.Nil(t, collectRunningPodAssignments( + map[string]*tenant.Pod{}, + []*ServerAssignment{{addr: "1"}, {addr: "2"}}, + 0, + )) + }) + + t.Run("no assignments", func(t *testing.T) { + require.Nil(t, collectRunningPodAssignments( + map[string]*tenant.Pod{"1": {State: tenant.RUNNING}}, + nil, + 0, + )) + }) + + for _, tc := range []struct { + name string + percentDeviation float64 + pods map[string]*tenant.Pod + partitionDistribution []int + expectedSetDistributionAny []map[string]int + }{ + { + name: "balanced partition", + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.DRAINING}, + "3": {State: tenant.RUNNING}, + }, + // [1, 2] are bounds. Draining pod isn't included, even if it has + // many assignments. Partition is already balanced. + partitionDistribution: []int{2, 4, 1}, + expectedSetDistributionAny: []map[string]int{{}}, + }, + { + name: "multiple new pods", + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.RUNNING}, + "3": {State: tenant.RUNNING}, + }, + // [1, 2] are bounds. New pods have no assignments (underutilized). + partitionDistribution: []int{3, 0, 0}, + expectedSetDistributionAny: []map[string]int{{"1": 2}}, + }, + { + name: "single new pod", + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.RUNNING}, + "3": {State: tenant.RUNNING}, + "4": {State: tenant.RUNNING}, + "5": {State: tenant.RUNNING}, + "6": {State: tenant.RUNNING}, + }, + // [1, 2] are bounds. New pod has no assignments (underutilized). + partitionDistribution: []int{1, 1, 2, 2, 2, 0}, + expectedSetDistributionAny: []map[string]int{{"3": 1}, {"4": 1}, {"5": 1}}, + }, + { + name: "more overloaded pods", // Compared to underloaded ones. + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.RUNNING}, + "3": {State: tenant.RUNNING}, + "4": {State: tenant.RUNNING}, + "5": {State: tenant.RUNNING}, + "6": {State: tenant.RUNNING}, + }, + // [1, 3] are bounds. Two overloaded pods. + partitionDistribution: []int{1, 4, 1, 3, 1, 4}, + expectedSetDistributionAny: []map[string]int{{"2": 1, "6": 1}}, + }, + { + name: "more underloaded pods", // Compared to overloaded ones. + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.RUNNING}, + "3": {State: tenant.RUNNING}, + "4": {State: tenant.RUNNING}, + "5": {State: tenant.RUNNING}, + }, + percentDeviation: 0.8, + // [2, 27] are bounds. Exceed by 0, but short by 2+2=4. Greedily + // pick the remaining 4 from pod 2. + partitionDistribution: []int{0, 25, 25, 25, 0}, + expectedSetDistributionAny: []map[string]int{{"2": 4}, {"3": 4}, {"4": 4}}, + }, + { + name: "equally imbalanced", + pods: map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.RUNNING}, + "3": {State: tenant.RUNNING}, + "4": {State: tenant.RUNNING}, + "5": {State: tenant.RUNNING}, + }, + // [6, 9] are bounds. Both exceed=short=18. + partitionDistribution: []int{0, 15, 0, 21, 0}, + expectedSetDistributionAny: []map[string]int{{"2": 6, "4": 12}}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if tc.percentDeviation == 0 { + tc.percentDeviation = 0.15 + } + + // Ensure that every pod has an address. + for addr, pod := range tc.pods { + pod.Addr = addr + } + + // Construct partition based on partition distribution. + var partition []*ServerAssignment + for i, count := range tc.partitionDistribution { + for j := 0; j < count; j++ { + partition = append( + partition, + &ServerAssignment{addr: fmt.Sprintf("%d", i+1)}, + ) + } + } + + set := collectRunningPodAssignments(tc.pods, partition, tc.percentDeviation) + setDistribution := make(map[string]int) + for _, a := range set { + setDistribution[a.Addr()]++ + } + + // Match one of the set distributions. There are multiple choices + // here because map iteration is non deterministic. + matched := false + for _, expected := range tc.expectedSetDistributionAny { + if reflect.DeepEqual(expected, setDistribution) { + matched = true + break + } + } + require.True(t, matched, "could not match expected set distribution") + }) + } +} + +func TestCollectDrainingPodAssignments(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("no pods", func(t *testing.T) { + set := collectDrainingPodAssignments( + map[string]*tenant.Pod{}, + []*ServerAssignment{{addr: "1"}, {addr: "2"}}, + nil, + ) + require.Nil(t, set) + }) + + t.Run("with pods", func(t *testing.T) { + // Use a custom time source for testing. + t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) + timeSource := timeutil.NewManualTime(t0) + + // Pod 3 just transitioned into the DRAINING state. + pods := map[string]*tenant.Pod{ + "1": {State: tenant.RUNNING}, + "2": {State: tenant.DRAINING, StateTimestamp: timeSource.Now().Add(-minDrainPeriod)}, // 1m + "3": {State: tenant.DRAINING, StateTimestamp: timeSource.Now()}, // 0s + "4": {State: tenant.RUNNING}, + "5": {State: tenant.DRAINING, StateTimestamp: timeSource.Now().Add(-minDrainPeriod).Add(-1 * time.Second)}, // 1m1s + "6": {State: tenant.DRAINING, StateTimestamp: timeSource.Now().Add(-minDrainPeriod).Add(1 * time.Second)}, // 59s + } + + // Create 3 assignments per pod, in addition to a non-existent pod 7. + var partition []*ServerAssignment + for i := 1; i <= 7; i++ { + for count := 0; count < 3; count++ { + partition = append(partition, &ServerAssignment{addr: fmt.Sprintf("%d", i)}) + } + } + + // Empty partition. + set := collectDrainingPodAssignments(pods, nil, timeSource) + require.Nil(t, set) + + // Actual partition. + set = collectDrainingPodAssignments(pods, partition, timeSource) + distribution := make(map[string]int) + for _, a := range set { + distribution[a.Addr()]++ + } + require.Equal(t, map[string]int{"2": 3, "5": 3}, distribution) + }) +} + func TestRebalancerQueue(t *testing.T) { defer leaktest.AfterTest(t)() @@ -666,3 +1124,11 @@ func (r *testDirectoryCache) upsertPod(pod *tenant.Pod) bool { r.mu.pods[tenantID] = append(r.mu.pods[tenantID], pod) return true } + +// makeTestHandle returns a test handle that doesn't panic when +// TransferConnection is called. +func makeTestHandle() *testConnHandle { + return &testConnHandle{ + onTransferConnection: func() error { return nil }, + } +} diff --git a/pkg/ccl/sqlproxyccl/balancer/conn_tracker.go b/pkg/ccl/sqlproxyccl/balancer/conn_tracker.go index 4092a6af8136..31d7a855522c 100644 --- a/pkg/ccl/sqlproxyccl/balancer/conn_tracker.go +++ b/pkg/ccl/sqlproxyccl/balancer/conn_tracker.go @@ -110,6 +110,18 @@ func (t *ConnTracker) getTenantIDs() []roachpb.TenantID { return tenants } +// listAssignments returns a snapshot of both the active and idle partitions +// that contain ServerAssignment instances for the given tenant. +func (t *ConnTracker) listAssignments( + tenantID roachpb.TenantID, +) (activeList, idleList []*ServerAssignment) { + e := t.getEntry(tenantID, false /* allowCreate */) + if e == nil { + return nil, nil + } + return e.listAssignments() +} + // getEntry retrieves the tenantEntry instance for the given tenant. If // allowCreate is set to false, getEntry returns nil if the entry does not // exist for the given tenant. On the other hand, if allowCreate is set to diff --git a/pkg/ccl/sqlproxyccl/balancer/conn_tracker_test.go b/pkg/ccl/sqlproxyccl/balancer/conn_tracker_test.go index db175e7e1fe6..310d0af078b1 100644 --- a/pkg/ccl/sqlproxyccl/balancer/conn_tracker_test.go +++ b/pkg/ccl/sqlproxyccl/balancer/conn_tracker_test.go @@ -37,14 +37,17 @@ func TestConnTracker(t *testing.T) { tracker, err := NewConnTracker(ctx, stopper, nil /* timeSource */) require.NoError(t, err) - tenantID := roachpb.MakeTenantID(20) + tenant20 := roachpb.MakeTenantID(20) sa := &ServerAssignment{addr: "127.0.0.10:8090", owner: &testConnHandle{}} // Run twice for idempotency. - tracker.registerAssignment(tenantID, sa) - tracker.registerAssignment(tenantID, sa) + tracker.registerAssignment(tenant20, sa) + tracker.registerAssignment(tenant20, sa) + activeList, idleList := tracker.listAssignments(tenant20) + require.Equal(t, []*ServerAssignment{sa}, activeList) + require.Empty(t, idleList) - connsMap := tracker.GetConnsMap(tenantID) + connsMap := tracker.GetConnsMap(tenant20) require.Len(t, connsMap, 1) h, ok := connsMap[sa.Addr()] require.True(t, ok) @@ -52,15 +55,21 @@ func TestConnTracker(t *testing.T) { tenantIDs := tracker.getTenantIDs() require.Len(t, tenantIDs, 1) - require.Equal(t, tenantID, tenantIDs[0]) + require.Equal(t, tenant20, tenantIDs[0]) // Non-existent. connsMap = tracker.GetConnsMap(roachpb.MakeTenantID(42)) require.Empty(t, connsMap) + activeList, idleList = tracker.listAssignments(roachpb.MakeTenantID(42)) + require.Empty(t, activeList) + require.Empty(t, idleList) // Run twice for idempotency. - tracker.unregisterAssignment(tenantID, sa) - tracker.unregisterAssignment(tenantID, sa) + tracker.unregisterAssignment(tenant20, sa) + tracker.unregisterAssignment(tenant20, sa) + activeList, idleList = tracker.listAssignments(tenant20) + require.Empty(t, activeList) + require.Empty(t, idleList) // Once the assignment gets unregistered, we shouldn't return that tenant // since there are no active connections. @@ -94,6 +103,9 @@ func TestConnTracker(t *testing.T) { require.Empty(t, entry.assignments.active) require.Empty(t, entry.assignments.idle) } + activeList, idleList = tracker.listAssignments(tenant20) + require.Empty(t, activeList) + require.Empty(t, idleList) } func TestConnTracker_GetConnsMap(t *testing.T) { diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 63d928a04cd9..76d8c3682a22 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -441,6 +441,8 @@ func (handler *proxyHandler) startPodWatcher(ctx context.Context, podWatcher cha case <-ctx.Done(): return case pod := <-podWatcher: + // TODO(jaylim-crl): Invoke rebalance logic here whenever we see + // a new SQL pod. if pod.State == tenant.DRAINING { handler.idleMonitor.SetIdleChecks(pod.Addr) } else {