From 608a85a4b193277f6588c715516eab2a6814e877 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Fri, 25 Mar 2022 17:17:54 -0300 Subject: [PATCH] fix(db): send initial heartbeat when there is no static dbs (#11160) --- integration/db_integration_test.go | 79 ++++++++++++++++++++++++++++++ lib/srv/db/access_test.go | 3 ++ lib/srv/db/server.go | 6 +++ lib/srv/db/server_test.go | 67 +++++++++++++++++++++++++ 4 files changed, 155 insertions(+) diff --git a/integration/db_integration_test.go b/integration/db_integration_test.go index 65c721c7f477e..1677b4c4cdb06 100644 --- a/integration/db_integration_test.go +++ b/integration/db_integration_test.go @@ -18,7 +18,9 @@ package integration import ( "context" + "fmt" "net" + "net/http" "testing" "time" @@ -31,6 +33,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/mongodb" @@ -588,6 +591,49 @@ func TestDatabaseAccessMongoSeparateListener(t *testing.T) { require.NoError(t, err) } +func TestDatabaseAgentState(t *testing.T) { + tests := map[string]struct { + agentParams databaseAgentStartParams + }{ + "WithStaticDatabases": { + agentParams: databaseAgentStartParams{ + databases: []service.Database{ + {Name: "mysql", Protocol: defaults.ProtocolMySQL, URI: "localhost:3306"}, + {Name: "pg", Protocol: defaults.ProtocolPostgres, URI: "localhost:5432"}, + }, + }, + }, + "WithResourceMatchers": { + agentParams: databaseAgentStartParams{ + resourceMatchers: []services.ResourceMatcher{ + {Labels: types.Labels{"*": []string{"*"}}}, + }, + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + pack := setupDatabaseTest(t) + + // Start also ensures that the database agent has the “ready” state. + // If the agent can’t make it, this function will fail the test. + agent, _ := pack.startRootDatabaseAgent(t, test.agentParams) + + // In addition to the checks performed during the agent start, + // we’ll request the diagnostic server to ensure the readyz route + // is returning to the proper state. + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%v/readyz", agent.Config.DiagnosticAddr.Addr), nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + } +} + func waitForAuditEventTypeWithBackoff(t *testing.T, cli *auth.Server, startTime time.Time, eventType string) []apievents.AuditEvent { max := time.Second timeout := time.After(max) @@ -1015,6 +1061,39 @@ func (p *databasePack) waitForLeaf(t *testing.T) { } } +// databaseAgentStartParams parameters used to configure a database agent. +type databaseAgentStartParams struct { + databases []service.Database + resourceMatchers []services.ResourceMatcher +} + +// startRootDatabaseAgent starts a database agent with the provided +// configuration on the root cluster. +func (p *databasePack) startRootDatabaseAgent(t *testing.T, params databaseAgentStartParams) (*service.TeleportProcess, *auth.Client) { + conf := service.MakeDefaultConfig() + conf.DataDir = t.TempDir() + conf.Token = "static-token-value" + conf.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("localhost", ports.Pop())} + conf.AuthServers = []utils.NetAddr{ + { + AddrNetwork: "tcp", + Addr: net.JoinHostPort(Loopback, p.root.cluster.GetPortWeb()), + }, + } + conf.Clock = p.clock + conf.Databases.Enabled = true + conf.Databases.Databases = params.databases + conf.Databases.ResourceMatchers = params.resourceMatchers + + server, authClient, err := p.root.cluster.StartDatabase(conf) + require.NoError(t, err) + t.Cleanup(func() { + server.Close() + }) + + return server, authClient +} + func containsDB(servers []types.DatabaseServer, name string) bool { for _, server := range servers { if server.GetDatabase().GetName() == name { diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 79710a02dfa4d..817dd03028774 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -1809,6 +1809,8 @@ type agentParams struct { NoStart bool // GCPSQL defines the GCP Cloud SQL mock to use for GCP API calls. GCPSQL *cloud.GCPSQLAdminClientMock + // OnHeartbeat defines a heartbeat function that generates heartbeat events. + OnHeartbeat func(error) } func (p *agentParams) setDefaults(c *testContext) { @@ -1874,6 +1876,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a Limiter: connLimiter, Auth: testAuth, Databases: p.Databases, + OnHeartbeat: p.OnHeartbeat, ResourceMatchers: p.ResourceMatchers, GetServerInfoFn: p.GetServerInfoFn, GetRotation: func(types.SystemRole) (*types.Rotation, error) { diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index bd35f2375547b..1e165c3db2837 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -573,6 +573,12 @@ func (s *Server) Start(ctx context.Context) (err error) { return trace.Wrap(err) } + // If the agent doesn’t have any static databases configured, send a + // heartbeat without error to make the component “ready”. + if len(s.cfg.Databases) == 0 && s.cfg.OnHeartbeat != nil { + s.cfg.OnHeartbeat(nil) + } + return nil } diff --git a/lib/srv/db/server_test.go b/lib/srv/db/server_test.go index 8c693b07e06a9..bd158385cbd0d 100644 --- a/lib/srv/db/server_test.go +++ b/lib/srv/db/server_test.go @@ -18,10 +18,13 @@ package db import ( "context" + "sync/atomic" "testing" + "time" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" "github.com/jackc/pgconn" @@ -177,3 +180,67 @@ func TestDatabaseServerLimiting(t *testing.T) { require.FailNow(t, "we should exceed the connection limit by now") }) } + +func TestHeartbeatEvents(t *testing.T) { + ctx := context.Background() + + dbOne, err := types.NewDatabaseV3(types.Metadata{ + Name: "dbOne", + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }) + require.NoError(t, err) + + dbTwo, err := types.NewDatabaseV3(types.Metadata{ + Name: "dbOne", + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "localhost:3306", + }) + require.NoError(t, err) + + tests := map[string]struct { + staticDatabases types.Databases + heartbeatCount int64 + }{ + "SingleStaticDatabase": { + staticDatabases: types.Databases{dbOne}, + heartbeatCount: 1, + }, + "MultipleStaticDatabases": { + staticDatabases: types.Databases{dbOne, dbTwo}, + heartbeatCount: 2, + }, + "EmptyStaticDatabases": { + staticDatabases: types.Databases{}, + heartbeatCount: 1, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var heartbeatEvents int64 + heartbeatRecorder := func(err error) { + require.NoError(t, err) + atomic.AddInt64(&heartbeatEvents, 1) + } + + testCtx := setupTestContext(ctx, t) + server := testCtx.setupDatabaseServer(ctx, t, agentParams{ + NoStart: true, + OnHeartbeat: heartbeatRecorder, + Databases: test.staticDatabases, + }) + require.NoError(t, server.Start(ctx)) + t.Cleanup(func() { + server.Close() + }) + + require.NotNil(t, server) + require.Eventually(t, func() bool { + return atomic.LoadInt64(&heartbeatEvents) == test.heartbeatCount + }, 2*time.Second, 500*time.Millisecond) + }) + } +}