Skip to content

Commit

Permalink
Config.StatusCode, ignore non 202 or 404 response
Browse files Browse the repository at this point in the history
  • Loading branch information
mcuadros committed May 21, 2021
1 parent 7e2ab11 commit 286e58c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
20 changes: 18 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type Config struct {
TTL time.Duration `default:"1m"`
// Methods methods to be cached.
Methods []string `default:"[GET]"`
// StatusCode method to be cached.
StatusCode []int `default:"[200,404]"`
// IgnoreQuery if true the Query values from the requests are ignored on
// the key generation.
IgnoreQuery bool
Expand All @@ -32,7 +34,6 @@ func New(cfg *Config, cache *freecache.Cache) echo.MiddlewareFunc {
}

defaults.SetDefaults(cfg)

m := &CacheMiddleware{cfg: cfg, cache: cache}
return m.Handler
}
Expand Down Expand Up @@ -93,14 +94,29 @@ func (m *CacheMiddleware) readCache(key []byte, c echo.Context) error {
}

func (m *CacheMiddleware) cacheResult(key []byte, r *ResponseRecorder) error {
b, err := r.Result().Encode()
e := r.Result()
b, err := e.Encode()
if err != nil {
return fmt.Errorf("unable to read recorded response: %s", err)
}

if !m.isStatusCacheable(e) {
return nil
}

return m.cache.Set(key, b, int(m.cfg.TTL.Seconds()))
}

func (m *CacheMiddleware) isStatusCacheable(e *CacheEntry) bool {
for _, status := range m.cfg.StatusCode {
if e.StatusCode == status {
return true
}
}

return false
}

func (m *CacheMiddleware) isCacheable(r *http.Request) bool {
if m.cfg.Cache != nil {
return m.cfg.Cache(r)
Expand Down
19 changes: 18 additions & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,30 @@ func TestCache_Methods(t *testing.T) {
assertRequest(t, resp, http.StatusOK, "test_4")
}

func TestCache_StatusCode(t *testing.T) {
client := getCachedServerWithCode(t, &Config{StatusCode: []int{200, 404}}, http.StatusInternalServerError)
defer client.Close()

resp, err := http.Get(client.URL)
assert.NoError(t, err)
assertRequest(t, resp, http.StatusOK, "test_1")

resp, err = http.Get(client.URL)
assert.NoError(t, err)
assertRequest(t, resp, http.StatusOK, "test_2")
}

func getCachedServer(t *testing.T, cfg *Config) *httptest.Server {
return getCachedServerWithCode(t, cfg, http.StatusOK)
}

func getCachedServerWithCode(t *testing.T, cfg *Config, status int) *httptest.Server {
e := echo.New()

var i int
h := New(cfg, freecache.NewCache(42*1024*1024))(func(c echo.Context) error {
i++
return c.String(http.StatusOK, fmt.Sprintf("test_%d", i))
return c.String(status, fmt.Sprintf("test_%d", i))
})

return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
1 change: 0 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func (w *ResponseRecorder) WriteHeader(statusCode int) {

func (r *ResponseRecorder) Result() *CacheEntry {
r.copyHeaders()
r.ResponseWriter = nil

return &CacheEntry{
Header: r.headers,
Expand Down

0 comments on commit 286e58c

Please sign in to comment.