diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index c39fb9b4c05d..7833f3c90dfd 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -795,6 +795,7 @@ go_test( "//pkg/sql/sessiondata", "//pkg/sql/sessiondatapb", "//pkg/sql/sessionphase", + "//pkg/sql/sqlinstance", "//pkg/sql/sqlliveness", "//pkg/sql/sqlstats", "//pkg/sql/sqltestutils", diff --git a/pkg/sql/distsql_physical_planner.go b/pkg/sql/distsql_physical_planner.go index 91ff46439f1c..9c023ad9887a 100644 --- a/pkg/sql/distsql_physical_planner.go +++ b/pkg/sql/distsql_physical_planner.go @@ -1359,9 +1359,6 @@ func (dsp *DistSQLPlanner) makeSQLInstanceIDForKVNodeIDTenantResolver( hasLocalitySet bool, _ error, ) { - if dsp.sqlAddressResolver == nil { - return nil, nil, false, errors.AssertionFailedf("sql instance provider not available in multi-tenant environment") - } // GetAllInstances only returns healthy instances. // TODO(yuzefovich): confirm that all instances are of compatible version. instances, err := dsp.sqlAddressResolver.GetAllInstances(ctx) @@ -1372,75 +1369,70 @@ func (dsp *DistSQLPlanner) makeSQLInstanceIDForKVNodeIDTenantResolver( return nil, nil, false, errors.New("no healthy sql instances available for planning") } - // Populate a map from the region string to all healthy SQL instances in - // that region. - regionToSQLInstanceIDs := make(map[string][]base.SQLInstanceID) - for _, instance := range instances { - region, ok := instance.Locality.Find("region") - if !ok { - // If we can't determine the region of this instance, don't use it - // for planning. - log.Eventf(ctx, "could not find region for SQL instance %s", instance) - continue + rng, _ := randutil.NewPseudoRand() + + for i := range instances { + if instances[i].Locality.NonEmpty() { + hasLocalitySet = true + break } - instancesInRegion := regionToSQLInstanceIDs[region] - instancesInRegion = append(instancesInRegion, instance.InstanceID) - regionToSQLInstanceIDs[region] = instancesInRegion } - rng, _ := randutil.NewPseudoRand() - if len(regionToSQLInstanceIDs) > 0 { - // If we were able to determine the region information at least for some - // instances, use the region-aware resolver. - hasLocalitySet = true + // If we were able to determine the locality information for at least some + // instances, use the region-aware resolver. + if hasLocalitySet { resolver = func(nodeID roachpb.NodeID) base.SQLInstanceID { + // Lookup the node localities to compare to the instance localities. nodeDesc, err := dsp.nodeDescs.GetNodeDescriptor(nodeID) if err != nil { log.Eventf(ctx, "unable to get node descriptor for KV node %s", nodeID) return dsp.gatewaySQLInstanceID } - region, ok := nodeDesc.Locality.Find("region") - if !ok { - log.Eventf(ctx, "could not find region for KV node %s", nodeDesc) - return dsp.gatewaySQLInstanceID + // TODO(dt): Pre-compute / cache this result, e.g. in the instance reader. + if closest := closestInstances(instances, nodeDesc.Locality); len(closest) > 0 { + return closest[rng.Intn(len(closest))] } - instancesInRegion, ok := regionToSQLInstanceIDs[region] - if !ok { - // There are no instances in this region, so just use the - // gateway. - // TODO(yuzefovich): we should instead pick the closest instance - // in a different region. - return dsp.gatewaySQLInstanceID - } - // Pick a random instance in this region in order to spread the - // load. - // TODO(yuzefovich): consider using a different probability - // distribution for the "local" region (i.e. where the gateway is) - // where the gateway instance is favored. Also, if we had the - // information about latencies between different instances, we could - // favor those that are closer to the gateway. However, we need to - // be careful since non-query code paths (like CDC and BulkIO) do - // benefit from the even spread of the spans. - return instancesInRegion[rng.Intn(len(instancesInRegion))] + // No instances had any locality tiers in common with the node locality so + // just return the gateway. + return dsp.gatewaySQLInstanceID } - } else { - // If it just so happens that we couldn't determine the region for all - // SQL instances, we'll use the naive round-robin strategy that is - // completely locality-ignorant. - hasLocalitySet = false - // Randomize the order in which we choose instances so that work is - // allocated fairly across queries. - rng.Shuffle(len(instances), func(i, j int) { - instances[i], instances[j] = instances[j], instances[i] - }) - var i int - resolver = func(roachpb.NodeID) base.SQLInstanceID { - id := instances[i%len(instances)].InstanceID - i++ - return id - } - } - return resolver, instances, hasLocalitySet, nil + return resolver, instances, hasLocalitySet, nil + } + + // If no sql instances have locality information, fallback to a naive + // round-robin strategy that is completely locality-ignorant. Randomize the + // order in which we choose instances so that work is allocated fairly across + // queries. + rng.Shuffle(len(instances), func(i, j int) { + instances[i], instances[j] = instances[j], instances[i] + }) + var i int + resolver = func(roachpb.NodeID) base.SQLInstanceID { + id := instances[i%len(instances)].InstanceID + i++ + return id + } + return resolver, instances, false, nil +} + +// closestInstances returns the subset of instances which are closest to the +// passed locality, i.e. those which jointly have the longest shared prefix of +// at least length 1. Returns nil, rather than the entire input, if no instances +// have *any* shared locality prefix. +func closestInstances( + instances []sqlinstance.InstanceInfo, loc roachpb.Locality, +) []base.SQLInstanceID { + best := 1 + var res []base.SQLInstanceID + for _, i := range instances { + if l := i.Locality.SharedPrefix(loc); l > best { + best = l + res = append(res[:0], i.InstanceID) + } else if l == best { + res = append(res, i.InstanceID) + } + } + return res } // maybeReassignToGatewaySQLInstance checks whether the span partitioning is diff --git a/pkg/sql/distsql_physical_planner_test.go b/pkg/sql/distsql_physical_planner_test.go index 29ae77741be0..6f0c4f22a2e9 100644 --- a/pkg/sql/distsql_physical_planner_test.go +++ b/pkg/sql/distsql_physical_planner_test.go @@ -44,6 +44,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/randgen" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -1390,3 +1391,47 @@ func TestCheckScanParallelizationIfLocal(t *testing.T) { require.Equal(t, tc.hasScanNodeToParallelize, hasScanNodeToParallize) } } + +func TestClosestInstances(t *testing.T) { + defer leaktest.AfterTest(t)() + type instances map[int]string + type picked []int + + for _, tc := range []struct { + instances instances + loc string + expected []int + }{ + {instances{1: "a=x", 2: "a=y", 3: "a=z"}, "z=z", picked{}}, + {instances{1: "a=x", 2: "a=y", 3: "a=z"}, "", picked{}}, + + {instances{1: "a=x", 2: "a=y", 3: "a=z"}, "a=x", picked{1}}, + {instances{1: "a=x", 2: "a=y", 3: "a=z"}, "a=z", picked{3}}, + {instances{1: "a=x", 2: "a=x", 3: "a=z", 4: "a=z"}, "a=x", picked{1, 2}}, + {instances{1: "a=x", 2: "a=x", 3: "a=z", 4: "a=z"}, "a=z", picked{3, 4}}, + + {instances{1: "a=x,b=1", 2: "a=x,b=2", 3: "a=x,b=3", 4: "a=y,b=1", 5: "a=z,b=1"}, "a=x", picked{1, 2, 3}}, + {instances{1: "a=x,b=1", 2: "a=x,b=2", 3: "a=x,b=3", 4: "a=y,b=1", 5: "a=z,b=1"}, "a=x,b=2", picked{2}}, + {instances{1: "a=x,b=1", 2: "a=x,b=2", 3: "a=x,b=3", 4: "a=y,b=1", 5: "a=z,b=1"}, "a=z", picked{5}}, + } { + t.Run("", func(t *testing.T) { + var l roachpb.Locality + if tc.loc != "" { + require.NoError(t, l.Set(tc.loc)) + } + var infos []sqlinstance.InstanceInfo + for id, l := range tc.instances { + info := sqlinstance.InstanceInfo{InstanceID: base.SQLInstanceID(id)} + if l != "" { + require.NoError(t, info.Locality.Set(l)) + } + infos = append(infos, info) + } + var got picked + for _, i := range closestInstances(infos, l) { + got = append(got, int(i)) + } + require.ElementsMatch(t, tc.expected, got) + }) + } +}