diff --git a/pkg/async/notifications/factory.go b/pkg/async/notifications/factory.go index ad0b4b311..83ee8c663 100644 --- a/pkg/async/notifications/factory.go +++ b/pkg/async/notifications/factory.go @@ -2,6 +2,9 @@ package notifications import ( "context" + "time" + + "github.com/lyft/flyteadmin/pkg/async" "github.com/lyft/flyteadmin/pkg/async/notifications/implementations" "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" @@ -60,6 +63,8 @@ func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Sc } func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second var sub pubsub.Subscriber var emailer interfaces.Emailer switch config.Type { @@ -73,11 +78,15 @@ func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, sco ConsumeBase64: &enable64decoding, } sqsConfig.Region = config.Region - process, err := gizmoAWS.NewSubscriber(sqsConfig) + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + return err + }) + if err != nil { panic(err) } - sub = process emailer = GetEmailer(config, scope) case common.Local: fallthrough @@ -90,6 +99,8 @@ func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, sco } func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second switch config.Type { case common.AWS: snsConfig := gizmoAWS.SNSConfig{ @@ -100,8 +111,15 @@ func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, sco } else { snsConfig.Region = config.Region } - publisher, err := gizmoAWS.NewPublisher(snsConfig) - // Any errors initiating Publisher with Amazon configurations results in a failed start up. + + var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) + + // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. if err != nil { panic(err) } @@ -111,7 +129,13 @@ func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, sco Topic: config.NotificationsPublisherConfig.TopicName, } pubsubConfig.ProjectID = config.GCPConfig.ProjectID - publisher, err := gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) + if err != nil { panic(err) } diff --git a/pkg/async/notifications/implementations/noop_notifications.go b/pkg/async/notifications/implementations/noop_notifications.go index f9a0e1595..3dde2b4cb 100644 --- a/pkg/async/notifications/implementations/noop_notifications.go +++ b/pkg/async/notifications/implementations/noop_notifications.go @@ -39,9 +39,8 @@ func NewNoopPublish() interfaces.Publisher { type NoopProcess struct{} -func (n *NoopProcess) StartProcessing() error { +func (n *NoopProcess) StartProcessing() { logger.Debug(context.Background(), "call to noop start processing.") - return nil } func (n *NoopProcess) StopProcessing() error { diff --git a/pkg/async/notifications/implementations/processor.go b/pkg/async/notifications/implementations/processor.go index a3d5ff9c8..fb23c2a4e 100644 --- a/pkg/async/notifications/implementations/processor.go +++ b/pkg/async/notifications/implementations/processor.go @@ -2,6 +2,9 @@ package implementations import ( "context" + "time" + + "github.com/lyft/flyteadmin/pkg/async" "github.com/lyft/flyteadmin/pkg/async/notifications/interfaces" @@ -38,7 +41,16 @@ type Processor struct { // Currently only email is the supported notification because slack and pagerduty both use // email client to trigger those notifications. // When Pagerduty and other notifications are supported, a publisher per type should be created. -func (p *Processor) StartProcessing() error { +func (p *Processor) StartProcessing() { + for { + logger.Warningf(context.Background(), "Starting notifications processor") + err := p.run() + logger.Errorf(context.Background(), "error with running processor err: [%v] ", err) + time.Sleep(async.RetryDelay) + } +} + +func (p *Processor) run() error { var emailMessage admin.EmailMessage var err error for msg := range p.sub.Start() { diff --git a/pkg/async/notifications/implementations/processor_test.go b/pkg/async/notifications/implementations/processor_test.go index c714b5480..d7a752d31 100644 --- a/pkg/async/notifications/implementations/processor_test.go +++ b/pkg/async/notifications/implementations/processor_test.go @@ -41,13 +41,13 @@ func TestProcessor_StartProcessing(t *testing.T) { mockEmailer.SetSendEmailFunc(sendEmailValidationFunc) // TODO Add test for metric inc for number of messages processed. // Assert 1 message processed and 1 total. - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) } func TestProcessor_StartProcessingNoMessages(t *testing.T) { initializeProcessor() // Expect no errors are returned. - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) // TODO add test for metric inc() for number of messages processed. // Assert 0 messages processed and 0 total. } @@ -59,7 +59,7 @@ func TestProcessor_StartProcessingNoNotificationMessage(t *testing.T) { } initializeProcessor() testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) // TODO add test for metric inc() for number of messages processed. // Assert 1 messages error and 1 total. } @@ -72,7 +72,7 @@ func TestProcessor_StartProcessingMessageWrongDataType(t *testing.T) { } initializeProcessor() testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) // TODO add test for metric inc() for number of messages processed. // Assert 1 messages error and 1 total. } @@ -85,7 +85,7 @@ func TestProcessor_StartProcessingBase64DecodeError(t *testing.T) { } initializeProcessor() testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) // TODO add test for metric inc() for number of messages processed. // Assert 1 messages error and 1 total. } @@ -99,7 +99,7 @@ func TestProcessor_StartProcessingProtoMarshallError(t *testing.T) { } initializeProcessor() testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) // TODO add test for metric inc() for number of messages processed. // Assert 1 messages error and 1 total. } @@ -110,7 +110,7 @@ func TestProcessor_StartProcessingError(t *testing.T) { // The error set by GivenErrError is returned by Err(). // Err() is checked before Run() returning. testSubscriber.GivenErrError = ret - assert.Equal(t, ret, testProcessor.StartProcessing()) + assert.Equal(t, ret, testProcessor.(*Processor).run()) } func TestProcessor_StartProcessingEmailError(t *testing.T) { @@ -124,7 +124,7 @@ func TestProcessor_StartProcessingEmailError(t *testing.T) { // Even if there is an error in sending an email StartProcessing will return no errors. // TODO: Once stats have been added check for an email error stat. - assert.Nil(t, testProcessor.StartProcessing()) + assert.Nil(t, testProcessor.(*Processor).run()) } func TestProcessor_StopProcessing(t *testing.T) { diff --git a/pkg/async/notifications/interfaces/processor.go b/pkg/async/notifications/interfaces/processor.go index 0a8534d57..b7ce30d56 100644 --- a/pkg/async/notifications/interfaces/processor.go +++ b/pkg/async/notifications/interfaces/processor.go @@ -8,7 +8,7 @@ type Processor interface { // If the channel closes gracefully, no error will be returned. // If the underlying channel experiences errors, // an error is returned and the channel is closed. - StartProcessing() error + StartProcessing() // This should be invoked when the application is shutting down. // If StartProcessing() returned an error, StopProcessing() will return an error because diff --git a/pkg/async/schedule/aws/workflow_executor.go b/pkg/async/schedule/aws/workflow_executor.go index 9934a432e..9eb253120 100644 --- a/pkg/async/schedule/aws/workflow_executor.go +++ b/pkg/async/schedule/aws/workflow_executor.go @@ -6,6 +6,9 @@ import ( "fmt" "time" + "github.com/lyft/flyteadmin/pkg/async" + runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" + "github.com/lyft/flytestdlib/contextutils" "github.com/golang/protobuf/ptypes/timestamp" @@ -42,6 +45,7 @@ type workflowExecutorMetrics struct { MessageReceivedDelay labeled.StopWatch ScheduledEventProcessingDelay labeled.StopWatch CreateExecutionDuration labeled.StopWatch + ChannelClosedError prometheus.Counter } type workflowExecutor struct { @@ -168,6 +172,15 @@ func (e *workflowExecutor) formulateExecutionCreateRequest( } func (e *workflowExecutor) Run() { + for { + logger.Warningf(context.Background(), "Starting workflow executor") + err := e.run() + logger.Errorf(context.Background(), "error with workflow executor err: [%v] ", err) + time.Sleep(async.RetryDelay) + } +} + +func (e *workflowExecutor) run() error { for message := range e.subscriber.Start() { scheduledWorkflowExecutionRequest, err := DeserializeScheduleWorkflowPayload(message.Message()) ctx := context.Background() @@ -243,7 +256,11 @@ func (e *workflowExecutor) Run() { observedMessageTriggeredTime) } err := e.subscriber.Err() - logger.Errorf(context.TODO(), "Gizmo subscriber closed channel with err: [%+v]", err) + if err != nil { + logger.Errorf(context.TODO(), "Gizmo subscriber closed channel with err: [%+v]", err) + e.metrics.ChannelClosedError.Inc() + } + return err } func (e *workflowExecutor) Stop() error { @@ -286,18 +303,28 @@ func newWorkflowExecutorMetrics(scope promutils.Scope) workflowExecutorMetrics { CreateExecutionDuration: labeled.NewStopWatch("create_execution_duration", "time spent waiting on the call to CreateExecution to return", time.Second, scope, labeled.EmitUnlabeledMetric), + ChannelClosedError: scope.MustNewCounter("channel_closed_error", "count of channel closing errors"), } } func NewWorkflowExecutor( - config aws.SQSConfig, executionManager interfaces.ExecutionInterface, + config aws.SQSConfig, schedulerConfig runtimeInterfaces.SchedulerConfig, executionManager interfaces.ExecutionInterface, launchPlanManager interfaces.LaunchPlanInterface, scope promutils.Scope) scheduleInterfaces.WorkflowExecutor { config.TimeoutSeconds = &timeout // By default gizmo tries to base64 decode messages. Since we don't use the gizmo publisher interface to publish // messages these are not encoded in base64 by default. Disable this behavior. config.ConsumeBase64 = &doNotconsumeBase64 - subscriber, err := aws.NewSubscriber(config) + + maxReconnectAttempts := schedulerConfig.ReconnectAttempts + reconnectDelay := time.Duration(schedulerConfig.ReconnectDelaySeconds) * time.Second + var subscriber pubsub.Subscriber + var err error + err = async.Retry(maxReconnectAttempts, reconnectDelay, func() error { + subscriber, err = aws.NewSubscriber(config) + return err + }) + if err != nil { scope.MustNewCounter( "initialize_executor_failed", "failures initializing scheduled workflow executor").Inc() diff --git a/pkg/async/schedule/aws/workflow_executor_test.go b/pkg/async/schedule/aws/workflow_executor_test.go index e430ccaa7..a36da9929 100644 --- a/pkg/async/schedule/aws/workflow_executor_test.go +++ b/pkg/async/schedule/aws/workflow_executor_test.go @@ -283,8 +283,9 @@ func TestRun(t *testing.T) { }, nil }) testExecutor := newWorkflowExecutorForTest(&testSubscriber, &testExecutionManager, launchPlanManager) - testExecutor.Run() + err := testExecutor.run() assert.Len(t, messages, messagesSeen) + assert.Nil(t, err) } func TestStop(t *testing.T) { diff --git a/pkg/async/schedule/factory.go b/pkg/async/schedule/factory.go index 0ef6c98e7..14e3b2bf8 100644 --- a/pkg/async/schedule/factory.go +++ b/pkg/async/schedule/factory.go @@ -2,6 +2,9 @@ package schedule import ( "context" + "time" + + "github.com/lyft/flyteadmin/pkg/async" gizmoConfig "github.com/NYTimes/gizmo/pubsub/aws" "github.com/aws/aws-sdk-go/aws" @@ -17,10 +20,9 @@ import ( ) type WorkflowSchedulerConfig struct { - Retries int - EventSchedulerConfig runtimeInterfaces.EventSchedulerConfig - WorkflowExecutorConfig runtimeInterfaces.WorkflowExecutorConfig - Scope promutils.Scope + Retries int + SchedulerConfig runtimeInterfaces.SchedulerConfig + Scope promutils.Scope } type WorkflowScheduler interface { @@ -44,12 +46,12 @@ func (w *workflowScheduler) GetWorkflowExecutor( launchPlanManager managerInterfaces.LaunchPlanInterface) interfaces.WorkflowExecutor { if w.workflowExecutor == nil { sqsConfig := gizmoConfig.SQSConfig{ - QueueName: w.cfg.WorkflowExecutorConfig.ScheduleQueueName, - QueueOwnerAccountID: w.cfg.WorkflowExecutorConfig.AccountID, + QueueName: w.cfg.SchedulerConfig.WorkflowExecutorConfig.ScheduleQueueName, + QueueOwnerAccountID: w.cfg.SchedulerConfig.WorkflowExecutorConfig.AccountID, } - sqsConfig.Region = w.cfg.WorkflowExecutorConfig.Region + sqsConfig.Region = w.cfg.SchedulerConfig.WorkflowExecutorConfig.Region w.workflowExecutor = awsSchedule.NewWorkflowExecutor( - sqsConfig, executionManager, launchPlanManager, w.cfg.Scope.NewSubScope("workflow_executor")) + sqsConfig, w.cfg.SchedulerConfig, executionManager, launchPlanManager, w.cfg.Scope.NewSubScope("workflow_executor")) } return w.workflowExecutor } @@ -58,26 +60,33 @@ func NewWorkflowScheduler(cfg WorkflowSchedulerConfig) WorkflowScheduler { var eventScheduler interfaces.EventScheduler var workflowExecutor interfaces.WorkflowExecutor - switch cfg.EventSchedulerConfig.Scheme { + switch cfg.SchedulerConfig.EventSchedulerConfig.Scheme { case common.AWS: - awsConfig := aws.NewConfig().WithRegion(cfg.WorkflowExecutorConfig.Region).WithMaxRetries(cfg.Retries) - sess, err := session.NewSession(awsConfig) + awsConfig := aws.NewConfig().WithRegion(cfg.SchedulerConfig.WorkflowExecutorConfig.Region).WithMaxRetries(cfg.Retries) + var sess *session.Session + var err error + err = async.Retry(cfg.SchedulerConfig.ReconnectAttempts, + time.Duration(cfg.SchedulerConfig.ReconnectDelaySeconds)*time.Second, func() error { + sess, err = session.NewSession(awsConfig) + return err + }) + if err != nil { panic(err) } eventScheduler = awsSchedule.NewCloudWatchScheduler( - cfg.EventSchedulerConfig.ScheduleRole, cfg.EventSchedulerConfig.TargetName, sess, awsConfig, + cfg.SchedulerConfig.EventSchedulerConfig.ScheduleRole, cfg.SchedulerConfig.EventSchedulerConfig.TargetName, sess, awsConfig, cfg.Scope.NewSubScope("cloudwatch_scheduler")) case common.Local: fallthrough default: logger.Infof(context.Background(), "Using default noop event scheduler implementation for cloud provider type [%s]", - cfg.EventSchedulerConfig.Scheme) + cfg.SchedulerConfig.EventSchedulerConfig.Scheme) eventScheduler = noop.NewNoopEventScheduler() } - switch cfg.WorkflowExecutorConfig.Scheme { + switch cfg.SchedulerConfig.WorkflowExecutorConfig.Scheme { case common.AWS: // Do nothing, this special case depends on the execution manager and launch plan manager having been // initialized and is handled in GetWorkflowExecutor. @@ -87,7 +96,7 @@ func NewWorkflowScheduler(cfg WorkflowSchedulerConfig) WorkflowScheduler { default: logger.Infof(context.Background(), "Using default noop workflow executor implementation for cloud provider type [%s]", - cfg.EventSchedulerConfig.Scheme) + cfg.SchedulerConfig.EventSchedulerConfig.Scheme) workflowExecutor = noop.NewWorkflowExecutor() } return &workflowScheduler{ diff --git a/pkg/async/shared.go b/pkg/async/shared.go new file mode 100644 index 000000000..bdfed2c4a --- /dev/null +++ b/pkg/async/shared.go @@ -0,0 +1,25 @@ +package async + +import ( + "context" + "time" + + "github.com/lyft/flytestdlib/logger" +) + +// RetryDelay indicates how long to wait between restarting a subscriber connection in the case of network failures. +var RetryDelay = 30 * time.Second + +func Retry(attempts int, delay time.Duration, f func() error) error { + var err error + for attempt := 0; attempt <= attempts; attempt++ { + err = f() + if err == nil { + return nil + } + logger.Warningf(context.Background(), + "Failed [%v] on attempt %d of %d", err, attempt, attempts) + time.Sleep(delay) + } + return err +} diff --git a/pkg/async/shared_test.go b/pkg/async/shared_test.go new file mode 100644 index 000000000..4b6a89551 --- /dev/null +++ b/pkg/async/shared_test.go @@ -0,0 +1,33 @@ +package async + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetry(t *testing.T) { + attemptsRecorded := 0 + err := Retry(3, time.Millisecond, func() error { + if attemptsRecorded == 3 { + return nil + } + attemptsRecorded++ + return errors.New("foo") + }) + assert.Nil(t, err) +} + +func TestRetry_RetriesExhausted(t *testing.T) { + attemptsRecorded := 0 + err := Retry(2, time.Millisecond, func() error { + if attemptsRecorded == 3 { + return nil + } + attemptsRecorded++ + return errors.New("foo") + }) + assert.EqualError(t, err, "foo") +} diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 8757727f0..bedf4bd88 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "runtime/debug" - "time" "github.com/lyft/flyteadmin/pkg/manager/impl/resources" @@ -98,21 +97,16 @@ func NewAdminServer(kubeConfig, master string) *AdminService { publisher := notifications.NewNotificationsPublisher(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) go func() { - err = processor.StartProcessing() - if err != nil { - logger.Errorf(context.Background(), "error with starting processor err: [%v] ", err) - } else { - logger.Info(context.Background(), "Successfully started processing notifications.") - } + logger.Info(context.Background(), "Started processing notifications.") + processor.StartProcessing() }() // Configure workflow scheduler async processes. schedulerConfig := configuration.ApplicationConfiguration().GetSchedulerConfig() workflowScheduler := schedule.NewWorkflowScheduler(schedule.WorkflowSchedulerConfig{ - Retries: defaultRetries, - EventSchedulerConfig: schedulerConfig.EventSchedulerConfig, - WorkflowExecutorConfig: schedulerConfig.WorkflowExecutorConfig, - Scope: adminScope, + Retries: defaultRetries, + SchedulerConfig: *schedulerConfig, + Scope: adminScope, }) eventScheduler := workflowScheduler.GetEventScheduler() @@ -143,17 +137,6 @@ func NewAdminServer(kubeConfig, master string) *AdminService { go func() { logger.Info(context.Background(), "Starting the scheduled workflow executor") scheduledWorkflowExecutor.Run() - - maxReconnectAttempts := configuration.ApplicationConfiguration().GetSchedulerConfig(). - WorkflowExecutorConfig.ReconnectAttempts - reconnectDelay := time.Duration(configuration.ApplicationConfiguration().GetSchedulerConfig(). - WorkflowExecutorConfig.ReconnectDelaySeconds) * time.Second - for reconnectAttempt := 0; reconnectAttempt < maxReconnectAttempts; reconnectAttempt++ { - time.Sleep(reconnectDelay) - logger.Warningf(context.Background(), - "Restarting scheduled workflow executor, attempt %d of %d", reconnectAttempt, maxReconnectAttempts) - scheduledWorkflowExecutor.Run() - } }() // Serve profiling endpoints. diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 52306ed34..21c4da1f6 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -84,16 +84,16 @@ type WorkflowExecutorConfig struct { // The account id (according to whichever cloud provider scheme is used) that has permission to read from the above // queue. AccountID string `json:"accountId"` - // Specifies the number of times to attempt recreating a workflow executor client should there be any disruptions. - ReconnectAttempts int `json:"reconnectAttempts"` - // Specifies the time interval to wait before attempting to reconnect the workflow executor client. - ReconnectDelaySeconds int `json:"reconnectDelaySeconds"` } // This configuration is the base configuration for all scheduler-related set-up. type SchedulerConfig struct { EventSchedulerConfig EventSchedulerConfig `json:"eventScheduler"` WorkflowExecutorConfig WorkflowExecutorConfig `json:"workflowExecutor"` + // Specifies the number of times to attempt recreating a workflow executor client should there be any disruptions. + ReconnectAttempts int `json:"reconnectAttempts"` + // Specifies the time interval to wait before attempting to reconnect the workflow executor client. + ReconnectDelaySeconds int `json:"reconnectDelaySeconds"` } // Configuration specific to setting up signed urls. @@ -149,6 +149,10 @@ type NotificationsConfig struct { NotificationsPublisherConfig NotificationsPublisherConfig `json:"publisher"` NotificationsProcessorConfig NotificationsProcessorConfig `json:"processor"` NotificationsEmailerConfig NotificationsEmailerConfig `json:"emailer"` + // Number of times to attempt recreating a notifications processor client should there be any disruptions. + ReconnectAttempts int `json:"reconnectAttempts"` + // Specifies the time interval to wait before attempting to reconnect the notifications processor client. + ReconnectDelaySeconds int `json:"reconnectDelaySeconds"` } // Domains are always globally set in the application config, whereas individual projects can be individually registered.