diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 08b2840ebb3a..c58252fd99b2 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -2654,6 +2654,37 @@ func (c *SyncedCluster) pghosts( return m, nil } +// resolveLoadBalancerURL resolves the load balancer postgres URL for the given +// virtual cluster and SQL instance. Returns an empty string if a load balancer +// is not found. +func (c *SyncedCluster) loadBalancerURL( + ctx context.Context, + l *logger.Logger, + virtualClusterName string, + sqlInstance int, + auth PGAuthMode, +) (string, error) { + services, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeSQL) + if err != nil { + return "", err + } + var port int + serviceMode := ServiceModeShared + for _, service := range services { + if service.VirtualClusterName == virtualClusterName && service.Instance == sqlInstance { + serviceMode = service.ServiceMode + port = service.Port + break + } + } + address, err := c.FindLoadBalancer(l, port) + if err != nil { + return "", err + } + loadBalancerURL := c.NodeURL(address.IP, address.Port, virtualClusterName, serviceMode, auth) + return loadBalancerURL, nil +} + // SSH creates an interactive shell connecting the caller to the first // node on the cluster (or to all nodes in an iterm2 split screen if // supported). diff --git a/pkg/roachprod/install/expander.go b/pkg/roachprod/install/expander.go index b6e123808daf..b0ea40b3a5a0 100644 --- a/pkg/roachprod/install/expander.go +++ b/pkg/roachprod/install/expander.go @@ -22,8 +22,8 @@ import ( ) var parameterRe = regexp.MustCompile(`{[^{}]*}`) -var pgURLRe = regexp.MustCompile(`{pgurl(:[-,0-9]+)?(:[a-z0-9\-]+)?(:[0-9]+)?}`) -var pgHostRe = regexp.MustCompile(`{pghost(:[-,0-9]+)?}`) +var pgURLRe = regexp.MustCompile(`{pgurl(:[-,0-9]+|:L)?(:[a-z0-9\-]+)?(:[0-9]+)?}`) +var pgHostRe = regexp.MustCompile(`{pghost(:[-,0-9]+|:L)?(:[a-z0-9\-]+)?(:[0-9]+)?}`) var pgPortRe = regexp.MustCompile(`{pgport(:[-,0-9]+)?(:[a-z0-9\-]+)?(:[0-9]+)?}`) var uiPortRe = regexp.MustCompile(`{uiport(:[-,0-9]+)}`) var storeDirRe = regexp.MustCompile(`{store-dir(:[0-9]+)?}`) @@ -157,14 +157,20 @@ func (e *expander) maybeExpandPgURL( if err != nil { return "", false, err } - if e.pgURLs[virtualClusterName] == nil { - e.pgURLs[virtualClusterName], err = c.pgurls(ctx, l, allNodes(len(c.VMs)), virtualClusterName, sqlInstance) - if err != nil { - return "", false, err + switch m[1] { + case ":L": + url, err := c.loadBalancerURL(ctx, l, virtualClusterName, sqlInstance, AuthUserCert) + return url, url != "", err + default: + if e.pgURLs[virtualClusterName] == nil { + e.pgURLs[virtualClusterName], err = c.pgurls(ctx, l, allNodes(len(c.VMs)), virtualClusterName, sqlInstance) + if err != nil { + return "", false, err + } } + s, err = e.maybeExpandMap(c, e.pgURLs[virtualClusterName], m[1]) + return s, err == nil, err } - s, err = e.maybeExpandMap(c, e.pgURLs[virtualClusterName], m[1]) - return s, err == nil, err } // maybeExpandPgHost is an expanderFunc for {pghost:} @@ -175,17 +181,38 @@ func (e *expander) maybeExpandPgHost( if m == nil { return s, false, nil } + virtualClusterName, sqlInstance, err := extractVirtualClusterInfo(m[2:]) + if err != nil { + return "", false, err + } - if e.pgHosts == nil { - var err error - e.pgHosts, err = c.pghosts(ctx, l, allNodes(len(c.VMs))) + switch m[1] { + case ":L": + services, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeSQL, ServiceInstancePredicate(sqlInstance)) if err != nil { return "", false, err } + for _, svc := range services { + if svc.VirtualClusterName == virtualClusterName && svc.Instance == sqlInstance { + addr, err := c.FindLoadBalancer(l, svc.Port) + if err != nil { + return "", false, err + } + return addr.IP, true, nil + } + } + return "", false, err + default: + if e.pgHosts == nil { + var err error + e.pgHosts, err = c.pghosts(ctx, l, allNodes(len(c.VMs))) + if err != nil { + return "", false, err + } + } + s, err := e.maybeExpandMap(c, e.pgHosts, m[1]) + return s, err == nil, err } - - s, err := e.maybeExpandMap(c, e.pgHosts, m[1]) - return s, err == nil, err } // maybeExpandPgURL is an expanderFunc for {pgport:} diff --git a/pkg/roachprod/install/services.go b/pkg/roachprod/install/services.go index 014622c67033..daaf56724e60 100644 --- a/pkg/roachprod/install/services.go +++ b/pkg/roachprod/install/services.go @@ -507,3 +507,19 @@ func (c *SyncedCluster) TargetDNSName(node Node) string { // Targets always end with a period as per SRV record convention. return fmt.Sprintf("%s.%s", cVM.PublicDNS, postfix) } + +// FindLoadBalancer returns the first load balancer address that matches the +// given port. If no load balancer is found, an error is returned. +func (c *SyncedCluster) FindLoadBalancer(l *logger.Logger, port int) (*vm.ServiceAddress, error) { + addresses, err := c.ListLoadBalancers(l) + if err != nil { + return nil, err + } + // Find the load balancer with the matching port. + for _, a := range addresses { + if a.Port == port { + return &a, nil + } + } + return nil, errors.Newf("no load balancer found for port %d", port) +}