Skip to content

Commit

Permalink
Agressively close redis connections
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Nov 6, 2024
1 parent 8af0bd8 commit 5c8ee9a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 41 deletions.
9 changes: 6 additions & 3 deletions handlers/firebase/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ func (h *handler) sendWithCredsJSON(msg courier.MsgOut, res *courier.SendResult,
}

func (h *handler) getAccessToken(channel courier.Channel) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

rc := h.Backend().RedisPool().Get()
token, err := redis.String(rc.Do("GET", tokenKey))
rc.Close()

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -360,7 +360,10 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

rc = h.Backend().RedisPool().Get()
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
rc.Close()

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions handlers/hormuud/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg c

// we got a token, cache it to redis with an expiration from the response(we default to 60 minutes)
rc = h.Backend().RedisPool().Get()
defer rc.Close()

_, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token)
rc.Close()

if err != nil {
slog.Error("error caching HM access token", "error", err)
}
Expand Down
21 changes: 12 additions & 9 deletions handlers/jiochat/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return courier.ErrChannelConfig
}
Expand Down Expand Up @@ -198,7 +198,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen

// DescribeURN handles Jiochat contact details
func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,
return nil, err
}

accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand All @@ -250,16 +250,16 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,

var _ courier.AttachmentRequestBuilder = (*handler)(nil)

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

rc := h.Backend().RedisPool().Get()
token, err := redis.String(rc.Do("GET", tokenKey))
rc.Close()

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -268,12 +268,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

rc = h.Backend().RedisPool().Get()
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
rc.Close()

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -288,7 +291,7 @@ type fetchPayload struct {
}

// fetchAccessToken tries to fetch a new token for our channel
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
tokenURL, _ := url.Parse(fmt.Sprintf("%s/%s", sendURL, "auth/token.action"))
payload := &fetchPayload{
GrantType: "client_credentials",
Expand Down
17 changes: 10 additions & 7 deletions handlers/mtn/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return courier.ErrChannelConfig
}
Expand Down Expand Up @@ -175,16 +175,16 @@ func (h *handler) RedactValues(ch courier.Channel) []string {
}
}

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

rc := h.Backend().RedisPool().Get()
token, err := redis.String(rc.Do("GET", tokenKey))
rc.Close()

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -193,12 +193,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

rc = h.Backend().RedisPool().Get()
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
rc.Close()

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -207,7 +210,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
}

// fetchAccessToken tries to fetch a new token for our channel, setting the result in redis
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
form := url.Values{
"client_id": []string{channel.StringConfigForKey(courier.ConfigAPIKey, "")},
"client_secret": []string{channel.StringConfigForKey(courier.ConfigAuthToken, "")},
Expand Down
21 changes: 12 additions & 9 deletions handlers/wechat/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return err
}
Expand Down Expand Up @@ -216,7 +216,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen

// DescribeURN handles WeChat contact details
func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -255,7 +255,7 @@ func (h *handler) RedactValues(ch courier.Channel) []string {

// BuildAttachmentRequest download media for message attachment
func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, channel courier.Channel, attachmentURL string, clog *courier.ChannelLog) (*http.Request, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand All @@ -275,16 +275,16 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,

var _ courier.AttachmentRequestBuilder = (*handler)(nil)

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

rc := h.Backend().RedisPool().Get()
token, err := redis.String(rc.Do("GET", tokenKey))
rc.Close()

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -293,12 +293,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

rc = h.Backend().RedisPool().Get()
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
rc.Close()

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -307,7 +310,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
}

// fetchAccessToken tries to fetch a new token for our channel, setting the result in redis
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
form := url.Values{
"grant_type": []string{"client_credential"},
"appid": []string{channel.StringConfigForKey(configAppID, "")},
Expand Down
28 changes: 17 additions & 11 deletions handlers/whatsapp_legacy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"time"

"github.com/buger/jsonparser"
"github.com/gomodule/redigo/redis"
"github.com/nyaruka/courier"
"github.com/nyaruka/courier/handlers"
"github.com/nyaruka/courier/utils"
Expand Down Expand Up @@ -495,9 +494,6 @@ type mtErrorPayload struct {
const maxMsgLength = 4096

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
conn := h.Backend().RedisPool().Get()
defer conn.Close()

// get our token
token := msg.Channel().StringConfigForKey(courier.ConfigAuthToken, "")
urlStr := msg.Channel().StringConfigForKey(courier.ConfigBaseURL, "")
Expand All @@ -519,7 +515,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen

for _, payload := range payloads {
externalID := ""
wppID, externalID, err = h.sendWhatsAppMsg(conn, msg, sendPath, payload, clog)
wppID, externalID, err = h.sendWhatsAppMsg(msg, sendPath, payload, clog)
if err != nil {
return err
}
Expand Down Expand Up @@ -562,7 +558,7 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([]
for attachmentCount, attachment := range msg.Attachments() {

mimeType, mediaURL := handlers.SplitAttachment(attachment)
mediaID, err := h.fetchMediaID(msg, mimeType, mediaURL, clog)
mediaID, err := h.fetchMediaID(msg, mediaURL, clog)
if err != nil {
slog.Error("error while uploading media to whatsapp", "error", err, "channel_uuid", msg.Channel().UUID())
}
Expand Down Expand Up @@ -817,14 +813,15 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([]
}

// fetchMediaID tries to fetch the id for the uploaded media, setting the result in redis.
func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, clog *courier.ChannelLog) (string, error) {
func (h *handler) fetchMediaID(msg courier.MsgOut, mediaURL string, clog *courier.ChannelLog) (string, error) {
// check in cache first
rc := h.Backend().RedisPool().Get()
defer rc.Close()

cacheKey := fmt.Sprintf(mediaCacheKeyPattern, msg.Channel().UUID())
mediaCache := redisx.NewIntervalHash(cacheKey, time.Hour*24, 2)

rc := h.Backend().RedisPool().Get()
mediaID, err := mediaCache.Get(rc, mediaURL)
rc.Close()

if err != nil {
return "", fmt.Errorf("error reading media id from redis: %s : %s: %w", cacheKey, mediaURL, err)
} else if mediaID != "" {
Expand Down Expand Up @@ -885,15 +882,18 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, cl
}

// put in cache
rc = h.Backend().RedisPool().Get()
err = mediaCache.Set(rc, mediaURL, mediaID)
rc.Close()

if err != nil {
return "", fmt.Errorf("error setting media id in cache: %w", err)
}

return mediaID, nil
}

func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) {
func (h *handler) sendWhatsAppMsg(msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) {
jsonBody := jsonx.MustMarshal(payload)

req, _ := http.NewRequest(http.MethodPost, sendPath.String(), bytes.NewReader(jsonBody))
Expand All @@ -906,12 +906,15 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u

if resp != nil && (resp.StatusCode == 429 || resp.StatusCode == 503) {
rateLimitKey := fmt.Sprintf("rate_limit:%s", msg.Channel().UUID())

rc := h.Backend().RedisPool().Get()
rc.Do("SET", rateLimitKey, "engaged")

// The rate limit is 50 requests per second
// We pause sending 2 seconds so the limit count is reset
// TODO: In the future we should the header value when available
rc.Do("EXPIRE", rateLimitKey, 2)
rc.Close()

return "", "", courier.ErrConnectionThrottled
}
Expand All @@ -923,11 +926,14 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u
if err == nil && len(errPayload.Errors) > 0 {
if hasTiersError(*errPayload) {
rateLimitBulkKey := fmt.Sprintf("rate_limit_bulk:%s", msg.Channel().UUID())

rc := h.Backend().RedisPool().Get()
rc.Do("SET", rateLimitBulkKey, "engaged")

// The WA tiers spam rate limit hit
// We pause the bulk queue for 24 hours and 5min
rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60))
rc.Close()

return "", "", courier.ErrConnectionThrottled
}
Expand Down

0 comments on commit 5c8ee9a

Please sign in to comment.