diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index abd59dbef414..f8c03c44e94d 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -275,6 +275,7 @@ ALL_TESTS = [ "//pkg/roachprod/prometheus:prometheus_test", "//pkg/roachprod/ssh:ssh_test", "//pkg/roachprod/vm/gce:gce_test", + "//pkg/roachprod/vm/local:local_test", "//pkg/roachprod/vm:vm_test", "//pkg/rpc/nodedialer:nodedialer_test", "//pkg/rpc:rpc_test", @@ -1475,6 +1476,7 @@ GO_TARGETS = [ "//pkg/roachprod/errors:errors", "//pkg/roachprod/install:install", "//pkg/roachprod/install:install_test", + "//pkg/roachprod/lock:lock", "//pkg/roachprod/logger:logger", "//pkg/roachprod/prometheus:prometheus", "//pkg/roachprod/prometheus:prometheus_test", @@ -1489,6 +1491,7 @@ GO_TARGETS = [ "//pkg/roachprod/vm/gce:gce", "//pkg/roachprod/vm/gce:gce_test", "//pkg/roachprod/vm/local:local", + "//pkg/roachprod/vm/local:local_test", "//pkg/roachprod/vm:vm", "//pkg/roachprod/vm:vm_test", "//pkg/roachprod:roachprod", diff --git a/pkg/cmd/roachprod/flags.go b/pkg/cmd/roachprod/flags.go index d3536cae5907..98b587c33004 100644 --- a/pkg/cmd/roachprod/flags.go +++ b/pkg/cmd/roachprod/flags.go @@ -354,7 +354,7 @@ Default is "RECURRING '*/15 * * * *' FULL BACKUP '@hourly' WITH SCHEDULE OPTIONS cmd.Flags().BoolVar(&secure, "secure", false, "use a secure cluster") } - for _, cmd := range []*cobra.Command{pgurlCmd, sqlCmd} { + for _, cmd := range []*cobra.Command{pgurlCmd, sqlCmd, adminurlCmd} { cmd.Flags().StringVar(&tenantName, "tenant-name", "", "specific tenant to connect to") } diff --git a/pkg/cmd/roachprod/main.go b/pkg/cmd/roachprod/main.go index 23b83af174d9..beabc539fa76 100644 --- a/pkg/cmd/roachprod/main.go +++ b/pkg/cmd/roachprod/main.go @@ -945,7 +945,7 @@ var adminurlCmd = &cobra.Command{ `, Args: cobra.ExactArgs(1), Run: wrap(func(cmd *cobra.Command, args []string) error { - urls, err := roachprod.AdminURL(config.Logger, args[0], adminurlPath, adminurlIPs, adminurlOpen, secure) + urls, err := roachprod.AdminURL(config.Logger, args[0], tenantName, adminurlPath, adminurlIPs, adminurlOpen, secure) if err != nil { return err } diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index a791d5818683..82fb36aa147f 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -2341,8 +2341,7 @@ func addrToAdminUIAddr(addr string) (string, error) { if err != nil { return "", err } - // Roachprod makes Admin UI's port to be node's port + 1. - return fmt.Sprintf("%s:%d", host, webPort+1), nil + return fmt.Sprintf("%s:%d", host, webPort), nil } func urlToAddr(pgURL string) (string, error) { @@ -2379,12 +2378,17 @@ func (c *clusterImpl) InternalAdminUIAddr( ctx context.Context, l *logger.Logger, node option.NodeListOption, ) ([]string, error) { var addrs []string - urls, err := c.InternalAddr(ctx, l, node) + internalAddrs, err := roachprod.AdminURL(l, c.MakeNodes(node), "", "", + false, false, false) if err != nil { return nil, err } - for _, u := range urls { - adminUIAddr, err := addrToAdminUIAddr(u) + for _, u := range internalAddrs { + addr, err := urlToAddr(u) + if err != nil { + return nil, err + } + adminUIAddr, err := addrToAdminUIAddr(addr) if err != nil { return nil, err } @@ -2396,15 +2400,20 @@ func (c *clusterImpl) InternalAdminUIAddr( // ExternalAdminUIAddr returns the external Admin UI address in the form host:port // for the specified node. func (c *clusterImpl) ExternalAdminUIAddr( - ctx context.Context, l *logger.Logger, node option.NodeListOption, + _ context.Context, l *logger.Logger, node option.NodeListOption, ) ([]string, error) { var addrs []string - externalAddrs, err := c.ExternalAddr(ctx, l, node) + externalAddrs, err := roachprod.AdminURL(l, c.MakeNodes(node), "", "", + true, false, false) if err != nil { return nil, err } for _, u := range externalAddrs { - adminUIAddr, err := addrToAdminUIAddr(u) + addr, err := urlToAddr(u) + if err != nil { + return nil, err + } + adminUIAddr, err := addrToAdminUIAddr(addr) if err != nil { return nil, err } diff --git a/pkg/cmd/roachtest/tests/cluster_init.go b/pkg/cmd/roachtest/tests/cluster_init.go index cf1576003e86..f97bdfe00580 100644 --- a/pkg/cmd/roachtest/tests/cluster_init.go +++ b/pkg/cmd/roachtest/tests/cluster_init.go @@ -33,37 +33,26 @@ import ( func runClusterInit(ctx context.Context, t test.Test, c cluster.Cluster) { c.Put(ctx, t.Cockroach(), "./cockroach") - t.L().Printf("retrieving VM addresses") - addrs, err := c.InternalAddr(ctx, t.L(), c.All()) - if err != nil { - t.Fatal(err) - } - - // TODO(tbg): this should never happen, but I saw it locally. The result - // is the test hanging forever, because all nodes will create their own - // single node cluster and waitForFullReplication never returns. - if addrs[0] == "" { - t.Fatal("no address for first node") - } - // We start all nodes with the same join flags and then issue an "init" // command to one of the nodes. We do this twice, since roachtest has some // special casing for the first node in a cluster (the join flags of all nodes // default to just the first node) and we want to make sure that we're not // relying on it. + startOpts := option.DefaultStartOpts() + + // We don't want roachprod to auto-init this cluster. + startOpts.RoachprodOpts.SkipInit = true + + // We need to point all nodes at all other nodes. By default, + // roachprod will point all nodes at the first node, but this + // won't allow init'ing any but the first node - we require + // that all nodes can discover the init'ed node (transitively) + // via the join targets. + startOpts.RoachprodOpts.JoinTargets = c.All() + for _, initNode := range []int{2, 1} { c.Wipe(ctx, false /* preserveCerts */) t.L().Printf("starting test with init node %d", initNode) - startOpts := option.DefaultStartOpts() - - // We don't want roachprod to auto-init this cluster. - startOpts.RoachprodOpts.SkipInit = true - // We need to point all nodes at all other nodes. By default - // roachprod will point all nodes at the first node, but this - // won't allow init'ing any but the first node - we require - // that all nodes can discover the init'ed node (transitively) - // via their join flags. - startOpts.RoachprodOpts.ExtraArgs = append(startOpts.RoachprodOpts.ExtraArgs, "--join="+strings.Join(addrs, ",")) c.Start(ctx, t.L(), startOpts, install.MakeClusterSettings()) urlMap := make(map[int]string) diff --git a/pkg/cmd/roachtest/tests/decommission.go b/pkg/cmd/roachtest/tests/decommission.go index 42de73b56aa0..9efa43434363 100644 --- a/pkg/cmd/roachtest/tests/decommission.go +++ b/pkg/cmd/roachtest/tests/decommission.go @@ -349,13 +349,9 @@ func runDecommission( db := c.Conn(ctx, t.L(), pinnedNode) defer db.Close() - internalAddrs, err := c.InternalAddr(ctx, t.L(), c.Node(pinnedNode)) - if err != nil { - return err - } startOpts := option.DefaultStartSingleNodeOpts() + startOpts.RoachprodOpts.JoinTargets = []int{pinnedNode} extraArgs := []string{ - "--join", internalAddrs[0], fmt.Sprintf("--attrs=node%d", node), } startOpts.RoachprodOpts.ExtraArgs = append(startOpts.RoachprodOpts.ExtraArgs, extraArgs...) diff --git a/pkg/roachprod/BUILD.bazel b/pkg/roachprod/BUILD.bazel index 131ccbb4f06e..f55435506d0b 100644 --- a/pkg/roachprod/BUILD.bazel +++ b/pkg/roachprod/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//pkg/roachprod/cloud", "//pkg/roachprod/config", "//pkg/roachprod/install", + "//pkg/roachprod/lock", "//pkg/roachprod/logger", "//pkg/roachprod/prometheus", "//pkg/roachprod/vm", @@ -30,6 +31,5 @@ go_library( "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//oserror", - "@org_golang_x_sys//unix", ], ) diff --git a/pkg/roachprod/cloud/cluster_cloud.go b/pkg/roachprod/cloud/cluster_cloud.go index 5b8e46699c59..34f2ee60e68e 100644 --- a/pkg/roachprod/cloud/cluster_cloud.go +++ b/pkg/roachprod/cloud/cluster_cloud.go @@ -267,13 +267,19 @@ func CreateCluster( // DestroyCluster TODO(peter): document func DestroyCluster(l *logger.Logger, c *Cluster) error { - return vm.FanOut(c.VMs, func(p vm.Provider, vms vm.List) error { + err := vm.FanOut(c.VMs, func(p vm.Provider, vms vm.List) error { // Enable a fast-path for providers that can destroy a cluster in one shot. if x, ok := p.(vm.DeleteCluster); ok { return x.DeleteCluster(l, c.Name) } return p.Delete(l, vms) }) + if err != nil { + return err + } + return vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error { + return p.DeleteRecordsBySubdomain(c.Name) + }) } // ExtendCluster TODO(peter): document diff --git a/pkg/roachprod/clusters_cache.go b/pkg/roachprod/clusters_cache.go index 99aa3e3d6dad..a2b3f91c67b4 100644 --- a/pkg/roachprod/clusters_cache.go +++ b/pkg/roachprod/clusters_cache.go @@ -57,11 +57,13 @@ func readSyncedClusters(key string) (*cloud.Cluster, bool) { // InitDirs initializes the directories for storing cluster metadata and debug // logs. func InitDirs() error { - cd := os.ExpandEnv(config.ClustersDir) - if err := os.MkdirAll(cd, 0755); err != nil { - return err + dirs := []string{config.ClustersDir, config.DefaultDebugDir, config.DNSDir} + for _, dir := range dirs { + if err := os.MkdirAll(os.ExpandEnv(dir), 0755); err != nil { + return err + } } - return os.MkdirAll(os.ExpandEnv(config.DefaultDebugDir), 0755) + return nil } // saveCluster creates (or overwrites) the file in config.ClusterDir storing the diff --git a/pkg/roachprod/config/config.go b/pkg/roachprod/config/config.go index 888cb4529b84..960852189855 100644 --- a/pkg/roachprod/config/config.go +++ b/pkg/roachprod/config/config.go @@ -71,6 +71,13 @@ const ( // ClustersDir is the directory where we cache information about clusters. ClustersDir = "${HOME}/.roachprod/clusters" + // DefaultLockPath is the path to the lock file used to synchronize access to + // shared roachprod resources. + DefaultLockPath = "$HOME/.roachprod/LOCK" + + // DNSDir is the directory where we cache local cluster DNS information. + DNSDir = "${HOME}/.roachprod/dns" + // SharedUser is the linux username for shared use on all vms. SharedUser = "ubuntu" @@ -86,6 +93,9 @@ const ( // listening for HTTP connections for the Admin UI. DefaultAdminUIPort = 26258 + // DefaultOpenPortStart is the default starting range used to find open ports. + DefaultOpenPortStart = 29000 + // DefaultNumFilesLimit is the default limit on the number of files that can // be opened by the process. DefaultNumFilesLimit = 65 << 13 diff --git a/pkg/roachprod/install/BUILD.bazel b/pkg/roachprod/install/BUILD.bazel index 93cbe6b86e1f..55d7891c036a 100644 --- a/pkg/roachprod/install/BUILD.bazel +++ b/pkg/roachprod/install/BUILD.bazel @@ -11,12 +11,14 @@ go_library( "install.go", "iterm2.go", "nodes.go", + "services.go", "session.go", "staging.go", ], embedsrcs = [ "scripts/download.sh", "scripts/start.sh", + "scripts/open_ports.sh", ], importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/install", visibility = ["//visibility:public"], @@ -49,6 +51,7 @@ go_test( name = "install_test", srcs = [ "cluster_synced_test.go", + "services_test.go", "staging_test.go", "start_template_test.go", ], @@ -56,8 +59,10 @@ go_test( data = glob(["testdata/**"]), embed = [":install"], deps = [ + "//pkg/roachprod/cloud", "//pkg/roachprod/logger", "//pkg/roachprod/vm", + "//pkg/roachprod/vm/local", "//pkg/testutils/datapathutils", "//pkg/util/retry", "@com_github_cockroachdb_datadriven//:datadriven", diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 94690cae67a9..db66b2851f2c 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -662,7 +662,12 @@ func (c *SyncedCluster) Monitor( defer wg.Done() node := nodes[i] - + port, err := c.NodePort(node) + if err != nil { + err := errors.Wrap(err, "failed to get node port") + sendEvent(NodeMonitorInfo{Node: node, Event: MonitorError{err}}) + return + } // On each monitored node, we loop looking for a cockroach process. data := struct { OneShot bool @@ -678,7 +683,7 @@ func (c *SyncedCluster) Monitor( OneShot: opts.OneShot, IgnoreEmpty: opts.IgnoreEmptyNodes, Store: c.NodeDir(node, 1 /* storeIndex */), - Port: c.NodePort(node), + Port: port, Local: c.IsLocal(), Separator: separator, SkippedMsg: skippedMsg, @@ -2345,7 +2350,15 @@ func (c *SyncedCluster) pgurls( } m := make(map[Node]string, len(hosts)) for node, host := range hosts { - m[node] = c.NodeURL(host, c.NodePort(node), tenantName) + desc, err := c.DiscoverService(node, tenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + sharedTenantName := "" + if desc.ServiceMode == ServiceModeShared { + sharedTenantName = tenantName + } + m[node] = c.NodeURL(host, desc.Port, sharedTenantName) } return m, nil } diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index c4ed56740f98..5573d1a477e5 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/ssh" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce" "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" ) @@ -105,6 +106,18 @@ type StartOpts struct { // that will be used when constructing join arguments. InitTarget int + // JoinTargets is the list of nodes that will be used when constructing join + // arguments. If empty, the default is to use the InitTarget. + JoinTargets []int + + // SQLPort is the port on which the cockroach process is listening for SQL + // connections. Default (0) will find an open port. + SQLPort int + + // AdminUIPort is the port on which the cockroach process is listening for + // HTTP traffic for the Admin UI. Default (0) will find an open port. + AdminUIPort int + // -- Options that apply only to StartDefault target -- SkipInit bool @@ -112,9 +125,10 @@ type StartOpts struct { EncryptedStores bool // -- Options that apply only to StartTenantSQL target -- - TenantID int - KVAddrs string - KVCluster *SyncedCluster + TenantName string + TenantID int + KVAddrs string + KVCluster *SyncedCluster } // startSQLTimeout identifies the COCKROACH_CONNECT_TIMEOUT to use (in seconds) @@ -152,6 +166,96 @@ func (so StartOpts) GetInitTarget() Node { return Node(so.InitTarget) } +// GetJoinTargets returns the list of Nodes that should be used for +// join operations. If no join targets are specified, the init target +// is used. +func (so StartOpts) GetJoinTargets() []Node { + nodes := make([]Node, len(so.JoinTargets)) + for i, n := range so.JoinTargets { + nodes[i] = Node(n) + } + if len(nodes) == 0 { + nodes = []Node{so.GetInitTarget()} + } + return nodes +} + +// maybeRegisterServices registers the SQL and Admin UI DNS services for the +// cluster if no previous services for the tenant or host cluster are found. Any +// ports specified in the startOpts are used for the services. If no ports are +// specified, a search for open ports will be performed and selected for use. +func (c *SyncedCluster) maybeRegisterServices( + ctx context.Context, l *logger.Logger, startOpts StartOpts, +) error { + serviceMap, err := c.MapServices(startOpts.TenantName) + if err != nil { + return err + } + tenantName := SystemTenantName + serviceMode := ServiceModeShared + if startOpts.Target == StartTenantSQL { + tenantName = startOpts.TenantName + serviceMode = ServiceModeExternal + } + + mu := syncutil.Mutex{} + servicesToRegister := make(ServiceDescriptors, 0) + err = c.Parallel(ctx, l, c.Nodes, func(ctx context.Context, node Node) (*RunResultDetails, error) { + services := make(ServiceDescriptors, 0) + res := &RunResultDetails{Node: node} + if _, ok := serviceMap[node][ServiceTypeSQL]; !ok { + services = append(services, ServiceDesc{ + TenantName: tenantName, + ServiceType: ServiceTypeSQL, + ServiceMode: serviceMode, + Node: node, + Port: startOpts.SQLPort, + }) + } + if _, ok := serviceMap[node][ServiceTypeUI]; !ok { + services = append(services, ServiceDesc{ + TenantName: tenantName, + ServiceType: ServiceTypeUI, + ServiceMode: serviceMode, + Node: node, + Port: startOpts.AdminUIPort, + }) + } + requiredPorts := 0 + for _, service := range services { + if service.Port == 0 { + requiredPorts++ + } + } + if requiredPorts > 0 { + openPorts, err := c.FindOpenPorts(ctx, l, node, config.DefaultOpenPortStart, requiredPorts) + if err != nil { + res.Err = err + return res, errors.Wrapf(err, "failed to find %d open ports", requiredPorts) + } + for idx := range services { + if services[idx].Port != 0 { + continue + } + services[idx].Port = openPorts[0] + openPorts = openPorts[1:] + } + } + if err != nil { + return nil, err + } + + mu.Lock() + defer mu.Unlock() + servicesToRegister = append(servicesToRegister, services...) + return res, nil + }) + if err != nil { + return err + } + return c.RegisterServices(servicesToRegister) +} + // Start the cockroach process on the cluster. // // Starting the first node is special-cased quite a bit, it's used to distribute @@ -164,6 +268,24 @@ func (c *SyncedCluster) Start(ctx context.Context, l *logger.Logger, startOpts S if startOpts.Target == StartTenantProxy { return fmt.Errorf("start tenant proxy not implemented") } + // Local clusters do not support specifying ports. An error is returned if we + // detect that they were set. + if c.IsLocal() && (startOpts.SQLPort != 0 || startOpts.AdminUIPort != 0) { + // We don't need to return an error if the ports are the default values + // specified in DefaultStartOps, as these have not been specified explicitly + // by the user. + if startOpts.SQLPort != config.DefaultSQLPort || startOpts.AdminUIPort != config.DefaultAdminUIPort { + return fmt.Errorf("local clusters do not support specifying ports") + } + startOpts.SQLPort = 0 + startOpts.AdminUIPort = 0 + } + + err := c.maybeRegisterServices(ctx, l, startOpts) + if err != nil { + return err + } + switch startOpts.Target { case StartDefault: if err := c.distributeCerts(ctx, l); err != nil { @@ -251,8 +373,10 @@ func (c *SyncedCluster) CertsDir(node Node) string { return "certs" } -// NodeURL constructs a postgres URL. -func (c *SyncedCluster) NodeURL(host string, port int, tenantName string) string { +// NodeURL constructs a postgres URL. If sharedTenantName is not empty, it will +// be used as the virtual cluster name in the URL. This is used to connect to a +// shared process hosting multiple tenants. +func (c *SyncedCluster) NodeURL(host string, port int, sharedTenantName string) string { var u url.URL u.User = url.User("root") u.Scheme = "postgres" @@ -266,21 +390,29 @@ func (c *SyncedCluster) NodeURL(host string, port int, tenantName string) string } else { v.Add("sslmode", "disable") } - if tenantName != "" { - v.Add("options", fmt.Sprintf("-ccluster=%s", tenantName)) + if sharedTenantName != "" { + v.Add("options", fmt.Sprintf("-ccluster=%s", sharedTenantName)) } u.RawQuery = v.Encode() return "'" + u.String() + "'" } -// NodePort returns the SQL port for the given node. -func (c *SyncedCluster) NodePort(node Node) int { - return c.VMs[node-1].SQLPort +// NodePort returns the system tenant's SQL port for the given node. +func (c *SyncedCluster) NodePort(node Node) (int, error) { + desc, err := c.DiscoverService(node, SystemTenantName, ServiceTypeSQL) + if err != nil { + return 0, err + } + return desc.Port, nil } -// NodeUIPort returns the AdminUI port for the given node. -func (c *SyncedCluster) NodeUIPort(node Node) int { - return c.VMs[node-1].AdminUIPort +// NodeUIPort returns the system tenant's AdminUI port for the given node. +func (c *SyncedCluster) NodeUIPort(node Node) (int, error) { + desc, err := c.DiscoverService(node, SystemTenantName, ServiceTypeUI) + if err != nil { + return 0, err + } + return desc.Port, nil } // ExecOrInteractiveSQL ssh's onto a single node and executes `./ cockroach sql` @@ -295,7 +427,18 @@ func (c *SyncedCluster) ExecOrInteractiveSQL( if len(c.Nodes) != 1 { return fmt.Errorf("invalid number of nodes for interactive sql: %d", len(c.Nodes)) } - url := c.NodeURL("localhost", c.NodePort(c.Nodes[0]), tenantName) + desc, err := c.DiscoverService(c.Nodes[0], tenantName, ServiceTypeSQL) + if err != nil { + return err + } + if tenantName == "" { + tenantName = SystemTenantName + } + sharedTenantName := "" + if desc.ServiceMode == ServiceModeShared { + sharedTenantName = tenantName + } + url := c.NodeURL("localhost", desc.Port, sharedTenantName) binary := cockroachNodeBinary(c, c.Nodes[0]) allArgs := []string{binary, "sql", "--url", url} allArgs = append(allArgs, ssh.Escape(args)) @@ -305,16 +448,24 @@ func (c *SyncedCluster) ExecOrInteractiveSQL( // ExecSQL runs a `cockroach sql` . // It is assumed that the args include the -e flag. func (c *SyncedCluster) ExecSQL( - ctx context.Context, l *logger.Logger, tenantName string, args []string, + ctx context.Context, l *logger.Logger, nodes Nodes, tenantName string, args []string, ) error { display := fmt.Sprintf("%s: executing sql", c.Name) - results, _, err := c.ParallelE(ctx, l, c.Nodes, func(ctx context.Context, node Node) (*RunResultDetails, error) { + results, _, err := c.ParallelE(ctx, l, nodes, func(ctx context.Context, node Node) (*RunResultDetails, error) { + desc, err := c.DiscoverService(node, tenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + sharedTenantName := "" + if desc.ServiceMode == ServiceModeShared { + sharedTenantName = tenantName + } var cmd string if c.IsLocal() { cmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(node)) } cmd += cockroachNodeBinary(c, node) + " sql --url " + - c.NodeURL("localhost", c.NodePort(node), tenantName) + " " + + c.NodeURL("localhost", desc.Port, sharedTenantName) + " " + ssh.Escape(args) return c.runCmdOnSingleNode(ctx, l, node, cmd, defaultCmdOpts("run-sql")) @@ -466,12 +617,29 @@ func (c *SyncedCluster) generateStartArgs( listenHost = "127.0.0.1" } + tenantName := startOpts.TenantName + var sqlPort int if startOpts.Target == StartTenantSQL { - args = append(args, fmt.Sprintf("--sql-addr=%s:%d", listenHost, c.NodePort(node))) + desc, err := c.DiscoverService(node, tenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + sqlPort = desc.Port + args = append(args, fmt.Sprintf("--sql-addr=%s:%d", listenHost, sqlPort)) } else { - args = append(args, fmt.Sprintf("--listen-addr=%s:%d", listenHost, c.NodePort(node))) + tenantName = SystemTenantName + desc, err := c.DiscoverService(node, tenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + sqlPort = desc.Port + args = append(args, fmt.Sprintf("--listen-addr=%s:%d", listenHost, sqlPort)) + } + desc, err := c.DiscoverService(node, tenantName, ServiceTypeUI) + if err != nil { + return nil, err } - args = append(args, fmt.Sprintf("--http-addr=%s:%d", listenHost, c.NodeUIPort(node))) + args = append(args, fmt.Sprintf("--http-addr=%s:%d", listenHost, desc.Port)) if !c.IsLocal() { advertiseHost := "" @@ -481,14 +649,22 @@ func (c *SyncedCluster) generateStartArgs( advertiseHost = c.VMs[node-1].PrivateIP } args = append(args, - fmt.Sprintf("--advertise-addr=%s:%d", advertiseHost, c.NodePort(node)), + fmt.Sprintf("--advertise-addr=%s:%d", advertiseHost, sqlPort), ) } // --join flags are unsupported/unnecessary in `cockroach start-single-node`. if startOpts.Target == StartDefault && !c.useStartSingleNode() { - initTarget := startOpts.GetInitTarget() - args = append(args, fmt.Sprintf("--join=%s:%d", c.Host(initTarget), c.NodePort(initTarget))) + joinTargets := startOpts.GetJoinTargets() + addresses := make([]string, len(joinTargets)) + for i, joinNode := range startOpts.GetJoinTargets() { + desc, err := c.DiscoverService(joinNode, SystemTenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + addresses[i] = fmt.Sprintf("%s:%d", c.Host(joinNode), desc.Port) + } + args = append(args, fmt.Sprintf("--join=%s", strings.Join(addresses, ","))) } if startOpts.Target == StartTenantSQL { args = append(args, fmt.Sprintf("--kv-addrs=%s", startOpts.KVAddrs)) @@ -593,7 +769,10 @@ func (c *SyncedCluster) initializeCluster( ctx context.Context, l *logger.Logger, node Node, ) (*RunResultDetails, error) { l.Printf("%s: initializing cluster\n", c.Name) - cmd := c.generateInitCmd(node) + cmd, err := c.generateInitCmd(node) + if err != nil { + return nil, err + } res, err := c.runCmdOnSingleNode(ctx, l, node, cmd, defaultCmdOpts("init-cluster")) if res != nil { @@ -609,7 +788,10 @@ func (c *SyncedCluster) setClusterSettings( ctx context.Context, l *logger.Logger, node Node, ) (*RunResultDetails, error) { l.Printf("%s: setting cluster settings", c.Name) - cmd := c.generateClusterSettingCmd(l, node) + cmd, err := c.generateClusterSettingCmd(l, node) + if err != nil { + return nil, err + } res, err := c.runCmdOnSingleNode(ctx, l, node, cmd, defaultCmdOpts("set-cluster-settings")) if res != nil { @@ -621,7 +803,7 @@ func (c *SyncedCluster) setClusterSettings( return res, err } -func (c *SyncedCluster) generateClusterSettingCmd(l *logger.Logger, node Node) string { +func (c *SyncedCluster) generateClusterSettingCmd(l *logger.Logger, node Node) (string, error) { if config.CockroachDevLicense == "" { l.Printf("%s: COCKROACH_DEV_LICENSE unset: enterprise features will be unavailable\n", c.Name) @@ -646,29 +828,37 @@ func (c *SyncedCluster) generateClusterSettingCmd(l *logger.Logger, node Node) s binary := cockroachNodeBinary(c, node) path := fmt.Sprintf("%s/%s", c.NodeDir(node, 1 /* storeIndex */), "settings-initialized") - url := c.NodeURL("localhost", c.NodePort(node), "" /* tenantName */) + port, err := c.NodePort(node) + if err != nil { + return "", err + } + url := c.NodeURL("localhost", port, SystemTenantName /* tenantName */) clusterSettingsCmd += fmt.Sprintf(` if ! test -e %s ; then COCKROACH_CONNECT_TIMEOUT=%d %s sql --url %s -e "%s" && touch %s fi`, path, startSQLTimeout, binary, url, clusterSettingsString, path) - return clusterSettingsCmd + return clusterSettingsCmd, nil } -func (c *SyncedCluster) generateInitCmd(node Node) string { +func (c *SyncedCluster) generateInitCmd(node Node) (string, error) { var initCmd string if c.IsLocal() { initCmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(node)) } path := fmt.Sprintf("%s/%s", c.NodeDir(node, 1 /* storeIndex */), "cluster-bootstrapped") - url := c.NodeURL("localhost", c.NodePort(node), "" /* tenantName */) + port, err := c.NodePort(node) + if err != nil { + return "", err + } + url := c.NodeURL("localhost", port, SystemTenantName /* tenantName */) binary := cockroachNodeBinary(c, node) initCmd += fmt.Sprintf(` if ! test -e %[1]s ; then COCKROACH_CONNECT_TIMEOUT=%[4]d %[2]s init --url %[3]s && touch %[1]s fi`, path, binary, url, startSQLTimeout) - return initCmd + return initCmd, nil } func (c *SyncedCluster) generateKeyCmd( @@ -784,7 +974,11 @@ func (c *SyncedCluster) createFixedBackupSchedule( node := c.Nodes[0] binary := cockroachNodeBinary(c, node) - url := c.NodeURL("localhost", c.NodePort(node), "" /* tenantName */) + port, err := c.NodePort(node) + if err != nil { + return err + } + url := c.NodeURL("localhost", port, SystemTenantName /* tenantName */) fullCmd := fmt.Sprintf(`COCKROACH_CONNECT_TIMEOUT=%d %s sql --url %s -e %q`, startSQLTimeout, binary, url, createScheduleCmd) // Instead of using `c.ExecSQL()`, use `c.runCmdOnSingleNode()`, which allows us to diff --git a/pkg/roachprod/install/expander.go b/pkg/roachprod/install/expander.go index 06faa7209d35..2912f674d676 100644 --- a/pkg/roachprod/install/expander.go +++ b/pkg/roachprod/install/expander.go @@ -23,8 +23,8 @@ import ( var parameterRe = regexp.MustCompile(`{[^{}]*}`) var pgURLRe = regexp.MustCompile(`{pgurl(:[-,0-9]+)?(:[a-z0-9\-]+)?}`) var pgHostRe = regexp.MustCompile(`{pghost(:[-,0-9]+)?}`) -var pgPortRe = regexp.MustCompile(`{pgport(:[-,0-9]+)?}`) -var uiPortRe = regexp.MustCompile(`{uiport(:[-,0-9]+)?}`) +var pgPortRe = regexp.MustCompile(`{pgport(:[-,0-9]+)?(:[a-z0-9\-]+)?}`) +var uiPortRe = regexp.MustCompile(`{uiport(:[-,0-9]+)}`) var storeDirRe = regexp.MustCompile(`{store-dir}`) var logDirRe = regexp.MustCompile(`{log-dir}`) var certsDirRe = regexp.MustCompile(`{certs-dir}`) @@ -121,7 +121,7 @@ func (e *expander) maybeExpandPgURL( if e.pgURLs == nil { e.pgURLs = make(map[string]map[Node]string) } - tenant := "system" + tenant := SystemTenantName if m[2] != "" { // Trim off the leading ':' in the capture group. tenant = m[2][1:] @@ -167,11 +167,20 @@ func (e *expander) maybeExpandPgPort( if m == nil { return s, false, nil } + tenant := SystemTenantName + if m[2] != "" { + // Trim off the leading ':' in the capture group. + tenant = m[2][1:] + } if e.pgPorts == nil { e.pgPorts = make(map[Node]string, len(c.VMs)) for _, node := range allNodes(len(c.VMs)) { - e.pgPorts[node] = fmt.Sprint(c.NodePort(node)) + desc, err := c.DiscoverService(node, tenant, ServiceTypeSQL) + if err != nil { + return s, false, err + } + e.pgPorts[node] = fmt.Sprint(desc.Port) } } @@ -191,6 +200,7 @@ func (e *expander) maybeExpandUIPort( if e.uiPorts == nil { e.uiPorts = make(map[Node]string, len(c.VMs)) for _, node := range allNodes(len(c.VMs)) { + // TODO(herko): Add support for external tenants. e.uiPorts[node] = fmt.Sprint(c.NodeUIPort(node)) } } diff --git a/pkg/roachprod/install/nodes.go b/pkg/roachprod/install/nodes.go index 9642a2e09e26..5cf55f22d218 100644 --- a/pkg/roachprod/install/nodes.go +++ b/pkg/roachprod/install/nodes.go @@ -99,3 +99,12 @@ func allNodes(numNodesInCluster int) Nodes { } return r } + +func (n Nodes) Contains(node Node) bool { + for _, v := range n { + if v == node { + return true + } + } + return false +} diff --git a/pkg/roachprod/install/scripts/open_ports.sh b/pkg/roachprod/install/scripts/open_ports.sh new file mode 100644 index 000000000000..daba385b072c --- /dev/null +++ b/pkg/roachprod/install/scripts/open_ports.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# +# Copyright 2023 The Cockroach Authors. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. +set -euo pipefail + +start_port=#{.StartPort#} +port_count=#{.PortCount#} + +open_ports=() +ports_found=0 + +set +e +for ((port = start_port; port < 32768; port++)); do + if ! lsof -i :"$port" >/dev/null 2>&1; then + open_ports+=("$port") + ((ports_found++)) + + if ((ports_found >= port_count)); then + break + fi + fi +done + +set -e +if ((ports_found > 0)); then + echo "${open_ports[@]}" +else + echo "no open ports found" >&2 + exit 1 +fi diff --git a/pkg/roachprod/install/services.go b/pkg/roachprod/install/services.go new file mode 100644 index 000000000000..1841636f2987 --- /dev/null +++ b/pkg/roachprod/install/services.go @@ -0,0 +1,400 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package install + +import ( + "context" + _ "embed" + "fmt" + "net" + "strconv" + "strings" + "text/template" + + "github.com/alessio/shellescape" + "github.com/cockroachdb/cockroach/pkg/roachprod/config" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" +) + +//go:embed scripts/open_ports.sh +var openPortsScript string + +type ServiceType string + +const ( + // ServiceTypeSQL is the service type for SQL services on a node. + ServiceTypeSQL ServiceType = "sql" + // ServiceTypeUI is the service type for UI services on a node. + ServiceTypeUI ServiceType = "ui" +) + +// SystemTenantName is default system tenant name. +const SystemTenantName = "system" + +type ServiceMode string + +const ( + // ServiceModeShared is the service mode for services that are shared on a host process. + ServiceModeShared ServiceMode = "shared" + // ServiceModeExternal is the service mode for services that are run in a separate process. + ServiceModeExternal ServiceMode = "external" +) + +// SharedPriorityClass is the priority class used to indicate when a service is shared. +const SharedPriorityClass = 1000 + +// ServiceDesc describes a service running on a node. +type ServiceDesc struct { + // TenantName is the name of the tenant that owns the service. + TenantName string + // ServiceType is the type of service. + ServiceType ServiceType + // ServiceMode is the mode of the service. + ServiceMode ServiceMode + // Node is the node the service is running on. + Node Node + // Port is the port the service is running on. + Port int +} + +// NodeServiceMap is a convenience type for mapping services by service type for each node. +type NodeServiceMap map[Node]map[ServiceType]*ServiceDesc + +// ServiceDescriptors is a convenience type for a slice of service descriptors. +type ServiceDescriptors []ServiceDesc + +// localClusterPortCache is a workaround for local clusters to prevent multiple +// nodes from using the same port when searching for open ports. +var localClusterPortCache struct { + mu syncutil.Mutex + startPort int +} + +// serviceDNSName returns the DNS name for a service in the standard SRV form. +func serviceDNSName( + dnsProvider vm.DNSProvider, tenantName string, serviceType ServiceType, clusterName string, +) string { + // An SRV record name must adhere to the standard form: + // _service._proto.name. + return fmt.Sprintf("_%s-%s._tcp.%s.%s", tenantName, serviceType, clusterName, dnsProvider.Domain()) +} + +// serviceNameComponents returns the tenant name and service type from a DNS +// name in the standard SRV form. +func serviceNameComponents(name string) (string, ServiceType, error) { + nameParts := strings.Split(name, ".") + if len(nameParts) < 2 { + return "", "", errors.Newf("invalid DNS SRV name: %s", name) + } + + serviceName := strings.TrimPrefix(nameParts[0], "_") + splitIndex := strings.LastIndex(serviceName, "-") + if splitIndex == -1 { + return "", "", errors.Newf("invalid service name: %s", serviceName) + } + + serviceTypeStr := serviceName[splitIndex+1:] + var serviceType ServiceType + switch { + case serviceTypeStr == string(ServiceTypeSQL): + serviceType = ServiceTypeSQL + case serviceTypeStr == string(ServiceTypeUI): + serviceType = ServiceTypeUI + default: + return "", "", errors.Newf("invalid service type: %s", serviceTypeStr) + } + return serviceName[:splitIndex], serviceType, nil +} + +// DiscoverServices discovers services running on the given nodes. Services +// matching the tenant name and service type are returned. It's possible that +// more than one service can be returned for the given parameters if additional +// services of the same type are running for the same tenant. +func (c *SyncedCluster) DiscoverServices( + nodes Nodes, tenantName string, serviceType ServiceType, +) (ServiceDescriptors, error) { + // If no tenant name is specified, use the system tenant. + if tenantName == "" { + tenantName = SystemTenantName + } + mu := syncutil.Mutex{} + records := make([]vm.DNSRecord, 0) + err := vm.FanOutDNS(c.VMs, func(dnsProvider vm.DNSProvider, _ vm.List) error { + service := fmt.Sprintf("%s-%s", tenantName, string(serviceType)) + r, lookupErr := dnsProvider.LookupSRVRecords(service, "tcp", c.Name) + if lookupErr != nil { + return lookupErr + } + mu.Lock() + defer mu.Unlock() + records = append(records, r...) + return nil + }) + if err != nil { + return nil, err + } + descriptors, err := c.dnsRecordsToServiceDescriptors(records) + if err != nil { + return nil, err + } + return descriptors.Filter(nodes), nil +} + +// DiscoverService is a convenience method for discovering a single service. It +// returns the highest priority service returned by DiscoverServices. If no +// services are found, it returns a service descriptor with the default port for +// the service type. +func (c *SyncedCluster) DiscoverService( + node Node, tenantName string, serviceType ServiceType, +) (ServiceDesc, error) { + services, err := c.DiscoverServices([]Node{node}, tenantName, serviceType) + if err != nil { + return ServiceDesc{}, err + } + // If no services are found, attempt to discover a service for the system + // tenant, and assume the service is shared. + if len(services) == 0 { + services, err = c.DiscoverServices([]Node{node}, SystemTenantName, serviceType) + if err != nil { + return ServiceDesc{}, err + } + } + // Finally, fall back to the default ports if no services are found. This is + // useful for backwards compatibility with clusters that were created before + // the introduction of service discovery, or without a DNS provider. + // TODO(Herko): Remove this once DNS support is fully functional. + if len(services) == 0 { + var port int + switch serviceType { + case ServiceTypeSQL: + port = config.DefaultSQLPort + case ServiceTypeUI: + port = config.DefaultAdminUIPort + default: + return ServiceDesc{}, errors.Newf("invalid service type: %s", serviceType) + } + return ServiceDesc{ + ServiceType: serviceType, + ServiceMode: ServiceModeShared, + TenantName: tenantName, + Node: node, + Port: port, + }, nil + } + + // If there are multiple services available select the first one. + return services[0], err +} + +// MapServices discovers all service types for a given tenant and maps it by +// node and service type. +func (c *SyncedCluster) MapServices(tenantName string) (NodeServiceMap, error) { + sqlServices, err := c.DiscoverServices(c.Nodes, tenantName, ServiceTypeSQL) + if err != nil { + return nil, err + } + uiServices, err := c.DiscoverServices(c.Nodes, tenantName, ServiceTypeUI) + if err != nil { + return nil, err + } + serviceMap := make(NodeServiceMap) + for _, node := range c.Nodes { + serviceMap[node] = make(map[ServiceType]*ServiceDesc) + } + services := append(sqlServices, uiServices...) + for _, service := range services { + serviceMap[service.Node][service.ServiceType] = &service + } + return serviceMap, nil +} + +// RegisterServices registers services with the DNS provider. This function is +// lenient and will not return an error if no DNS provider is available to +// register the service. +func (c *SyncedCluster) RegisterServices(services ServiceDescriptors) error { + servicesByDNSProvider := make(map[string]ServiceDescriptors) + for _, desc := range services { + dnsProvider := c.VMs[desc.Node-1].DNSProvider + if dnsProvider == "" { + continue + } + servicesByDNSProvider[dnsProvider] = append(servicesByDNSProvider[dnsProvider], desc) + } + for dnsProviderName := range servicesByDNSProvider { + return vm.ForDNSProvider(dnsProviderName, func(dnsProvider vm.DNSProvider) error { + records := make([]vm.DNSRecord, 0) + for _, desc := range servicesByDNSProvider[dnsProviderName] { + name := serviceDNSName(dnsProvider, desc.TenantName, desc.ServiceType, c.Name) + priority := 0 + if desc.ServiceMode == ServiceModeShared { + priority = SharedPriorityClass + } + srvData := net.SRV{ + Target: c.TargetDNSName(desc.Node), + Port: uint16(desc.Port), + Priority: uint16(priority), + Weight: 0, + } + records = append(records, vm.CreateSRVRecord(name, srvData)) + } + err := dnsProvider.CreateRecords(records...) + if err != nil { + return err + } + return nil + }) + } + return nil +} + +// Filter returns ServiceDescriptors with only the descriptors that match +// the given nodes. +func (d ServiceDescriptors) Filter(nodes Nodes) ServiceDescriptors { + filteredDescriptors := make(ServiceDescriptors, 0) + for _, descriptor := range d { + if !nodes.Contains(descriptor.Node) { + continue + } + filteredDescriptors = append(filteredDescriptors, descriptor) + } + return filteredDescriptors +} + +// FindOpenPorts finds the requested number of open ports on the provided node. +func (c *SyncedCluster) FindOpenPorts( + ctx context.Context, l *logger.Logger, node Node, startPort, count int, +) ([]int, error) { + tpl, err := template.New("open_ports"). + Funcs(template.FuncMap{"shesc": func(i interface{}) string { + return shellescape.Quote(fmt.Sprint(i)) + }}). + Delims("#{", "#}"). + Parse(openPortsScript) + if err != nil { + return nil, err + } + + var ports []int + if c.IsLocal() { + // For local clusters, we need to keep track of the ports we've already used + // so that we don't use them again, when this function is called in + // parallel. This does not protect against the case where concurrent calls + // are made to roachprod to create local clusters. + localClusterPortCache.mu.Lock() + defer func() { + nextPort := startPort + if len(ports) > 0 { + nextPort = ports[len(ports)-1] + } + localClusterPortCache.startPort = nextPort + 1 + localClusterPortCache.mu.Unlock() + }() + if localClusterPortCache.startPort > startPort { + startPort = localClusterPortCache.startPort + } + } + + var buf strings.Builder + if err := tpl.Execute(&buf, struct { + StartPort int + PortCount int + }{ + StartPort: startPort, + PortCount: count, + }); err != nil { + return nil, err + } + + res, err := c.runCmdOnSingleNode(ctx, l, node, buf.String(), defaultCmdOpts("find-ports")) + if err != nil { + return nil, err + } + ports, err = stringToIntegers(strings.TrimSpace(res.CombinedOut)) + if err != nil { + return nil, err + } + if len(ports) != count { + return nil, errors.Errorf("expected %d ports, got %d", count, len(ports)) + } + return ports, nil +} + +// stringToIntegers converts a string of space-separated integers into a slice. +func stringToIntegers(str string) ([]int, error) { + fields := strings.Fields(str) + integers := make([]int, len(fields)) + for i, field := range fields { + port, err := strconv.Atoi(field) + if err != nil { + return nil, err + } + integers[i] = port + } + return integers, nil +} + +// dnsRecordsToServiceDescriptors converts a slice of DNS SRV records into a +// slice of ServiceDescriptors. +func (c *SyncedCluster) dnsRecordsToServiceDescriptors( + records []vm.DNSRecord, +) (ServiceDescriptors, error) { + // Map public DNS names to nodes. + dnsNameToNode := make(map[string]Node) + for idx := range c.VMs { + node := Node(idx + 1) + dnsNameToNode[c.TargetDNSName(node)] = node + } + // Parse SRV records into service descriptors. + ports := make(ServiceDescriptors, 0) + for _, record := range records { + if record.Type != vm.SRV { + continue + } + data, err := record.ParseSRVRecord() + if err != nil { + return nil, err + } + if _, ok := dnsNameToNode[data.Target]; !ok { + continue + } + serviceMode := ServiceModeExternal + if data.Priority >= SharedPriorityClass { + serviceMode = ServiceModeShared + } + tenantName, serviceType, err := serviceNameComponents(record.Name) + if err != nil { + return nil, err + } + ports = append(ports, ServiceDesc{ + TenantName: tenantName, + ServiceType: serviceType, + ServiceMode: serviceMode, + Port: int(data.Port), + Node: dnsNameToNode[data.Target], + }) + } + return ports, nil +} + +func (c *SyncedCluster) TargetDNSName(node Node) string { + cVM := c.VMs[node-1] + postfix := "" + if c.IsLocal() { + // For local clusters the Public DNS is the same for all nodes, so we + // need to add a postfix to make them unique. + postfix = fmt.Sprintf("%d.", int(node)) + } + // Targets always end with a period as per SRV record convention. + return fmt.Sprintf("%s.%s", cVM.PublicDNS, postfix) +} diff --git a/pkg/roachprod/install/services_test.go b/pkg/roachprod/install/services_test.go new file mode 100644 index 000000000000..eedbc1f82677 --- /dev/null +++ b/pkg/roachprod/install/services_test.go @@ -0,0 +1,94 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package install + +import ( + "net" + "sort" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/cloud" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/local" + "github.com/stretchr/testify/require" +) + +type testProvider struct { + vm.Provider + vm.DNSProvider +} + +func TestServicePorts(t *testing.T) { + clusterName := "tc" + z1NS := local.NewDNSProvider(t.TempDir(), "z1") + vm.Providers["p1"] = &testProvider{DNSProvider: z1NS} + z2NS := local.NewDNSProvider(t.TempDir(), "z2") + vm.Providers["p2"] = &testProvider{DNSProvider: z2NS} + + err := z1NS.CreateRecords( + vm.CreateSRVRecord(serviceDNSName(z1NS, "t1", ServiceTypeSQL, clusterName), net.SRV{ + Target: "host1.rp.", + Port: 12345, + }), + ) + require.NoError(t, err) + + err = z2NS.CreateRecords( + vm.CreateSRVRecord(serviceDNSName(z2NS, "t1", ServiceTypeSQL, clusterName), net.SRV{ + Target: "host1.rp.", + Port: 12346, + }), + ) + require.NoError(t, err) + + c := &SyncedCluster{ + Cluster: cloud.Cluster{ + Name: clusterName, + VMs: vm.List{ + vm.VM{ + Provider: "p1", + DNSProvider: "p1", + PublicDNS: "host1.rp", + }, + vm.VM{ + Provider: "p2", + DNSProvider: "p2", + PublicDNS: "host2.rp", + }, + }, + }, + Nodes: allNodes(2), + } + + descriptors, err := c.DiscoverServices(c.Nodes, "t1", ServiceTypeSQL) + sort.Slice(descriptors, func(i, j int) bool { + return descriptors[i].Port < descriptors[j].Port + }) + require.NoError(t, err) + require.Len(t, descriptors, 2) + require.Equal(t, 12345, descriptors[0].Port) + require.Equal(t, 12346, descriptors[1].Port) +} + +func TestStringToIntegers(t *testing.T) { + integers, err := stringToIntegers(" 20 333 4 5\n 89\n\n") + require.NoError(t, err) + require.Equal(t, []int{20, 333, 4, 5, 89}, integers) +} + +func TestServiceNameComponents(t *testing.T) { + z := local.NewDNSProvider(t.TempDir(), "z1") + dnsName := serviceDNSName(z, "tenant-100", ServiceTypeSQL, "test-cluster") + tenantName, serviceType, err := serviceNameComponents(dnsName) + require.NoError(t, err) + require.Equal(t, "tenant-100", tenantName) + require.Equal(t, ServiceTypeSQL, serviceType) +} diff --git a/pkg/roachprod/lock/BUILD.bazel b/pkg/roachprod/lock/BUILD.bazel new file mode 100644 index 000000000000..023f0440b08e --- /dev/null +++ b/pkg/roachprod/lock/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "lock", + srcs = ["fs.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/lock", + visibility = ["//visibility:public"], + deps = [ + "@com_github_cockroachdb_errors//:errors", + "@org_golang_x_sys//unix", + ], +) diff --git a/pkg/roachprod/lock/fs.go b/pkg/roachprod/lock/fs.go new file mode 100644 index 000000000000..1efa86c41381 --- /dev/null +++ b/pkg/roachprod/lock/fs.go @@ -0,0 +1,36 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package lock + +import ( + "os" + + "github.com/cockroachdb/errors" + "golang.org/x/sys/unix" +) + +// AcquireFilesystemLock acquires a filesystem lock in order that concurrent +// operations or roachprod processes that access shared system resources do not +// conflict. Different locks can be specified by passing different paths. +func AcquireFilesystemLock(path string) (unlockFn func(), _ error) { + lockFile := os.ExpandEnv(path) + f, err := os.Create(lockFile) + if err != nil { + return nil, errors.Wrapf(err, "creating lock file %q", lockFile) + } + if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil { + f.Close() + return nil, errors.Wrap(err, "acquiring lock on %q") + } + return func() { + f.Close() + }, nil +} diff --git a/pkg/roachprod/multitenant.go b/pkg/roachprod/multitenant.go index 8a3341578121..36fce48989c6 100644 --- a/pkg/roachprod/multitenant.go +++ b/pkg/roachprod/multitenant.go @@ -60,23 +60,26 @@ func StartTenant( if startOpts.TenantID < 2 { return errors.Errorf("invalid tenant ID %d (must be 2 or higher)", startOpts.TenantID) } + // TODO(herko): Allow users to pass in a tenant name. + startOpts.TenantName = fmt.Sprintf("tenant-%d", startOpts.TenantID) - // Create tenant, if necessary. We need to run this SQL against a single host, - // so temporarily restrict the target nodes to 1. - saveNodes := hc.Nodes - hc.Nodes = hc.Nodes[:1] + // Create tenant, if necessary. We need to run this SQL against a single host. l.Printf("Creating tenant metadata") - if err := hc.ExecSQL(ctx, l, "", []string{ + if err := hc.ExecSQL(ctx, l, hc.Nodes[:1], "", []string{ `-e`, fmt.Sprintf(createTenantIfNotExistsQuery, startOpts.TenantID), }); err != nil { return err } - hc.Nodes = saveNodes + l.Printf("Starting tenant nodes") var kvAddrs []string for _, node := range hc.Nodes { - kvAddrs = append(kvAddrs, fmt.Sprintf("%s:%d", hc.Host(node), hc.NodePort(node))) + port, err := hc.NodePort(node) + if err != nil { + return err + } + kvAddrs = append(kvAddrs, fmt.Sprintf("%s:%d", hc.Host(node), port)) } startOpts.KVAddrs = strings.Join(kvAddrs, ",") startOpts.KVCluster = hc diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 3a01b172eeee..6f34582b261a 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/cloud" "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/install" + "github.com/cockroachdb/cockroach/pkg/roachprod/lock" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/prometheus" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" @@ -50,7 +51,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "github.com/cockroachdb/errors/oserror" - "golang.org/x/sys/unix" ) // verifyClusterName ensures that the given name conforms to @@ -224,24 +224,6 @@ func CachedClusters(l *logger.Logger, fn func(clusterName string, numVMs int)) { } } -// acquireFilesystemLock acquires a filesystem lock in order that concurrent -// operations or roachprod processes that access shared system resources do -// not conflict. -func acquireFilesystemLock() (unlockFn func(), _ error) { - lockFile := os.ExpandEnv("$HOME/.roachprod/LOCK") - f, err := os.Create(lockFile) - if err != nil { - return nil, errors.Wrapf(err, "creating lock file %q", lockFile) - } - if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil { - f.Close() - return nil, errors.Wrap(err, "acquiring lock on %q") - } - return func() { - f.Close() - }, nil -} - // Sync grabs an exclusive lock on the roachprod state and then proceeds to // read the current state from the cloud and write it out to disk. The locking // protects both the reading and the writing in order to prevent the hazard @@ -251,7 +233,7 @@ func Sync(l *logger.Logger, options vm.ListOptions) (*cloud.Cloud, error) { if !config.Quiet { l.Printf("Syncing...") } - unlock, err := acquireFilesystemLock() + unlock, err := lock.AcquireFilesystemLock(config.DefaultLockPath) if err != nil { return nil, err } @@ -449,7 +431,7 @@ func SQL( if len(c.Nodes) == 1 { return c.ExecOrInteractiveSQL(ctx, l, tenantName, cmdArray) } - return c.ExecSQL(ctx, l, tenantName, cmdArray) + return c.ExecSQL(ctx, l, c.Nodes, tenantName, cmdArray) } // IP gets the ip addresses of the nodes in a cluster. @@ -587,7 +569,7 @@ func SetupSSH(ctx context.Context, l *logger.Logger, clusterName string) error { // Configure SSH for machines in the zones we operate on. if err := vm.ProvidersSequential(providers, func(p vm.Provider) error { - unlock, lockErr := acquireFilesystemLock() + unlock, lockErr := lock.AcquireFilesystemLock(config.DefaultLockPath) if lockErr != nil { return lockErr } @@ -683,6 +665,8 @@ func DefaultStartOpts() install.StartOpts { ScheduleBackups: false, ScheduleBackupArgs: "", InitTarget: 1, + SQLPort: config.DefaultSQLPort, + AdminUIPort: config.DefaultAdminUIPort, } } @@ -941,10 +925,14 @@ func PgURL( var urls []string for i, ip := range ips { + desc, err := c.DiscoverService(nodes[i], opts.TenantName, install.ServiceTypeSQL) + if err != nil { + return nil, err + } if ip == "" { return nil, errors.Errorf("empty ip: %v", ips) } - urls = append(urls, c.NodeURL(ip, c.NodePort(nodes[i]), opts.TenantName)) + urls = append(urls, c.NodeURL(ip, desc.Port, opts.TenantName)) } if len(urls) != len(nodes) { return nil, errors.Errorf("have nodes %v, but urls %v from ips %v", nodes, urls, ips) @@ -958,6 +946,7 @@ type urlConfig struct { openInBrowser bool secure bool port int + tenantName string } func urlGenerator( @@ -978,8 +967,13 @@ func urlGenerator( if uConfig.usePublicIP { host = c.VMs[node-1].PublicIP } - if uConfig.port == 0 { - uConfig.port = c.NodeUIPort(node) + port := uConfig.port + if port == 0 { + desc, err := c.DiscoverService(node, uConfig.tenantName, install.ServiceTypeUI) + if err != nil { + return nil, err + } + port = desc.Port } scheme := "http" if c.Secure { @@ -988,7 +982,7 @@ func urlGenerator( if !strings.HasPrefix(uConfig.path, "/") { uConfig.path = "/" + uConfig.path } - url := fmt.Sprintf("%s://%s:%d%s", scheme, host, uConfig.port, uConfig.path) + url := fmt.Sprintf("%s://%s:%d%s", scheme, host, port, uConfig.path) urls = append(urls, url) if uConfig.openInBrowser { cmd := browserCmd(url) @@ -1019,7 +1013,7 @@ func browserCmd(url string) *exec.Cmd { // AdminURL generates admin UI URLs for the nodes in a cluster. func AdminURL( - l *logger.Logger, clusterName, path string, usePublicIP, openInBrowser, secure bool, + l *logger.Logger, clusterName, tenantName, path string, usePublicIP, openInBrowser, secure bool, ) ([]string, error) { if err := LoadClusters(); err != nil { return nil, err @@ -1033,6 +1027,7 @@ func AdminURL( usePublicIP: usePublicIP, openInBrowser: openInBrowser, secure: secure, + tenantName: tenantName, } return urlGenerator(c, l, c.TargetNodes(), uConfig) } @@ -1080,7 +1075,10 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof err = c.Parallel(ctx, l, c.TargetNodes(), func(ctx context.Context, node install.Node) (*install.RunResultDetails, error) { res := &install.RunResultDetails{Node: node} host := c.Host(node) - port := c.NodeUIPort(node) + port, err := c.NodeUIPort(node) + if err != nil { + return nil, err + } scheme := "http" if c.Secure { scheme = "https" @@ -1335,7 +1333,7 @@ func Create( if isLocal { // To ensure that multiple processes don't create local clusters at // the same time (causing port collisions), acquire the lock file. - unlockFn, err := acquireFilesystemLock() + unlockFn, err := lock.AcquireFilesystemLock(config.DefaultLockPath) if err != nil { return err } @@ -1959,9 +1957,12 @@ func sendCaptureCommand( httpClient := httputil.NewClientWithTimeout(0 /* timeout: None */) _, _, err := c.ParallelE(ctx, l, nodes, func(ctx context.Context, node install.Node) (*install.RunResultDetails, error) { + port, err := c.NodeUIPort(node) + if err != nil { + return nil, err + } res := &install.RunResultDetails{Node: node} host := c.Host(node) - port := c.NodeUIPort(node) scheme := "http" if c.Secure { scheme = "https" diff --git a/pkg/roachprod/vm/BUILD.bazel b/pkg/roachprod/vm/BUILD.bazel index ce1888b19330..1dbcbd888539 100644 --- a/pkg/roachprod/vm/BUILD.bazel +++ b/pkg/roachprod/vm/BUILD.bazel @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "vm", - srcs = ["vm.go"], + srcs = [ + "dns.go", + "vm.go", + ], importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm", visibility = ["//visibility:public"], deps = [ diff --git a/pkg/roachprod/vm/aws/BUILD.bazel b/pkg/roachprod/vm/aws/BUILD.bazel index 9b2d2772711a..fb248dabff68 100644 --- a/pkg/roachprod/vm/aws/BUILD.bazel +++ b/pkg/roachprod/vm/aws/BUILD.bazel @@ -13,7 +13,6 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/aws", visibility = ["//visibility:public"], deps = [ - "//pkg/roachprod/config", "//pkg/roachprod/logger", "//pkg/roachprod/vm", "//pkg/roachprod/vm/flagstub", diff --git a/pkg/roachprod/vm/aws/aws.go b/pkg/roachprod/vm/aws/aws.go index 20bfb1aec427..04434cd91e56 100644 --- a/pkg/roachprod/vm/aws/aws.go +++ b/pkg/roachprod/vm/aws/aws.go @@ -21,7 +21,6 @@ import ( "strings" "time" - "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/flagstub" @@ -984,8 +983,6 @@ func (p *Provider) listRegion( VPC: in.VpcID, MachineType: in.InstanceType, Zone: in.Placement.AvailabilityZone, - SQLPort: config.DefaultSQLPort, - AdminUIPort: config.DefaultAdminUIPort, NonBootAttachedVolumes: nonBootableVolumes, } ret = append(ret, m) diff --git a/pkg/roachprod/vm/azure/BUILD.bazel b/pkg/roachprod/vm/azure/BUILD.bazel index 04ecf03c3054..7092dd5da83a 100644 --- a/pkg/roachprod/vm/azure/BUILD.bazel +++ b/pkg/roachprod/vm/azure/BUILD.bazel @@ -13,7 +13,6 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/azure", visibility = ["//visibility:public"], deps = [ - "//pkg/roachprod/config", "//pkg/roachprod/logger", "//pkg/roachprod/vm", "//pkg/roachprod/vm/flagstub", diff --git a/pkg/roachprod/vm/azure/azure.go b/pkg/roachprod/vm/azure/azure.go index 339405443295..b4e0d1b56f0e 100644 --- a/pkg/roachprod/vm/azure/azure.go +++ b/pkg/roachprod/vm/azure/azure.go @@ -27,7 +27,6 @@ import ( "github.com/Azure/azure-sdk-for-go/profiles/latest/resources/mgmt/subscriptions" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/to" - "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/flagstub" @@ -494,9 +493,7 @@ func (p *Provider) List(l *logger.Logger, opts vm.ListOptions) (vm.List, error) MachineType: string(found.HardwareProfile.VMSize), // We add a fake availability-zone suffix since other roachprod // code assumes particular formats. For example, "eastus2z". - Zone: *found.Location + "z", - SQLPort: config.DefaultSQLPort, - AdminUIPort: config.DefaultAdminUIPort, + Zone: *found.Location + "z", } if createdPtr := found.Tags[vm.TagCreated]; createdPtr == nil { diff --git a/pkg/roachprod/vm/dns.go b/pkg/roachprod/vm/dns.go new file mode 100644 index 000000000000..8e4282de1df4 --- /dev/null +++ b/pkg/roachprod/vm/dns.go @@ -0,0 +1,146 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package vm + +import ( + "fmt" + "net" + "regexp" + "strconv" + + "github.com/cockroachdb/errors" + "golang.org/x/sync/errgroup" +) + +const DNSRecordTTL = 60 + +var srvRe = regexp.MustCompile(`(\d+)\s+(\d+)\s+(\d+)\s+(\S+)$`) + +// DNSType represents a DNS record type. +type DNSType string + +const ( + A DNSType = "A" + SRV DNSType = "SRV" +) + +// DNSRecord represents a DNS record. +type DNSRecord struct { + // Name is the name of the DNS record. + Name string `json:"name"` + // Type is the type of the DNS record. + Type DNSType `json:"type"` + // Data is the data of the DNS record. + Data string `json:"data"` + // TTL is the time to live of the DNS record. + TTL int `json:"TTL"` +} + +// DNSProvider is an optional capability for a Provider that provides DNS +// management services. +type DNSProvider interface { + CreateRecords(records ...DNSRecord) error + LookupSRVRecords(service, proto, subdomain string) ([]DNSRecord, error) + DeleteRecordsBySubdomain(subdomain string) error + Domain() string +} + +// FanOutDNS collates a collection of VMs by their DNS providers and invoke the +// callbacks in parallel. This function is lenient and skips VMs that do not +// have a DNS provider or if the provider is not a DNSProvider. +func FanOutDNS(list List, action func(DNSProvider, List) error) error { + var m = map[string]List{} + for _, vm := range list { + // We allow DNSProvider to be empty, in which case we don't do anything. + if vm.DNSProvider == "" { + continue + } + m[vm.DNSProvider] = append(m[vm.DNSProvider], vm) + } + + var g errgroup.Group + for name, vms := range m { + // capture loop variables + n := name + v := vms + g.Go(func() error { + p, ok := Providers[n] + if !ok { + return errors.Errorf("unknown provider name: %s", n) + } + dnsProvider, ok := p.(DNSProvider) + if !ok { + return errors.Errorf("provider %s is not a DNS provider", n) + } + return action(dnsProvider, v) + }) + } + + return g.Wait() +} + +// ForDNSProvider resolves the DNSProvider with the given name and executes the +// action. +func ForDNSProvider(named string, action func(DNSProvider) error) error { + if named == "" { + return errors.New("no DNS provider specified") + } + p, ok := Providers[named] + if !ok { + return errors.Errorf("unknown vm provider: %s", named) + } + dnsProvider, ok := p.(DNSProvider) + if !ok { + return errors.Errorf("provider %s is not a DNS provider", named) + } + if err := action(dnsProvider); err != nil { + return errors.Wrapf(err, "in provider: %s", named) + } + return nil +} + +// CreateDNSRecord creates a new DNS record. +func CreateDNSRecord(name string, dnsType DNSType, data string, ttl int) DNSRecord { + return DNSRecord{ + Name: name, + Type: dnsType, + Data: data, + TTL: ttl, + } +} + +// CreateSRVRecord creates a new SRV DNS record. +func CreateSRVRecord(name string, data net.SRV) DNSRecord { + dataStr := fmt.Sprintf("%d %d %d %s", data.Priority, data.Weight, data.Port, data.Target) + return CreateDNSRecord(name, SRV, dataStr, DNSRecordTTL) +} + +// ParseSRVRecord parses the data field in a DNS record. An SRV data struct is +// returned if the DNS record is an SRV record, otherwise an error is returned. +func (record DNSRecord) ParseSRVRecord() (*net.SRV, error) { + if record.Type != SRV { + return nil, fmt.Errorf("record is not an SRV record") + } + matches := srvRe.FindStringSubmatch(record.Data) + if len(matches) != 5 { + return nil, fmt.Errorf("invalid SRV record: %s", record.Data) + } + data := &net.SRV{} + data.Target = matches[4] + for i, field := range []*uint16{&data.Priority, &data.Weight, &data.Port} { + v, err := strconv.Atoi(matches[i+1]) + *field = uint16(v) + if err != nil { + return nil, err + } + } + return data, nil +} diff --git a/pkg/roachprod/vm/gce/gcloud.go b/pkg/roachprod/vm/gce/gcloud.go index ea2dea4a1c2f..570530a23113 100644 --- a/pkg/roachprod/vm/gce/gcloud.go +++ b/pkg/roachprod/vm/gce/gcloud.go @@ -238,8 +238,6 @@ func (jsonVM *jsonVM) toVM( MachineType: machineType, Zone: zone, Project: project, - SQLPort: config.DefaultSQLPort, - AdminUIPort: config.DefaultAdminUIPort, NonBootAttachedVolumes: volumes, LocalDisks: localDisks, } diff --git a/pkg/roachprod/vm/local/BUILD.bazel b/pkg/roachprod/vm/local/BUILD.bazel index ebe7e0b3e9cf..f875aae17348 100644 --- a/pkg/roachprod/vm/local/BUILD.bazel +++ b/pkg/roachprod/vm/local/BUILD.bazel @@ -1,18 +1,34 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "local", - srcs = ["local.go"], + srcs = [ + "dns.go", + "local.go", + ], importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/local", visibility = ["//visibility:public"], deps = [ "//pkg/roachprod/cloud", "//pkg/roachprod/config", + "//pkg/roachprod/lock", "//pkg/roachprod/logger", "//pkg/roachprod/vm", - "//pkg/util/intsets", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_errors//oserror", "@com_github_spf13_pflag//:pflag", + "@org_golang_x_exp//maps", + ], +) + +go_test( + name = "local_test", + srcs = ["dns_test.go"], + args = ["-test.timeout=295s"], + embed = [":local"], + deps = [ + "//pkg/roachprod/vm", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/roachprod/vm/local/dns.go b/pkg/roachprod/vm/local/dns.go new file mode 100644 index 000000000000..0065e63e53ca --- /dev/null +++ b/pkg/roachprod/vm/local/dns.go @@ -0,0 +1,164 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package local + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path" + "regexp" + + "github.com/cockroachdb/cockroach/pkg/roachprod/lock" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/errors/oserror" + "golang.org/x/exp/maps" +) + +var _ vm.DNSProvider = &dnsProvider{} + +// dnsProvider implements the vm.DNSProvider interface. +type dnsProvider struct { + configDir string + lockFilePath string + zone string +} + +func NewDNSProvider(configDir, zone string) vm.DNSProvider { + return &dnsProvider{configDir: configDir, lockFilePath: path.Join(configDir, "DNS_LOCK"), zone: zone} +} + +// Domain is part of the vm.DNSProvider interface. +func (n *dnsProvider) Domain() string { + return n.zone +} + +// CreateRecords is part of the vm.DNSProvider interface. +func (n *dnsProvider) CreateRecords(records ...vm.DNSRecord) error { + unlock, err := lock.AcquireFilesystemLock(n.lockFilePath) + if err != nil { + return err + } + defer unlock() + + entries, err := n.loadRecords() + if err != nil { + return err + } + for _, record := range records { + key := dnsKey(record) + entries[key] = record + } + return n.saveRecords(entries) +} + +// LookupSRVRecords is part of the vm.DNSProvider interface. +func (n *dnsProvider) LookupSRVRecords(service, proto, subdomain string) ([]vm.DNSRecord, error) { + records, err := n.loadRecords() + if err != nil { + return nil, err + } + name := fmt.Sprintf("_%s._%s.%s.%s", service, proto, subdomain, n.Domain()) + var matchingRecords []vm.DNSRecord + for _, record := range records { + if record.Name == name && record.Type == vm.SRV { + matchingRecords = append(matchingRecords, record) + } + } + return matchingRecords, nil +} + +// DeleteRecordsBySubdomain is part of the vm.DNSProvider interface. +func (n *dnsProvider) DeleteRecordsBySubdomain(subdomain string) error { + unlock, err := lock.AcquireFilesystemLock(n.lockFilePath) + if err != nil { + return err + } + defer unlock() + + re := regexp.MustCompile(fmt.Sprintf(`.*\.%s\.%s$`, subdomain, n.Domain())) + entries, err := n.loadRecords() + if err != nil { + return err + } + for key, record := range entries { + if re.MatchString(record.Name) { + delete(entries, key) + } + } + return n.saveRecords(entries) +} + +// saveRecords saves the given records to a local DNS cache file. +func (n *dnsProvider) saveRecords(recordEntries map[string]vm.DNSRecord) error { + var b bytes.Buffer + enc := json.NewEncoder(&b) + enc.SetIndent("", " ") + records := maps.Values(recordEntries) + if err := enc.Encode(&records); err != nil { + return err + } + + // Other roachprod processes might be accessing the cache files at the same + // time, so we need to write the file atomically by writing to a temporary + // file and renaming. We store the temporary file in the same directory so + // that it can always be renamed. + tmpFile, err := os.CreateTemp(os.ExpandEnv(n.configDir), n.zone+".tmp") + if err != nil { + return err + } + + _, err = tmpFile.Write(b.Bytes()) + err = errors.CombineErrors(err, tmpFile.Sync()) + err = errors.CombineErrors(err, tmpFile.Close()) + if err == nil { + err = os.Rename(tmpFile.Name(), n.dnsFileName()) + } + if err != nil { + _ = os.Remove(tmpFile.Name()) + return err + } + return nil +} + +// loadRecords loads the DNS records from the local DNS cache file. +func (n *dnsProvider) loadRecords() (map[string]vm.DNSRecord, error) { + data, err := os.ReadFile(n.dnsFileName()) + recordEntries := make(map[string]vm.DNSRecord, 0) + if err != nil { + // It is expected that the file might not exist yet if no records have been + // created before. In this case, return an empty map. + if oserror.IsNotExist(err) { + return recordEntries, nil + } + return nil, err + } + records := make([]vm.DNSRecord, 0) + if err := json.Unmarshal(data, &records); err != nil { + return nil, err + } + for _, record := range records { + recordEntries[dnsKey(record)] = record + } + return recordEntries, nil +} + +// dnsFileName returns the name of the local file storing DNS records. +func (n *dnsProvider) dnsFileName() string { + return path.Join(os.ExpandEnv(n.configDir), n.zone+".json") +} + +// dnsKey returns a unique key for the given DNS record. +func dnsKey(record vm.DNSRecord) string { + return fmt.Sprintf("%s:%s:%s", record.Name, record.Type, record.Data) +} diff --git a/pkg/roachprod/vm/local/dns_test.go b/pkg/roachprod/vm/local/dns_test.go new file mode 100644 index 000000000000..9ef59a4c1617 --- /dev/null +++ b/pkg/roachprod/vm/local/dns_test.go @@ -0,0 +1,74 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package local + +import ( + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/stretchr/testify/require" +) + +type dnsTestRec [2]string + +func createTestDNSRecords(testRecords ...dnsTestRec) []vm.DNSRecord { + var dnsRecords []vm.DNSRecord + for _, v := range testRecords { + dnsRecords = append(dnsRecords, vm.DNSRecord{ + Name: v[0], + Type: vm.SRV, + Data: v[1], + TTL: 60, + }) + } + return dnsRecords +} + +func createTestDNSProvider(t *testing.T, testRecords ...dnsTestRec) vm.DNSProvider { + p := NewDNSProvider(t.TempDir(), "local-zone") + err := p.CreateRecords(createTestDNSRecords(testRecords...)...) + require.NoError(t, err) + return p +} + +func TestLookupRecords(t *testing.T) { + p := createTestDNSProvider(t, []dnsTestRec{ + {"_system-sql._tcp.local.local-zone", "0 1000 29001 local-0001.local-zone"}, + {"_system-sql._tcp.local.local-zone", "0 1000 29002 local-0002.local-zone"}, + {"_system-sql._tcp.local.local-zone", "0 1000 29003 local-0003.local-zone"}, + {"_tenant-1-sql._tcp.local.local-zone", "5 50 29004 local-0001.local-zone"}, + {"_tenant-2-sql._tcp.local.local-zone", "5 50 29005 local-0002.local-zone"}, + {"_tenant-3-sql._tcp.local.local-zone", "5 50 29006 local-0003.local-zone"}, + }...) + + t.Run("lookup system", func(t *testing.T) { + records, err := p.LookupSRVRecords("system-sql", "tcp", "local") + require.NoError(t, err) + require.Equal(t, 3, len(records)) + for _, r := range records { + require.True(t, strings.HasPrefix(r.Name, "_system-sql")) + require.Equal(t, vm.SRV, r.Type) + } + }) + + t.Run("parse SRV data", func(t *testing.T) { + records, err := p.LookupSRVRecords("tenant-1-sql", "tcp", "local") + require.NoError(t, err) + require.Equal(t, 1, len(records)) + data, err := records[0].ParseSRVRecord() + require.NoError(t, err) + require.Equal(t, uint16(5), data.Priority) + require.Equal(t, uint16(50), data.Weight) + require.Equal(t, uint16(29004), data.Port) + }) + +} diff --git a/pkg/roachprod/vm/local/local.go b/pkg/roachprod/vm/local/local.go index a006a0c5fb03..78959a8f0e38 100644 --- a/pkg/roachprod/vm/local/local.go +++ b/pkg/roachprod/vm/local/local.go @@ -21,7 +21,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" - "github.com/cockroachdb/cockroach/pkg/util/intsets" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "github.com/spf13/pflag" @@ -58,8 +57,9 @@ func VMDir(clusterName string, nodeIdx int) string { // Init initializes the Local provider and registers it into vm.Providers. func Init(storage VMStorage) error { vm.Providers[ProviderName] = &Provider{ - clusters: make(cloud.Clusters), - storage: storage, + clusters: make(cloud.Clusters), + storage: storage, + DNSProvider: NewDNSProvider(config.DNSDir, "local-zone"), } return nil } @@ -93,7 +93,8 @@ func DeleteCluster(l *logger.Logger, name string) error { } delete(p.clusters, name) - return nil + + return p.DeleteRecordsBySubdomain(c.Name) } // Clusters returns a list of all known local clusters. @@ -116,8 +117,8 @@ type VMStorage interface { // A Provider is used to create stub VM objects. type Provider struct { clusters cloud.Clusters - - storage VMStorage + storage VMStorage + vm.DNSProvider } func (p *Provider) CreateVolumeSnapshot( @@ -197,30 +198,6 @@ func (p *Provider) Create( return errors.Errorf("'%s' is not a valid local cluster name", c.Name) } - // We will need to assign ports to the nodes, and they must not conflict with - // any other local clusters. - var portsTaken intsets.Fast - for _, c := range p.clusters { - for i := range c.VMs { - portsTaken.Add(c.VMs[i].SQLPort) - portsTaken.Add(c.VMs[i].AdminUIPort) - } - } - sqlPort := config.DefaultSQLPort - adminUIPort := config.DefaultAdminUIPort - - // getPort returns the first available port (starting at *port), and modifies - // (*port) to be the following value. - getPort := func(port *int) int { - for portsTaken.Contains(*port) { - (*port)++ - } - result := *port - portsTaken.Add(result) - (*port)++ - return result - } - for i := range names { c.VMs[i] = vm.VM{ Name: "localhost", @@ -228,14 +205,14 @@ func (p *Provider) Create( Lifetime: time.Hour, PrivateIP: "127.0.0.1", Provider: ProviderName, + DNSProvider: ProviderName, ProviderID: ProviderName, PublicIP: "127.0.0.1", + PublicDNS: "localhost", RemoteUser: config.OSUser.Username, VPC: ProviderName, MachineType: ProviderName, Zone: ProviderName, - SQLPort: getPort(&sqlPort), - AdminUIPort: getPort(&adminUIPort), LocalClusterName: c.Name, } path := VMDir(c.Name, i+1) diff --git a/pkg/roachprod/vm/vm.go b/pkg/roachprod/vm/vm.go index be04d5ebed2e..e2940048e362 100644 --- a/pkg/roachprod/vm/vm.go +++ b/pkg/roachprod/vm/vm.go @@ -79,6 +79,12 @@ type VM struct { Labels map[string]string `json:"labels"` // The provider-internal DNS name for the VM instance DNS string `json:"dns"` + + // PublicDNS is the public DNS name that can be used to connect to the VM. + PublicDNS string `json:"public_dns"` + // The DNS provider to use for DNS operations performed for this VM. + DNSProvider string `json:"dns_provider"` + // The name of the cloud provider that hosts the VM instance Provider string `json:"provider"` // The provider-specific id for the instance. This may or may not be the same as Name, depending @@ -99,16 +105,6 @@ type VM struct { // cloud that supports project (i.e. GCE). Empty otherwise. Project string `json:"project"` - // SQLPort is the port on which the cockroach process is listening for SQL - // connections. - // Usually config.DefaultSQLPort, except for local clusters. - SQLPort int `json:"sql_port"` - - // AdminUIPort is the port on which the cockroach process is listening for - // HTTP traffic for the Admin UI. - // Usually config.DefaultAdminUIPort, except for local clusters. - AdminUIPort int `json:"adminui_port"` - // LocalClusterName is only set for VMs in a local cluster. LocalClusterName string `json:"local_cluster_name,omitempty"`