Skip to content

Commit

Permalink
Merge pull request #800 from nyaruka/redis_connections
Browse files Browse the repository at this point in the history
Aggressively close Redis connections
  • Loading branch information
rowanseymour authored Nov 6, 2024
2 parents 8af0bd8 + 9c173f8 commit f727f99
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 67 deletions.
8 changes: 8 additions & 0 deletions handlers/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"

"github.com/gomodule/redigo/redis"
"github.com/nyaruka/courier"
"github.com/nyaruka/gocommon/httpx"
)
Expand Down Expand Up @@ -148,3 +149,10 @@ func (h *BaseHandler) WriteRequestError(ctx context.Context, w http.ResponseWrit
func (h *BaseHandler) WriteRequestIgnored(ctx context.Context, w http.ResponseWriter, details string) error {
return courier.WriteIgnored(w, details)
}

// WithRedisConn is a utility to execute some code with a redis connection
func (h *BaseHandler) WithRedisConn(fn func(rc redis.Conn)) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()
fn(rc)
}
15 changes: 10 additions & 5 deletions handlers/firebase/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,17 @@ func (h *handler) sendWithCredsJSON(msg courier.MsgOut, res *courier.SendResult,
}

func (h *handler) getAccessToken(channel courier.Channel) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

token, err := redis.String(rc.Do("GET", tokenKey))
var token string
var err error
h.WithRedisConn(func(rc redis.Conn) {
token, err = redis.String(rc.Do("GET", tokenKey))
})

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -360,7 +362,10 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
h.WithRedisConn(func(rc redis.Conn) {
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
})

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand Down
20 changes: 10 additions & 10 deletions handlers/hormuud/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen
// FetchToken gets the current token for this channel, either from Redis if cached or by requesting it
func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg courier.MsgOut, username, password string, clog *courier.ChannelLog) (string, error) {
// first check whether we have it in redis
rc := h.Backend().RedisPool().Get()
token, _ := redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID())))
rc.Close()
var token string
h.WithRedisConn(func(rc redis.Conn) {
token, _ = redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID())))
})

// got a token, use it
if token != "" {
Expand Down Expand Up @@ -172,13 +173,12 @@ func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg c
}

// we got a token, cache it to redis with an expiration from the response(we default to 60 minutes)
rc = h.Backend().RedisPool().Get()
defer rc.Close()

_, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token)
if err != nil {
slog.Error("error caching HM access token", "error", err)
}
h.WithRedisConn(func(rc redis.Conn) {
_, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token)
if err != nil {
slog.Error("error caching HM access token", "error", err)
}
})

return token, nil
}
27 changes: 16 additions & 11 deletions handlers/jiochat/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return courier.ErrChannelConfig
}
Expand Down Expand Up @@ -198,7 +198,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen

// DescribeURN handles Jiochat contact details
func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,
return nil, err
}

accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand All @@ -250,16 +250,18 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,

var _ courier.AttachmentRequestBuilder = (*handler)(nil)

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

token, err := redis.String(rc.Do("GET", tokenKey))
var token string
var err error
h.WithRedisConn(func(rc redis.Conn) {
token, err = redis.String(rc.Do("GET", tokenKey))
})

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -268,12 +270,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
h.WithRedisConn(func(rc redis.Conn) {
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
})

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -288,7 +293,7 @@ type fetchPayload struct {
}

// fetchAccessToken tries to fetch a new token for our channel
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
tokenURL, _ := url.Parse(fmt.Sprintf("%s/%s", sendURL, "auth/token.action"))
payload := &fetchPayload{
GrantType: "client_credentials",
Expand Down
23 changes: 14 additions & 9 deletions handlers/mtn/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return courier.ErrChannelConfig
}
Expand Down Expand Up @@ -175,16 +175,18 @@ func (h *handler) RedactValues(ch courier.Channel) []string {
}
}

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

token, err := redis.String(rc.Do("GET", tokenKey))
var token string
var err error
h.WithRedisConn(func(rc redis.Conn) {
token, err = redis.String(rc.Do("GET", tokenKey))
})

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -193,12 +195,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
h.WithRedisConn(func(rc redis.Conn) {
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
})

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -207,7 +212,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
}

// fetchAccessToken tries to fetch a new token for our channel, setting the result in redis
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
form := url.Values{
"client_id": []string{channel.StringConfigForKey(courier.ConfigAPIKey, "")},
"client_secret": []string{channel.StringConfigForKey(courier.ConfigAuthToken, "")},
Expand Down
27 changes: 16 additions & 11 deletions handlers/wechat/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ type mtPayload struct {
}

func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error {
accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog)
accessToken, err := h.getAccessToken(msg.Channel(), clog)
if err != nil {
return err
}
Expand Down Expand Up @@ -216,7 +216,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen

// DescribeURN handles WeChat contact details
func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -255,7 +255,7 @@ func (h *handler) RedactValues(ch courier.Channel) []string {

// BuildAttachmentRequest download media for message attachment
func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, channel courier.Channel, attachmentURL string, clog *courier.ChannelLog) (*http.Request, error) {
accessToken, err := h.getAccessToken(ctx, channel, clog)
accessToken, err := h.getAccessToken(channel, clog)
if err != nil {
return nil, err
}
Expand All @@ -275,16 +275,18 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend,

var _ courier.AttachmentRequestBuilder = (*handler)(nil)

func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) {
rc := h.Backend().RedisPool().Get()
defer rc.Close()

func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) {
tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID())

h.fetchTokenMutex.Lock()
defer h.fetchTokenMutex.Unlock()

token, err := redis.String(rc.Do("GET", tokenKey))
var token string
var err error
h.WithRedisConn(func(rc redis.Conn) {
token, err = redis.String(rc.Do("GET", tokenKey))
})

if err != nil && err != redis.ErrNil {
return "", fmt.Errorf("error reading cached access token: %w", err)
}
Expand All @@ -293,12 +295,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
return token, nil
}

token, expires, err := h.fetchAccessToken(ctx, channel, clog)
token, expires, err := h.fetchAccessToken(channel, clog)
if err != nil {
return "", fmt.Errorf("error fetching new access token: %w", err)
}

_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
h.WithRedisConn(func(rc redis.Conn) {
_, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second))
})

if err != nil {
return "", fmt.Errorf("error updating cached access token: %w", err)
}
Expand All @@ -307,7 +312,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c
}

// fetchAccessToken tries to fetch a new token for our channel, setting the result in redis
func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) {
form := url.Values{
"grant_type": []string{"client_credential"},
"appid": []string{channel.StringConfigForKey(configAppID, "")},
Expand Down
Loading

0 comments on commit f727f99

Please sign in to comment.