diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index e1c49877d1..3057acc3fc 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Vitess Authors. +Copyright 2024 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/go/vt/vtgateproxy/discovery.go b/go/vt/vtgateproxy/discovery.go index b282abc9ed..ec708e3a46 100644 --- a/go/vt/vtgateproxy/discovery.go +++ b/go/vt/vtgateproxy/discovery.go @@ -1,3 +1,18 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package vtgateproxy import ( @@ -9,15 +24,18 @@ import ( "io" "math/rand" "os" - "strconv" "time" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/resolver" + + "vitess.io/vitess/go/vt/log" ) var ( jsonDiscoveryConfig = flag.String("json_config", "", "json file describing the host list to use fot vitess://vtgate resolution") + addressField = flag.String("address_field", "address", "field name in the json file containing the address") + portField = flag.String("port_field", "port", "field name in the json file containing the port") + numConnections = flag.Int("num_connections", 4, "number of outbound GPRC connections to maintain") ) // File based discovery for vtgate grpc endpoints @@ -55,32 +73,32 @@ type JSONGateConfigDiscovery struct { } func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { - fmt.Printf("Start registration for target: %v\n", target.URL.String()) - queryOpts := target.URL.Query() - queryParamCount := queryOpts.Get("num_connections") - queryAZID := queryOpts.Get("az_id") - num_connections := 0 - - gateType := target.URL.Host - - if queryParamCount != "" { - num_connections, _ = strconv.Atoi(queryParamCount) + attrs := target.URL.Query() + + // If the config specifies a pool type attribute, then the caller must supply it in the connection + // attributes, otherwise reject the request. + poolType := "" + if *poolTypeAttr != "" { + poolType = attrs.Get(*poolTypeAttr) + if poolType == "" { + return nil, fmt.Errorf("pool type attribute %s not in target", *poolTypeAttr) + } } - filters := resolveFilters{ - gate_type: gateType, + // Affinity on the other hand is just an optimization + affinity := "" + if *affinityAttr != "" { + affinity = attrs.Get(*affinityAttr) } - if queryAZID != "" { - filters.az_id = queryAZID - } + log.V(100).Infof("Start discovery for target %v poolType %s affinity %s\n", target.URL.String(), poolType, affinity) - r := &resolveJSONGateConfig{ - target: target, - cc: cc, - jsonPath: b.JsonPath, - num_connections: num_connections, - filters: filters, + r := &JSONGateConfigResolver{ + target: target, + cc: cc, + jsonPath: b.JsonPath, + poolType: poolType, + affinity: affinity, } r.start() return r, nil @@ -88,89 +106,111 @@ func (b *JSONGateConfigDiscovery) Build(target resolver.Target, cc resolver.Clie func (*JSONGateConfigDiscovery) Scheme() string { return "vtgate" } func RegisterJsonDiscovery() { - fmt.Printf("Registering: %v\n", *jsonDiscoveryConfig) jsonDiscovery := &JSONGateConfigDiscovery{ JsonPath: *jsonDiscoveryConfig, } resolver.Register(jsonDiscovery) - fmt.Printf("Registered %v scheme\n", jsonDiscovery.Scheme()) + log.Infof("Registered JSON discovery scheme %v to watch: %v\n", jsonDiscovery.Scheme(), *jsonDiscoveryConfig) } -type resolveFilters struct { - gate_type string - az_id string +// Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver). +type JSONGateConfigResolver struct { + target resolver.Target + cc resolver.ClientConn + jsonPath string + poolType string + affinity string + + ticker *time.Ticker + rand *rand.Rand // safe for concurrent use. } -// exampleResolver is a -// Resolver(https://godoc.org/google.golang.org/grpc/resolver#Resolver). -type resolveJSONGateConfig struct { - target resolver.Target - cc resolver.ClientConn - jsonPath string - ticker *time.Ticker - rand *rand.Rand // safe for concurrent use. - num_connections int - filters resolveFilters +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func jsonDump(data interface{}) string { + json, _ := json.Marshal(data) + return string(json) } -type discoverySlackAZ struct{} -type discoverySlackType struct{} +func (r *JSONGateConfigResolver) resolve() (*[]resolver.Address, []byte, error) { -func (r *resolveJSONGateConfig) loadConfig() (*[]resolver.Address, []byte, error) { - config := []DiscoveryHost{} - fmt.Printf("Loading config %v\n", r.jsonPath) + log.V(100).Infof("resolving target %s to %d connections\n", r.target.URL.String(), *numConnections) data, err := os.ReadFile(r.jsonPath) if err != nil { return nil, nil, err } - err = json.Unmarshal(data, &config) + hosts := []map[string]interface{}{} + err = json.Unmarshal(data, &hosts) if err != nil { - fmt.Printf("parse err: %v\n", err) + log.Errorf("error parsing JSON discovery file %s: %v\n", r.jsonPath, err) return nil, nil, err } - addrs := []resolver.Address{} - for _, s := range config { - az := attributes.New(discoverySlackAZ{}, s.AZId).WithValue(discoverySlackType{}, s.Type) + // optionally filter to only hosts that match the pool type + if r.poolType != "" { + candidates := []map[string]interface{}{} + for _, host := range hosts { + hostType, ok := host[*poolTypeAttr] + if ok && hostType == r.poolType { + candidates = append(candidates, host) + log.V(1000).Infof("matched host %s with type %s", jsonDump(host), hostType) + } else { + log.V(1000).Infof("skipping host %s with type %s", jsonDump(host), hostType) + } + } + hosts = candidates + } - // Filter hosts to this gate type - if r.filters.gate_type != "" { - if r.filters.gate_type != s.Type { - continue + // Shuffle to ensure every host has a different order to iterate through + r.rand.Shuffle(len(hosts), func(i, j int) { + hosts[i], hosts[j] = hosts[j], hosts[i] + }) + + // If affinity is specified, then shuffle those hosts to the front + if r.affinity != "" { + i := 0 + for j := 0; j < len(hosts); j++ { + hostAffinity, ok := hosts[j][*affinityAttr] + if ok && hostAffinity == r.affinity { + hosts[i], hosts[j] = hosts[j], hosts[i] + i++ } } + } - // Add matching hosts to registration list + // Grab the first N addresses, and voila! + var addrs []resolver.Address + hosts = hosts[:min(*numConnections, len(hosts))] + for _, host := range hosts { addrs = append(addrs, resolver.Address{ - Addr: fmt.Sprintf("%s:%s", s.NebulaAddress, s.Grpc), - BalancerAttributes: az, + Addr: fmt.Sprintf("%s:%s", host[*addressField], host[*portField]), }) } - fmt.Printf("Addrs: %v\n", addrs) - - // Shuffle to ensure every host has a different order to iterate through - r.rand.Shuffle(len(addrs), func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) - h := sha256.New() if _, err := io.Copy(h, bytes.NewReader(data)); err != nil { return nil, nil, err } + sum := h.Sum(nil) + + log.V(100).Infof("resolved %s to hosts %s addrs: 0x%x, %v\n", r.target.URL.String(), jsonDump(hosts), sum, addrs) - fmt.Printf("Returning discovery: %d hosts checksum %x\n", len(addrs), h.Sum(nil)) - return &addrs, h.Sum(nil), nil + return &addrs, sum, nil } -func (r *resolveJSONGateConfig) start() { - fmt.Print("Starting discovery checker\n") +func (r *JSONGateConfigResolver) start() { + log.V(100).Infof("Starting discovery checker\n") r.rand = rand.New(rand.NewSource(time.Now().UnixNano())) // Immediately load the initial config - addrs, hash, err := r.loadConfig() + addrs, hash, err := r.resolve() if err == nil { // if we parse ok, populate the local address store r.cc.UpdateState(resolver.State{Addresses: *addrs}) @@ -186,7 +226,7 @@ func (r *resolveJSONGateConfig) start() { for range r.ticker.C { checkFileStat, err := os.Stat(r.jsonPath) if err != nil { - fmt.Printf("Error stat'ing config %v\n", err) + log.Errorf("Error stat'ing config %v\n", err) continue } isUnchanged := checkFileStat.Size() == fileStat.Size() || checkFileStat.ModTime() == fileStat.ModTime() @@ -196,35 +236,33 @@ func (r *resolveJSONGateConfig) start() { } fileStat = checkFileStat - fmt.Printf("Detected config change\n") + log.V(100).Infof("Detected config change\n") - addrs, newHash, err := r.loadConfig() + addrs, newHash, err := r.resolve() if err != nil { // better luck next loop // TODO: log this - fmt.Print("Can't load config: %v\n", err) + log.Errorf("Error resolving config: %v\n", err) continue } // Make sure this wasn't a spurious change by checking the hash - if bytes.Compare(hash, newHash) == 0 && newHash != nil { - fmt.Printf("No content changed in discovery file... ignoring\n") + if bytes.Equal(hash, newHash) && newHash != nil { + log.V(100).Infof("No content changed in discovery file... ignoring\n") continue } hash = newHash - fmt.Printf("Loaded %d hosts\n", len(*addrs)) - fmt.Printf("Loaded %v", addrs) r.cc.UpdateState(resolver.State{Addresses: *addrs}) } }() - fmt.Printf("Loaded hosts, starting ticker\n") + log.V(100).Infof("Loaded hosts, starting ticker\n") } -func (r *resolveJSONGateConfig) ResolveNow(o resolver.ResolveNowOptions) {} -func (r *resolveJSONGateConfig) Close() { +func (r *JSONGateConfigResolver) ResolveNow(o resolver.ResolveNowOptions) {} +func (r *JSONGateConfigResolver) Close() { r.ticker.Stop() } diff --git a/go/vt/vtgateproxy/gate_balancer.go b/go/vt/vtgateproxy/gate_balancer.go deleted file mode 100644 index 77f8de98c1..0000000000 --- a/go/vt/vtgateproxy/gate_balancer.go +++ /dev/null @@ -1,119 +0,0 @@ -package vtgateproxy - -import ( - "context" - "errors" - "fmt" - "strconv" - "sync" - "sync/atomic" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" -) - -// Name is the name of az affinity balancer. -const Name = "slack_affinity_balancer" -const MetadataAZKey = "grpc-slack-az-metadata" -const MetadataHostAffinityCount = "grpc-slack-num-connections-metadata" - -var logger = grpclog.Component("slack_affinity_balancer") - -func WithSlackAZAffinityContext(ctx context.Context, azID string, numConnections string) context.Context { - ctx = metadata.AppendToOutgoingContext(ctx, MetadataAZKey, azID, MetadataHostAffinityCount, numConnections) - return ctx -} - -func newBuilder() balancer.Builder { - return base.NewBalancerBuilder(Name, &slackAZAffinityBalancer{}, base.Config{HealthCheck: true}) -} - -func init() { - balancer.Register(newBuilder()) -} - -type slackAZAffinityBalancer struct{} - -func (*slackAZAffinityBalancer) Build(info base.PickerBuildInfo) balancer.Picker { - logger.Infof("slackAZAffinityBalancer: Build called with info: %v", info) - fmt.Printf("Rebuilding picker\n") - - if len(info.ReadySCs) == 0 { - return base.NewErrPicker(balancer.ErrNoSubConnAvailable) - } - allSubConns := []balancer.SubConn{} - subConnsByAZ := map[string][]balancer.SubConn{} - - for sc := range info.ReadySCs { - subConnInfo, _ := info.ReadySCs[sc] - az := subConnInfo.Address.BalancerAttributes.Value(discoverySlackAZ{}).(string) - - allSubConns = append(allSubConns, sc) - subConnsByAZ[az] = append(subConnsByAZ[az], sc) - } - return &slackAZAffinityPicker{ - allSubConns: allSubConns, - subConnsByAZ: subConnsByAZ, - } -} - -type slackAZAffinityPicker struct { - // allSubConns is all subconns that were in the ready state when the picker was created - allSubConns []balancer.SubConn - subConnsByAZ map[string][]balancer.SubConn - nextByAZ sync.Map - next uint32 -} - -// Pick the next in the list from the list of subconns (RR) -func (p *slackAZAffinityPicker) pickFromSubconns(scList []balancer.SubConn, nextIndex uint32) (balancer.PickResult, error) { - subConnsLen := uint32(len(scList)) - - if subConnsLen == 0 { - return balancer.PickResult{}, errors.New("No hosts in list") - } - - fmt.Printf("Select offset: %v %v %v\n", nextIndex, nextIndex%subConnsLen, len(scList)) - - sc := scList[nextIndex%subConnsLen] - return balancer.PickResult{SubConn: sc}, nil -} - -func (p *slackAZAffinityPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - hdrs, _ := metadata.FromOutgoingContext(info.Ctx) - numConnections := 0 - keys := hdrs.Get(MetadataAZKey) - if len(keys) < 1 { - return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) - } - az := keys[0] - - if az == "" { - return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) - } - - keys = hdrs.Get(MetadataHostAffinityCount) - if len(keys) > 0 { - if i, err := strconv.Atoi(keys[0]); err != nil { - numConnections = i - } - } - - subConns := p.subConnsByAZ[az] - if len(subConns) == 0 { - fmt.Printf("No subconns in az and gate type, pick from anywhere\n") - return p.pickFromSubconns(p.allSubConns, atomic.AddUint32(&p.next, 1)) - } - val, _ := p.nextByAZ.LoadOrStore(az, new(uint32)) - ptr := val.(*uint32) - atomic.AddUint32(ptr, 1) - - if len(subConns) >= numConnections && numConnections > 0 { - fmt.Printf("Limiting to first %v\n", numConnections) - return p.pickFromSubconns(subConns[0:numConnections], *ptr) - } else { - return p.pickFromSubconns(subConns, *ptr) - } -} diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index 27ad82d187..e07b4ff109 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -1,5 +1,5 @@ /* -Copyright 2023 The Vitess Authors. +Copyright 2024 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/go/vt/vtgateproxy/vtgateproxy.go b/go/vt/vtgateproxy/vtgateproxy.go index 68869dbd32..5f4a50c8ea 100644 --- a/go/vt/vtgateproxy/vtgateproxy.go +++ b/go/vt/vtgateproxy/vtgateproxy.go @@ -1,5 +1,5 @@ /* -Copyright 2023 The Vitess Authors. +Copyright 2024 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,81 +21,68 @@ package vtgateproxy import ( "context" "flag" - "fmt" "io" "net/url" "strings" "sync" - "time" "google.golang.org/grpc" "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/grpcclient" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/vterrors" _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" "vitess.io/vitess/go/vt/vtgate/vtgateconn" ) var ( - dialTimeout = flag.Duration("dial_timeout", 5*time.Second, "dialer timeout for the GRPC connection") - - defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") - sysVarSetEnabled = flag.Bool("enable_system_settings", true, "This will enable the system settings to be changed per session at the database connection level") + poolTypeAttr = flag.String("pool_type_attr", "", "Attribute (both mysql connection and JSON file) used to specify the target vtgate type and filter the hosts, e.g. 'type'") + affinityAttr = flag.String("affinity_attr", "", "Attribute (both mysql protocol connection and JSON file) used to specify the routing affinity , e.g. 'az_id'") vtGateProxy *VTGateProxy = &VTGateProxy{ targetConns: map[string]*vtgateconn.VTGateConn{}, - mu: sync.Mutex{}, + mu: sync.RWMutex{}, } ) type VTGateProxy struct { - targetConns map[string]*vtgateconn.VTGateConn - mu sync.Mutex - azID string - gateType string - numConnections string + targetConns map[string]*vtgateconn.VTGateConn + mu sync.RWMutex } func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vtgateconn.VTGateConn, error) { - targetURL, err := url.Parse(target) - if err != nil { - return nil, err - } - - proxy.azID = targetURL.Query().Get("az_id") - proxy.numConnections = targetURL.Query().Get("num_connections") - proxy.gateType = targetURL.Host - - fmt.Printf("Getting connection for %v in %v with %v connections\n", target, proxy.azID, proxy.numConnections) + log.V(100).Infof("Getting connection for %v\n", target) // If the connection exists, return it - proxy.mu.Lock() - existingConn, _ := proxy.targetConns[target] + proxy.mu.RLock() + existingConn := proxy.targetConns[target] if existingConn != nil { - proxy.mu.Unlock() + proxy.mu.RUnlock() + log.V(100).Infof("Reused connection for %v\n", target) return existingConn, nil } - proxy.mu.Unlock() + + // No luck, need to create a new one. Serialize new additions so we don't create multiple + // for a given target. + log.V(100).Infof("Need to create connection for %v\n", target) + proxy.mu.RUnlock() + proxy.mu.Lock() // Otherwise create a new connection after dropping the lock, allowing multiple requests to // race to create the conn for now. - // grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { - // return append(opts, grpc.WithBlock()), nil - // }) grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) { - return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"slack_affinity_balancer":{}}]}`)), nil + return append(opts, grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`)), nil }) - conn, err := vtgateconn.DialProtocol(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.numConnections), "grpc", target) + conn, err := vtgateconn.DialProtocol(ctx, "grpc", target) if err != nil { return nil, err } - proxy.mu.Lock() + log.V(100).Infof("Created new connection for %v\n", target) proxy.targetConns[target] = conn proxy.mu.Unlock() @@ -103,12 +90,33 @@ func (proxy *VTGateProxy) getConnection(ctx context.Context, target string) (*vt } func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) { - target, ok := connectionAttributes["target"] - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "no target string supplied by client") + + targetUrl := url.URL{ + Scheme: "vtgate", + Host: "pool", + } + + values := url.Values{} + + if *poolTypeAttr != "" { + poolType, ok := connectionAttributes[*poolTypeAttr] + if ok { + values.Set(*poolTypeAttr, poolType) + } else { + return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "pool type attribute %s not supplied by client", *poolTypeAttr) + } } - conn, err := proxy.getConnection(ctx, target) + if *affinityAttr != "" { + affinity, ok := connectionAttributes[*affinityAttr] + if ok { + values.Set(*affinityAttr, affinity) + } + } + + targetUrl.RawQuery = values.Encode() + + conn, err := proxy.getConnection(ctx, targetUrl.String()) if err != nil { return nil, err } @@ -120,7 +128,7 @@ func (proxy *VTGateProxy) NewSession(ctx context.Context, options *querypb.Execu // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. func (proxy *VTGateProxy) CloseSession(ctx context.Context, session *vtgateconn.VTGateSession) error { - return session.CloseSession(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType)) + return session.CloseSession(ctx) } // ResolveTransaction resolves the specified 2PC transaction. @@ -142,11 +150,11 @@ func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGat return &sqltypes.Result{}, nil } - return session.Execute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) + return session.Execute(ctx, sql, bindVariables) } func (proxy *VTGateProxy) StreamExecute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { - stream, err := session.StreamExecute(WithSlackAZAffinityContext(ctx, proxy.azID, proxy.gateType), sql, bindVariables) + stream, err := session.StreamExecute(ctx, sql, bindVariables) if err != nil { return err }