From ab2a28604414dd9fb0f9077ff76a4b5e64f92459 Mon Sep 17 00:00:00 2001 From: Jeff Date: Thu, 23 Mar 2023 18:04:36 +0000 Subject: [PATCH] server: clean up sql session and instance on shutdown Remove the sql server's system.sql_instance and sqlliveness.sqlliveness rows on sql server shut down. The original motivation for this change is it improves the behavior of DistSQL for serverless tenants. SQL servers would attempt to schedule DistSQL flows on sql_instance rows that belonged to shutdown servers. Cleaning up the session and instance have a hand full of small benefits. - The sql_instances pre-allocated in a region are less likely to run out and trigger the slow cold start path. - Resources leased by the session, like jobs, may be re-claimed more quickly after the server shuts down. Fixes: CC-9095 Release note: none --- pkg/server/BUILD.bazel | 1 + pkg/server/drain.go | 12 +- pkg/server/drain_test.go | 46 ++++ .../instancestorage/instancereader_test.go | 7 +- .../instancestorage/instancestorage.go | 86 ++++--- .../instancestorage/instancestorage_test.go | 218 ++++++++---------- pkg/sql/sqlliveness/slinstance/slinstance.go | 66 +++++- .../sqlliveness/slinstance/slinstance_test.go | 27 ++- pkg/sql/sqlliveness/slstorage/BUILD.bazel | 1 + pkg/sql/sqlliveness/slstorage/slstorage.go | 30 +++ .../sqlliveness/slstorage/slstorage_test.go | 5 +- pkg/sql/sqlliveness/sqlliveness.go | 8 +- 12 files changed, 325 insertions(+), 182 deletions(-) diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 2c61e55fed0d..252322f37995 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -445,6 +445,7 @@ go_test( "//pkg/base", "//pkg/base/serverident", "//pkg/build", + "//pkg/ccl/kvccl/kvtenantccl", "//pkg/cli/exit", "//pkg/clusterversion", "//pkg/config", diff --git a/pkg/server/drain.go b/pkg/server/drain.go index adb351a79f18..7d17504414ac 100644 --- a/pkg/server/drain.go +++ b/pkg/server/drain.go @@ -401,8 +401,16 @@ func (s *drainServer) drainClients( // errors/warnings, if any. log.Infof(ctx, "SQL server drained successfully; SQL queries cannot execute any more") - // FIXME(Jeff): Add code here to remove the sql_instances row or - // something similar. + session, err := s.sqlServer.sqlLivenessProvider.Release(ctx) + if err != nil { + return err + } + + instanceID := s.sqlServer.sqlIDContainer.SQLInstanceID() + err = s.sqlServer.sqlInstanceStorage.ReleaseInstance(ctx, session, instanceID) + if err != nil { + return err + } // Mark the node as fully drained. s.sqlServer.gracefulDrainComplete.Set(true) diff --git a/pkg/server/drain_test.go b/pkg/server/drain_test.go index c91833638e9a..912d938c90c6 100644 --- a/pkg/server/drain_test.go +++ b/pkg/server/drain_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/sql" @@ -307,3 +308,48 @@ func getAdminClientForServer( _ = conn.Close() // nolint:grpcconnclose }, nil } + +func TestServerShutdownReleasesSession(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + DisableDefaultTestTenant: true, + }) + defer s.Stopper().Stop(ctx) + + tenantArgs := base.TestTenantArgs{ + TenantID: serverutils.TestTenantID(), + } + + // ensure the tenant conenctor is linked + var _ = kvtenantccl.Connector{} + tenant, tenantSQLRaw := serverutils.StartTenant(t, s, tenantArgs) + defer tenant.Stopper().Stop(ctx) + tenantSQL := sqlutils.MakeSQLRunner(tenantSQLRaw) + + queryOwner := func(id base.SQLInstanceID) (owner *string) { + tenantSQL.QueryRow(t, "SELECT session_id FROM system.sql_instances WHERE id = $1", id).Scan(&owner) + return owner + } + + sessionExists := func(session string) bool { + rows := tenantSQL.QueryStr(t, "SELECT session_id FROM system.sqlliveness WHERE session_id = $1", session) + return 0 < len(rows) + } + + tmpTenant, err := s.StartTenant(ctx, tenantArgs) + require.NoError(t, err) + + tmpSQLInstance := tmpTenant.SQLInstanceID() + session := queryOwner(tmpSQLInstance) + require.NotNil(t, session) + require.True(t, sessionExists(*session)) + + require.NoError(t, tmpTenant.DrainClients(context.Background())) + + require.False(t, sessionExists(*session), "expected session %s to be deleted from the sqlliveness table, but it still exists", *session) + require.Nil(t, queryOwner(tmpSQLInstance), "expected sql_instance %d to have no owning session_id", tmpSQLInstance) +} diff --git a/pkg/sql/sqlinstance/instancestorage/instancereader_test.go b/pkg/sql/sqlinstance/instancestorage/instancereader_test.go index 0dbfab6b08ba..4167100c7e9d 100644 --- a/pkg/sql/sqlinstance/instancestorage/instancereader_test.go +++ b/pkg/sql/sqlinstance/instancestorage/instancereader_test.go @@ -22,7 +22,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/server/settingswatcher" "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" - "github.com/cockroachdb/cockroach/pkg/sql/enum" "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance/instancestorage" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" @@ -145,7 +144,6 @@ func TestReader(t *testing.T) { require.NoError(t, reader.WaitForStarted(ctx)) // Set up expected test data. - region := enum.One instanceIDs := []base.SQLInstanceID{1, 2, 3} rpcAddresses := []string{"addr1", "addr2", "addr3"} sqlAddresses := []string{"addr4", "addr5", "addr6"} @@ -259,7 +257,7 @@ func TestReader(t *testing.T) { // Release an instance and verify only active instances are returned. { - err := storage.ReleaseInstanceID(ctx, region, instanceIDs[0]) + err := storage.ReleaseInstance(ctx, sessionIDs[0], instanceIDs[0]) if err != nil { t.Fatal(err) } @@ -314,7 +312,6 @@ func TestReader(t *testing.T) { reader.Start(ctx, sqlinstance.InstanceInfo{}) require.NoError(t, reader.WaitForStarted(ctx)) // Create three instances and release one. - region := enum.One instanceIDs := [...]base.SQLInstanceID{1, 2, 3} rpcAddresses := [...]string{"addr1", "addr2", "addr3"} sqlAddresses := [...]string{"addr4", "addr5", "addr6"} @@ -367,7 +364,7 @@ func TestReader(t *testing.T) { // Verify request for released instance data results in an error. { - err := storage.ReleaseInstanceID(ctx, region, instanceIDs[0]) + err := storage.ReleaseInstance(ctx, sessionIDs[0], instanceIDs[0]) if err != nil { t.Fatal(err) } diff --git a/pkg/sql/sqlinstance/instancestorage/instancestorage.go b/pkg/sql/sqlinstance/instancestorage/instancestorage.go index a02378bcf1ff..52ae8a879c0e 100644 --- a/pkg/sql/sqlinstance/instancestorage/instancestorage.go +++ b/pkg/sql/sqlinstance/instancestorage/instancestorage.go @@ -188,6 +188,62 @@ func (s *Storage) CreateInstance( return s.createInstanceRow(ctx, sessionID, sessionExpiration, rpcAddr, sqlAddr, locality, binaryVersion, noNodeID) } +// ReleaseInstance deallocates the instance id iff it is currently owned by the +// provided sessionID. +func (s *Storage) ReleaseInstance( + ctx context.Context, sessionID sqlliveness.SessionID, instanceID base.SQLInstanceID, +) error { + return s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { + version, err := s.versionGuard(ctx, txn) + if err != nil { + return err + } + + region, _, err := slstorage.UnsafeDecodeSessionID(sessionID) + if err != nil { + return errors.Wrap(err, "unable to determine region for sql_instance") + } + + readCodec := s.getReadCodec(&version) + + key := readCodec.encodeKey(region, instanceID) + kv, err := txn.Get(ctx, key) + if err != nil { + return err + } + + instance, err := readCodec.decodeRow(kv.Key, kv.Value) + if err != nil { + return err + } + + if instance.sessionID != sessionID { + // Great! The session was already released or released and + // claimed by another server. + return nil + } + + batch := txn.NewBatch() + + value, err := readCodec.encodeAvailableValue() + if err != nil { + return err + } + batch.Put(key, value) + + if dualCodec := s.getDualWriteCodec(&version); dualCodec != nil { + dualKey := dualCodec.encodeKey(region, instanceID) + dualValue, err := dualCodec.encodeAvailableValue() + if err != nil { + return err + } + batch.Put(dualKey, dualValue) + } + + return txn.CommitInBatch(ctx, batch) + }) +} + func (s *Storage) createInstanceRow( ctx context.Context, sessionID sqlliveness.SessionID, @@ -507,36 +563,6 @@ func (s *Storage) getInstanceRows( return instances, nil } -// ReleaseInstanceID deletes an instance ID record. The instance ID becomes -// available to be reused by another SQL pod of the same tenant. -// TODO(jeffswenson): delete this, it is unused. -func (s *Storage) ReleaseInstanceID( - ctx context.Context, region []byte, id base.SQLInstanceID, -) error { - // TODO(andrei): Ensure that we do not delete an instance ID that we no longer - // own, instead of deleting blindly. - ctx = multitenant.WithTenantCostControlExemption(ctx) - err := s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { - version, err := s.versionGuard(ctx, txn) - if err != nil { - return err - } - - b := txn.NewBatch() - - readCodec := s.getReadCodec(&version) - b.Del(readCodec.encodeKey(region, id)) - if dualCodec := s.getDualWriteCodec(&version); dualCodec != nil { - b.Del(dualCodec.encodeKey(region, id)) - } - return txn.CommitInBatch(ctx, b) - }) - if err != nil { - return errors.Wrapf(err, "could not delete instance %d", id) - } - return nil -} - // RunInstanceIDReclaimLoop runs a background task that allocates available // instance IDs and reclaim expired ones within the sql_instances table. func (s *Storage) RunInstanceIDReclaimLoop( diff --git a/pkg/sql/sqlinstance/instancestorage/instancestorage_test.go b/pkg/sql/sqlinstance/instancestorage/instancestorage_test.go index 70c08663aca2..3d5513b5163f 100644 --- a/pkg/sql/sqlinstance/instancestorage/instancestorage_test.go +++ b/pkg/sql/sqlinstance/instancestorage/instancestorage_test.go @@ -107,28 +107,56 @@ func TestStorage(t *testing.T) { stopper, storage, slStorage, clock := setup(t) defer stopper.Stop(ctx) - // Create three instances and release one. - region := enum.One - instanceIDs := [...]base.SQLInstanceID{1, 2, 3, 4, 5} - rpcAddresses := [...]string{"addr1", "addr2", "addr3", "addr4", "addr5"} - sqlAddresses := [...]string{"addr6", "addr7", "addr8", "addr9", "addr10"} - sessionIDs := [...]sqlliveness.SessionID{makeSession(), makeSession(), makeSession(), makeSession(), makeSession()} - localities := [...]roachpb.Locality{ - {Tiers: []roachpb.Tier{{Key: "region", Value: "region1"}}}, - {Tiers: []roachpb.Tier{{Key: "region", Value: "region2"}}}, - {Tiers: []roachpb.Tier{{Key: "region", Value: "region3"}}}, - {Tiers: []roachpb.Tier{{Key: "region", Value: "region4"}}}, - {Tiers: []roachpb.Tier{{Key: "region", Value: "region5"}}}, - } - binaryVersions := []roachpb.Version{ - {Major: 22, Minor: 2}, {Major: 23, Minor: 1}, {Major: 23, Minor: 2}, {Major: 23, Minor: 3}, {Major: 24, Minor: 1}, - } sessionExpiry := clock.Now().Add(expiration.Nanoseconds(), 0) - for _, index := range []int{0, 1, 2} { - instance, err := storage.CreateInstance(ctx, sessionIDs[index], sessionExpiry, rpcAddresses[index], sqlAddresses[index], localities[index], binaryVersions[index]) + + makeInstance := func(id int) sqlinstance.InstanceInfo { + return sqlinstance.InstanceInfo{ + Region: enum.One, + InstanceID: base.SQLInstanceID(id), + InstanceSQLAddr: fmt.Sprintf("sql-addr-%d", id), + InstanceRPCAddr: fmt.Sprintf("rpc-addr-%d", id), + SessionID: makeSession(), + Locality: roachpb.Locality{Tiers: []roachpb.Tier{{Key: "region", Value: fmt.Sprintf("region-%d", id)}}}, + BinaryVersion: roachpb.Version{Major: 22, Minor: int32(id)}, + } + } + + createInstance := func(t *testing.T, instance sqlinstance.InstanceInfo) { + t.Helper() + + alive, err := slStorage.IsAlive(ctx, instance.SessionID) + require.NoError(t, err) + if !alive { + require.NoError(t, slStorage.Insert(ctx, instance.SessionID, sessionExpiry)) + } + + created, err := storage.CreateInstance(ctx, instance.SessionID, sessionExpiry, instance.InstanceRPCAddr, instance.InstanceSQLAddr, instance.Locality, instance.BinaryVersion) require.NoError(t, err) - require.NoError(t, slStorage.Insert(ctx, sessionIDs[index], sessionExpiry)) - require.Equal(t, instanceIDs[index], instance.InstanceID) + + require.Equal(t, instance, created) + } + + equalInstance := func(t *testing.T, expect sqlinstance.InstanceInfo, actual sqlinstance.InstanceInfo) { + require.Equal(t, expect.InstanceID, actual.InstanceID) + require.Equal(t, actual.SessionID, actual.SessionID) + require.Equal(t, actual.InstanceRPCAddr, actual.InstanceRPCAddr) + require.Equal(t, actual.InstanceSQLAddr, actual.InstanceSQLAddr) + require.Equal(t, actual.Locality, actual.Locality) + require.Equal(t, actual.BinaryVersion, actual.BinaryVersion) + } + + isAvailable := func(t *testing.T, instance sqlinstance.InstanceInfo, id base.SQLInstanceID) { + require.Equal(t, sqlinstance.InstanceInfo{InstanceID: id}, instance) + } + + var initialInstances []sqlinstance.InstanceInfo + for i := 1; i <= 5; i++ { + initialInstances = append(initialInstances, makeInstance(i)) + } + + // Create three instances and release one. + for _, instance := range initialInstances[:3] { + createInstance(t, instance) } // Verify all instances are returned by GetAllInstancesDataForTest. @@ -138,29 +166,16 @@ func TestStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, preallocatedCount, len(instances)) for _, i := range []int{0, 1, 2} { - require.Equal(t, instanceIDs[i], instances[i].InstanceID) - require.Equal(t, sessionIDs[i], instances[i].SessionID) - require.Equal(t, rpcAddresses[i], instances[i].InstanceRPCAddr) - require.Equal(t, sqlAddresses[i], instances[i].InstanceSQLAddr) - require.Equal(t, localities[i], instances[i].Locality) - require.Equal(t, binaryVersions[i], instances[i].BinaryVersion) + equalInstance(t, initialInstances[i], instances[i]) } for _, i := range []int{3, 4} { - require.Equal(t, base.SQLInstanceID(i+1), instances[i].InstanceID) - require.Empty(t, instances[i].SessionID) - require.Empty(t, instances[i].InstanceRPCAddr) - require.Empty(t, instances[i].InstanceSQLAddr) - require.Empty(t, instances[i].Locality) - require.Empty(t, instances[i].BinaryVersion) + isAvailable(t, instances[i], initialInstances[i].InstanceID) } } // Create two more instances. - for _, index := range []int{3, 4} { - instance, err := storage.CreateInstance(ctx, sessionIDs[index], sessionExpiry, rpcAddresses[index], sqlAddresses[index], localities[index], binaryVersions[index]) - require.NoError(t, err) - require.NoError(t, slStorage.Insert(ctx, sessionIDs[index], sessionExpiry)) - require.Equal(t, instanceIDs[index], instance.InstanceID) + for _, instance := range initialInstances[3:] { + createInstance(t, instance) } // Verify all instances are returned by GetAllInstancesDataForTest. @@ -170,113 +185,59 @@ func TestStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, preallocatedCount, len(instances)) for i := range instances { - require.Equal(t, instanceIDs[i], instances[i].InstanceID) - require.Equal(t, sessionIDs[i], instances[i].SessionID) - require.Equal(t, rpcAddresses[i], instances[i].InstanceRPCAddr) - require.Equal(t, sqlAddresses[i], instances[i].InstanceSQLAddr) - require.Equal(t, localities[i], instances[i].Locality) - require.Equal(t, binaryVersions[i], instances[i].BinaryVersion) + equalInstance(t, initialInstances[i], instances[i]) } } - // Release an instance and verify all instances are returned. + // Release an instance and verify the instance is available { - require.NoError(t, storage.ReleaseInstanceID(ctx, region, instanceIDs[0])) + toRelease := initialInstances[0] + + // Call twice to ensure it is idempotent + require.NoError(t, storage.ReleaseInstance(ctx, toRelease.SessionID, toRelease.InstanceID)) + require.NoError(t, storage.ReleaseInstance(ctx, toRelease.SessionID, toRelease.InstanceID)) + instances, err := storage.GetAllInstancesDataForTest(ctx) require.NoError(t, err) - require.Equal(t, preallocatedCount-1, len(instances)) + require.Equal(t, preallocatedCount, len(instances)) sortInstances(instances) - for i := 0; i < len(instances)-1; i++ { - require.Equal(t, instanceIDs[i+1], instances[i].InstanceID) - require.Equal(t, sessionIDs[i+1], instances[i].SessionID) - require.Equal(t, rpcAddresses[i+1], instances[i].InstanceRPCAddr) - require.Equal(t, sqlAddresses[i+1], instances[i].InstanceSQLAddr) - require.Equal(t, localities[i+1], instances[i].Locality) + + for i, instance := range instances { + if i == 0 { + isAvailable(t, instance, toRelease.InstanceID) + } else { + equalInstance(t, initialInstances[i], instance) + } } + + // re-allocate the instance + createInstance(t, initialInstances[0]) } // Verify instance ID associated with an expired session gets reused. - sessionID6 := makeSession() - rpcAddr6 := "rpcAddr6" - sqlAddr6 := "sqlAddr6" - locality6 := roachpb.Locality{Tiers: []roachpb.Tier{{Key: "region", Value: "region6"}}} - binaryVersion6 := roachpb.Version{Major: 24, Minor: 3} + newInstance4 := makeInstance(1337) + newInstance4.InstanceID = initialInstances[4].InstanceID { - require.NoError(t, slStorage.Delete(ctx, sessionIDs[4])) - require.NoError(t, slStorage.Insert(ctx, sessionID6, sessionExpiry)) - instance, err := storage.CreateInstance(ctx, sessionID6, sessionExpiry, rpcAddr6, sqlAddr6, locality6, binaryVersion6) - require.NoError(t, err) - require.Equal(t, instanceIDs[4], instance.InstanceID) + require.NoError(t, slStorage.Delete(ctx, initialInstances[4].SessionID)) + + createInstance(t, newInstance4) + instances, err := storage.GetAllInstancesDataForTest(ctx) - sortInstances(instances) require.NoError(t, err) - require.Equal(t, preallocatedCount-1, len(instances)) - var foundIDs []base.SQLInstanceID + sortInstances(instances) + + require.Equal(t, len(initialInstances), len(instances)) for index, instance := range instances { - foundIDs = append(foundIDs, instance.InstanceID) + expect := initialInstances[index] if index == 3 { - require.Equal(t, sessionID6, instance.SessionID) - require.Equal(t, rpcAddr6, instance.InstanceRPCAddr) - require.Equal(t, sqlAddr6, instance.InstanceSQLAddr) - require.Equal(t, locality6, instance.Locality) - require.Equal(t, binaryVersion6, instance.BinaryVersion) + expect = newInstance4 continue } - require.Equal(t, sessionIDs[index+1], instance.SessionID) - require.Equal(t, rpcAddresses[index+1], instance.InstanceRPCAddr) - require.Equal(t, sqlAddresses[index+1], instance.InstanceSQLAddr) - require.Equal(t, localities[index+1], instance.Locality) - require.Equal(t, binaryVersions[index+1], instance.BinaryVersion) + equalInstance(t, expect, instance) } - require.Equal(t, []base.SQLInstanceID{2, 3, 4, 5}, foundIDs) } - // Verify released instance ID gets reused. - { - newSessionID := makeSession() - newRPCAddr := "rpcAddr7" - newSQLAddr := "sqlAddr7" - newBinaryVersion := roachpb.Version{Major: 29, Minor: 4} - newLocality := roachpb.Locality{Tiers: []roachpb.Tier{{Key: "region", Value: "region7"}}} - instance, err := storage.CreateInstance(ctx, newSessionID, sessionExpiry, newRPCAddr, newSQLAddr, newLocality, newBinaryVersion) - require.NoError(t, err) - require.Equal(t, instanceIDs[0], instance.InstanceID) - instances, err := storage.GetAllInstancesDataForTest(ctx) - sortInstances(instances) - require.NoError(t, err) - require.Equal(t, preallocatedCount+4, len(instances)) - for index := 0; index < len(instanceIDs); index++ { - require.Equal(t, instanceIDs[index], instances[index].InstanceID) - switch index { - case 0: - require.Equal(t, newSessionID, instances[index].SessionID) - require.Equal(t, newRPCAddr, instances[index].InstanceRPCAddr) - require.Equal(t, newSQLAddr, instances[index].InstanceSQLAddr) - require.Equal(t, newLocality, instances[index].Locality) - require.Equal(t, newBinaryVersion, instances[index].BinaryVersion) - case 4: - require.Equal(t, sessionID6, instances[index].SessionID) - require.Equal(t, rpcAddr6, instances[index].InstanceRPCAddr) - require.Equal(t, sqlAddr6, instances[index].InstanceSQLAddr) - require.Equal(t, locality6, instances[index].Locality) - require.Equal(t, binaryVersion6, instances[index].BinaryVersion) - default: - require.Equal(t, sessionIDs[index], instances[index].SessionID) - require.Equal(t, rpcAddresses[index], instances[index].InstanceRPCAddr) - require.Equal(t, sqlAddresses[index], instances[index].InstanceSQLAddr) - require.Equal(t, localities[index], instances[index].Locality) - require.Equal(t, binaryVersions[index], instances[index].BinaryVersion) - } - } - for index := len(instanceIDs); index < len(instances); index++ { - require.Equal(t, base.SQLInstanceID(index+1), instances[index].InstanceID) - require.Empty(t, instances[index].SessionID) - require.Empty(t, instances[index].InstanceRPCAddr) - require.Empty(t, instances[index].InstanceSQLAddr) - require.Empty(t, instances[index].Locality) - require.Empty(t, instances[index].BinaryVersion) - } - } + // TODO(jeffswenson): verify release is idempotent }) } @@ -487,7 +448,7 @@ func TestConcurrentCreateAndRelease(t *testing.T) { if i == -1 { return } - require.NoError(t, storage.ReleaseInstanceID(ctx, region, i)) + require.NoError(t, storage.ReleaseInstance(ctx, sessionID, i)) state.freeInstances[i] = struct{}{} delete(state.liveInstances, i) } @@ -516,11 +477,14 @@ func TestConcurrentCreateAndRelease(t *testing.T) { state.RLock() defer state.RUnlock() instanceInfo, err := storage.GetInstanceDataForTest(ctx, region, i) + require.NoError(t, err) if _, free := state.freeInstances[i]; free { - require.Error(t, err) - require.ErrorIs(t, err, sqlinstance.NonExistentInstanceError) + require.Empty(t, instanceInfo.InstanceRPCAddr) + require.Empty(t, instanceInfo.InstanceSQLAddr) + require.Empty(t, instanceInfo.SessionID) + require.Empty(t, instanceInfo.Locality) + require.Empty(t, instanceInfo.BinaryVersion) } else { - require.NoError(t, err) require.Equal(t, rpcAddr, instanceInfo.InstanceRPCAddr) require.Equal(t, sqlAddr, instanceInfo.InstanceSQLAddr) require.Equal(t, sessionID, instanceInfo.SessionID) diff --git a/pkg/sql/sqlliveness/slinstance/slinstance.go b/pkg/sql/sqlliveness/slinstance/slinstance.go index 7bfb6dcb9972..9c47223fe430 100644 --- a/pkg/sql/sqlliveness/slinstance/slinstance.go +++ b/pkg/sql/sqlliveness/slinstance/slinstance.go @@ -16,6 +16,7 @@ package slinstance import ( "context" + "sync" "time" "github.com/cockroachdb/cockroach/pkg/settings" @@ -62,6 +63,8 @@ type Writer interface { // the storage and if found replaces it with the input returning true. // Otherwise it returns false to indicate that the session does not exist. Update(ctx context.Context, id sqlliveness.SessionID, expiration hlc.Timestamp) (bool, error) + // Delete removes the session from the sqlliveness table. + Delete(ctx context.Context, id sqlliveness.SessionID) error } type session struct { @@ -126,7 +129,16 @@ type Instance struct { ttl func() time.Duration hb func() time.Duration testKnobs sqlliveness.TestingKnobs - mu struct { + + // drainOnce and drain are used to signal the shutdown of the heartbeat + // loop. + drainOnce sync.Once + drain chan struct{} + + // wait tracks the life of the heartbeat goroutine. + wait sync.WaitGroup + + mu struct { syncutil.Mutex // stopErr, if set, indicates that the heartbeat loop has stopped because of // this error. Calls to Session() will return this error. @@ -295,17 +307,23 @@ func (l *Instance) heartbeatLoop(ctx context.Context) { log.Fatal(ctx, "expected heartbeat to always terminate with an error") } - log.Warning(ctx, "exiting heartbeat loop") - // Keep track of the fact that this Instance is not usable anymore. Further // Session() calls will return errors. l.mu.Lock() defer l.mu.Unlock() - if l.mu.s != nil { - _ = l.clearSessionLocked(ctx) - } + l.mu.stopErr = err - close(l.mu.blockCh) + + select { + case <-l.drain: + log.Infof(ctx, "draining heartbeat loop") + default: + log.Warningf(ctx, "exiting heartbeat loop with error: %v", l.mu.stopErr) + if l.mu.s != nil { + _ = l.clearSessionLocked(ctx) + } + close(l.mu.blockCh) + } } func (l *Instance) heartbeatLoopInner(ctx context.Context) error { @@ -322,6 +340,8 @@ func (l *Instance) heartbeatLoopInner(ctx context.Context) error { t.Reset(0) for { select { + case <-l.drain: + return stop.ErrUnavailable case <-ctx.Done(): return stop.ErrUnavailable case <-t.C: @@ -395,6 +415,7 @@ func NewSQLInstance( hb: func() time.Duration { return DefaultHeartBeat.Get(&settings.SV) }, + drain: make(chan struct{}), } if testKnobs != nil { l.testKnobs = *testKnobs @@ -411,7 +432,36 @@ func (l *Instance) Start(ctx context.Context, regionPhysicalRep []byte) { // Detach from ctx's cancelation. taskCtx := l.AnnotateCtx(context.Background()) taskCtx = logtags.WithTags(taskCtx, logtags.FromContext(ctx)) - _ = l.stopper.RunAsyncTask(taskCtx, "slinstance", l.heartbeatLoop) + + l.wait.Add(1) + err := l.stopper.RunAsyncTask(taskCtx, "slinstance", func(ctx context.Context) { + l.heartbeatLoop(ctx) + l.wait.Add(-1) + }) + if err != nil { + l.wait.Add(-1) + } +} + +func (l *Instance) Release(ctx context.Context) (sqlliveness.SessionID, error) { + l.drainOnce.Do(func() { close(l.drain) }) + l.wait.Wait() + + session := func() *session { + l.mu.Lock() + defer l.mu.Unlock() + return l.mu.s + }() + + if session == nil { + return sqlliveness.SessionID(""), errors.New("no session to release") + } + + if err := l.storage.Delete(ctx, session.ID()); err != nil { + return sqlliveness.SessionID(""), err + } + + return session.ID(), nil } // Session returns a live session id. For each Sqlliveness instance the diff --git a/pkg/sql/sqlliveness/slinstance/slinstance_test.go b/pkg/sql/sqlliveness/slinstance/slinstance_test.go index 84997e2743c5..08962013dbff 100644 --- a/pkg/sql/sqlliveness/slinstance/slinstance_test.go +++ b/pkg/sql/sqlliveness/slinstance/slinstance_test.go @@ -99,7 +99,7 @@ func TestSQLInstance(t *testing.T) { require.Error(t, err) } -func TestSQLInstanceWithRegion(t *testing.T) { +func TestSQLInstanceRelease(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -117,16 +117,27 @@ func TestSQLInstanceWithRegion(t *testing.T) { fakeStorage := slstorage.NewFakeStorage() var ambientCtx log.AmbientContext sqlInstance := slinstance.NewSQLInstance(ambientCtx, stopper, clock, fakeStorage, settings, nil, nil) - sqlInstance.Start(ctx, []byte{42}) + sqlInstance.Start(ctx, enum.One) - s1, err := sqlInstance.Session(ctx) + activeSession, err := sqlInstance.Session(ctx) require.NoError(t, err) - a, err := fakeStorage.IsAlive(ctx, s1.ID()) + activeSessionID := activeSession.ID() + + a, err := fakeStorage.IsAlive(ctx, activeSessionID) require.NoError(t, err) require.True(t, a) - region, id, err := slstorage.UnsafeDecodeSessionID(s1.ID()) - require.NoError(t, err) - require.Equal(t, []byte{42}, region) - require.NotNil(t, id) + require.NotEqual(t, 0, stopper.NumTasks()) + + // Make sure release is idempotent + for i := 0; i < 5; i++ { + finalSession, err := sqlInstance.Release(ctx) + require.NoError(t, err) + + // Release should always return the last active session + require.Equal(t, activeSessionID, finalSession) + + // Release should tear down the background heartbeat + require.Equal(t, 0, stopper.NumTasks()) + } } diff --git a/pkg/sql/sqlliveness/slstorage/BUILD.bazel b/pkg/sql/sqlliveness/slstorage/BUILD.bazel index b1794a9bab83..3d399ee26430 100644 --- a/pkg/sql/sqlliveness/slstorage/BUILD.bazel +++ b/pkg/sql/sqlliveness/slstorage/BUILD.bazel @@ -58,6 +58,7 @@ go_test( embed = [":slstorage"], deps = [ "//pkg/base", + "//pkg/ccl/kvccl/kvtenantccl", "//pkg/clusterversion", "//pkg/keys", "//pkg/kv/kvpb", diff --git a/pkg/sql/sqlliveness/slstorage/slstorage.go b/pkg/sql/sqlliveness/slstorage/slstorage.go index 4fdc98751e5f..d2d370afb0d1 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage.go @@ -590,6 +590,36 @@ func (s *Storage) Update( return sessionExists, nil } +// Delete removes the session from the sqlliveness table without checking the +// expiration. This is only safe to call during the shutdown process after all +// tasks using the session have stopped. +func (s *Storage) Delete(ctx context.Context, session sqlliveness.SessionID) error { + return s.txn(ctx, func(ctx context.Context, txn *kv.Txn) error { + version, err := s.versionGuard(ctx, txn) + if err != nil { + return err + } + + batch := txn.NewBatch() + + readCodec := s.getReadCodec(&version) + key, err := readCodec.encode(session) + if err != nil { + return err + } + batch.Del(key) + + if dualCodec := s.getDualWriteCodec(&version); dualCodec != nil { + dualKey, err := dualCodec.encode(session) + if err != nil { + return err + } + batch.Del(dualKey) + } + return txn.CommitInBatch(ctx, batch) + }) +} + // CachedReader returns an implementation of sqlliveness.Reader which does // not synchronously read from the store. Calls to IsAlive will return the // currently known state of the session, but will trigger an asynchronous diff --git a/pkg/sql/sqlliveness/slstorage/slstorage_test.go b/pkg/sql/sqlliveness/slstorage/slstorage_test.go index e72bbef59615..79431c92432a 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage_test.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage_test.go @@ -720,13 +720,16 @@ func TestDeleteMidUpdateFails(t *testing.T) { unblock := <-getChan // Delete the session being updated. - tdb.Exec(t, `DELETE FROM "`+t.Name()+`".sqlliveness WHERE true`) + require.NoError(t, storage.Delete(ctx, ID)) // Unblock the update and ensure that it saw that its session was deleted. close(unblock) res := <-resCh require.False(t, res.exists) require.NoError(t, res.err) + + // Ensure delete is idempotent + require.NoError(t, storage.Delete(ctx, ID)) } func newSqllivenessTable( diff --git a/pkg/sql/sqlliveness/sqlliveness.go b/pkg/sql/sqlliveness/sqlliveness.go index 6c70e02aebf4..0942aa4272cf 100644 --- a/pkg/sql/sqlliveness/sqlliveness.go +++ b/pkg/sql/sqlliveness/sqlliveness.go @@ -39,9 +39,15 @@ type Provider interface { // Start starts the sqlliveness subsystem. regionPhysicalRep should // represent the physical representation of the current process region - // stored in the multi-region enum type associated with the system database. + // stored in the multi-region enum type associated with the system + // database. Start(ctx context.Context, regionPhysicalRep []byte) + // Release delete's the sqlliveness session managed by the provider. This + // should be called near the end of the drain process, after the server has + // no running tasks that depend on the session. + Release(ctx context.Context) (SessionID, error) + // Metrics returns a metric.Struct which holds metrics for the provider. Metrics() metric.Struct }