Skip to content

Commit

Permalink
ROX-23260: Add Rate Limiter to Email Sender (#1887)
Browse files Browse the repository at this point in the history
Add Rate Limiter to Email Sender
  • Loading branch information
kurlov authored Jun 17, 2024
1 parent 5bd9b61 commit 0af0262
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 22 deletions.
5 changes: 2 additions & 3 deletions emailsender/cmd/app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ func main() {

// initialize components
dbConnection := db.NewDatabaseConnection(dbCfg)
// TODO(ROX-23260): connect Rate Limiter to Email Sender
_ = email.NewRateLimiterService(dbConnection)
rateLimiter := email.NewRateLimiterService(dbConnection)
sesClient, err := email.NewSES(ctx, cfg.SesMaxBackoffDelay, cfg.SesMaxAttempts)
if err != nil {
glog.Errorf("Failed to initialise SES Client: %v", err)
os.Exit(1)
}

emailSender := email.NewEmailSender(cfg.SenderAddress, sesClient)
emailSender := email.NewEmailSender(cfg.SenderAddress, sesClient, rateLimiter)
emailHandler := api.NewEmailHandler(emailSender)

router, err := api.SetupRoutes(cfg.AuthConfig, emailHandler)
Expand Down
7 changes: 2 additions & 5 deletions emailsender/pkg/api/emailhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ func (eh *EmailHandler) SendEmail(w http.ResponseWriter, r *http.Request) {
return
}

sub, err := claims.GetSubject()
tenantID, err := claims.GetSubject()
if err != nil {
shared.HandleError(r, w, errors.Unauthenticated("failed to get sub claim"))
return
}

// TODO: use sub for rate limiting later on instead of printing it here
glog.Info(sub)

if err := eh.emailSender.Send(r.Context(), request.To, request.RawMessage); err != nil {
if err := eh.emailSender.Send(r.Context(), request.To, request.RawMessage, tenantID); err != nil {
eh.errorResponse(w, "Cannot send email", http.StatusInternalServerError)
return
}
Expand Down
2 changes: 1 addition & 1 deletion emailsender/pkg/api/emailhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type MockEmailSender struct {
SendFunc func(ctx context.Context, to []string, rawMessage []byte) error
}

func (m *MockEmailSender) Send(ctx context.Context, to []string, rawMessage []byte) error {
func (m *MockEmailSender) Send(ctx context.Context, to []string, rawMessage []byte, tenantID string) error {
return m.SendFunc(ctx, to, rawMessage)
}

Expand Down
29 changes: 17 additions & 12 deletions emailsender/pkg/email/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,47 @@ import (
"bytes"
"context"
"fmt"

"github.com/golang/glog"
)

const fromTemplate = "From: RHACS Cloud Service <%s>\r\n"

// Sender defines the interface to send emails
type Sender interface {
Send(ctx context.Context, to []string, rawMessage []byte) error
Send(ctx context.Context, to []string, rawMessage []byte, tenantID string) error
}

// MailSender is the default implementation for the Sender interface
type MailSender struct {
from string
ses *SES
from string
ses *SES
rateLimiter RateLimiter
}

// NewEmailSender returns a new MailSender instance
func NewEmailSender(from string, ses *SES) *MailSender {
func NewEmailSender(from string, ses *SES, rateLimiter RateLimiter) *MailSender {
return &MailSender{
from: from,
ses: ses,
from: from,
ses: ses,
rateLimiter: rateLimiter,
}
}

// Send sends an email to the given AWS SES
func (s *MailSender) Send(ctx context.Context, to []string, rawMessage []byte) error {
// Even though AWS adds the from handler we need to set it the the message to show
// an alias in email inboxes that is more human friendly ([email protected] vs. RHACS Cloud Service)
func (s *MailSender) Send(ctx context.Context, to []string, rawMessage []byte, tenantID string) error {
// Even though AWS adds the "from" handler we need to set it to the message to show
// an alias in email inboxes. It is more human friendly ([email protected] vs. RHACS Cloud Service)
if !s.rateLimiter.IsAllowed(tenantID) {
return fmt.Errorf("rate limit exceeded for tenant: %s", tenantID)
}
fromBytes := []byte(fmt.Sprintf(fromTemplate, s.from))
raw := bytes.Join([][]byte{fromBytes, rawMessage}, nil)
_, err := s.ses.SendRawEmail(ctx, s.from, to, raw)
if err != nil {
glog.Errorf("Failed sending email: %v", err)
return fmt.Errorf("failed to send email: %v", err)
}
if err = s.rateLimiter.PersistEmailSendEvent(tenantID); err != nil {
return fmt.Errorf("failed to store email sent event for teantnt %s: %v", tenantID, err)
}

return nil
}
56 changes: 55 additions & 1 deletion emailsender/pkg/email/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@ import (
"github.com/aws/aws-sdk-go-v2/service/ses"
)

type MockedRateLimiter struct {
calledIsAllowed bool
calledPersistEmailSendEvent bool

IsAllowedFunc func(tenantID string) bool
PersistEmailSendEventFunc func(tenantID string) error
}

func (m *MockedRateLimiter) IsAllowed(tenantID string) bool {
m.calledIsAllowed = true
return m.IsAllowedFunc(tenantID)
}

func (m *MockedRateLimiter) PersistEmailSendEvent(tenantID string) error {
m.calledPersistEmailSendEvent = true
return m.PersistEmailSendEventFunc(tenantID)
}

func TestSend_Success(t *testing.T) {
from := "[email protected]"
to := []string{"[email protected]", "[email protected]"}
Expand All @@ -21,6 +39,7 @@ func TestSend_Success(t *testing.T) {
messageBuf.WriteString(textBody)
rawMessage := messageBuf.Bytes()
called := false
tenantID := "test-tenant-id"

mockClient := &MockSESClient{
SendRawEmailFunc: func(ctx context.Context, params *ses.SendRawEmailInput, optFns ...func(*ses.Options)) (*ses.SendRawEmailOutput, error) {
Expand All @@ -30,14 +49,49 @@ func TestSend_Success(t *testing.T) {
}, nil
},
}
mockedRateLimiter := &MockedRateLimiter{
IsAllowedFunc: func(tenantID string) bool {
return true
},
PersistEmailSendEventFunc: func(tenantID string) error {
return nil
},
}
mockedSES := &SES{sesClient: mockClient}
mockedSender := MailSender{
from,
mockedSES,
mockedRateLimiter,
}

err := mockedSender.Send(context.Background(), to, rawMessage)
err := mockedSender.Send(context.Background(), to, rawMessage, tenantID)

assert.NoError(t, err)
assert.True(t, called)
assert.True(t, mockedRateLimiter.calledIsAllowed)
assert.True(t, mockedRateLimiter.calledPersistEmailSendEvent)
}

func TestSend_LimitExceeded(t *testing.T) {
var messageBuf bytes.Buffer
rawMessage := messageBuf.Bytes()

mockClient := &MockSESClient{}
mockedRateLimiter := &MockedRateLimiter{
IsAllowedFunc: func(tenantID string) bool {
return false
},
}
mockedSES := &SES{sesClient: mockClient}
mockedSender := MailSender{
"[email protected]",
mockedSES,
mockedRateLimiter,
}

err := mockedSender.Send(context.Background(), []string{"[email protected]"}, rawMessage, "test-tenant-id")

assert.ErrorContains(t, err, "rate limit exceeded")
assert.True(t, mockedRateLimiter.calledIsAllowed)
assert.False(t, mockedRateLimiter.calledPersistEmailSendEvent)
}

0 comments on commit 0af0262

Please sign in to comment.