From 5196152644b221f78747b80e226665d3c10a5f1f Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Wed, 10 Nov 2021 14:08:15 +0100 Subject: [PATCH] base: ensure that a SQL instance ID is not set to conflicting value Prior to this patch, a SQL instance ID container could be assigned successively different values. This would make it possible for erroneous code to mistakenly initialize the instance ID twice, without any evidence that something was amiss. This patch fixes this by using the same prevention logic that we already use for node IDs. Release note: None --- pkg/base/node_id.go | 20 ++++++++++++++++++-- pkg/server/server_sql.go | 2 +- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pkg/base/node_id.go b/pkg/base/node_id.go index 82ca9a04ac5a..c9dc9615051b 100644 --- a/pkg/base/node_id.go +++ b/pkg/base/node_id.go @@ -199,11 +199,27 @@ func NewSQLIDContainer(sqlInstanceID SQLInstanceID, nodeID *NodeIDContainer) *SQ // SetSQLInstanceID sets the SQL instance ID. It returns an error if // we attempt to set an instance ID when the nodeID has already been // initialized. -func (c *SQLIDContainer) SetSQLInstanceID(sqlInstanceID SQLInstanceID) error { +func (c *SQLIDContainer) SetSQLInstanceID(ctx context.Context, sqlInstanceID SQLInstanceID) error { if _, ok := c.OptionalNodeID(); ok { return errors.New("attempting to initialize instance ID when node ID is set") } - c.sqlInstanceID = sqlInstanceID + + // Use the same logic to set the instance ID as for the node ID. + // + // TODO(knz): All this could be advantageously simplified if we agreed + // to use the same type for NodeIDContainer and SQLIDContainer. + if sqlInstanceID <= 0 { + log.Fatalf(ctx, "trying to set invalid SQLInstanceID: %d", sqlInstanceID) + } + oldVal := atomic.SwapInt32((*int32)(&c.sqlInstanceID), int32(sqlInstanceID)) + if oldVal == 0 { + if log.V(2) { + log.Infof(ctx, "SQLInstanceID set to %d", sqlInstanceID) + } + } else if oldVal != int32(sqlInstanceID) { + log.Fatalf(ctx, "different SQLInstanceIDs set: %d, then %d", oldVal, sqlInstanceID) + } + c.str.Store(strconv.Itoa(int(sqlInstanceID))) return nil } diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index e30402b757e0..57cd142aafb3 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -974,7 +974,7 @@ func (s *SQLServer) initInstanceID(ctx context.Context) error { if err != nil { return err } - err = s.sqlIDContainer.SetSQLInstanceID(instanceID) + err = s.sqlIDContainer.SetSQLInstanceID(ctx, instanceID) if err != nil { return err }