diff --git a/emailsender/cmd/app/main.go b/emailsender/cmd/app/main.go index 11a5378a23..d590972d32 100644 --- a/emailsender/cmd/app/main.go +++ b/emailsender/cmd/app/main.go @@ -5,6 +5,7 @@ import ( "context" "errors" "flag" + "github.com/stackrox/acs-fleet-manager/emailsender/pkg/db" "net/http" "os" "os/signal" @@ -36,10 +37,19 @@ func main() { os.Exit(1) } + dbCfg := cfg.DatabaseConfig.GetDbConfig() + if err = dbCfg.ReadFiles(); err != nil { + glog.Warningf("Failed to read DB configuration from files: %v", err) + glog.Warning("Use DB configuration from plain environment variables") + } + ctx := context.Background() // initialize components - sesClient, err := email.NewSES(ctx) + dbConnection := db.NewDatabaseConnection(dbCfg) + // TODO(ROX-23260): connect Rate Limiter to Email Sender + _ = email.NewRateLimiterService(dbConnection) + sesClient, err := email.NewSES(ctx, cfg.SesMaxBackoffDelay, cfg.SesMaxAttempts) if err != nil { glog.Errorf("Failed to initialise SES Client: %v", err) os.Exit(1) diff --git a/emailsender/config/config.go b/emailsender/config/config.go index 6e887c4043..4079314667 100644 --- a/emailsender/config/config.go +++ b/emailsender/config/config.go @@ -9,11 +9,13 @@ import ( "os" "path" "strings" + "time" "github.com/caarlos0/env/v6" "gopkg.in/yaml.v2" "github.com/pkg/errors" + commonDbConfig "github.com/stackrox/acs-fleet-manager/pkg/db" "github.com/stackrox/acs-fleet-manager/pkg/shared" "github.com/stackrox/rox/pkg/errorhelpers" "github.com/stackrox/rox/pkg/utils" @@ -28,16 +30,54 @@ const ( // Config contains this application's runtime configuration. type Config struct { - ClusterID string `env:"CLUSTER_ID"` - ServerAddress string `env:"SERVER_ADDRESS" envDefault:":8080"` - EnableHTTPS bool `env:"ENABLE_HTTPS" envDefault:"false"` - HTTPSCertFile string `env:"HTTPS_CERT_FILE" envDefault:""` - HTTPSKeyFile string `env:"HTTPS_KEY_FILE" envDefault:""` - MetricsAddress string `env:"METRICS_ADDRESS" envDefault:":9090"` - AuthConfigFile string `env:"AUTH_CONFIG_FILE" envDefault:"config/emailsender-authz.yaml"` - AuthConfigFromKubernetes bool `env:"AUTH_CONFIG_FROM_KUBERNETES" envDefault:"false"` - SenderAddress string `env:"SENDER_ADDRESS" envDefault:"noreply@mail.rhacs-dev.com"` + ClusterID string `env:"CLUSTER_ID"` + ServerAddress string `env:"SERVER_ADDRESS" envDefault:":8080"` + EnableHTTPS bool `env:"ENABLE_HTTPS" envDefault:"false"` + HTTPSCertFile string `env:"HTTPS_CERT_FILE" envDefault:""` + HTTPSKeyFile string `env:"HTTPS_KEY_FILE" envDefault:""` + MetricsAddress string `env:"METRICS_ADDRESS" envDefault:":9090"` + AuthConfigFile string `env:"AUTH_CONFIG_FILE" envDefault:"config/emailsender-authz.yaml"` + AuthConfigFromKubernetes bool `env:"AUTH_CONFIG_FROM_KUBERNETES" envDefault:"false"` + SenderAddress string `env:"SENDER_ADDRESS" envDefault:"noreply@mail.rhacs-dev.com"` + LimitEmailPerTenant int `env:"LIMIT_EMAIL_PER_TENANT" envDefault:"250"` + SesMaxBackoffDelay time.Duration `env:"SES_MAX_BACKOFF_DELAY" envDefault:"5s"` + SesMaxAttempts int `env:"SES_MAX_ATTEMPTS" envDefault:"3"` AuthConfig AuthConfig + DatabaseConfig DbConfig +} + +type DbConfig struct { + HostFile string `env:"DATABASE_HOST_FILE" envDefault:"secrets/db.host"` + PortFile string `env:"DATABASE_PORT_FILE" envDefault:"secrets/db.port"` + NameFile string `env:"DATABASE_NAME_FILE" envDefault:"secrets/db.name"` + UserFile string `env:"DATABASE_USER_FILE" envDefault:"secrets/db.user"` + PasswordFile string `env:"DATABASE_PASSWORD_FILE" envDefault:"secrets/db.password"` + CaCertFile string `env:"DATABASE_CA_CERT_FILE" envDefault:"secrets/db.ca_cert"` + Host string `env:"DATABASE_HOST" envDefault:"localhost"` + Port int `env:"DATABASE_PORT" envDefault:"5432"` + Name string `env:"DATABASE_NAME" envDefault:"postgres"` + User string `env:"DATABASE_USER" envDefault:"postgres"` + Password string `env:"DATABASE_PASSWORD" envDefault:"postgres"` + SSLMode string `env:"DATABASE_SSL_MODE" envDefault:"disable"` + MaxOpenConnections int `env:"DATABASE_MAX_CONNECTIONS" envDefault:"50"` +} + +func (d *DbConfig) GetDbConfig() *commonDbConfig.DatabaseConfig { + cfg := commonDbConfig.NewDatabaseConfig() + cfg.SSLMode = d.SSLMode + cfg.MaxOpenConnections = d.MaxOpenConnections + cfg.HostFile = d.HostFile + cfg.PortFile = d.PortFile + cfg.NameFile = d.NameFile + cfg.UsernameFile = d.UserFile + cfg.PasswordFile = d.PasswordFile // pragma: allowlist secret + cfg.DatabaseCaCertFile = d.CaCertFile + cfg.Host = d.Host + cfg.Port = d.Port + cfg.Name = d.Name + cfg.Username = d.User + cfg.Password = d.Password // pragma: allowlist secret + return cfg } // GetConfig retrieves the current runtime configuration from the environment and returns it. diff --git a/emailsender/pkg/db/connect.go b/emailsender/pkg/db/connect.go new file mode 100644 index 0000000000..60cd9329e9 --- /dev/null +++ b/emailsender/pkg/db/connect.go @@ -0,0 +1,52 @@ +package db + +import ( + "fmt" + "gorm.io/gorm" + "time" + + commonDB "github.com/stackrox/acs-fleet-manager/pkg/db" +) + +// DatabaseClient defines methods for fetching or updating models in DB +type DatabaseClient interface { + InsertEmailSentByTenant(tenantID string) error + CountEmailSentByTenantSince(tenantID string, since time.Time) (int64, error) +} + +// DatabaseConnection contains dependency for communicating with DB +type DatabaseConnection struct { + DB *gorm.DB +} + +// NewDatabaseConnection creates a new DB connection +func NewDatabaseConnection(dbConfig *commonDB.DatabaseConfig) *DatabaseConnection { + connection, _ := commonDB.NewConnectionFactory(dbConfig) + return &DatabaseConnection{DB: connection.DB} +} + +// Migrate automatically migrates listed models in the database +// Documentation: https://gorm.io/docs/migration.html#Auto-Migration +func (d *DatabaseConnection) Migrate() error { + return d.DB.AutoMigrate(&EmailSentByTenant{}) +} + +// InsertEmailSentByTenant returns an instance of EmailSentByTenant representing how many emails tenant sent for provided date +func (d *DatabaseConnection) InsertEmailSentByTenant(tenantID string) error { + + if result := d.DB.Create(&EmailSentByTenant{TenantID: tenantID}); result.Error != nil { + return fmt.Errorf("failed inserting into email_sent_by_tenant table: %v", result.Error) + } + return nil +} + +// CountEmailSentByTenantSince counts how many emails tenant sent since provided timestamp +func (d *DatabaseConnection) CountEmailSentByTenantSince(tenantID string, since time.Time) (int64, error) { + var count int64 + if result := d.DB.Model(&EmailSentByTenant{}). + Where("tenant_id = ? AND created_at > ?", tenantID, since). + Count(&count); result.Error != nil { + return count, fmt.Errorf("failed count items in email_sent_by_tenant: %v", result.Error) + } + return count, nil +} diff --git a/emailsender/pkg/db/models.go b/emailsender/pkg/db/models.go new file mode 100644 index 0000000000..3bdc89dc03 --- /dev/null +++ b/emailsender/pkg/db/models.go @@ -0,0 +1,9 @@ +package db + +import "time" + +// EmailSentByTenant represents how many emails sent by tenant +type EmailSentByTenant struct { + TenantID string `gorm:"index"` + CreatedAt time.Time `gorm:"index"` // gorm automatically set to current unix seconds on create +} diff --git a/emailsender/pkg/email/ratelimiter.go b/emailsender/pkg/email/ratelimiter.go new file mode 100644 index 0000000000..e007cf3281 --- /dev/null +++ b/emailsender/pkg/email/ratelimiter.go @@ -0,0 +1,58 @@ +package email + +import ( + "fmt" + "github.com/golang/glog" + "github.com/stackrox/acs-fleet-manager/emailsender/pkg/db" + "time" +) + +const ( + windowSizeHours = 24 +) + +// RateLimiter defines an exact methods for rate limiter +type RateLimiter interface { + IsAllowed(tenantID string) bool + PersistEmailSendEvent(tenantID string) error +} + +// RateLimiterService contains configuration and dependency for rate limiter +type RateLimiterService struct { + limitPerTenant int + dbConnection db.DatabaseClient +} + +// NewRateLimiterService creates a new instance of RateLimiterService +func NewRateLimiterService(dbConnection *db.DatabaseConnection) *RateLimiterService { + return &RateLimiterService{ + dbConnection: dbConnection, + } +} + +// IsAllowed checks whether specified tenant can send an email for current timestamp +func (r *RateLimiterService) IsAllowed(tenantID string) bool { + now := time.Now() + dayAgo := now.Add(time.Duration(-windowSizeHours) * time.Hour) + sentDuringWindow, err := r.dbConnection.CountEmailSentByTenantSince(tenantID, dayAgo) + if err != nil { + glog.Errorf("Cannot count sent emails during window for tenant %s: %v", tenantID, err) + return false + } + + if sentDuringWindow >= int64(r.limitPerTenant) { + glog.Warningf("Reached limit for sent emails during window for tenant %s", tenantID) + return false + } + + return true +} + +// PersistEmailSendEvent stores email sent event +func (r *RateLimiterService) PersistEmailSendEvent(tenantID string) error { + err := r.dbConnection.InsertEmailSentByTenant(tenantID) + if err != nil { + return fmt.Errorf("failed register sent email: %v", err) + } + return nil +} diff --git a/emailsender/pkg/email/ratelimiter_test.go b/emailsender/pkg/email/ratelimiter_test.go new file mode 100644 index 0000000000..14b8c3a6cf --- /dev/null +++ b/emailsender/pkg/email/ratelimiter_test.go @@ -0,0 +1,82 @@ +package email + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +var limitPerTenant = 20 +var testTenantID = "test-tenant-id" + +type MockDatabaseClient struct { + calledInsertEmailSentByTenant bool + calledCountEmailSentByTenantFrom bool + + InsertEmailSentByTenantFunc func(tenantID string) error + CountEmailSentByTenantFromFunc func(tenantID string, from time.Time) (int64, error) +} + +func (m *MockDatabaseClient) InsertEmailSentByTenant(tenantID string) error { + m.calledInsertEmailSentByTenant = true + return m.InsertEmailSentByTenantFunc(tenantID) +} + +func (m *MockDatabaseClient) CountEmailSentByTenantSince(tenantID string, from time.Time) (int64, error) { + m.calledCountEmailSentByTenantFrom = true + return m.CountEmailSentByTenantFromFunc(tenantID, from) +} + +func TestAllowTrue_Success(t *testing.T) { + mockDatabaseClient := &MockDatabaseClient{ + CountEmailSentByTenantFromFunc: func(tenantID string, from time.Time) (int64, error) { + return int64(limitPerTenant - 1), nil + }, + } + + service := RateLimiterService{ + limitPerTenant: limitPerTenant, + dbConnection: mockDatabaseClient, + } + + allowed := service.IsAllowed(testTenantID) + + assert.True(t, allowed) + assert.True(t, mockDatabaseClient.calledCountEmailSentByTenantFrom) +} + +func TestAllowFalse_LimitReached(t *testing.T) { + mockDatabaseClient := &MockDatabaseClient{ + CountEmailSentByTenantFromFunc: func(tenantID string, from time.Time) (int64, error) { + return int64(limitPerTenant + 1), nil + }, + } + + service := RateLimiterService{ + limitPerTenant: limitPerTenant, + dbConnection: mockDatabaseClient, + } + + allowed := service.IsAllowed(testTenantID) + + assert.False(t, allowed) + assert.True(t, mockDatabaseClient.calledCountEmailSentByTenantFrom) +} + +func TestPersistEmailSendEvent(t *testing.T) { + mockDatabaseClient := &MockDatabaseClient{ + InsertEmailSentByTenantFunc: func(tenantID string) error { + return nil + }, + } + + service := RateLimiterService{ + limitPerTenant: limitPerTenant, + dbConnection: mockDatabaseClient, + } + + err := service.PersistEmailSendEvent(testTenantID) + + assert.NoError(t, err) + assert.True(t, mockDatabaseClient.calledInsertEmailSentByTenant) +} diff --git a/emailsender/pkg/email/ses.go b/emailsender/pkg/email/ses.go index 9d917cf914..3df0df3868 100644 --- a/emailsender/pkg/email/ses.go +++ b/emailsender/pkg/email/ses.go @@ -4,6 +4,8 @@ package email import ( "context" "fmt" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -18,8 +20,12 @@ type SES struct { } // NewSES creates a new SES instance with initialised AWS SES client using AWS Config -func NewSES(ctx context.Context) (*SES, error) { - cfg, err := config.LoadDefaultConfig(ctx) +func NewSES(ctx context.Context, maxBackoffDelay time.Duration, maxAttempts int) (*SES, error) { + retryerWithBackoff := retry.AddWithMaxBackoffDelay(retry.NewStandard(), maxBackoffDelay) + awsRetryer := config.WithRetryer(func() aws.Retryer { + return retry.AddWithMaxAttempts(retryerWithBackoff, maxAttempts) + }) + cfg, err := config.LoadDefaultConfig(ctx, awsRetryer) if err != nil { return nil, fmt.Errorf("unable to laod AWS SDK config: %v", err) }