Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issues/168: Add ReloadProtector middleware #171

Merged
merged 11 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Most recent version is listed first.

## v0.0.24
- Set session cookie only if non-empty: https://github.com/komuw/ong/pull/170
- Add ReloadProtector middleware: https://github.com/komuw/ong/pull/171

## v0.0.23
- ong/client: Add log id http header: https://github.com/komuw/ong/pull/166
Expand Down
26 changes: 15 additions & 11 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ func allDefaultMiddlewares(
// 7. Cors since we might get pre-flight requests and we don't want those to go through all the middlewares for performance reasons.
// 8. Csrf since this one is a bit more involved perf-wise.
// 9. Gzip since it is very involved perf-wise.
// 10. Session since we want sessions to saved as soon as possible.
// 10. ReloadProtector, ideally I feel like it should come earlier but I'm yet to figure out where.
// 11. Session since we want sessions to saved as soon as possible.
//
// user -> Panic -> Log -> RateLimiter -> LoadShedder -> HttpsRedirector -> SecurityHeaders -> Cors -> Csrf -> Gzip -> Session -> actual-handler
// user -> Panic -> Log -> RateLimiter -> LoadShedder -> HttpsRedirector -> SecurityHeaders -> Cors -> Csrf -> Gzip -> ReloadProtector -> Session -> actual-handler

// We have disabled Gzip for now, since it is about 2.5times slower than no-gzip for a 50MB sample response.
// see: https://github.com/komuw/ong/issues/85
Expand All @@ -105,9 +106,12 @@ func allDefaultMiddlewares(
SecurityHeaders(
Cors(
Csrf(
Session(
wrappedHandler,
secretKey,
ReloadProtector(
Session(
wrappedHandler,
secretKey,
domain,
),
domain,
),
secretKey,
Expand All @@ -133,7 +137,7 @@ func allDefaultMiddlewares(

// All is a middleware that allows all http methods.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func All(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand All @@ -150,7 +154,7 @@ func all(wrappedHandler http.HandlerFunc) http.HandlerFunc {

// Get is a middleware that only allows http GET requests and http OPTIONS requests.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func Get(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand Down Expand Up @@ -181,7 +185,7 @@ func get(wrappedHandler http.HandlerFunc) http.HandlerFunc {

// Post is a middleware that only allows http POST requests and http OPTIONS requests.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func Post(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand Down Expand Up @@ -210,7 +214,7 @@ func post(wrappedHandler http.HandlerFunc) http.HandlerFunc {

// Head is a middleware that only allows http HEAD requests and http OPTIONS requests.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func Head(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand Down Expand Up @@ -239,7 +243,7 @@ func head(wrappedHandler http.HandlerFunc) http.HandlerFunc {

// Put is a middleware that only allows http PUT requests and http OPTIONS requests.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func Put(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand Down Expand Up @@ -268,7 +272,7 @@ func put(wrappedHandler http.HandlerFunc) http.HandlerFunc {

// Delete is a middleware that only allows http DELETE requests and http OPTIONS requests.
//
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware.
// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware.
// As such, it provides the features and functionalities of all those middlewares.
func Delete(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc {
return allDefaultMiddlewares(
Expand Down
69 changes: 69 additions & 0 deletions middleware/reload_protect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package middleware

import (
"fmt"
"net/http"
"strings"
"time"

"github.com/komuw/ong/cookie"
"golang.org/x/exp/slices"
)

const reloadProtectCookiePrefix = "ong_form_reload_protect"

// ReloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form.
//
// If such a situation is detected; this middleware will issue a http GET redirect to the same url.
func ReloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc {
safeMethods := []string{
// safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
http.MethodGet,
http.MethodHead,
http.MethodOptions,
http.MethodTrace,
}
return func(w http.ResponseWriter, r *http.Request) {
// It is possible for one to send a form without having added the requiste form http header.
if !slices.Contains(safeMethods, r.Method) {
// This could be a http POST/DELETE/etc

theCookie := fmt.Sprintf("%s-%s",
reloadProtectCookiePrefix,
strings.ReplaceAll(r.URL.EscapedPath(), "/", ""),
)

// todo: should we check if gotCookie.MaxAge > 0
gotCookie, err := r.Cookie(theCookie)
if err == nil && gotCookie != nil {
// It means that the form had been submitted before.

cookie.Delete(
w,
theCookie,
domain,
)
http.Redirect(
w,
r,
r.URL.String(),
// http 303(StatusSeeOther) is guaranteed by the spec to always use http GET.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303
http.StatusSeeOther,
)
return
} else {
cookie.Set(
w,
theCookie,
"YES",
domain,
1*time.Hour,
false,
)
}
}

wrappedHandler(w, r)
}
}
152 changes: 152 additions & 0 deletions middleware/reload_protect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package middleware

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/akshayjshah/attest"
)

func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc {
// count is state that is affected by form submission.
// eg, when a form is submitted; we create a new user.
count := 0
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
panic(err)
}
val := r.Form.Get(expectedFormName)
if val != expectedFormValue {
panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val))
}

count = count + 1
if count > 1 {
// form re-submission happened
panic("form re-submission happened")
}
}

fmt.Fprint(w, msg)
}
}

func TestReloadProtector(t *testing.T) {
t.Parallel()

t.Run("middleware succeds", func(t *testing.T) {
t.Parallel()

msg := "hello"
domain := "localhost"
expectedFormName := "user_name"
expectedFormValue := "John Doe"
wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
wrappedHandler.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()

rb, err := io.ReadAll(res.Body)
attest.Ok(t, err)

attest.Equal(t, res.StatusCode, http.StatusOK)
attest.Equal(t, string(rb), msg)
})

t.Run("re-submission protected", func(t *testing.T) {
t.Parallel()

msg := "hello"
domain := "localhost"
expectedFormName := "user_name"
expectedFormValue := "John Doe"
wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain)

req := httptest.NewRequest(http.MethodPost, "/someUri", nil)
err := req.ParseForm()
attest.Ok(t, err)
req.Form.Add(expectedFormName, expectedFormValue)

var addedCookie *http.Cookie
{
// first form submission
rec := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()

rb, errR := io.ReadAll(res.Body)
attest.Ok(t, errR)

attest.Equal(t, res.StatusCode, http.StatusOK)
attest.Equal(t, string(rb), msg)

attest.Equal(t, len(res.Cookies()), 1)
attest.Subsequence(t, res.Cookies()[0].Name, reloadProtectCookiePrefix)
addedCookie = res.Cookies()[0]
}

{
// second form submission
req.AddCookie(addedCookie)
rec := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()

rb, errR := io.ReadAll(res.Body)
attest.Ok(t, errR)

attest.Equal(t, res.StatusCode, http.StatusSeeOther)
attest.Equal(t, string(rb), "")
attest.Equal(t, len(res.Cookies()), 0)
}
})

t.Run("concurrency safe", func(t *testing.T) {
t.Parallel()

msg := "hello"
domain := "localhost"
expectedFormName := "user_name"
expectedFormValue := "John Doe"
wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain)

runhandler := func() {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
wrappedHandler.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()

rb, err := io.ReadAll(res.Body)
attest.Ok(t, err)

attest.Equal(t, res.StatusCode, http.StatusOK)
attest.Equal(t, string(rb), msg)
}

wg := &sync.WaitGroup{}
for rN := 0; rN <= 10; rN++ {
wg.Add(1)
go func() {
defer wg.Done()
runhandler()
}()
}
wg.Wait()
})
}
8 changes: 4 additions & 4 deletions middleware/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func bigMap() map[string]string {
return y
}

func someTestHandler(msg, key, value string) http.HandlerFunc {
func someSessionHandler(msg, key, value string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sess.Set(r, key, value)
sess.SetM(r, bigMap())
Expand All @@ -42,7 +42,7 @@ func TestSession(t *testing.T) {
domain := "localhost"
key := "name"
value := "John Doe"
wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain)
wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -66,7 +66,7 @@ func TestSession(t *testing.T) {
domain := "localhost"
key := "name"
value := "John Doe"
wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain)
wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain)

ts := httptest.NewServer(
wrappedHandler,
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestSession(t *testing.T) {
domain := "localhost"
key := "name"
value := "John Doe"
wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain)
wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain)

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down