From 65c30a386376b0f8a6c10a77722cba9c399b57b3 Mon Sep 17 00:00:00 2001 From: Darin Peshev Date: Sat, 17 Apr 2021 14:46:30 -0700 Subject: [PATCH] ccl/sqlproxyccl: directory proto and test server * Defines a new interface between a tenant directory client and server * Moves the tenant directory from the CC repo over * Tenant directory modified to use the new interface * Tenant directory modified to use stop.Stopper * Modified pod watcher to use streaming grpc call * Renamed package from directory to tenant, tenantID to ID etc * Renamed references to k8s, pods to server, endpoints etc * Prevents test tenant servers to start for non-existing/inactive tenants * Adds ability to shut down individual tenant servers within a test cluster * Adds a test directory server that can start/stop tenants * Adds tests of the directory running agianst the test server * Allow insecure connections from test tenant to KV server * Fixed a race in kvtenantccl Release note: None --- pkg/BUILD.bazel | 1 + pkg/base/test_server_args.go | 4 + pkg/ccl/backupccl/backup_test.go | 2 +- pkg/ccl/kvccl/kvtenantccl/connector.go | 24 +- pkg/ccl/serverccl/server_sql_test.go | 43 +- pkg/ccl/sqlproxyccl/tenant/BUILD.bazel | 75 + pkg/ccl/sqlproxyccl/tenant/directory.go | 363 ++++ pkg/ccl/sqlproxyccl/tenant/directory.pb.go | 1815 +++++++++++++++++ pkg/ccl/sqlproxyccl/tenant/directory.proto | 92 + pkg/ccl/sqlproxyccl/tenant/directory_test.go | 389 ++++ pkg/ccl/sqlproxyccl/tenant/entry.go | 298 +++ pkg/ccl/sqlproxyccl/tenant/main_test.go | 34 + pkg/ccl/sqlproxyccl/tenant/mocks_generated.go | 240 +++ .../sqlproxyccl/tenant/test_directory_svr.go | 352 ++++ pkg/server/testserver.go | 17 +- pkg/sql/logictest/logic.go | 2 +- pkg/testutils/lint/lint_test.go | 1 + pkg/testutils/serverutils/test_server_shim.go | 4 +- 18 files changed, 3719 insertions(+), 37 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/tenant/BUILD.bazel create mode 100644 pkg/ccl/sqlproxyccl/tenant/directory.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/directory.pb.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/directory.proto create mode 100644 pkg/ccl/sqlproxyccl/tenant/directory_test.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/entry.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/main_test.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/mocks_generated.go create mode 100644 pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index adb0a1c60619..584214f22f53 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -24,6 +24,7 @@ ALL_TESTS = [ "//pkg/ccl/partitionccl:partitionccl_test", "//pkg/ccl/serverccl:serverccl_test", "//pkg/ccl/sqlproxyccl/cache:cache_test", + "//pkg/ccl/sqlproxyccl/tenant:tenant_test", "//pkg/ccl/sqlproxyccl:sqlproxyccl_test", "//pkg/ccl/storageccl/engineccl:engineccl_test", "//pkg/ccl/storageccl:storageccl_test", diff --git a/pkg/base/test_server_args.go b/pkg/base/test_server_args.go index cb39038fb73d..d731317bdad3 100644 --- a/pkg/base/test_server_args.go +++ b/pkg/base/test_server_args.go @@ -239,4 +239,8 @@ type TestTenantArgs struct { // TestingKnobs for the test server. TestingKnobs TestingKnobs + + // Test server starts with secure mode by default. When this is set to true + // it will switch to insecure + ForceInsecure bool } diff --git a/pkg/ccl/backupccl/backup_test.go b/pkg/ccl/backupccl/backup_test.go index c7da325735e4..a28c044716f8 100644 --- a/pkg/ccl/backupccl/backup_test.go +++ b/pkg/ccl/backupccl/backup_test.go @@ -3594,7 +3594,7 @@ func TestBackupTenantsWithRevisionHistory(t *testing.T) { ctx, tc, sqlDB, _, cleanupFn := BackupRestoreTestSetup(t, singleNode, numAccounts, InitManualReplication) defer cleanupFn() - _, err := tc.Servers[0].StartTenant(base.TestTenantArgs{TenantID: roachpb.MakeTenantID(10)}) + _, err := tc.Servers[0].StartTenant(ctx, base.TestTenantArgs{TenantID: roachpb.MakeTenantID(10)}) require.NoError(t, err) const msg = "can not backup tenants with revision history" diff --git a/pkg/ccl/kvccl/kvtenantccl/connector.go b/pkg/ccl/kvccl/kvtenantccl/connector.go index db6876362fb7..51381ee7071e 100644 --- a/pkg/ccl/kvccl/kvtenantccl/connector.go +++ b/pkg/ccl/kvccl/kvtenantccl/connector.go @@ -365,12 +365,20 @@ func (c *Connector) getClient(ctx context.Context) (roachpb.InternalClient, erro dialCtx := c.AnnotateCtx(context.Background()) dialCtx, cancel := c.rpcContext.Stopper.WithCancelOnQuiesce(dialCtx) defer cancel() - err := c.rpcContext.Stopper.RunTaskWithErr(dialCtx, "kvtenant.Connector: dial", c.dialAddrs) + var client roachpb.InternalClient + err := c.rpcContext.Stopper.RunTaskWithErr(dialCtx, "kvtenant.Connector: dial", + func(ctx context.Context) error { + var err error + client, err = c.dialAddrs(ctx) + return err + }) if err != nil { return nil, err } - // NB: read lock not needed. - return c.mu.client, nil + c.mu.Lock() + defer c.mu.Unlock() + c.mu.client = client + return client, nil }) c.mu.RUnlock() @@ -387,7 +395,7 @@ func (c *Connector) getClient(ctx context.Context) (roachpb.InternalClient, erro // dialAddrs attempts to dial each of the configured addresses in a retry loop. // The method will only return a non-nil error on context cancellation. -func (c *Connector) dialAddrs(ctx context.Context) error { +func (c *Connector) dialAddrs(ctx context.Context) (roachpb.InternalClient, error) { for r := retry.StartWithCtx(ctx, c.rpcRetryOptions); r.Next(); { // Try each address on each retry iteration. randStart := rand.Intn(len(c.addrs)) @@ -398,14 +406,10 @@ func (c *Connector) dialAddrs(ctx context.Context) error { log.Warningf(ctx, "error dialing tenant KV address %s: %v", addr, err) continue } - client := roachpb.NewInternalClient(conn) - c.mu.Lock() - c.mu.client = client - c.mu.Unlock() - return nil + return roachpb.NewInternalClient(conn), nil } } - return ctx.Err() + return nil, ctx.Err() } func (c *Connector) dialAddr(ctx context.Context, addr string) (conn *grpc.ClientConn, err error) { diff --git a/pkg/ccl/serverccl/server_sql_test.go b/pkg/ccl/serverccl/server_sql_test.go index f66df08503ac..9003b0f6c7a6 100644 --- a/pkg/ccl/serverccl/server_sql_test.go +++ b/pkg/ccl/serverccl/server_sql_test.go @@ -93,15 +93,16 @@ func TestTenantUnauthenticatedAccess(t *testing.T) { tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) defer tc.Stopper().Stop(ctx) - _, err := tc.Server(0).StartTenant(base.TestTenantArgs{ - TenantID: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[0]), - TestingKnobs: base.TestingKnobs{ - TenantTestingKnobs: &sql.TenantTestingKnobs{ - // Configure the SQL server to access the wrong tenant keyspace. - TenantIDCodecOverride: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[1]), + _, err := tc.Server(0).StartTenant(ctx, + base.TestTenantArgs{ + TenantID: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[0]), + TestingKnobs: base.TestingKnobs{ + TenantTestingKnobs: &sql.TenantTestingKnobs{ + // Configure the SQL server to access the wrong tenant keyspace. + TenantIDCodecOverride: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[1]), + }, }, - }, - }) + }) require.Error(t, err) require.Regexp(t, `Unauthenticated desc = requested key .* not fully contained in tenant keyspace /Tenant/1{0-1}`, err) } @@ -115,9 +116,10 @@ func TestTenantHTTP(t *testing.T) { tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) defer tc.Stopper().Stop(ctx) - tenant, err := tc.Server(0).StartTenant(base.TestTenantArgs{ - TenantID: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[0]), - }) + tenant, err := tc.Server(0).StartTenant(ctx, + base.TestTenantArgs{ + TenantID: roachpb.MakeTenantID(security.EmbeddedTenantIDs()[0]), + }) require.NoError(t, err) t.Run("prometheus", func(t *testing.T) { resp, err := httputil.Get(ctx, "http://"+tenant.HTTPAddr()+"/_status/vars") @@ -150,16 +152,17 @@ func TestIdleExit(t *testing.T) { warmupDuration := 500 * time.Millisecond countdownDuration := 4000 * time.Millisecond - tenant, err := tc.Server(0).StartTenant(base.TestTenantArgs{ - TenantID: roachpb.MakeTenantID(10), - IdleExitAfter: warmupDuration, - TestingKnobs: base.TestingKnobs{ - TenantTestingKnobs: &sql.TenantTestingKnobs{ - IdleExitCountdownDuration: countdownDuration, + tenant, err := tc.Server(0).StartTenant(ctx, + base.TestTenantArgs{ + TenantID: roachpb.MakeTenantID(10), + IdleExitAfter: warmupDuration, + TestingKnobs: base.TestingKnobs{ + TenantTestingKnobs: &sql.TenantTestingKnobs{ + IdleExitCountdownDuration: countdownDuration, + }, }, - }, - Stopper: tc.Stopper(), - }) + Stopper: tc.Stopper(), + }) require.NoError(t, err) diff --git a/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel new file mode 100644 index 000000000000..1b20e21f2350 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/BUILD.bazel @@ -0,0 +1,75 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") + +proto_library( + name = "tenant_proto", + srcs = ["directory.proto"], + strip_import_prefix = "/pkg", + visibility = ["//visibility:public"], + deps = ["@com_github_gogo_protobuf//gogoproto:gogo_proto"], +) + +go_proto_library( + name = "tenant_go_proto", + compilers = ["//pkg/cmd/protoc-gen-gogoroach:protoc-gen-gogoroach_grpc_compiler"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant", + proto = ":tenant_proto", + visibility = ["//visibility:public"], + deps = ["@com_github_gogo_protobuf//gogoproto"], +) + +go_library( + name = "tenant", + srcs = [ + "directory.go", + "entry.go", + "mocks_generated.go", + "test_directory_svr.go", + ], + embed = [":tenant_go_proto"], + importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant", + visibility = ["//visibility:public"], + deps = [ + "//pkg/roachpb", + "//pkg/util/grpcutil", + "//pkg/util/log", + "//pkg/util/stop", + "//pkg/util/syncutil", + "//pkg/util/timeutil", + "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", + "@com_github_golang_mock//gomock", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//metadata", + "@org_golang_google_grpc//status", + ], +) + +go_test( + name = "tenant_test", + srcs = [ + "directory_test.go", + "main_test.go", + ], + embed = [":tenant"], + deps = [ + "//pkg/base", + "//pkg/ccl", + "//pkg/ccl/utilccl", + "//pkg/roachpb", + "//pkg/security", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql", + "//pkg/testutils/serverutils", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/randutil", + "//pkg/util/stop", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory.go b/pkg/ccl/sqlproxyccl/tenant/directory.go new file mode 100644 index 000000000000..efc46470a262 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/directory.go @@ -0,0 +1,363 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenant + +import ( + "context" + "io" + "sync" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" + "google.golang.org/grpc/status" +) + +//go:generate mockgen -package=tenant -destination=mocks_generated.go . DirectoryClient,Directory_WatchEndpointsClient + +// dirOptions control the behavior of tenant.Directory. +type dirOptions struct { + deterministic bool + refreshDelay time.Duration +} + +// DirOption defines an option that can be passed to tenant.Directory in order +// to control its behavior. +type DirOption func(opts *dirOptions) + +// RefreshDelay specifies the minimum amount of time that must elapse between +// attempts to refresh endpoints for a given tenant after ReportFailure is called. +// This delay has the effect of throttling calls to directory server, in order to +// avoid overloading it. +// +// RefreshDelay defaults to 100ms. Use -1 to never throttle. +func RefreshDelay(delay time.Duration) func(opts *dirOptions) { + return func(opts *dirOptions) { + opts.refreshDelay = delay + } +} + +// Directory tracks the network locations of SQL tenant processes. It is used by the +// sqlproxy to route incoming traffic to the correct backend process. Process +// information is populated and kept relatively up-to-date using a streaming watcher. +// However, since watchers deliver slightly stale information, the directory will also make +// direct server calls to fetch the latest information about a process that is not yet +// in the cache, or when a process is suspected to have failed. When a new tenant is +// created, or is resumed from suspension, this capability allows the directory +// to immediately return the IP address for the new process. +// +// All methods in the directory are thread-safe. Methods are intended to be +// called concurrently by many threads at once, and so locking is carefully +// designed to minimize contention. While a lock shared across tenants is used +// to synchronize access to shared in-memory data structures, each tenant also +// has its own locks that are used to synchronize per-tenant operations such as +// making directory server calls to fetch updated tenant information. +type Directory struct { + // client is the directory client instance used to make directory server calls. + client DirectoryClient + + // stopper use used for graceful shutdown of the endpoint watcher. + stopper *stop.Stopper + + // options control how the environment operates. + options dirOptions + + // mut synchronizes access to the in-memory tenant entry caches. Take care + // to never hold this lock during directory server calls - it should only be used + // while adding and removing tenant entries to/from the caches. + mut struct { + syncutil.Mutex + + // tenants is a cache of tenant entries. Each entry tracks available IP + // addresses for SQL processes for a given tenant. Entries may not be + // fully initialized. + tenants map[roachpb.TenantID]*tenantEntry + } +} + +// NewDirectory constructs a new Directory instance that tracks SQL tenant processes +// managed by a given Directory server. The given context is used for tracing +// endpoint watcher activity. +// +// NOTE: stopper.Stop must be called on the directory when it is no longer needed. +func NewDirectory( + ctx context.Context, stopper *stop.Stopper, client DirectoryClient, opts ...DirOption, +) (*Directory, error) { + dir := &Directory{client: client, stopper: stopper} + + dir.mut.tenants = make(map[roachpb.TenantID]*tenantEntry) + for _, opt := range opts { + opt(&dir.options) + } + if dir.options.refreshDelay == 0 { + // Default to a delay of 100ms between refresh attempts for a given tenant. + dir.options.refreshDelay = 100 * time.Millisecond + } + + // Starts the endpoint watcher and then returns + if err := dir.watchEndpoints(ctx, stopper); err != nil { + return nil, err + } + + return dir, nil +} + +// EnsureTenantIP returns the IP address of one of the given tenant's SQL processes. +// If the tenant was just created or is suspended, such that there are no +// available processes, then EnsureTenantIP will trigger resumption of a new instance and +// block until the process is ready. If there are multiple processes for +// the tenant, then LookupTenantIPs will choose one of them (note that currently +// there is always at most one SQL process per tenant). +// +// If clusterName is non-empty, then an error is returned if no endpoints match the +// cluster name. This can be used to ensure that the incoming SQL connection "knows" +// some additional information about the tenant, such as the name of the +// cluster, before being allowed to connect. +func (d *Directory) EnsureTenantIP( + ctx context.Context, tenantID roachpb.TenantID, clusterName string, +) (string, error) { + // Ensure that a directory entry has been created for this tenant. + entry, err := d.getEntry(ctx, tenantID, true /* allowCreate */) + if err != nil { + return "", err + } + + // Check the cluster name matches, if specified. + if clusterName != "" && clusterName != entry.ClusterName { + log.Errorf(ctx, "cluster name %s doesn't match expected %s", clusterName, entry.ClusterName) + return "", errors.New("not found") + } + ctx, _ = d.stopper.WithCancelOnQuiesce(ctx) + ip, err := entry.ChooseEndpointIP(ctx, d.client, d.options.deterministic) + if err != nil { + if s, ok := status.FromError(err); ok && s.Message() == "not found" { + d.deleteEntry(tenantID) + } + } + return ip, err +} + +// LookupTenantIPs returns the IP addresses for all available SQL processes for the +// given tenant. It returns an error if the tenant has not yet been created. If +// no processes are available for the tenant, LookupTenantIPs will return the empty +// set (unlike EnsureTenantIP). +func (d *Directory) LookupTenantIPs( + ctx context.Context, tenantID roachpb.TenantID, +) ([]string, error) { + // Ensure that a directory entry has been created for this tenant. + entry, err := d.getEntry(ctx, tenantID, false /* allowCreate */) + if err != nil { + return nil, err + } + + if entry == nil { + return nil, errors.New("not found") + } + return entry.getEndpointIPs(), nil +} + +// ReportFailure should be called when attempts to connect to a particular SQL +// tenant endpoint have failed. Since this could be due to a failed process, +// ReportFailure will attempt to refresh the cache with the latest information +// about available tenant processes. +// TODO(andyk): In the future, the ip parameter will be used to mark a +// particular endpoint as "unhealthy" so that it's less likely to be chosen. However, +// today there can be at most one endpoint for a given tenant, so it must always be +// chosen. Keep the parameter as a placeholder for the future. +func (d *Directory) ReportFailure(ctx context.Context, tenantID roachpb.TenantID, ip string) error { + entry, err := d.getEntry(ctx, tenantID, false /* allowCreate */) + if err != nil { + return err + } else if entry == nil { + // If no tenant is in the cache, no-op. + return nil + } + + // Refresh the entry in case there is a new endpoint IP address. + return entry.RefreshEndpoints(ctx, d.client) +} + +// getEntry returns a directory entry for the given tenant. If the directory +// does not contain such an entry, then getEntry will create one if allowCreate +// is true. Otherwise, it returns nil. If an entry is returned, then getEntry +// ensures that it is fully initialized with tenant metadata. Obtaining this +// metadata requires making a separate directory server call; +// getEntry will block until that's complete. +func (d *Directory) getEntry( + ctx context.Context, tenantID roachpb.TenantID, allowCreate bool, +) (*tenantEntry, error) { + entry := func() *tenantEntry { + // Acquire the directory lock just long enough to check the tenants map + // for the given tenant ID. Don't complete initialization while holding + // this lock, since that requires directory server calls. + d.mut.Lock() + defer d.mut.Unlock() + + entry, ok := d.mut.tenants[tenantID] + if ok { + // Entry exists, so return it. + return entry + } + + if !allowCreate { + // No entry, but not allowed to create one, so done. + return nil + } + + // Create the tenant entry and enter it into the tenants map. + log.Infof(ctx, "creating directory entry for tenant %d", tenantID) + entry = &tenantEntry{TenantID: tenantID, RefreshDelay: d.options.refreshDelay} + d.mut.tenants[tenantID] = entry + return entry + }() + + if entry == nil { + return nil, nil + } + + // Initialize the entry now if not yet done. + err := entry.Initialize(ctx, d.client) + if err != nil { + // Remove the entry from the tenants map, since initialization failed. + d.mut.Lock() + defer d.mut.Unlock() + + // Threads can race to add/remove entries, so ensure that right entry is + // removed. + existing, ok := d.mut.tenants[tenantID] + if ok && entry == existing { + delete(d.mut.tenants, tenantID) + log.Infof(ctx, "error initializing tenant %d: %v", tenantID, err) + } + + return nil, err + } + + return entry, nil +} + +// deleteEntry removes the directory entry for the given tenant, if it exists. +func (d *Directory) deleteEntry(tenantID roachpb.TenantID) { + d.mut.Lock() + defer d.mut.Unlock() + delete(d.mut.tenants, tenantID) +} + +// watchEndpoints establishes a watcher that looks for changes to tenant endpoint addresses. +// Whenever tenant processes start or terminate, the watcher will get +// a notification and update the directory to reflect that change. +func (d *Directory) watchEndpoints(ctx context.Context, stopper *stop.Stopper) error { + req := WatchEndpointsRequest{} + + // The loop that processes the event stream is running in a separate go routine. + // It is desirable however, before we return, to have a guarantee that the + // separate go routine started processing events. This wait group helps us + // achieve this. Without the wait group, it will be possible to: + // + // 1. call watchEndpoints + // 2. call EnsureTenantIP + // 3. wait forever to receive notification about the tenant that just started. + // + // and the reason why the notification may not ever arrive is because the + // watchEndpoints go routine can start listening after the server started the + // tenant and sent notification. + var waitInit sync.WaitGroup + waitInit.Add(1) + + err := stopper.RunAsyncTask(ctx, "watch-endpoints-client", func(ctx context.Context) { + var client Directory_WatchEndpointsClient + var err error + firstRun := true + ctx, _ = stopper.WithCancelOnQuiesce(ctx) + + for { + if client == nil { + client, err = d.client.WatchEndpoints(ctx, &req) + if firstRun { + waitInit.Done() + firstRun = false + } + if err != nil { + if grpcutil.IsContextCanceled(err) { + break + } + log.Errorf(ctx, "err creating new watch endpoint client: %s", err) + sleepContext(ctx, time.Second) + continue + } + } + + // Read the next watcher event. + resp, err := client.Recv() + if err != nil { + if grpcutil.IsContextCanceled(err) { + break + } + if err != io.EOF { + log.Errorf(ctx, "err receiving stream events: %s", err) + time.Sleep(time.Second) + } + // Loop around and try a new call to get a client stream. + client = nil + continue + } + + endpointIP := resp.IP + if endpointIP == "" { + // Nothing needs to be done if there is no IP address specified. + continue + } + + // Ensure that a directory entry exists for this tenant. + entry, err := d.getEntry(ctx, roachpb.MakeTenantID(resp.TenantID), false) + if err != nil { + if grpcutil.IsContextCanceled(err) { + break + } + // This should never happen. + log.Errorf(ctx, "ignoring error getting entry for tenant %d: %v", resp.TenantID, err) + continue + } + + if entry != nil { + // For now, all we care about is the IP addresses of the tenant endpoint. + switch resp.Typ { + case ADDED, MODIFIED: + if entry.AddEndpointIP(endpointIP) { + log.Infof(ctx, "added IP address %s for tenant %d", endpointIP, resp.TenantID) + } + + case DELETED: + if entry.RemoveEndpointIP(endpointIP) { + log.Infof(ctx, "deleted IP address %s for tenant %d", endpointIP, resp.TenantID) + } + } + } + } + }) + if err != nil { + return err + } + // Block until the initial endpoint watcher client stream is constructed. + waitInit.Wait() + return err +} + +// sleepContext sleeps for the given duration or until the given context is +// canceled, whichever comes first. +func sleepContext(ctx context.Context, delay time.Duration) { + select { + case <-ctx.Done(): + case <-time.After(delay): + } +} diff --git a/pkg/ccl/sqlproxyccl/tenant/directory.pb.go b/pkg/ccl/sqlproxyccl/tenant/directory.pb.go new file mode 100644 index 000000000000..5490250fbafe --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/directory.pb.go @@ -0,0 +1,1815 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: ccl/sqlproxyccl/tenant/directory.proto + +package tenant + +import ( + context "context" + fmt "fmt" + _ "github.com/gogo/protobuf/gogoproto" + proto "github.com/gogo/protobuf/proto" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +// EventType shows the event type of the notifications that the server streams to its clients. +type EventType int32 + +const ( + ADDED EventType = 0 + MODIFIED EventType = 1 + DELETED EventType = 2 +) + +var EventType_name = map[int32]string{ + 0: "ADDED", + 1: "MODIFIED", + 2: "DELETED", +} + +var EventType_value = map[string]int32{ + "ADDED": 0, + "MODIFIED": 1, + "DELETED": 2, +} + +func (x EventType) String() string { + return proto.EnumName(EventType_name, int32(x)) +} + +func (EventType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{0} +} + +// WatchEndpointsRequest is empty as we want to get all notifications. +type WatchEndpointsRequest struct { +} + +func (m *WatchEndpointsRequest) Reset() { *m = WatchEndpointsRequest{} } +func (m *WatchEndpointsRequest) String() string { return proto.CompactTextString(m) } +func (*WatchEndpointsRequest) ProtoMessage() {} +func (*WatchEndpointsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{0} +} +func (m *WatchEndpointsRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WatchEndpointsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *WatchEndpointsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_WatchEndpointsRequest.Merge(m, src) +} +func (m *WatchEndpointsRequest) XXX_Size() int { + return m.Size() +} +func (m *WatchEndpointsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_WatchEndpointsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_WatchEndpointsRequest proto.InternalMessageInfo + +// WatchEndpointsResponse represents the notifications that the server sends to its clients when clients +// want to monitor the directory server activity. +type WatchEndpointsResponse struct { + // EventType is the type of the notifications - added, modified, deleted. + Typ EventType `protobuf:"varint,1,opt,name=typ,proto3,enum=cockroach.ccl.sqlproxyccl.tenant.EventType" json:"typ,omitempty"` + // IP is the endpoint that this notification applies to. + IP string `protobuf:"bytes,2,opt,name=ip,proto3" json:"ip,omitempty"` + // TenantID is the tenant that owns the endpoint. + TenantID uint64 `protobuf:"varint,3,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` +} + +func (m *WatchEndpointsResponse) Reset() { *m = WatchEndpointsResponse{} } +func (m *WatchEndpointsResponse) String() string { return proto.CompactTextString(m) } +func (*WatchEndpointsResponse) ProtoMessage() {} +func (*WatchEndpointsResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{1} +} +func (m *WatchEndpointsResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WatchEndpointsResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *WatchEndpointsResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_WatchEndpointsResponse.Merge(m, src) +} +func (m *WatchEndpointsResponse) XXX_Size() int { + return m.Size() +} +func (m *WatchEndpointsResponse) XXX_DiscardUnknown() { + xxx_messageInfo_WatchEndpointsResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_WatchEndpointsResponse proto.InternalMessageInfo + +// ListEndpointsRequest is used to query the server for the list of current endpoints of a given tenant. +type ListEndpointsRequest struct { + // TenantID identifies the tenant for which the client is requesting a list of the endpoints. + TenantID uint64 `protobuf:"varint,1,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` +} + +func (m *ListEndpointsRequest) Reset() { *m = ListEndpointsRequest{} } +func (m *ListEndpointsRequest) String() string { return proto.CompactTextString(m) } +func (*ListEndpointsRequest) ProtoMessage() {} +func (*ListEndpointsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{2} +} +func (m *ListEndpointsRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *ListEndpointsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *ListEndpointsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListEndpointsRequest.Merge(m, src) +} +func (m *ListEndpointsRequest) XXX_Size() int { + return m.Size() +} +func (m *ListEndpointsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ListEndpointsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ListEndpointsRequest proto.InternalMessageInfo + +// EnsureEndpointRequest is used to ensure that a tenant's backend is active. If there is an active backend then the +// server doesn't have to do anything. If there isn't an active backend, then the server has to bring a new one up. +type EnsureEndpointRequest struct { + // TenantID is the id of the tenant for which an active backend is requested. + TenantID uint64 `protobuf:"varint,1,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` +} + +func (m *EnsureEndpointRequest) Reset() { *m = EnsureEndpointRequest{} } +func (m *EnsureEndpointRequest) String() string { return proto.CompactTextString(m) } +func (*EnsureEndpointRequest) ProtoMessage() {} +func (*EnsureEndpointRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{3} +} +func (m *EnsureEndpointRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EnsureEndpointRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *EnsureEndpointRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_EnsureEndpointRequest.Merge(m, src) +} +func (m *EnsureEndpointRequest) XXX_Size() int { + return m.Size() +} +func (m *EnsureEndpointRequest) XXX_DiscardUnknown() { + xxx_messageInfo_EnsureEndpointRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_EnsureEndpointRequest proto.InternalMessageInfo + +// EnsureEndpointResponse is empty and indicates that the server processed the request. +type EnsureEndpointResponse struct { +} + +func (m *EnsureEndpointResponse) Reset() { *m = EnsureEndpointResponse{} } +func (m *EnsureEndpointResponse) String() string { return proto.CompactTextString(m) } +func (*EnsureEndpointResponse) ProtoMessage() {} +func (*EnsureEndpointResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{4} +} +func (m *EnsureEndpointResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EnsureEndpointResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *EnsureEndpointResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_EnsureEndpointResponse.Merge(m, src) +} +func (m *EnsureEndpointResponse) XXX_Size() int { + return m.Size() +} +func (m *EnsureEndpointResponse) XXX_DiscardUnknown() { + xxx_messageInfo_EnsureEndpointResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_EnsureEndpointResponse proto.InternalMessageInfo + +// Endpoint contains the information about a tenant endpoint. Most often it is a combination of an ip address and port. +// i.e. 132.130.1.11:34576 +type Endpoint struct { + // IP is the ip and port combo identifying the tenant endpoint. + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` +} + +func (m *Endpoint) Reset() { *m = Endpoint{} } +func (m *Endpoint) String() string { return proto.CompactTextString(m) } +func (*Endpoint) ProtoMessage() {} +func (*Endpoint) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{5} +} +func (m *Endpoint) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Endpoint) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *Endpoint) XXX_Merge(src proto.Message) { + xxx_messageInfo_Endpoint.Merge(m, src) +} +func (m *Endpoint) XXX_Size() int { + return m.Size() +} +func (m *Endpoint) XXX_DiscardUnknown() { + xxx_messageInfo_Endpoint.DiscardUnknown(m) +} + +var xxx_messageInfo_Endpoint proto.InternalMessageInfo + +// ListEndpointsResponse is sent back as a result of requesting the list of endpoints for a given tenant. +type ListEndpointsResponse struct { + // Endpoints is the list of endpoints currently active for the requested tenant. + Endpoints []*Endpoint `protobuf:"bytes,1,rep,name=endpoints,proto3" json:"endpoints,omitempty"` +} + +func (m *ListEndpointsResponse) Reset() { *m = ListEndpointsResponse{} } +func (m *ListEndpointsResponse) String() string { return proto.CompactTextString(m) } +func (*ListEndpointsResponse) ProtoMessage() {} +func (*ListEndpointsResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{6} +} +func (m *ListEndpointsResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *ListEndpointsResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *ListEndpointsResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListEndpointsResponse.Merge(m, src) +} +func (m *ListEndpointsResponse) XXX_Size() int { + return m.Size() +} +func (m *ListEndpointsResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ListEndpointsResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ListEndpointsResponse proto.InternalMessageInfo + +// GetTenantRequest is used by a client to request from the sever metadata related to a given tenant. +type GetTenantRequest struct { + // TenantID identifies the tenant for which the metadata is being requested. + TenantID uint64 `protobuf:"varint,1,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` +} + +func (m *GetTenantRequest) Reset() { *m = GetTenantRequest{} } +func (m *GetTenantRequest) String() string { return proto.CompactTextString(m) } +func (*GetTenantRequest) ProtoMessage() {} +func (*GetTenantRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{7} +} +func (m *GetTenantRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GetTenantRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *GetTenantRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetTenantRequest.Merge(m, src) +} +func (m *GetTenantRequest) XXX_Size() int { + return m.Size() +} +func (m *GetTenantRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetTenantRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetTenantRequest proto.InternalMessageInfo + +// GetTenantResponse is sent back when a client requests metadata for a tenant. +type GetTenantResponse struct { + // ClusterName is the name of the tenant's cluster. + ClusterName string `protobuf:"bytes,1,opt,name=cluster_name,json=clusterName,proto3" json:"cluster_name,omitempty"` +} + +func (m *GetTenantResponse) Reset() { *m = GetTenantResponse{} } +func (m *GetTenantResponse) String() string { return proto.CompactTextString(m) } +func (*GetTenantResponse) ProtoMessage() {} +func (*GetTenantResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_ec8b5028e8f2b222, []int{8} +} +func (m *GetTenantResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GetTenantResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil +} +func (m *GetTenantResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetTenantResponse.Merge(m, src) +} +func (m *GetTenantResponse) XXX_Size() int { + return m.Size() +} +func (m *GetTenantResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetTenantResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetTenantResponse proto.InternalMessageInfo + +func init() { + proto.RegisterEnum("cockroach.ccl.sqlproxyccl.tenant.EventType", EventType_name, EventType_value) + proto.RegisterType((*WatchEndpointsRequest)(nil), "cockroach.ccl.sqlproxyccl.tenant.WatchEndpointsRequest") + proto.RegisterType((*WatchEndpointsResponse)(nil), "cockroach.ccl.sqlproxyccl.tenant.WatchEndpointsResponse") + proto.RegisterType((*ListEndpointsRequest)(nil), "cockroach.ccl.sqlproxyccl.tenant.ListEndpointsRequest") + proto.RegisterType((*EnsureEndpointRequest)(nil), "cockroach.ccl.sqlproxyccl.tenant.EnsureEndpointRequest") + proto.RegisterType((*EnsureEndpointResponse)(nil), "cockroach.ccl.sqlproxyccl.tenant.EnsureEndpointResponse") + proto.RegisterType((*Endpoint)(nil), "cockroach.ccl.sqlproxyccl.tenant.Endpoint") + proto.RegisterType((*ListEndpointsResponse)(nil), "cockroach.ccl.sqlproxyccl.tenant.ListEndpointsResponse") + proto.RegisterType((*GetTenantRequest)(nil), "cockroach.ccl.sqlproxyccl.tenant.GetTenantRequest") + proto.RegisterType((*GetTenantResponse)(nil), "cockroach.ccl.sqlproxyccl.tenant.GetTenantResponse") +} + +func init() { + proto.RegisterFile("ccl/sqlproxyccl/tenant/directory.proto", fileDescriptor_ec8b5028e8f2b222) +} + +var fileDescriptor_ec8b5028e8f2b222 = []byte{ + // 523 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x94, 0xc1, 0x8b, 0xd3, 0x40, + 0x14, 0xc6, 0x33, 0x6d, 0xad, 0xcd, 0x6b, 0x2d, 0x75, 0xd8, 0xd6, 0x92, 0x43, 0x36, 0xe6, 0x20, + 0x71, 0x85, 0x54, 0xba, 0xb0, 0xeb, 0x65, 0x0f, 0x5b, 0x12, 0x35, 0xb0, 0x6a, 0x09, 0x05, 0xc1, + 0xcb, 0x12, 0xa7, 0xc3, 0x6e, 0xb0, 0x9b, 0xc9, 0x26, 0x53, 0xb1, 0x37, 0x41, 0x04, 0x8f, 0xe2, + 0xd1, 0xab, 0xff, 0xcc, 0x1e, 0xf7, 0xb8, 0xa7, 0x45, 0xd3, 0x7f, 0x44, 0xda, 0x69, 0xd6, 0x36, + 0x14, 0xdb, 0xee, 0x6d, 0x78, 0x93, 0xdf, 0x37, 0xdf, 0x7b, 0xf9, 0x66, 0xe0, 0x11, 0x21, 0x83, + 0x56, 0x7c, 0x3e, 0x08, 0x23, 0xf6, 0x69, 0x34, 0x59, 0x73, 0x1a, 0x78, 0x01, 0x6f, 0xf5, 0xfd, + 0x88, 0x12, 0xce, 0xa2, 0x91, 0x19, 0x46, 0x8c, 0x33, 0xac, 0x11, 0x46, 0x3e, 0x44, 0xcc, 0x23, + 0xa7, 0x26, 0x21, 0x03, 0x73, 0x8e, 0x30, 0x05, 0xa1, 0x6c, 0x9d, 0xb0, 0x13, 0x36, 0xfd, 0xb8, + 0x35, 0x59, 0x09, 0x4e, 0x7f, 0x00, 0xf5, 0xb7, 0x1e, 0x27, 0xa7, 0x76, 0xd0, 0x0f, 0x99, 0x1f, + 0xf0, 0xd8, 0xa5, 0xe7, 0x43, 0x1a, 0x73, 0xfd, 0x27, 0x82, 0x46, 0x76, 0x27, 0x0e, 0x59, 0x10, + 0x53, 0x7c, 0x00, 0x79, 0x3e, 0x0a, 0x9b, 0x48, 0x43, 0x46, 0xb5, 0xfd, 0xc4, 0x5c, 0x75, 0xb2, + 0x69, 0x7f, 0xa4, 0x01, 0xef, 0x8d, 0x42, 0xea, 0x4e, 0x38, 0xdc, 0x80, 0x9c, 0x1f, 0x36, 0x73, + 0x1a, 0x32, 0xe4, 0x4e, 0x31, 0xb9, 0xde, 0xce, 0x39, 0x5d, 0x37, 0xe7, 0x87, 0xf8, 0x31, 0xc8, + 0x02, 0x38, 0xf6, 0xfb, 0xcd, 0xbc, 0x86, 0x8c, 0x42, 0xa7, 0x92, 0x5c, 0x6f, 0x97, 0x7a, 0xd3, + 0xa2, 0x63, 0xb9, 0x25, 0xb1, 0xed, 0xf4, 0xf5, 0x43, 0xd8, 0x3a, 0xf2, 0x63, 0x9e, 0x35, 0xbd, + 0x28, 0x81, 0xfe, 0x2b, 0xd1, 0x81, 0xba, 0x1d, 0xc4, 0xc3, 0x88, 0xa6, 0x22, 0xb7, 0xd0, 0x68, + 0x42, 0x23, 0xab, 0x21, 0x46, 0xa4, 0xeb, 0x50, 0x4a, 0x6b, 0x93, 0x7e, 0x9d, 0xee, 0x54, 0x69, + 0xae, 0x5f, 0xa7, 0xab, 0x7b, 0x50, 0xcf, 0x34, 0x31, 0x9b, 0xef, 0x4b, 0x90, 0x69, 0x5a, 0x6c, + 0x22, 0x2d, 0x6f, 0x94, 0xdb, 0x3b, 0x6b, 0x4c, 0x39, 0xf5, 0xf0, 0x0f, 0xd6, 0x0f, 0xa0, 0xf6, + 0x82, 0x72, 0xe1, 0xfc, 0x16, 0xfd, 0xed, 0xc1, 0xfd, 0x39, 0x7c, 0xe6, 0xee, 0x21, 0x54, 0xc8, + 0x60, 0x18, 0x73, 0x1a, 0x1d, 0x07, 0xde, 0x19, 0x15, 0x8d, 0xb9, 0xe5, 0x59, 0xed, 0xb5, 0x77, + 0x46, 0x77, 0xf6, 0x41, 0xbe, 0xf9, 0xe7, 0x58, 0x86, 0x3b, 0x87, 0x96, 0x65, 0x5b, 0x35, 0x09, + 0x57, 0xa0, 0xf4, 0xea, 0x8d, 0xe5, 0x3c, 0x77, 0x6c, 0xab, 0x86, 0x70, 0x19, 0xee, 0x5a, 0xf6, + 0x91, 0xdd, 0xb3, 0xad, 0x5a, 0x4e, 0x29, 0x7c, 0xfb, 0xa5, 0x4a, 0xed, 0x1f, 0x05, 0x90, 0xad, + 0x34, 0xd9, 0xf8, 0x33, 0x82, 0x7b, 0x0b, 0x13, 0xc2, 0x7b, 0xab, 0xc7, 0xb0, 0x2c, 0x17, 0xca, + 0xfe, 0xc6, 0xdc, 0xac, 0xd9, 0xaf, 0x08, 0xaa, 0x8b, 0xb7, 0x00, 0xaf, 0xa1, 0xb5, 0xf4, 0x46, + 0x29, 0xcf, 0x36, 0x07, 0x85, 0x8b, 0xa7, 0x08, 0x7f, 0x41, 0x50, 0x5d, 0x8c, 0xda, 0x3a, 0x3e, + 0x96, 0x06, 0x7c, 0x1d, 0x1f, 0xcb, 0x53, 0x8d, 0x39, 0xc8, 0x37, 0x79, 0xc0, 0xed, 0xd5, 0x32, + 0xd9, 0xec, 0x29, 0xbb, 0x1b, 0x31, 0xe2, 0xd4, 0x8e, 0x71, 0xf1, 0x47, 0x95, 0x2e, 0x12, 0x15, + 0x5d, 0x26, 0x2a, 0xba, 0x4a, 0x54, 0xf4, 0x3b, 0x51, 0xd1, 0xf7, 0xb1, 0x2a, 0x5d, 0x8e, 0x55, + 0xe9, 0x6a, 0xac, 0x4a, 0xef, 0x8a, 0x82, 0x7d, 0x5f, 0x9c, 0xbe, 0x69, 0xbb, 0x7f, 0x03, 0x00, + 0x00, 0xff, 0xff, 0xba, 0xed, 0x85, 0x11, 0x35, 0x05, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// DirectoryClient is the client API for Directory service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type DirectoryClient interface { + // ListEndpoints is used to query the server for the list of current endpoints of a given tenant. + ListEndpoints(ctx context.Context, in *ListEndpointsRequest, opts ...grpc.CallOption) (*ListEndpointsResponse, error) + // WatchEndpoints is used to get a stream, that is used to receive notifications about changes in tenant + // backend's state - added, modified and deleted. + WatchEndpoints(ctx context.Context, in *WatchEndpointsRequest, opts ...grpc.CallOption) (Directory_WatchEndpointsClient, error) + // EnsureEndpoint is used to ensure that a tenant's backend is active. If there is an active backend then the + // server doesn't have to do anything. If there isn't an active backend, then the server has to bring a new one up. + EnsureEndpoint(ctx context.Context, in *EnsureEndpointRequest, opts ...grpc.CallOption) (*EnsureEndpointResponse, error) + // GetTenant is used to fetch the metadata of a specific tenant. + GetTenant(ctx context.Context, in *GetTenantRequest, opts ...grpc.CallOption) (*GetTenantResponse, error) +} + +type directoryClient struct { + cc *grpc.ClientConn +} + +func NewDirectoryClient(cc *grpc.ClientConn) DirectoryClient { + return &directoryClient{cc} +} + +func (c *directoryClient) ListEndpoints(ctx context.Context, in *ListEndpointsRequest, opts ...grpc.CallOption) (*ListEndpointsResponse, error) { + out := new(ListEndpointsResponse) + err := c.cc.Invoke(ctx, "/cockroach.ccl.sqlproxyccl.tenant.Directory/ListEndpoints", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *directoryClient) WatchEndpoints(ctx context.Context, in *WatchEndpointsRequest, opts ...grpc.CallOption) (Directory_WatchEndpointsClient, error) { + stream, err := c.cc.NewStream(ctx, &_Directory_serviceDesc.Streams[0], "/cockroach.ccl.sqlproxyccl.tenant.Directory/WatchEndpoints", opts...) + if err != nil { + return nil, err + } + x := &directoryWatchEndpointsClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type Directory_WatchEndpointsClient interface { + Recv() (*WatchEndpointsResponse, error) + grpc.ClientStream +} + +type directoryWatchEndpointsClient struct { + grpc.ClientStream +} + +func (x *directoryWatchEndpointsClient) Recv() (*WatchEndpointsResponse, error) { + m := new(WatchEndpointsResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *directoryClient) EnsureEndpoint(ctx context.Context, in *EnsureEndpointRequest, opts ...grpc.CallOption) (*EnsureEndpointResponse, error) { + out := new(EnsureEndpointResponse) + err := c.cc.Invoke(ctx, "/cockroach.ccl.sqlproxyccl.tenant.Directory/EnsureEndpoint", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *directoryClient) GetTenant(ctx context.Context, in *GetTenantRequest, opts ...grpc.CallOption) (*GetTenantResponse, error) { + out := new(GetTenantResponse) + err := c.cc.Invoke(ctx, "/cockroach.ccl.sqlproxyccl.tenant.Directory/GetTenant", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// DirectoryServer is the server API for Directory service. +type DirectoryServer interface { + // ListEndpoints is used to query the server for the list of current endpoints of a given tenant. + ListEndpoints(context.Context, *ListEndpointsRequest) (*ListEndpointsResponse, error) + // WatchEndpoints is used to get a stream, that is used to receive notifications about changes in tenant + // backend's state - added, modified and deleted. + WatchEndpoints(*WatchEndpointsRequest, Directory_WatchEndpointsServer) error + // EnsureEndpoint is used to ensure that a tenant's backend is active. If there is an active backend then the + // server doesn't have to do anything. If there isn't an active backend, then the server has to bring a new one up. + EnsureEndpoint(context.Context, *EnsureEndpointRequest) (*EnsureEndpointResponse, error) + // GetTenant is used to fetch the metadata of a specific tenant. + GetTenant(context.Context, *GetTenantRequest) (*GetTenantResponse, error) +} + +// UnimplementedDirectoryServer can be embedded to have forward compatible implementations. +type UnimplementedDirectoryServer struct { +} + +func (*UnimplementedDirectoryServer) ListEndpoints(ctx context.Context, req *ListEndpointsRequest) (*ListEndpointsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListEndpoints not implemented") +} +func (*UnimplementedDirectoryServer) WatchEndpoints(req *WatchEndpointsRequest, srv Directory_WatchEndpointsServer) error { + return status.Errorf(codes.Unimplemented, "method WatchEndpoints not implemented") +} +func (*UnimplementedDirectoryServer) EnsureEndpoint(ctx context.Context, req *EnsureEndpointRequest) (*EnsureEndpointResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method EnsureEndpoint not implemented") +} +func (*UnimplementedDirectoryServer) GetTenant(ctx context.Context, req *GetTenantRequest) (*GetTenantResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetTenant not implemented") +} + +func RegisterDirectoryServer(s *grpc.Server, srv DirectoryServer) { + s.RegisterService(&_Directory_serviceDesc, srv) +} + +func _Directory_ListEndpoints_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListEndpointsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DirectoryServer).ListEndpoints(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/cockroach.ccl.sqlproxyccl.tenant.Directory/ListEndpoints", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DirectoryServer).ListEndpoints(ctx, req.(*ListEndpointsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Directory_WatchEndpoints_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(WatchEndpointsRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DirectoryServer).WatchEndpoints(m, &directoryWatchEndpointsServer{stream}) +} + +type Directory_WatchEndpointsServer interface { + Send(*WatchEndpointsResponse) error + grpc.ServerStream +} + +type directoryWatchEndpointsServer struct { + grpc.ServerStream +} + +func (x *directoryWatchEndpointsServer) Send(m *WatchEndpointsResponse) error { + return x.ServerStream.SendMsg(m) +} + +func _Directory_EnsureEndpoint_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EnsureEndpointRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DirectoryServer).EnsureEndpoint(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/cockroach.ccl.sqlproxyccl.tenant.Directory/EnsureEndpoint", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DirectoryServer).EnsureEndpoint(ctx, req.(*EnsureEndpointRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Directory_GetTenant_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetTenantRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DirectoryServer).GetTenant(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/cockroach.ccl.sqlproxyccl.tenant.Directory/GetTenant", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DirectoryServer).GetTenant(ctx, req.(*GetTenantRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Directory_serviceDesc = grpc.ServiceDesc{ + ServiceName: "cockroach.ccl.sqlproxyccl.tenant.Directory", + HandlerType: (*DirectoryServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ListEndpoints", + Handler: _Directory_ListEndpoints_Handler, + }, + { + MethodName: "EnsureEndpoint", + Handler: _Directory_EnsureEndpoint_Handler, + }, + { + MethodName: "GetTenant", + Handler: _Directory_GetTenant_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "WatchEndpoints", + Handler: _Directory_WatchEndpoints_Handler, + ServerStreams: true, + }, + }, + Metadata: "ccl/sqlproxyccl/tenant/directory.proto", +} + +func (m *WatchEndpointsRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WatchEndpointsRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WatchEndpointsRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + return len(dAtA) - i, nil +} + +func (m *WatchEndpointsResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WatchEndpointsResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WatchEndpointsResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.TenantID != 0 { + i = encodeVarintDirectory(dAtA, i, uint64(m.TenantID)) + i-- + dAtA[i] = 0x18 + } + if len(m.IP) > 0 { + i -= len(m.IP) + copy(dAtA[i:], m.IP) + i = encodeVarintDirectory(dAtA, i, uint64(len(m.IP))) + i-- + dAtA[i] = 0x12 + } + if m.Typ != 0 { + i = encodeVarintDirectory(dAtA, i, uint64(m.Typ)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *ListEndpointsRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ListEndpointsRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *ListEndpointsRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.TenantID != 0 { + i = encodeVarintDirectory(dAtA, i, uint64(m.TenantID)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *EnsureEndpointRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EnsureEndpointRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EnsureEndpointRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.TenantID != 0 { + i = encodeVarintDirectory(dAtA, i, uint64(m.TenantID)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *EnsureEndpointResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EnsureEndpointResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EnsureEndpointResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + return len(dAtA) - i, nil +} + +func (m *Endpoint) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Endpoint) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Endpoint) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.IP) > 0 { + i -= len(m.IP) + copy(dAtA[i:], m.IP) + i = encodeVarintDirectory(dAtA, i, uint64(len(m.IP))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *ListEndpointsResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ListEndpointsResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *ListEndpointsResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Endpoints) > 0 { + for iNdEx := len(m.Endpoints) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.Endpoints[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintDirectory(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *GetTenantRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GetTenantRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GetTenantRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.TenantID != 0 { + i = encodeVarintDirectory(dAtA, i, uint64(m.TenantID)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *GetTenantResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GetTenantResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GetTenantResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.ClusterName) > 0 { + i -= len(m.ClusterName) + copy(dAtA[i:], m.ClusterName) + i = encodeVarintDirectory(dAtA, i, uint64(len(m.ClusterName))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintDirectory(dAtA []byte, offset int, v uint64) int { + offset -= sovDirectory(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *WatchEndpointsRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + return n +} + +func (m *WatchEndpointsResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Typ != 0 { + n += 1 + sovDirectory(uint64(m.Typ)) + } + l = len(m.IP) + if l > 0 { + n += 1 + l + sovDirectory(uint64(l)) + } + if m.TenantID != 0 { + n += 1 + sovDirectory(uint64(m.TenantID)) + } + return n +} + +func (m *ListEndpointsRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.TenantID != 0 { + n += 1 + sovDirectory(uint64(m.TenantID)) + } + return n +} + +func (m *EnsureEndpointRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.TenantID != 0 { + n += 1 + sovDirectory(uint64(m.TenantID)) + } + return n +} + +func (m *EnsureEndpointResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + return n +} + +func (m *Endpoint) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.IP) + if l > 0 { + n += 1 + l + sovDirectory(uint64(l)) + } + return n +} + +func (m *ListEndpointsResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Endpoints) > 0 { + for _, e := range m.Endpoints { + l = e.Size() + n += 1 + l + sovDirectory(uint64(l)) + } + } + return n +} + +func (m *GetTenantRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.TenantID != 0 { + n += 1 + sovDirectory(uint64(m.TenantID)) + } + return n +} + +func (m *GetTenantResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.ClusterName) + if l > 0 { + n += 1 + l + sovDirectory(uint64(l)) + } + return n +} + +func sovDirectory(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozDirectory(x uint64) (n int) { + return sovDirectory(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *WatchEndpointsRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WatchEndpointsRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WatchEndpointsRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *WatchEndpointsResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WatchEndpointsResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WatchEndpointsResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Typ", wireType) + } + m.Typ = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Typ |= EventType(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field IP", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthDirectory + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthDirectory + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.IP = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TenantID", wireType) + } + m.TenantID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.TenantID |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ListEndpointsRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ListEndpointsRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ListEndpointsRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TenantID", wireType) + } + m.TenantID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.TenantID |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *EnsureEndpointRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EnsureEndpointRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EnsureEndpointRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TenantID", wireType) + } + m.TenantID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.TenantID |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *EnsureEndpointResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EnsureEndpointResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EnsureEndpointResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Endpoint) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Endpoint: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Endpoint: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field IP", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthDirectory + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthDirectory + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.IP = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ListEndpointsResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ListEndpointsResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ListEndpointsResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Endpoints", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthDirectory + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthDirectory + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Endpoints = append(m.Endpoints, &Endpoint{}) + if err := m.Endpoints[len(m.Endpoints)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GetTenantRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GetTenantRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GetTenantRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TenantID", wireType) + } + m.TenantID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.TenantID |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GetTenantResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GetTenantResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GetTenantResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClusterName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowDirectory + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthDirectory + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthDirectory + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClusterName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipDirectory(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthDirectory + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipDirectory(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDirectory + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDirectory + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowDirectory + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthDirectory + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupDirectory + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthDirectory + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthDirectory = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowDirectory = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupDirectory = fmt.Errorf("proto: unexpected end of group") +) diff --git a/pkg/ccl/sqlproxyccl/tenant/directory.proto b/pkg/ccl/sqlproxyccl/tenant/directory.proto new file mode 100644 index 000000000000..f62217774d43 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/directory.proto @@ -0,0 +1,92 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +syntax = "proto3"; +package cockroach.ccl.sqlproxyccl.tenant; +option go_package="tenant"; + +import "gogoproto/gogo.proto"; + +// WatchEndpointsRequest is empty as we want to get all notifications. +message WatchEndpointsRequest {} + +// EventType shows the event type of the notifications that the server streams to its clients. +enum EventType { + option (gogoproto.goproto_enum_prefix) = false; + + ADDED = 0; + MODIFIED = 1; + DELETED = 2; +} + +// WatchEndpointsResponse represents the notifications that the server sends to its clients when clients +// want to monitor the directory server activity. +message WatchEndpointsResponse { + // EventType is the type of the notifications - added, modified, deleted. + EventType typ = 1; + // IP is the endpoint that this notification applies to. + string ip = 2 [(gogoproto.customname) = "IP"]; + // TenantID is the tenant that owns the endpoint. + uint64 tenant_id = 3[(gogoproto.customname) = "TenantID"]; +} + +// ListEndpointsRequest is used to query the server for the list of current endpoints of a given tenant. +message ListEndpointsRequest { + // TenantID identifies the tenant for which the client is requesting a list of the endpoints. + uint64 tenant_id = 1[(gogoproto.customname) = "TenantID"]; +} + +// EnsureEndpointRequest is used to ensure that a tenant's backend is active. If there is an active backend then the +// server doesn't have to do anything. If there isn't an active backend, then the server has to bring a new one up. +message EnsureEndpointRequest { + // TenantID is the id of the tenant for which an active backend is requested. + uint64 tenant_id = 1[(gogoproto.customname) = "TenantID"]; +} + +// EnsureEndpointResponse is empty and indicates that the server processed the request. +message EnsureEndpointResponse { +} + +// Endpoint contains the information about a tenant endpoint. Most often it is a combination of an ip address and port. +// i.e. 132.130.1.11:34576 +message Endpoint { + // IP is the ip and port combo identifying the tenant endpoint. + string IP = 1[(gogoproto.customname) = "IP"]; +} + +// ListEndpointsResponse is sent back as a result of requesting the list of endpoints for a given tenant. +message ListEndpointsResponse { + // Endpoints is the list of endpoints currently active for the requested tenant. + repeated Endpoint endpoints = 1; +} + +// GetTenantRequest is used by a client to request from the sever metadata related to a given tenant. +message GetTenantRequest { + // TenantID identifies the tenant for which the metadata is being requested. + uint64 tenant_id = 1[(gogoproto.customname) = "TenantID"]; +} + +// GetTenantResponse is sent back when a client requests metadata for a tenant. +message GetTenantResponse { + // ClusterName is the name of the tenant's cluster. + string cluster_name = 1; // add more metadata if needed +} + +// Directory specifies a service that keeps track and manages tenant backends, related metadata and their endpoints. +service Directory { + // ListEndpoints is used to query the server for the list of current endpoints of a given tenant. + rpc ListEndpoints(ListEndpointsRequest) returns (ListEndpointsResponse); + // WatchEndpoints is used to get a stream, that is used to receive notifications about changes in tenant + // backend's state - added, modified and deleted. + rpc WatchEndpoints(WatchEndpointsRequest) returns (stream WatchEndpointsResponse); + // EnsureEndpoint is used to ensure that a tenant's backend is active. If there is an active backend then the + // server doesn't have to do anything. If there isn't an active backend, then the server has to bring a new one up. + rpc EnsureEndpoint(EnsureEndpointRequest) returns (EnsureEndpointResponse); + // GetTenant is used to fetch the metadata of a specific tenant. + rpc GetTenant(GetTenantRequest) returns (GetTenantResponse); +} diff --git a/pkg/ccl/sqlproxyccl/tenant/directory_test.go b/pkg/ccl/sqlproxyccl/tenant/directory_test.go new file mode 100644 index 000000000000..5027281400af --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/directory_test.go @@ -0,0 +1,389 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenant + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestDirectoryErrors(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + const tenantID = 10 + + ctx := context.Background() + + tc, dir, _ := newTestDirectory(t) + defer tc.Stopper().Stop(ctx) + + _, err := dir.LookupTenantIPs(ctx, roachpb.MakeTenantID(1000)) + assert.Regexp(t, "not found", err) + _, err = dir.LookupTenantIPs(ctx, roachpb.MakeTenantID(1001)) + assert.Regexp(t, "not found", err) + _, err = dir.LookupTenantIPs(ctx, roachpb.MakeTenantID(1002)) + assert.Regexp(t, "not found", err) + + // Fail to find tenant that does not exist. + _, err = dir.EnsureTenantIP(ctx, roachpb.MakeTenantID(1000), "") + assert.Regexp(t, "not found", err) + + // Fail to find tenant when cluster name doesn't match. + _, err = dir.EnsureTenantIP(ctx, roachpb.MakeTenantID(tenantID), "unknown") + assert.Regexp(t, "not found", err) + + // No-op when reporting failure for tenant that doesn't exit. + require.NoError(t, dir.ReportFailure(ctx, roachpb.MakeTenantID(1000), "")) +} + +func TestEndpointWatcher(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + // Create the directory. + ctx := context.Background() + tc, dir, tds := newTestDirectory(t) + defer tc.Stopper().Stop(ctx) + + tenantID := roachpb.MakeTenantID(20) + require.NoError(t, createTenant(tc, tenantID)) + + // Call EnsureTenantIP to start a new tenant and create an entry + ip, err := dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) + + // Now shut it down + processes := tds.Get(tenantID) + require.NotNil(t, processes) + require.Len(t, processes, 1) + // Stop the tenant and ensure its IP address is removed from the directory. + for _, process := range processes { + process.Stopper.Stop(ctx) + } + + require.Eventually(t, func() bool { + ips, _ := dir.LookupTenantIPs(ctx, tenantID) + return len(ips) == 0 + }, 10*time.Second, 100*time.Millisecond) + + // Resume tenant again by a direct call to the directory server + _, err = tds.EnsureEndpoint(ctx, &EnsureEndpointRequest{tenantID.ToUint64()}) + require.NoError(t, err) + + // Wait for background watcher to populate the initial endpoint. + require.Eventually(t, func() bool { + ips, _ := dir.LookupTenantIPs(ctx, tenantID) + return len(ips) != 0 + }, 10*time.Second, 100*time.Millisecond) + + // Verify that EnsureTenantIP returns the endpoint's IP address. + ip, err = dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) + + processes = tds.Get(tenantID) + require.NotNil(t, processes) + require.Len(t, processes, 1) + for dirIP := range processes { + require.Equal(t, ip, dirIP.String()) + } + + // Stop the tenant and ensure its IP address is removed from the directory. + for _, process := range processes { + process.Stopper.Stop(ctx) + } + + require.Eventually(t, func() bool { + ips, _ := dir.LookupTenantIPs(ctx, tenantID) + return len(ips) == 0 + }, 10*time.Second, 100*time.Millisecond) + + // Verify that a new call to EnsureTenantIP will resume again the tenant. + ip, err = dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) +} + +func TestCancelLookups(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + tenantID := roachpb.MakeTenantID(20) + const lookupCount = 1 + + // Create the directory. + ctx, cancel := context.WithCancel(context.Background()) + tc, dir, _ := newTestDirectory(t) + defer tc.Stopper().Stop(ctx) + require.NoError(t, createTenant(tc, tenantID)) + + backgroundErrors := make([]error, lookupCount) + var wait sync.WaitGroup + for i := 0; i < lookupCount; i++ { + wait.Add(1) + go func(i int) { + _, backgroundErrors[i] = dir.EnsureTenantIP(ctx, tenantID, "") + wait.Done() + }(i) + } + + // Cancel the lookup and verify errors. + time.Sleep(10 * time.Millisecond) + cancel() + wait.Wait() + + for i := 0; i < lookupCount; i++ { + require.Error(t, backgroundErrors[i]) + require.Regexp(t, "context canceled", backgroundErrors[i].Error()) + } +} + +func TestResume(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + tenantID := roachpb.MakeTenantID(40) + const lookupCount = 5 + + // Create the directory. + ctx := context.Background() + tc, dir, tds := newTestDirectory(t) + defer tc.Stopper().Stop(ctx) + require.NoError(t, createTenant(tc, tenantID)) + + // No tenant processes running. + require.Equal(t, 0, len(tds.Get(tenantID))) + + var ips [lookupCount]string + var wait sync.WaitGroup + for i := 0; i < lookupCount; i++ { + wait.Add(1) + go func(i int) { + var err error + ips[i], err = dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + wait.Done() + }(i) + } + + var processes map[net.Addr]*Process + // Eventually the tenant process will be resumed. + require.Eventually(t, func() bool { + processes = tds.Get(tenantID) + return len(processes) == 1 + }, 10*time.Second, 100*time.Millisecond) + + // Wait until background goroutines complete. + wait.Wait() + + for ip := range processes { + for i := 0; i < lookupCount; i++ { + require.Equal(t, ip.String(), ips[i]) + } + } +} + +func TestDeleteTenant(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + // Create the directory. + ctx := context.Background() + // Disable throttling for this test + tc, dir, tds := newTestDirectory(t, RefreshDelay(-1)) + defer tc.Stopper().Stop(ctx) + + tenantID := roachpb.MakeTenantID(50) + // Create test tenant. + require.NoError(t, createTenant(tc, tenantID)) + + // Perform lookup to create entry in cache. + ip, err := dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) + + // Report failure even though tenant is healthy - refresh should do nothing. + require.NoError(t, dir.ReportFailure(ctx, tenantID, ip)) + ip, err = dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) + + // Stop the tenant + for _, process := range tds.Get(tenantID) { + process.Stopper.Stop(ctx) + } + + // Report failure connecting to the endpoint to force refresh of ips. + require.NoError(t, dir.ReportFailure(ctx, tenantID, ip)) + + // Ensure that tenant has no valid IP addresses. + ips, err := dir.LookupTenantIPs(ctx, tenantID) + require.NoError(t, err) + require.Empty(t, ips) + + // Report failure again to ensure that works when there is no ip address. + require.NoError(t, dir.ReportFailure(ctx, tenantID, ip)) + + // Now delete the tenant. + require.NoError(t, destroyTenant(tc, tenantID)) + + // Now EnsureTenantIP should return an error and the directory should no + // longer cache the tenant. + _, err = dir.EnsureTenantIP(ctx, tenantID, "") + require.Regexp(t, "not found", err) + ips, err = dir.LookupTenantIPs(ctx, tenantID) + require.Regexp(t, "not found", err) + require.Nil(t, ips) +} + +// TestRefreshThrottling checks that throttling works. +func TestRefreshThrottling(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.ScopeWithoutShowLogs(t).Close(t) + + // Create the directory, but with extreme rate limiting so that directory + // will never refresh. + ctx := context.Background() + tc, dir, _ := newTestDirectory(t, RefreshDelay(60*time.Minute)) + defer tc.Stopper().Stop(ctx) + + // Create test tenant. + tenantID := roachpb.MakeTenantID(60) + require.NoError(t, createTenant(tc, tenantID)) + + // Perform lookup to create entry in cache. + ip, err := dir.EnsureTenantIP(ctx, tenantID, "") + require.NoError(t, err) + require.NotEmpty(t, ip) + + // Report a false failure and verify that IP is still present in the cache. + require.NoError(t, dir.ReportFailure(ctx, tenantID, ip)) + ips, err := dir.LookupTenantIPs(ctx, tenantID) + require.NoError(t, err) + require.Equal(t, []string{ip}, ips) + + // Now destroy the tenant and call ReportFailure again. This should be a no-op + // due to refresh throttling. + require.NoError(t, destroyTenant(tc, tenantID)) + require.NoError(t, dir.ReportFailure(ctx, tenantID, ip)) + ips, err = dir.LookupTenantIPs(ctx, tenantID) + require.NoError(t, err) + require.Equal(t, []string{ip}, ips) +} + +func createTenant(tc serverutils.TestClusterInterface, id roachpb.TenantID) error { + srv := tc.Server(0) + conn := srv.InternalExecutor().(*sql.InternalExecutor) + if _, err := conn.Exec( + context.Background(), + "testserver-create-tenant", + nil, /* txn */ + "SELECT crdb_internal.create_tenant($1)", + id.ToUint64(), + ); err != nil { + return err + } + return nil +} + +func destroyTenant(tc serverutils.TestClusterInterface, id roachpb.TenantID) error { + srv := tc.Server(0) + conn := srv.InternalExecutor().(*sql.InternalExecutor) + if _, err := conn.Exec( + context.Background(), + "testserver-destroy-tenant", + nil, /* txn */ + "SELECT crdb_internal.destroy_tenant($1)", + id.ToUint64(), + ); err != nil { + return err + } + return nil +} + +func startTenant( + ctx context.Context, srv serverutils.TestServerInterface, id uint64, +) (*Process, error) { + log.TestingClearServerIdentifiers() + tenantStopper := NewSubStopper(srv.Stopper()) + tenant, err := srv.StartTenant( + ctx, + base.TestTenantArgs{ + Existing: true, + TenantID: roachpb.MakeTenantID(id), + ForceInsecure: true, + Stopper: tenantStopper, + }) + if err != nil { + return nil, err + } + sqlAddr, err := net.ResolveTCPAddr("tcp", tenant.SQLAddr()) + if err != nil { + return nil, err + } + return &Process{SQL: sqlAddr, Stopper: tenantStopper}, nil +} + +// Setup directory that uses a client connected to a test directory server +// that manages tenants connected to a backing KV server. +func newTestDirectory( + t *testing.T, opts ...DirOption, +) (tc serverutils.TestClusterInterface, directory *Directory, tds *TestDirectoryServer) { + tc = serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{ + // We need to start the cluster insecure in order to not + // care about TLS settings for the RPC client connection. + ServerArgs: base.TestServerArgs{ + Insecure: true, + }, + }) + clusterStopper := tc.Stopper() + tds = NewTestDirectoryServer(clusterStopper) + tds.TenantStarterFunc = func(ctx context.Context, tenantID uint64) (*Process, error) { + t.Logf("starting tenant %d", tenantID) + process, err := startTenant(ctx, tc.Server(0), tenantID) + if err != nil { + return nil, err + } + t.Logf("tenant %d started", tenantID) + return process, nil + } + + listenPort, err := net.Listen("tcp", ":0") + require.NoError(t, err) + clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, listenPort.Close()) })) + go func() { _ = tds.Serve(listenPort) }() + + // Setup directory + directorySrvAddr := listenPort.Addr() + conn, err := grpc.Dial(directorySrvAddr.String(), grpc.WithInsecure()) + require.NoError(t, err) + // nolint:grpcconnclose + clusterStopper.AddCloser(stop.CloserFn(func() { require.NoError(t, conn.Close() /* nolint:grpcconnclose */) })) + client := NewDirectoryClient(conn) + directory, err = NewDirectory(context.Background(), clusterStopper, client, opts...) + require.NoError(t, err) + + return +} diff --git a/pkg/ccl/sqlproxyccl/tenant/entry.go b/pkg/ccl/sqlproxyccl/tenant/entry.go new file mode 100644 index 000000000000..c8e7edb415df --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/entry.go @@ -0,0 +1,298 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenant + +import ( + "context" + "fmt" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// tenantEntry is an entry in the tenant directory that records information +// about a single tenant, including its ID, cluster name, and the IP addresses for +// available endpoints. +type tenantEntry struct { + // These fields can be read by callers without synchronization, since + // they're written once during initialization, and are immutable thereafter. + + // TenantID is the identifier for this tenant which is unique within a CRDB + // cluster. + TenantID roachpb.TenantID + + // Full name of the tenant's cluster i.e. dim-dog. + ClusterName string + + // RefreshDelay is the minimum amount of time that must elapse between + // attempts to refresh endpoints for this tenant after ReportFailure is called. + RefreshDelay time.Duration + + // initialized is set to true once Initialized has been called. + initialized bool + + // initError is set to any error that occurs in Initialized (or nil if no + // error occurred). + initError error + + // endpoints synchronizes access to information about the tenant's SQL endpoints. + // These fields can be updated over time, so a lock must be obtained before + // accessing them. + endpoints struct { + syncutil.Mutex + ips []string + } + + // calls synchronizes calls to the K8s API for this tenant (e.g. calls to + // RefreshEndpoints). Synchronization is needed to ensure that only one thread at + // a time is calling on behalf of a tenant, and that calls are rate limited + // to prevent storms. + calls struct { + syncutil.Mutex + lastRefresh time.Time + } +} + +// Initialize fetches metadata about a tenant, such as its cluster name, and stores +// that in the entry. After this is called once, all future calls return the +// same result (and do nothing). +func (e *tenantEntry) Initialize(ctx context.Context, client DirectoryClient) error { + // Synchronize multiple threads trying to initialize. Only the first thread + // does the initialization. + e.calls.Lock() + defer e.calls.Unlock() + + // If Initialize has already been called, return any error that occurred. + if e.initialized { + return e.initError + } + + tenantResp, err := client.GetTenant(ctx, &GetTenantRequest{TenantID: e.TenantID.ToUint64()}) + if err != nil { + e.initialized = true + e.initError = err + return err + } + + e.ClusterName = tenantResp.ClusterName + + e.initialized = true + return nil +} + +// RefreshEndpoints makes a synchronous directory server call to fetch the latest information +// about the tenant's available endpoints, such as their IP addresses. +func (e *tenantEntry) RefreshEndpoints(ctx context.Context, client DirectoryClient) error { + if !e.initialized { + return errors.AssertionFailedf("entry for tenant %d is not initialized", e.TenantID) + } + + // Lock so that only one thread at a time will refresh, since there's no + // point in multiple threads doing it within a short span of time - it's + // likely nothing has changed. + e.calls.Lock() + defer e.calls.Unlock() + + // If refreshed recently, no-op. + if !e.canRefreshLocked() { + return nil + } + + log.Infof(ctx, "refreshing tenant %d endpoints", e.TenantID) + + _, err := e.fetchEndpointsLocked(ctx, client) + return err +} + +// ChooseEndpointIP returns the IP address of one of this tenant's available endpoints. +// If a tenant has multiple endpoints, then ChooseEndpointIP returns the IP address of one +// of those endpoints. If the tenant is suspended and no endpoints are available, then +// ChooseEndpointIP will trigger resumption of the tenant and return the IP address +// of the new endpoint. Note that resuming a tenant requires directory server calls, so +// ChooseEndpointIP can block for some time, until the resumption process is +// complete. However, if errorIfNoEndpoints is true, then ChooseEndpointIP returns an +// error if there are no endpoints available rather than blocking. +// +// TODO(andyk): Use better load-balancing algorithm once tenants can have more +// than one endpoint. +func (e *tenantEntry) ChooseEndpointIP( + ctx context.Context, client DirectoryClient, errorIfNoEndpoints bool, +) (string, error) { + if !e.initialized { + return "", errors.AssertionFailedf("entry for tenant %d is not initialized", e.TenantID) + } + + ips := e.getEndpointIPs() + if len(ips) == 0 { + // There are no known endpoint IP addresses, so fetch endpoint information from + // the directory server. Resume the tenant if it is suspended; that will + // always result in at least one endpoint IP address (or an error). + var err error + if ips, err = e.ensureTenantEndpoint(ctx, client, errorIfNoEndpoints); err != nil { + return "", err + } + } + return ips[0], nil +} + +// AddEndpointIP inserts the given IP address into the tenant's list of Endpoint IPs. If +// it is already present, then AddEndpointIP returns false. +func (e *tenantEntry) AddEndpointIP(ip string) bool { + e.endpoints.Lock() + defer e.endpoints.Unlock() + + for _, existing := range e.endpoints.ips { + if existing == ip { + return false + } + } + + e.endpoints.ips = append(e.endpoints.ips, ip) + return true +} + +// RemoveEndpointIP removes the given IP address from the tenant's list of Endpoint IPs. +// If it was not present, RemoveEndpointIP returns false. +func (e *tenantEntry) RemoveEndpointIP(ip string) bool { + e.endpoints.Lock() + defer e.endpoints.Unlock() + + for i, existing := range e.endpoints.ips { + if existing == ip { + copy(e.endpoints.ips[i:], e.endpoints.ips[i+1:]) + e.endpoints.ips = e.endpoints.ips[:len(e.endpoints.ips)-1] + return true + } + } + return false +} + +// getEndpointIPs gets the current list of endpoint IP addresses within scope of lock and +// returns them. +func (e *tenantEntry) getEndpointIPs() []string { + e.endpoints.Lock() + defer e.endpoints.Unlock() + return e.endpoints.ips +} + +// ensureTenantEndpoint ensures that at least one SQL process exists for this +// tenant, and is ready for connection attempts to its IP address. If +// errorIfNoEndpoints is true, then ensureTenantEndpoint returns an error if there are no +// endpoints available rather than blocking. +func (e *tenantEntry) ensureTenantEndpoint( + ctx context.Context, client DirectoryClient, errorIfNoEndpoints bool, +) (ips []string, err error) { + const retryDelay = 100 * time.Millisecond + + e.calls.Lock() + defer e.calls.Unlock() + + // If an IP address is already available, nothing more to do. Check this + // immediately after obtaining the lock so that only the first thread does + // the work to get information about the tenant. + ips = e.getEndpointIPs() + if len(ips) != 0 { + return ips, nil + } + + // Get up-to-date count of endpoints for the tenant from the K8s server. + resp, err := client.ListEndpoints(ctx, &ListEndpointsRequest{TenantID: e.TenantID.ToUint64()}) + if err != nil { + return nil, err + } + + for { + // Check for context cancellation or timeout. + if err = ctx.Err(); err != nil { + return nil, err + } + + // Check if tenant needs to be resumed. + if len(resp.Endpoints) == 0 { + log.Infof(ctx, "resuming tenant %d", e.TenantID) + + if _, err := client.EnsureEndpoint( + ctx, &EnsureEndpointRequest{e.TenantID.ToUint64()}, + ); err != nil { + return nil, err + } + } + + // Get endpoint information for the newly resumed tenant. Except in rare race + // conditions, this is expected to immediately find an IP address, since + // the above call started a tenant process that already has an IP address. + ips, err = e.fetchEndpointsLocked(ctx, client) + if err != nil { + return nil, err + } + if len(ips) != 0 { + break + } + + // In rare case where no IP address is ready, wait for a bit before + // retrying. + if errorIfNoEndpoints { + return nil, fmt.Errorf("no endpoints available for tenant %s", e.TenantID) + } + sleepContext(ctx, retryDelay) + } + + return ips, nil +} + +// fetchEndpointsLocked makes a synchronous directory server call to get the latest +// information about the tenant's available endpoints, such as their IP addresses. +// +// NOTE: Caller must lock the "calls" mutex before calling fetchEndpointsLocked. +func (e *tenantEntry) fetchEndpointsLocked( + ctx context.Context, client DirectoryClient, +) (ips []string, err error) { + // List the endpoints for the given tenant. + list, err := client.ListEndpoints(ctx, &ListEndpointsRequest{e.TenantID.ToUint64()}) + if err != nil { + return nil, err + } + + // Get updated list of running process endpoint IP addresses and save it to the entry. + ips = make([]string, 0, len(list.Endpoints)) + for i := range list.Endpoints { + endpoint := list.Endpoints[i] + ips = append(ips, endpoint.IP) + } + + // Need to lock in case another thread is reading the IP addresses (e.g. in + // ChooseEndpointIP). + e.endpoints.Lock() + defer e.endpoints.Unlock() + e.endpoints.ips = ips + + if len(ips) != 0 { + log.Infof(ctx, "fetched IP addresses for tenant %d: %v", e.TenantID, ips) + } + + return ips, nil +} + +// canRefreshLocked returns true if it's been at least X milliseconds since the +// last time the tenant endpoint information was refreshed. This has the effect of +// rate limiting RefreshEndpoints calls. +// +// NOTE: Caller must lock the "calls" mutex before calling canRefreshLocked. +func (e *tenantEntry) canRefreshLocked() bool { + now := timeutil.Now() + if now.Sub(e.calls.lastRefresh) < e.RefreshDelay { + return false + } + e.calls.lastRefresh = now + return true +} diff --git a/pkg/ccl/sqlproxyccl/tenant/main_test.go b/pkg/ccl/sqlproxyccl/tenant/main_test.go new file mode 100644 index 000000000000..3a25676ae8dd --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/main_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenant + +import ( + "os" + "testing" + + _ "github.com/cockroachdb/cockroach/pkg/ccl" + "github.com/cockroachdb/cockroach/pkg/ccl/utilccl" + "github.com/cockroachdb/cockroach/pkg/security" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +func TestMain(m *testing.M) { + defer utilccl.TestingEnableEnterprise()() + security.SetAssetLoader(securitytest.EmbeddedAssets) + randutil.SeedForTests() + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + os.Exit(m.Run()) +} + +//go:generate ../../../util/leaktest/add-leaktest.sh *_test.go diff --git a/pkg/ccl/sqlproxyccl/tenant/mocks_generated.go b/pkg/ccl/sqlproxyccl/tenant/mocks_generated.go new file mode 100644 index 000000000000..15c887ef9ea2 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/mocks_generated.go @@ -0,0 +1,240 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/tenant (interfaces: DirectoryClient,Directory_WatchEndpointsClient) + +// Package tenant is a generated GoMock package. +package tenant + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" +) + +// MockDirectoryClient is a mock of DirectoryClient interface. +type MockDirectoryClient struct { + ctrl *gomock.Controller + recorder *MockDirectoryClientMockRecorder +} + +// MockDirectoryClientMockRecorder is the mock recorder for MockDirectoryClient. +type MockDirectoryClientMockRecorder struct { + mock *MockDirectoryClient +} + +// NewMockDirectoryClient creates a new mock instance. +func NewMockDirectoryClient(ctrl *gomock.Controller) *MockDirectoryClient { + mock := &MockDirectoryClient{ctrl: ctrl} + mock.recorder = &MockDirectoryClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDirectoryClient) EXPECT() *MockDirectoryClientMockRecorder { + return m.recorder +} + +// EnsureEndpoint mocks base method. +func (m *MockDirectoryClient) EnsureEndpoint(arg0 context.Context, arg1 *EnsureEndpointRequest, arg2 ...grpc.CallOption) (*EnsureEndpointResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EnsureEndpoint", varargs...) + ret0, _ := ret[0].(*EnsureEndpointResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EnsureEndpoint indicates an expected call of EnsureEndpoint. +func (mr *MockDirectoryClientMockRecorder) EnsureEndpoint(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureEndpoint", reflect.TypeOf((*MockDirectoryClient)(nil).EnsureEndpoint), varargs...) +} + +// GetTenant mocks base method. +func (m *MockDirectoryClient) GetTenant(arg0 context.Context, arg1 *GetTenantRequest, arg2 ...grpc.CallOption) (*GetTenantResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetTenant", varargs...) + ret0, _ := ret[0].(*GetTenantResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTenant indicates an expected call of GetTenant. +func (mr *MockDirectoryClientMockRecorder) GetTenant(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTenant", reflect.TypeOf((*MockDirectoryClient)(nil).GetTenant), varargs...) +} + +// ListEndpoints mocks base method. +func (m *MockDirectoryClient) ListEndpoints(arg0 context.Context, arg1 *ListEndpointsRequest, arg2 ...grpc.CallOption) (*ListEndpointsResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListEndpoints", varargs...) + ret0, _ := ret[0].(*ListEndpointsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEndpoints indicates an expected call of ListEndpoints. +func (mr *MockDirectoryClientMockRecorder) ListEndpoints(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEndpoints", reflect.TypeOf((*MockDirectoryClient)(nil).ListEndpoints), varargs...) +} + +// WatchEndpoints mocks base method. +func (m *MockDirectoryClient) WatchEndpoints(arg0 context.Context, arg1 *WatchEndpointsRequest, arg2 ...grpc.CallOption) (Directory_WatchEndpointsClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WatchEndpoints", varargs...) + ret0, _ := ret[0].(Directory_WatchEndpointsClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WatchEndpoints indicates an expected call of WatchEndpoints. +func (mr *MockDirectoryClientMockRecorder) WatchEndpoints(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchEndpoints", reflect.TypeOf((*MockDirectoryClient)(nil).WatchEndpoints), varargs...) +} + +// MockDirectory_WatchEndpointsClient is a mock of Directory_WatchEndpointsClient interface. +type MockDirectory_WatchEndpointsClient struct { + ctrl *gomock.Controller + recorder *MockDirectory_WatchEndpointsClientMockRecorder +} + +// MockDirectory_WatchEndpointsClientMockRecorder is the mock recorder for MockDirectory_WatchEndpointsClient. +type MockDirectory_WatchEndpointsClientMockRecorder struct { + mock *MockDirectory_WatchEndpointsClient +} + +// NewMockDirectory_WatchEndpointsClient creates a new mock instance. +func NewMockDirectory_WatchEndpointsClient(ctrl *gomock.Controller) *MockDirectory_WatchEndpointsClient { + mock := &MockDirectory_WatchEndpointsClient{ctrl: ctrl} + mock.recorder = &MockDirectory_WatchEndpointsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDirectory_WatchEndpointsClient) EXPECT() *MockDirectory_WatchEndpointsClientMockRecorder { + return m.recorder +} + +// CloseSend mocks base method. +func (m *MockDirectory_WatchEndpointsClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockDirectory_WatchEndpointsClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockDirectory_WatchEndpointsClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).Header)) +} + +// Recv mocks base method. +func (m *MockDirectory_WatchEndpointsClient) Recv() (*WatchEndpointsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*WatchEndpointsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).Recv)) +} + +// RecvMsg mocks base method. +func (m *MockDirectory_WatchEndpointsClient) RecvMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).RecvMsg), arg0) +} + +// SendMsg mocks base method. +func (m *MockDirectory_WatchEndpointsClient) SendMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).SendMsg), arg0) +} + +// Trailer mocks base method. +func (m *MockDirectory_WatchEndpointsClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockDirectory_WatchEndpointsClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockDirectory_WatchEndpointsClient)(nil).Trailer)) +} diff --git a/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go new file mode 100644 index 000000000000..d2e4dd7ec624 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/tenant/test_directory_svr.go @@ -0,0 +1,352 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package tenant + +import ( + "bytes" + "container/list" + "context" + "fmt" + "net" + "os" + "os/exec" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" + "google.golang.org/grpc" +) + +// Making sure that TestDirectoryServer implements the DirectoryServer interface. +var _ DirectoryServer = (*TestDirectoryServer)(nil) + +// Process stores information about a running tenant process. +type Process struct { + Stopper *stop.Stopper + Cmd *exec.Cmd + SQL net.Addr +} + +// NewSubStopper creates a new stopper that will be stopped +// when either the parent is stopped or its own Stop is called. +// The code is slightly more complicated that simply calling +// NewStopper followed by AddCloser since there is a possibility that between +// the two calls, the parent stopper completes a stop and then the leak detection +// may find a leaked stopper. +func NewSubStopper(parentStopper *stop.Stopper) *stop.Stopper { + mu := &syncutil.Mutex{} + var subStopper *stop.Stopper + parentStopper.AddCloser(stop.CloserFn(func() { + mu.Lock() + defer mu.Unlock() + if subStopper == nil { + subStopper = stop.NewStopper() + } + subStopper.Stop(context.Background()) + })) + mu.Lock() + defer mu.Unlock() + if subStopper == nil { + subStopper = stop.NewStopper() + } + return subStopper +} + +// TestDirectoryServer is a directory server implementation that is used +// for testing. +type TestDirectoryServer struct { + stopper *stop.Stopper + grpcServer *grpc.Server + // TenantStarterFunc will be used to launch a new tenant process. + TenantStarterFunc func(ctx context.Context, tenantID uint64) (*Process, error) + + // When both mutexes need to be held, the locking should always be + // proc first and listen second. + proc struct { + syncutil.RWMutex + processByAddrByTenantID map[uint64]map[net.Addr]*Process + } + listen struct { + syncutil.RWMutex + eventListeners *list.List + } +} + +// Get a tenant's list of endpoints and the process information for each endpoint. +func (s *TestDirectoryServer) Get(id roachpb.TenantID) (result map[net.Addr]*Process) { + result = make(map[net.Addr]*Process) + s.proc.RLock() + defer s.proc.RUnlock() + processes, ok := s.proc.processByAddrByTenantID[id.ToUint64()] + if ok { + for k, v := range processes { + result[k] = v + } + } + return +} + +// GetTenant returns tenant metadata for a given ID. Hard coded to return +// every tenant's cluster name as "tenant-cluster" +func (s *TestDirectoryServer) GetTenant( + _ context.Context, _ *GetTenantRequest, +) (*GetTenantResponse, error) { + return &GetTenantResponse{ + ClusterName: "tenant-cluster", + }, nil +} + +// ListEndpoints returns a list of tenant process endpoints as well as status of the +// processes. +func (s *TestDirectoryServer) ListEndpoints( + ctx context.Context, req *ListEndpointsRequest, +) (*ListEndpointsResponse, error) { + ctx = logtags.AddTag(ctx, "tenant", req.TenantID) + s.proc.RLock() + defer s.proc.RUnlock() + return s.listLocked(ctx, req) +} + +// WatchEndpoints returns a new stream, that can be used to monitor server activity. +func (s *TestDirectoryServer) WatchEndpoints( + _ *WatchEndpointsRequest, server Directory_WatchEndpointsServer, +) error { + select { + case <-s.stopper.ShouldQuiesce(): + return context.Canceled + default: + } + // Make the channel with a small buffer to allow for a burst of notifications + // and a slow receiver. + c := make(chan *WatchEndpointsResponse, 10) + s.listen.Lock() + elem := s.listen.eventListeners.PushBack(c) + s.listen.Unlock() + err := s.stopper.RunTask(context.Background(), "watch-endpoints-server", + func(ctx context.Context) { + out: + for { + select { + case e, ok := <-c: + if !ok { + break out + } + if err := server.Send(e); err != nil { + s.listen.Lock() + s.listen.eventListeners.Remove(elem) + close(c) + s.listen.Unlock() + break out + } + case <-s.stopper.ShouldQuiesce(): + s.listen.Lock() + s.listen.eventListeners.Remove(elem) + close(c) + s.listen.Unlock() + break out + } + } + }) + return err +} + +func (s *TestDirectoryServer) notifyEventListenersLocked(req *WatchEndpointsResponse) { + for e := s.listen.eventListeners.Front(); e != nil; { + select { + case e.Value.(chan *WatchEndpointsResponse) <- req: + e = e.Next() + default: + // The receiver is unable to consume fast enough. Close the channel and + // remove it from the list. + eToClose := e + e = e.Next() + close(eToClose.Value.(chan *WatchEndpointsResponse)) + s.listen.eventListeners.Remove(eToClose) + } + } +} + +// EnsureEndpoint will ensure that there is either an already active tenant process or +// it will start a new one. It will return an error if starting a new tenant +// process is impossible. +func (s *TestDirectoryServer) EnsureEndpoint( + ctx context.Context, req *EnsureEndpointRequest, +) (*EnsureEndpointResponse, error) { + select { + case <-s.stopper.ShouldQuiesce(): + return nil, context.Canceled + default: + } + + ctx = logtags.AddTag(ctx, "tenant", req.TenantID) + + s.proc.Lock() + defer s.proc.Unlock() + + lst, err := s.listLocked(ctx, &ListEndpointsRequest{req.TenantID}) + if err != nil { + return nil, err + } + if len(lst.Endpoints) == 0 { + process, err := s.TenantStarterFunc(ctx, req.TenantID) + if err != nil { + return nil, err + } + s.registerInstanceLocked(req.TenantID, process) + process.Stopper.AddCloser(stop.CloserFn(func() { + s.deregisterInstance(req.TenantID, process.SQL) + })) + } + + return &EnsureEndpointResponse{}, nil +} + +// Serve requests on the given listener. +func (s *TestDirectoryServer) Serve(listener net.Listener) error { + return s.grpcServer.Serve(listener) +} + +// NewTestDirectoryServer will create a new server. +func NewTestDirectoryServer(stopper *stop.Stopper) *TestDirectoryServer { + dir := &TestDirectoryServer{ + grpcServer: grpc.NewServer(), + stopper: stopper, + } + dir.TenantStarterFunc = dir.startTenantLocked + dir.proc.processByAddrByTenantID = map[uint64]map[net.Addr]*Process{} + dir.listen.eventListeners = list.New() + RegisterDirectoryServer(dir.grpcServer, dir) + return dir +} + +func (s *TestDirectoryServer) listLocked( + _ context.Context, req *ListEndpointsRequest, +) (*ListEndpointsResponse, error) { + processByAddr, ok := s.proc.processByAddrByTenantID[req.TenantID] + if !ok { + return &ListEndpointsResponse{}, nil + } + resp := ListEndpointsResponse{} + for addr := range processByAddr { + resp.Endpoints = append(resp.Endpoints, &Endpoint{IP: addr.String()}) + } + return &resp, nil +} + +func (s *TestDirectoryServer) registerInstanceLocked(tenantID uint64, process *Process) { + processByAddr, ok := s.proc.processByAddrByTenantID[tenantID] + if !ok { + processByAddr = map[net.Addr]*Process{} + s.proc.processByAddrByTenantID[tenantID] = processByAddr + } + processByAddr[process.SQL] = process + + s.listen.RLock() + defer s.listen.RUnlock() + s.notifyEventListenersLocked(&WatchEndpointsResponse{ + Typ: ADDED, + IP: process.SQL.String(), + TenantID: tenantID, + }) +} + +func (s *TestDirectoryServer) deregisterInstance(tenantID uint64, sql net.Addr) { + s.proc.Lock() + defer s.proc.Unlock() + processByAddr, ok := s.proc.processByAddrByTenantID[tenantID] + if !ok { + return + } + + if _, ok = processByAddr[sql]; ok { + delete(processByAddr, sql) + + s.listen.RLock() + defer s.listen.RUnlock() + s.notifyEventListenersLocked(&WatchEndpointsResponse{ + Typ: DELETED, + IP: sql.String(), + TenantID: tenantID, + }) + } +} + +// startTenantLocked is the default tenant process startup logic that runs the +// cockroach db executable out of process. +func (s *TestDirectoryServer) startTenantLocked( + ctx context.Context, tenantID uint64, +) (*Process, error) { + // A hackish way to have the sql tenant process listen on known ports. + sql, err := net.Listen("tcp", "") + if err != nil { + return nil, err + } + http, err := net.Listen("tcp", "") + if err != nil { + return nil, err + } + process := &Process{SQL: sql.Addr()} + args := []string{ + "mt", "start-sql", "--kv-addrs=127.0.0.1:26257", "--idle-exit-after=30s", + fmt.Sprintf("--sql-addr=%s", sql.Addr().String()), + fmt.Sprintf("--http-addr=%s", http.Addr().String()), + fmt.Sprintf("--tenant-id=%d", tenantID), + } + if err = sql.Close(); err != nil { + return nil, err + } + if err = http.Close(); err != nil { + return nil, err + } + + c := exec.Command("cockroach", args...) + process.Cmd = c + c.Env = append(os.Environ(), "COCKROACH_TRUST_CLIENT_PROVIDED_SQL_REMOTE_ADDR=true") + + if c.Stdout != nil { + return nil, errors.New("exec: Stdout already set") + } + if c.Stderr != nil { + return nil, errors.New("exec: Stderr already set") + } + var b bytes.Buffer + c.Stdout = &b + c.Stderr = &b + err = c.Start() + if err != nil { + return nil, err + } + process.Stopper = NewSubStopper(s.stopper) + process.Stopper.AddCloser(stop.CloserFn(func() { + _ = c.Process.Kill() + s.deregisterInstance(tenantID, process.SQL) + })) + err = process.Stopper.RunAsyncTask(ctx, "cmd-wait", func(ctx context.Context) { + if err := c.Wait(); err != nil { + log.Infof(ctx, "finished %s with err %s", process.Cmd.Args, err) + log.Infof(ctx, "output %s", b.Bytes()) + return + } + log.Infof(ctx, "finished %s with success", process.Cmd.Args) + process.Stopper.Stop(ctx) + }) + if err != nil { + return nil, err + } + + // Need to wait here for the spawned cockroach tenant process to get ready. + // Ideally - we want to check that it is up and connected to the KV host + // before we return. + time.Sleep(500 * time.Millisecond) + return process, nil +} diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 8aa0c8faade1..b5d22f85fca1 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -710,10 +710,8 @@ func SetupIdleMonitor( // StartTenant starts a SQL tenant communicating with this TestServer. func (ts *TestServer) StartTenant( - params base.TestTenantArgs, + ctx context.Context, params base.TestTenantArgs, ) (serverutils.TestTenantInterface, error) { - ctx := context.Background() - if !params.Existing { if _, err := ts.InternalExecutor().(*sql.InternalExecutor).Exec( ctx, "testserver-create-tenant", nil /* txn */, "SELECT crdb_internal.create_tenant($1)", params.TenantID.ToUint64(), @@ -722,12 +720,25 @@ func (ts *TestServer) StartTenant( } } + rowCount, err := ts.InternalExecutor().(*sql.InternalExecutor).Exec( + ctx, "testserver-check-tenant-active", nil, + "SELECT 1 FROM system.tenants WHERE id=$1 AND active=true", + params.TenantID.ToUint64(), + ) + if err != nil { + return nil, err + } + if rowCount == 0 { + return nil, errors.New("not found") + } + st := cluster.MakeTestingClusterSettings() sqlCfg := makeTestSQLConfig(st, params.TenantID) sqlCfg.TenantKVAddrs = []string{ts.ServingRPCAddr()} baseCfg := makeTestBaseConfig(st) baseCfg.TestingKnobs = params.TestingKnobs baseCfg.IdleExitAfter = params.IdleExitAfter + baseCfg.Insecure = params.ForceInsecure if params.AllowSettingClusterSettings { baseCfg.TestingKnobs.TenantTestingKnobs = &sql.TenantTestingKnobs{ ClusterSettingsUpdater: st.MakeUpdater(), diff --git a/pkg/sql/logictest/logic.go b/pkg/sql/logictest/logic.go index 959432a418ae..96772b72ac7e 100644 --- a/pkg/sql/logictest/logic.go +++ b/pkg/sql/logictest/logic.go @@ -1460,7 +1460,7 @@ func (t *logicTest) newCluster(serverArgs TestServerArgs) { // Prevent a logging assertion that the server ID is initialized multiple times. log.TestingClearServerIdentifiers() - tenant, err := t.cluster.Server(t.nodeIdx).StartTenant(tenantArgs) + tenant, err := t.cluster.Server(t.nodeIdx).StartTenant(context.Background(), tenantArgs) if err != nil { t.rootT.Fatalf("%+v", err) } diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index b7eecd099991..9533870b428f 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -818,6 +818,7 @@ func TestLint(t *testing.T) { ":!util/grpcutil/grpc_util_test.go", ":!server/testserver.go", ":!util/tracing/*_test.go", + ":!ccl/sqlproxyccl/tenant/test_directory_svr.go", ) if err != nil { t.Fatal(err) diff --git a/pkg/testutils/serverutils/test_server_shim.go b/pkg/testutils/serverutils/test_server_shim.go index 86f99d44798d..0e094c944625 100644 --- a/pkg/testutils/serverutils/test_server_shim.go +++ b/pkg/testutils/serverutils/test_server_shim.go @@ -218,7 +218,7 @@ type TestServerInterface interface { DiagnosticsReporter() interface{} // StartTenant spawns off tenant process connecting to this TestServer. - StartTenant(params base.TestTenantArgs) (TestTenantInterface, error) + StartTenant(ctx context.Context, params base.TestTenantArgs) (TestTenantInterface, error) // ScratchRange splits off a range suitable to be used as KV scratch space. // (it doesn't overlap system spans or SQL tables). @@ -340,7 +340,7 @@ func StartServerRaw(args base.TestServerArgs) (TestServerInterface, error) { func StartTenant( t testing.TB, ts TestServerInterface, params base.TestTenantArgs, ) (TestTenantInterface, *gosql.DB) { - tenant, err := ts.StartTenant(params) + tenant, err := ts.StartTenant(context.Background(), params) if err != nil { t.Fatal(err) }