diff --git a/consts.go b/consts.go index 19835f7..65ddfd8 100644 --- a/consts.go +++ b/consts.go @@ -3,9 +3,17 @@ package heartbeat import "log" const ( + // Server + serverStartMessage = "Starting heartbeat workers, prometheus exporter..." + serverStartErrorMessage = "Unable to start heartbeat. Server must be inactive" + serverStartExporterErrorMessage = "Errors during starting prometheus exporter:" + serverStopMessage = "Heartbeat workers, prometheus exporter were stopped gracefully" + serverStopErrorMessage = "Unable to stop heartbeat. Server must be active" + serverStopExporterErrorMessage = "Errors during stopping prometheus exporter:" + // Exporter exporterStartMessage = "Prometheus exporter started on " - exporterErrorMsg = "Failed to start prometheus exporter on port " + exporterErrorMessage = "Failed to start prometheus exporter on port " exporterShutdownMessage = "Prometheus exporter is in the shutdown mode and won't accept new connections" exporterStopMessage = "Prometheus exporter stopped gracefully" diff --git a/exporter.go b/exporter.go index f3cef11..f916bd6 100644 --- a/exporter.go +++ b/exporter.go @@ -6,19 +6,11 @@ import ( "net" "net/http" "strconv" - "sync" "time" "github.com/prometheus/client_golang/prometheus/promhttp" ) -// WaitGroup interface -type waitGroup interface { - Add(int) - Done() - Wait() -} - // serverPrometheusWrapper structure. Used for testing purposes type serverPrometheusWrapper struct { *http.Server @@ -68,7 +60,7 @@ func newExporter(port, shutdownTimeout int, route string, logger logger) *export // Exporter methods // Starts exporter, runs listen channel from the parent (heartbeat server) -func (exporter *exporter) start(parentContext context.Context, wg *sync.WaitGroup) error { +func (exporter *exporter) start(parentContext context.Context, wg waitGroup) error { exporter.ctx, exporter.wg = parentContext, wg exporter.listenShutdownSignal() exporter.logger.info(exporterStartMessage + exporter.server.Port() + exporter.route) @@ -105,7 +97,7 @@ func (exporter *exporter) isPortAvailable() (err error) { port := exporter.server.Port() listener, err := net.Listen("tcp", port) if err != nil { - return errors.New(exporterErrorMsg + port) + return errors.New(exporterErrorMessage + port) } listener.Close() diff --git a/exporter_test.go b/exporter_test.go index c7cc7e7..da4e622 100644 --- a/exporter_test.go +++ b/exporter_test.go @@ -122,7 +122,7 @@ func TestExporterIsPortAvailable(t *testing.T) { listener, _ := net.Listen("tcp", port) defer listener.Close() - assert.Error(t, exporter.isPortAvailable(), exporterErrorMsg+port) + assert.EqualError(t, exporter.isPortAvailable(), exporterErrorMessage+port) prometheusServer.AssertExpectations(t) }) } diff --git a/heartbeat_instance.go b/heartbeat_instance.go index 2f8d6b0..8521d41 100644 --- a/heartbeat_instance.go +++ b/heartbeat_instance.go @@ -3,7 +3,6 @@ package heartbeat import ( "context" "fmt" - "sync" "time" ) @@ -14,7 +13,7 @@ type heartbeatInstance struct { metric metric session session ctx context.Context - wg *sync.WaitGroup + wg waitGroup logger logger } diff --git a/server.go b/server.go new file mode 100644 index 0000000..95379b7 --- /dev/null +++ b/server.go @@ -0,0 +1,130 @@ +package heartbeat + +import ( + "context" + "errors" + "sync" +) + +// WaitGroup interface +type waitGroup interface { + Add(int) + Done() + Wait() +} + +// Server structure +type Server struct { + heartbeatInstances []*heartbeatInstance + logger logger + ctx context.Context + shutdown context.CancelFunc + wg waitGroup + started bool + sync.Mutex + + exporter *exporter +} + +// Server builder. Returns pointer to new server structure +func newServer(configuration *Configuration) *Server { + var heartbeatInstances []*heartbeatInstance + + for _, instanceAttributes := range configuration.InstancesAttributes { + heartbeatInstances = append(heartbeatInstances, newInstance(instanceAttributes)) + } + logger := newLogger(configuration.LogToStdout, configuration.LogActivity) + + return &Server{ + heartbeatInstances: heartbeatInstances, + logger: logger, + wg: new(sync.WaitGroup), + exporter: newExporter( + configuration.Port, + configuration.ShutdownTimeout, + configuration.MetricsRoute, + logger, + ), + } +} + +// Server methods + +// Starts server. Returns error if any +func (server *Server) Start() (err error) { + logger := server.logger + + if server.isStarted() { + err = errors.New(serverStartErrorMessage) + logger.error(err.Error()) + + return err + } else if err := server.exporter.isPortAvailable(); err != nil { + logger.error(err.Error()) + + return err + } + + logger.info(serverStartMessage) + server.ctx, server.shutdown = context.WithCancel(context.Background()) + + server.wg.Add(1) + go func() { + // We have checked port availability before, it's safe to start exporter + if err := server.exporter.start(server.ctx, server.wg); err != nil { + logger.warning(serverStartExporterErrorMessage, err.Error()) + } + }() + + for _, instance := range server.heartbeatInstances { + instance.ctx, instance.wg, instance.logger = server.ctx, server.wg, server.logger + server.wg.Add(1) + go instance.workerRunner() + } + server.start() + + return err +} + +// Stops server. Returns error if server is not started +func (server *Server) Stop() (err error) { + logger := server.logger + + if server.isStarted() { + server.shutdown() + server.wg.Wait() + server.stop() + logger.info(serverStopMessage) + if err = server.exporter.err; err != nil { + logger.warning(serverStopExporterErrorMessage, err.Error()) + } + + return err + } + + err = errors.New(serverStopErrorMessage) + logger.error(err.Error()) + + return err +} + +// Thread-safe getter to check if server has been started. Returns server.started +func (server *Server) isStarted() bool { + server.Lock() + defer server.Unlock() + return server.started +} + +// Thread-safe setter of started-flag to indicate server has been started +func (server *Server) start() { + server.Lock() + defer server.Unlock() + server.started = true +} + +// Thread-safe setter of started-flag to indicate server has been stopped +func (server *Server) stop() { + server.Lock() + defer server.Unlock() + server.started = false +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..c167207 --- /dev/null +++ b/server_test.go @@ -0,0 +1,152 @@ +package heartbeat + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewServer(t *testing.T) { + t.Run("returns new server", func(t *testing.T) { + shutdownTimeout, metricsRoute := 42, "/metrics" + server := newServer( + &Configuration{ + InstancesAttributes: []*InstanceAttributes{ + { + Connection: "some_connection", + URL: "some_url", + }, + }, + Port: 8080, + ShutdownTimeout: shutdownTimeout, + MetricsRoute: metricsRoute, + LogToStdout: true, + LogActivity: true, + }, + ) + + assert.NotNil(t, server) + assert.Equal(t, 1, len(server.heartbeatInstances)) + assert.Equal(t, time.Duration(shutdownTimeout), server.exporter.shutdownTimeout) + assert.Equal(t, metricsRoute, server.exporter.route) + }) +} + +func TestServerStart(t *testing.T) { + t.Run("when no errors happens during starting and running the server", func(t *testing.T) { + server := newServer( + &Configuration{ + InstancesAttributes: []*InstanceAttributes{ + { + Connection: "postgres", + URL: "postgres://localhost:5432/postgres", + }, + }, + Port: 8080, + MetricsRoute: "/metrics", + }, + ) + + assert.NoError(t, server.Start()) + assert.True(t, server.isStarted()) + + _ = server.Stop() + }) + + t.Run("when error happens during starting the server, server is already started", func(t *testing.T) { + server, logger := &Server{started: true}, new(loggerMock) + server.logger = logger + logger.On("error", []string{serverStartErrorMessage}).Once() + + serverStart := server.Start() + assert.Error(t, serverStart) + assert.EqualError(t, serverStart, serverStartErrorMessage) + logger.AssertExpectations(t) + }) + + t.Run("when error happens during starting the server, port is already in use", func(t *testing.T) { + port := ":8080" + listener, _ := net.Listen("tcp", port) + defer listener.Close() + server, logger, errMessage := newServer(createNewMinimalConfiguration()), new(loggerMock), exporterErrorMessage+port + server.logger = logger + logger.On("error", []string{errMessage}).Once() + + serverStart := server.Start() + assert.Error(t, serverStart) + assert.EqualError(t, serverStart, errMessage) + assert.False(t, server.isStarted()) + logger.AssertExpectations(t) + }) +} + +func TestServerStop(t *testing.T) { + t.Run("when server is started, no errors happen during stopping exporter", func(t *testing.T) { + server, wg, logger := newServer(createNewMinimalConfiguration()), new(waitGroupMock), new(loggerMock) + server.wg, server.logger, server.started, server.shutdown = wg, logger, true, func() {} + wg.On("Wait").Once() + logger.On("info", []string{serverStopMessage}).Once() + + assert.NoError(t, server.Stop()) + assert.False(t, server.isStarted()) + wg.AssertExpectations(t) + logger.AssertExpectations(t) + }) + + t.Run("when server is started, exporter returns error during stopping", func(t *testing.T) { + server, wg, logger, err := newServer(createNewMinimalConfiguration()), new(waitGroupMock), new(loggerMock), errors.New("some error") + server.wg, server.logger, server.started, server.shutdown, server.exporter.err = wg, logger, true, func() {}, err + wg.On("Wait").Once() + logger.On("info", []string{serverStopMessage}).Once() + logger.On("warning", []string{serverStopExporterErrorMessage, err.Error()}).Once() + + serverStop := server.Stop() + assert.Error(t, serverStop) + assert.EqualError(t, serverStop, err.Error()) + wg.AssertExpectations(t) + logger.AssertExpectations(t) + }) + + t.Run("when server is not started", func(t *testing.T) { + server, logger := new(Server), new(loggerMock) + server.logger = logger + logger.On("error", []string{serverStopErrorMessage}).Once() + + serverStop := server.Stop() + assert.Error(t, serverStop) + assert.EqualError(t, serverStop, serverStopErrorMessage) + logger.AssertExpectations(t) + }) +} + +func TestServerIsStarted(t *testing.T) { + t.Run("when server is started", func(t *testing.T) { + server := new(Server) + server.started = true + + assert.True(t, server.isStarted()) + }) + + t.Run("when server is not started", func(t *testing.T) { + server := new(Server) + + assert.False(t, server.isStarted()) + }) +} + +func TestServerStartPrivate(t *testing.T) { + server := new(Server) + server.start() + + assert.True(t, server.started) +} + +func TestServerStopPrivate(t *testing.T) { + server := &Server{started: true} + server.stop() + + assert.False(t, server.started) +} diff --git a/test_helpers_test.go b/test_helpers_test.go index 04a23e0..b0c401c 100644 --- a/test_helpers_test.go +++ b/test_helpers_test.go @@ -68,6 +68,14 @@ func createNewWaitGroup() *sync.WaitGroup { return new(sync.WaitGroup) } +// Creates new minimal configuration +func createNewMinimalConfiguration() *Configuration { + return &Configuration{ + Port: 8080, + MetricsRoute: "/metrics", + } +} + // Generates a unique instance name for testing purposes func generateUniqueInstanceName() string { return fmt.Sprintf("test_instance_%d", time.Now().UnixNano()) diff --git a/test_mocks_test.go b/test_mocks_test.go index 1b29883..728628d 100644 --- a/test_mocks_test.go +++ b/test_mocks_test.go @@ -118,19 +118,19 @@ func (mock *metricMock) setFailureTimeout() { mock.Called() } -// // WaitGroup mock -// type waitGroupMock struct { -// mock.Mock -// } - -// func (mock *waitGroupMock) Add(count int) { -// mock.Called(count) -// } - -// func (mock *waitGroupMock) Done() { -// mock.Called() -// } - -// func (mock *waitGroupMock) Wait() { -// mock.Called() -// } +// WaitGroup mock +type waitGroupMock struct { + mock.Mock +} + +func (mock *waitGroupMock) Add(count int) { + mock.Called(count) +} + +func (mock *waitGroupMock) Done() { + mock.Called() +} + +func (mock *waitGroupMock) Wait() { + mock.Called() +}