diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index b9eba5b83a..2eec0f8cf3 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -51,6 +51,7 @@ require ( github.com/wolfeidau/humanhash v1.1.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 + golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.16.0 golang.org/x/time v0.5.0 google.golang.org/api v0.155.0 @@ -189,7 +190,6 @@ require ( go.opentelemetry.io/proto/otlp v1.1.0 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect - golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/term v0.22.0 // indirect diff --git a/flyteadmin/pkg/async/notifications/factory.go b/flyteadmin/pkg/async/notifications/factory.go index f94129a1d5..483978238e 100644 --- a/flyteadmin/pkg/async/notifications/factory.go +++ b/flyteadmin/pkg/async/notifications/factory.go @@ -18,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/common" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -27,6 +28,7 @@ const maxRetries = 3 var enable64decoding = false var msgChan chan []byte + var once sync.Once type PublisherConfig struct { @@ -35,220 +37,404 @@ type PublisherConfig struct { type ProcessorConfig struct { QueueName string + AccountID string } type EmailerConfig struct { SenderEmail string - BaseURL string + + BaseURL string } // For sandbox only + func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) + } -func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { +func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + // If an external email service is specified use that instead. + // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 + if config.NotificationsEmailerConfig.EmailerConfig.ServiceName != "" { + switch config.NotificationsEmailerConfig.EmailerConfig.ServiceName { + case implementations.Sendgrid: + return implementations.NewSendGridEmailer(config, scope) + + case implementations.SMTP: + + return implementations.NewSMTPEmailer(context.Background(), config, scope, sm) + default: + panic(fmt.Errorf("No matching email implementation for %s", config.NotificationsEmailerConfig.EmailerConfig.ServiceName)) + } + } switch config.Type { + case common.AWS: + region := config.AWSConfig.Region + if region == "" { + region = config.Region + } + awsConfig := aws.NewConfig().WithRegion(region).WithMaxRetries(maxRetries) + awsSession, err := session.NewSession(awsConfig) + if err != nil { + panic(err) + } + sesClient := ses.New(awsSession) + return implementations.NewAwsEmailer( + config, + scope, + sesClient, ) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), "Using default noop emailer implementation for config type [%s]", config.Type) + return implementations.NewNoopEmail() + } + } -func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Processor { +func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + var sub pubsub.Subscriber + var emailer interfaces.Emailer + switch config.Type { + case common.AWS: + sqsConfig := gizmoAWS.SQSConfig{ - QueueName: config.NotificationsProcessorConfig.QueueName, + + QueueName: config.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: config.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, } + if config.AWSConfig.Region != "" { + sqsConfig.Region = config.AWSConfig.Region + } else { + sqsConfig.Region = config.Region + } + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo aws subscriber with config [%+v] and err: %v", sqsConfig, err) + } + return err + }) if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewProcessor(sub, emailer, scope) + case common.GCP: + projectID := config.GCPConfig.ProjectID + subscription := config.NotificationsProcessorConfig.QueueName + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoGCP.NewSubscriber(context.TODO(), projectID, subscription) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo gcp subscriber with config [ProjectID: %s, Subscription: %s] and err: %v", projectID, subscription, err) + } + return err + }) + if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewGcpProcessor(sub, emailer, scope) + case common.Sandbox: - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewSandboxProcessor(msgChan, emailer) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + return implementations.NewNoopProcess() + } + } func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.NotificationsPublisherConfig.TopicName, } + if config.AWSConfig.Region != "" { + snsConfig.Region = config.AWSConfig.Region + } else { + snsConfig.Region = config.Region + } var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.NotificationsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.Sandbox: + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } func NewEventsPublisher(config runtimeInterfaces.ExternalEventsConfig, scope promutils.Scope) interfaces.Publisher { + if !config.Enable { + return implementations.NewNoopPublish() + } + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.EventsPublisherConfig.TopicName, } + snsConfig.Region = config.AWSConfig.Region var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.EventsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop events publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } diff --git a/flyteadmin/pkg/async/notifications/factory_test.go b/flyteadmin/pkg/async/notifications/factory_test.go index 1bfd1f4596..43602525a5 100644 --- a/flyteadmin/pkg/async/notifications/factory_test.go +++ b/flyteadmin/pkg/async/notifications/factory_test.go @@ -9,52 +9,76 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/implementations" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) var ( - scope = promutils.NewScope("test_sandbox_processor") + scope = promutils.NewScope("test_sandbox_processor") + notificationsConfig = runtimeInterfaces.NotificationsConfig{ + Type: "sandbox", } + testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", - Body: "This is a sample email.", + + Body: "This is a sample email.", } ) func TestGetEmailer(t *testing.T) { + defer func() { r := recover(); assert.NotNil(t, r) }() + cfg := runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + EmailerConfig: runtimeInterfaces.EmailServerConfig{ + ServiceName: "unsupported", }, }, } - GetEmailer(cfg, promutils.NewTestScope()) + GetEmailer(cfg, promutils.NewTestScope(), &mocks.SecretManager{}) // shouldn't reach here + t.Errorf("did not panic") + } func TestNewNotificationPublisherAndProcessor(t *testing.T) { + testSandboxPublisher := NewNotificationsPublisher(notificationsConfig, scope) + assert.IsType(t, testSandboxPublisher, &implementations.SandboxPublisher{}) - testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope) + + testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope, &mocks.SecretManager{}) + assert.IsType(t, testSandboxProcessor, &implementations.SandboxProcessor{}) go func() { + testSandboxProcessor.StartProcessing() + }() assert.Nil(t, testSandboxPublisher.Publish(context.Background(), "TEST_NOTIFICATION", &testEmail)) assert.Nil(t, testSandboxProcessor.StopProcessing()) + } diff --git a/flyteadmin/pkg/async/notifications/implementations/emailers.go b/flyteadmin/pkg/async/notifications/implementations/emailers.go index e630b5a4ea..0da3fbf600 100644 --- a/flyteadmin/pkg/async/notifications/implementations/emailers.go +++ b/flyteadmin/pkg/async/notifications/implementations/emailers.go @@ -4,4 +4,5 @@ type ExternalEmailer = string const ( Sendgrid ExternalEmailer = "sendgrid" + SMTP ExternalEmailer = "smtp" ) diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go new file mode 100644 index 0000000000..5a705bc0c1 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go @@ -0,0 +1,158 @@ +package implementations + +import ( + "crypto/tls" + "fmt" + "net/smtp" + "strings" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type SMTPEmailer struct { + config *runtimeInterfaces.NotificationsEmailerConfig + systemMetrics emailMetrics + tlsConf *tls.Config + auth *smtp.Auth + smtpClient interfaces.SMTPClient + CreateSMTPClientFunc func(connectString string) (interfaces.SMTPClient, error) +} + +func (s *SMTPEmailer) createClient(ctx context.Context) (interfaces.SMTPClient, error) { + newClient, err := s.CreateSMTPClientFunc(s.config.EmailerConfig.SMTPServer + ":" + s.config.EmailerConfig.SMTPPort) + + if err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error creating email client: %s", err)) + } + + if err = newClient.Hello("localhost"); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + + if ok, _ := newClient.Extension("STARTTLS"); ok { + if err = newClient.StartTLS(s.tlsConf); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + } + + if ok, _ := newClient.Extension("AUTH"); ok { + if err = newClient.Auth(*s.auth); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error authenticating email client: %s", err)) + } + } + + return newClient, nil +} + +func (s *SMTPEmailer) SendEmail(ctx context.Context, email *admin.EmailMessage) error { + + if s.smtpClient == nil || s.smtpClient.Noop() != nil { + + if s.smtpClient != nil { + err := s.smtpClient.Close() + if err != nil { + logger.Info(ctx, err) + } + } + smtpClient, err := s.createClient(ctx) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating SMTP email client: %s", err)) + } + + s.smtpClient = smtpClient + } + + if err := s.smtpClient.Mail(email.SenderEmail); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating email instance: %s", err)) + } + + for _, recipient := range email.RecipientsEmail { + if err := s.smtpClient.Rcpt(recipient); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + } + + writer, err := s.smtpClient.Data() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + + _, err = writer.Write([]byte(createMailBody(s.config.Sender, email))) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error writing mail body: %s", err)) + } + + err = writer.Close() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error closing mail body: %s", err)) + } + + s.systemMetrics.SendSuccess.Inc() + return nil +} + +func (s *SMTPEmailer) emailError(ctx context.Context, error string) error { + s.systemMetrics.SendError.Inc() + logger.Error(ctx, error) + return errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails") +} + +func createMailBody(emailSender string, email *admin.EmailMessage) string { + headerMap := make(map[string]string) + headerMap["From"] = emailSender + headerMap["To"] = strings.Join(email.RecipientsEmail, ",") + headerMap["Subject"] = email.SubjectLine + headerMap["Content-Type"] = "text/html; charset=\"UTF-8\"" + + mailMessage := "" + + for k, v := range headerMap { + mailMessage += fmt.Sprintf("%s: %s\r\n", k, v) + } + + mailMessage += "\r\n" + email.Body + + return mailMessage +} + +func NewSMTPEmailer(ctx context.Context, config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + var tlsConfiguration *tls.Config + emailConf := config.NotificationsEmailerConfig.EmailerConfig + + smtpPassword, err := sm.Get(ctx, emailConf.SMTPPasswordSecretName) + if err != nil { + logger.Debug(ctx, "No SMTP password found.") + smtpPassword = "" + } + + auth := smtp.PlainAuth("", emailConf.SMTPUsername, smtpPassword, emailConf.SMTPServer) + + // #nosec G402 + tlsConfiguration = &tls.Config{ + InsecureSkipVerify: emailConf.SMTPSkipTLSVerify, + ServerName: emailConf.SMTPServer, + } + + return &SMTPEmailer{ + config: &config.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(scope.NewSubScope("smtp")), + tlsConf: tlsConfiguration, + auth: &auth, + CreateSMTPClientFunc: func(connectString string) (interfaces.SMTPClient, error) { + return smtp.Dial(connectString) + }, + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go new file mode 100644 index 0000000000..558a5d6408 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go @@ -0,0 +1,498 @@ +package implementations + +import ( + "context" + "crypto/tls" + "errors" + "net/smtp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + + notification_interfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + notification_mocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks" + flyte_errors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type StringWriter struct { + buffer string + writeErr error + closeErr error +} + +func (s *StringWriter) Write(p []byte) (n int, err error) { + s.buffer = s.buffer + string(p) + return len(p), s.writeErr +} + +func (s *StringWriter) Close() error { + return s.closeErr +} + +func getNotificationsEmailerConfig() interfaces.NotificationsConfig { + return interfaces.NotificationsConfig{ + Type: "", + Region: "", + AWSConfig: interfaces.AWSConfig{}, + GCPConfig: interfaces.GCPConfig{}, + NotificationsPublisherConfig: interfaces.NotificationsPublisherConfig{}, + NotificationsProcessorConfig: interfaces.NotificationsProcessorConfig{}, + NotificationsEmailerConfig: interfaces.NotificationsEmailerConfig{ + EmailerConfig: interfaces.EmailServerConfig{ + ServiceName: SMTP, + SMTPServer: "smtpServer", + SMTPPort: "smtpPort", + SMTPUsername: "smtpUsername", + SMTPPasswordSecretName: "smtp_password", + }, + Subject: "subject", + Sender: "sender", + Body: "body"}, + ReconnectAttempts: 1, + ReconnectDelaySeconds: 2} +} + +func TestEmailCreation(t *testing.T) { + email := admin.EmailMessage{ + RecipientsEmail: []string{"john@doe.com", "teresa@tester.com"}, + SubjectLine: "subject", + Body: "Email Body", + SenderEmail: "sender@sender.com", + } + + body := createMailBody("sender@sender.com", &email) + assert.Contains(t, body, "From: sender@sender.com\r\n") + assert.Contains(t, body, "To: john@doe.com,teresa@tester.com") + assert.Contains(t, body, "Subject: subject\r\n") + assert.Contains(t, body, "Content-Type: text/html; charset=\"UTF-8\"\r\n") + assert.Contains(t, body, "Email Body") +} + +func TestNewSmtpEmailer(t *testing.T) { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + smtpEmailer := NewSMTPEmailer(context.Background(), notificationsConfig, promutils.NewTestScope(), &secretManagerMock) + + assert.NotNil(t, smtpEmailer) +} + +func TestCreateClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil) + smtpClient.On("Extension", "STARTTLS").Return(true, "") + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil) + smtpClient.On("Extension", "AUTH").Return(true, "") + smtpClient.On("Auth", auth).Return(nil) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Nil(t, err) + assert.NotNil(t, client) + +} + +func TestCreateClientErrorCreatingClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, errors.New("error creating client")) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorHello(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(errors.New("Error with hello")) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorStartTLS(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(errors.New("Error with startls")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorAuth(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(errors.New("Error with hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestSendMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: ""} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Nil(t, err) + +} + +func TestSendMailCreateClientError(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(errors.New("error hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(errors.New("error sending mail")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorRecipient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(errors.New("error adding recipient")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorData(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(nil, errors.New("error creating data writer")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorWriting(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: errors.New("error writing"), closeErr: nil} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorClose(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: nil, closeErr: errors.New("error writing")} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func createSMTPEmailer(smtpClient notification_interfaces.SMTPClient, tlsConf *tls.Config, auth *smtp.Auth, creationErr error) *SMTPEmailer { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + return &SMTPEmailer{ + config: ¬ificationsConfig.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(promutils.NewTestScope()), + tlsConf: tlsConf, + auth: auth, + CreateSMTPClientFunc: func(connectString string) (notification_interfaces.SMTPClient, error) { + return smtpClient, creationErr + }, + smtpClient: smtpClient, + } +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go new file mode 100644 index 0000000000..bdc6171f46 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go @@ -0,0 +1,22 @@ +package interfaces + +import ( + "crypto/tls" + "io" + "net/smtp" +) + +// This interface is introduced to allow for mocking of the smtp.Client object. + +//go:generate mockery -name=SMTPClient -output=../mocks -case=underscore +type SMTPClient interface { + Hello(localName string) error + Extension(ext string) (bool, string) + Auth(a smtp.Auth) error + StartTLS(config *tls.Config) error + Noop() error + Close() error + Mail(from string) error + Rcpt(to string) error + Data() (io.WriteCloser, error) +} diff --git a/flyteadmin/pkg/async/notifications/mocks/smtp_client.go b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go new file mode 100644 index 0000000000..11dafefc9c --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go @@ -0,0 +1,321 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + io "io" + smtp "net/smtp" + + mock "github.com/stretchr/testify/mock" + + tls "crypto/tls" +) + +// SMTPClient is an autogenerated mock type for the SMTPClient type +type SMTPClient struct { + mock.Mock +} + +type SMTPClient_Auth struct { + *mock.Call +} + +func (_m SMTPClient_Auth) Return(_a0 error) *SMTPClient_Auth { + return &SMTPClient_Auth{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnAuth(a smtp.Auth) *SMTPClient_Auth { + c_call := _m.On("Auth", a) + return &SMTPClient_Auth{Call: c_call} +} + +func (_m *SMTPClient) OnAuthMatch(matchers ...interface{}) *SMTPClient_Auth { + c_call := _m.On("Auth", matchers...) + return &SMTPClient_Auth{Call: c_call} +} + +// Auth provides a mock function with given fields: a +func (_m *SMTPClient) Auth(a smtp.Auth) error { + ret := _m.Called(a) + + var r0 error + if rf, ok := ret.Get(0).(func(smtp.Auth) error); ok { + r0 = rf(a) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Close struct { + *mock.Call +} + +func (_m SMTPClient_Close) Return(_a0 error) *SMTPClient_Close { + return &SMTPClient_Close{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnClose() *SMTPClient_Close { + c_call := _m.On("Close") + return &SMTPClient_Close{Call: c_call} +} + +func (_m *SMTPClient) OnCloseMatch(matchers ...interface{}) *SMTPClient_Close { + c_call := _m.On("Close", matchers...) + return &SMTPClient_Close{Call: c_call} +} + +// Close provides a mock function with given fields: +func (_m *SMTPClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Data struct { + *mock.Call +} + +func (_m SMTPClient_Data) Return(_a0 io.WriteCloser, _a1 error) *SMTPClient_Data { + return &SMTPClient_Data{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnData() *SMTPClient_Data { + c_call := _m.On("Data") + return &SMTPClient_Data{Call: c_call} +} + +func (_m *SMTPClient) OnDataMatch(matchers ...interface{}) *SMTPClient_Data { + c_call := _m.On("Data", matchers...) + return &SMTPClient_Data{Call: c_call} +} + +// Data provides a mock function with given fields: +func (_m *SMTPClient) Data() (io.WriteCloser, error) { + ret := _m.Called() + + var r0 io.WriteCloser + if rf, ok := ret.Get(0).(func() io.WriteCloser); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.WriteCloser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SMTPClient_Extension struct { + *mock.Call +} + +func (_m SMTPClient_Extension) Return(_a0 bool, _a1 string) *SMTPClient_Extension { + return &SMTPClient_Extension{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnExtension(ext string) *SMTPClient_Extension { + c_call := _m.On("Extension", ext) + return &SMTPClient_Extension{Call: c_call} +} + +func (_m *SMTPClient) OnExtensionMatch(matchers ...interface{}) *SMTPClient_Extension { + c_call := _m.On("Extension", matchers...) + return &SMTPClient_Extension{Call: c_call} +} + +// Extension provides a mock function with given fields: ext +func (_m *SMTPClient) Extension(ext string) (bool, string) { + ret := _m.Called(ext) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(ext) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 string + if rf, ok := ret.Get(1).(func(string) string); ok { + r1 = rf(ext) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + +type SMTPClient_Hello struct { + *mock.Call +} + +func (_m SMTPClient_Hello) Return(_a0 error) *SMTPClient_Hello { + return &SMTPClient_Hello{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnHello(localName string) *SMTPClient_Hello { + c_call := _m.On("Hello", localName) + return &SMTPClient_Hello{Call: c_call} +} + +func (_m *SMTPClient) OnHelloMatch(matchers ...interface{}) *SMTPClient_Hello { + c_call := _m.On("Hello", matchers...) + return &SMTPClient_Hello{Call: c_call} +} + +// Hello provides a mock function with given fields: localName +func (_m *SMTPClient) Hello(localName string) error { + ret := _m.Called(localName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(localName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Mail struct { + *mock.Call +} + +func (_m SMTPClient_Mail) Return(_a0 error) *SMTPClient_Mail { + return &SMTPClient_Mail{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnMail(from string) *SMTPClient_Mail { + c_call := _m.On("Mail", from) + return &SMTPClient_Mail{Call: c_call} +} + +func (_m *SMTPClient) OnMailMatch(matchers ...interface{}) *SMTPClient_Mail { + c_call := _m.On("Mail", matchers...) + return &SMTPClient_Mail{Call: c_call} +} + +// Mail provides a mock function with given fields: from +func (_m *SMTPClient) Mail(from string) error { + ret := _m.Called(from) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(from) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Noop struct { + *mock.Call +} + +func (_m SMTPClient_Noop) Return(_a0 error) *SMTPClient_Noop { + return &SMTPClient_Noop{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnNoop() *SMTPClient_Noop { + c_call := _m.On("Noop") + return &SMTPClient_Noop{Call: c_call} +} + +func (_m *SMTPClient) OnNoopMatch(matchers ...interface{}) *SMTPClient_Noop { + c_call := _m.On("Noop", matchers...) + return &SMTPClient_Noop{Call: c_call} +} + +// Noop provides a mock function with given fields: +func (_m *SMTPClient) Noop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Rcpt struct { + *mock.Call +} + +func (_m SMTPClient_Rcpt) Return(_a0 error) *SMTPClient_Rcpt { + return &SMTPClient_Rcpt{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnRcpt(to string) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", to) + return &SMTPClient_Rcpt{Call: c_call} +} + +func (_m *SMTPClient) OnRcptMatch(matchers ...interface{}) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", matchers...) + return &SMTPClient_Rcpt{Call: c_call} +} + +// Rcpt provides a mock function with given fields: to +func (_m *SMTPClient) Rcpt(to string) error { + ret := _m.Called(to) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(to) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_StartTLS struct { + *mock.Call +} + +func (_m SMTPClient_StartTLS) Return(_a0 error) *SMTPClient_StartTLS { + return &SMTPClient_StartTLS{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnStartTLS(config *tls.Config) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", config) + return &SMTPClient_StartTLS{Call: c_call} +} + +func (_m *SMTPClient) OnStartTLSMatch(matchers ...interface{}) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", matchers...) + return &SMTPClient_StartTLS{Call: c_call} +} + +// StartTLS provides a mock function with given fields: config +func (_m *SMTPClient) StartTLS(config *tls.Config) error { + ret := _m.Called(config) + + var r0 error + if rf, ok := ret.Get(0).(func(*tls.Config) error); ok { + r0 = rf(config) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteadmin/pkg/rpc/adminservice/base.go b/flyteadmin/pkg/rpc/adminservice/base.go index 8df2c595c7..491a24a1f0 100644 --- a/flyteadmin/pkg/rpc/adminservice/base.go +++ b/flyteadmin/pkg/rpc/adminservice/base.go @@ -20,6 +20,7 @@ import ( workflowengineImpl "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/impl" "github.com/flyteorg/flyte/flyteadmin/plugins" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -45,7 +46,7 @@ type AdminService struct { const defaultRetries = 3 func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, configuration runtimeIfaces.Configuration, - kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope) *AdminService { + kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope, sm core.SecretManager) *AdminService { applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() panicCounter := adminScope.MustNewCounter("initialization_panic", @@ -81,7 +82,7 @@ func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, confi pluginRegistry.RegisterDefault(plugins.PluginIDWorkflowExecutor, workflowExecutor) publisher := notifications.NewNotificationsPublisher(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) - processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) + processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope, sm) eventPublisher := notifications.NewEventsPublisher(*configuration.ApplicationConfiguration().GetExternalEventsConfig(), adminScope) go func() { logger.Info(ctx, "Started processing notifications.") diff --git a/flyteadmin/pkg/runtime/interfaces/application_configuration.go b/flyteadmin/pkg/runtime/interfaces/application_configuration.go index 3505150919..e3453db0f7 100644 --- a/flyteadmin/pkg/runtime/interfaces/application_configuration.go +++ b/flyteadmin/pkg/runtime/interfaces/application_configuration.go @@ -492,8 +492,13 @@ type NotificationsProcessorConfig struct { type EmailServerConfig struct { ServiceName string `json:"serviceName"` // Only one of these should be set. - APIKeyEnvVar string `json:"apiKeyEnvVar"` - APIKeyFilePath string `json:"apiKeyFilePath"` + APIKeyEnvVar string `json:"apiKeyEnvVar"` + APIKeyFilePath string `json:"apiKeyFilePath"` + SMTPServer string `json:"smtpServer"` + SMTPPort string `json:"smtpPort"` + SMTPSkipTLSVerify bool `json:"smtpSkipTLSVerify"` + SMTPUsername string `json:"smtpUsername"` + SMTPPasswordSecretName string `json:"smtpPasswordSecretName"` } // This section handles the configuration of notifications emails. diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index 587ea86e3b..840d0d9f17 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -43,6 +43,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/assets" grpcService "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/gateway/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/secretmanager" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -82,7 +83,7 @@ func SetMetricKeys(appConfig *runtimeIfaces.ApplicationConfig) { // Creates a new gRPC Server with all the configuration func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, storageCfg *storage.Config, authCtx interfaces.AuthenticationContext, - scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { + scope promutils.Scope, sm core.SecretManager, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer( @@ -152,7 +153,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } configuration := runtime2.NewConfigurationProvider() - adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope) + adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope, sm) grpcService.RegisterAdminServiceServer(grpcServer, adminServer) if cfg.Security.UseAuth { grpcService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) @@ -339,12 +340,15 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext var err error + + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + // This code is here to support authentication without SSL. This setup supports a network topology where // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. // Warning: Running authentication without SSL in any other topology is a severe security flaw. // See the auth.Config object for additional settings as well. if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -373,7 +377,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope, sm) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) } @@ -448,13 +452,14 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + if err != nil { return err } // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -483,7 +488,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, grpc.Creds(credentials.NewServerTLSFromCert(cert))) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, sm, grpc.Creds(credentials.NewServerTLSFromCert(cert))) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) }