From 93baef713dbe0d61959ac8fe9b21b68b6751f0c8 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Tue, 25 Oct 2022 15:00:38 +0300 Subject: [PATCH 01/10] f --- example/main.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/example/main.go b/example/main.go index fbd1863f..15f2aa90 100644 --- a/example/main.go +++ b/example/main.go @@ -208,6 +208,7 @@ func (m myAPI) login() http.HandlerFunc { reqL := m.l.WithCtx(r.Context()) if r.Method != http.MethodPost { + // Show html page with login form. data := struct { CsrfTokenName string CsrfTokenValue string @@ -223,6 +224,25 @@ func (m myAPI) login() http.HandlerFunc { } return } + refreshCookie := "ong_refresh_akja_cookie" + cookieOne, err := r.Cookie(refreshCookie) + if cookieOne != nil && cookieOne.Name == refreshCookie { + cookie.Delete(w, refreshCookie, "localhost") + reqL.WithImmediate().Info(log.F{"msg": "already submitted", "cookieOne": cookieOne}) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } + + defer func() { + cookie.Set(w, refreshCookie, "YES", "localhost", 1*time.Hour, false) + }() if err = r.ParseForm(); err != nil { panic(err) From 8a84e2451fc9bdf7fa8045f6837e0c4240c26a7e Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 14:04:40 +0300 Subject: [PATCH 02/10] f --- example/main.go | 20 --------- middleware/reload_protect.go | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 20 deletions(-) create mode 100644 middleware/reload_protect.go diff --git a/example/main.go b/example/main.go index 15f2aa90..fbd1863f 100644 --- a/example/main.go +++ b/example/main.go @@ -208,7 +208,6 @@ func (m myAPI) login() http.HandlerFunc { reqL := m.l.WithCtx(r.Context()) if r.Method != http.MethodPost { - // Show html page with login form. data := struct { CsrfTokenName string CsrfTokenValue string @@ -224,25 +223,6 @@ func (m myAPI) login() http.HandlerFunc { } return } - refreshCookie := "ong_refresh_akja_cookie" - cookieOne, err := r.Cookie(refreshCookie) - if cookieOne != nil && cookieOne.Name == refreshCookie { - cookie.Delete(w, refreshCookie, "localhost") - reqL.WithImmediate().Info(log.F{"msg": "already submitted", "cookieOne": cookieOne}) - http.Redirect( - w, - r, - r.URL.String(), - // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 - http.StatusSeeOther, - ) - return - } - - defer func() { - cookie.Set(w, refreshCookie, "YES", "localhost", 1*time.Hour, false) - }() if err = r.ParseForm(); err != nil { panic(err) diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go new file mode 100644 index 00000000..8498b5a7 --- /dev/null +++ b/middleware/reload_protect.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/komuw/ong/cookie" + "golang.org/x/exp/slices" +) + +// TODO: docs. +// ReloadProtect blah against Form blah +func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { + safeMethods := []string{ + // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodTrace, + } + return func(w http.ResponseWriter, r *http.Request) { + theCookie := fmt.Sprintf( + "ong_form_reload_protect-%s", + strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), + ) + + if !slices.Contains(safeMethods, r.Method) { + // This could be a http POST/DELETE/etc + defer func() { + cookie.Set( + w, + theCookie, + "YES", + domain, + 3*time.Hour, + false, + ) + }() + + gotCookie, _ := r.Cookie(theCookie) + if gotCookie != nil { + // It means that the form had been submitted before. + + cookie.Delete( + w, + theCookie, + domain, + ) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } + } + + // // TODO: check if request method is safe + + // ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) + // if err == nil && (ct == formUrlEncoded || ct == multiformData) { + // // For POST requests that; + // // - are not form data. + // // - have no cookies. + // // - are not using http authentication. + // // then it is okay to not validate csrf for them. + // // This is especially useful for REST API endpoints. + // // see: https://github.com/komuw/ong/issues/76 + // break + // } + + wrappedHandler(w, r) + } +} From 7f60cdab34dcb2e9c069c62ced3893e0d156fd2a Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:17:11 +0300 Subject: [PATCH 03/10] f --- middleware/reload_protect.go | 41 ++++---- middleware/reload_protect_test.go | 152 ++++++++++++++++++++++++++++++ middleware/session_test.go | 8 +- 3 files changed, 181 insertions(+), 20 deletions(-) create mode 100644 middleware/reload_protect_test.go diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index 8498b5a7..9238a9fa 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -10,6 +10,8 @@ import ( "golang.org/x/exp/slices" ) +const reloadProtectCookiePrefix = "ong_form_reload_protect" + // TODO: docs. // ReloadProtect blah against Form blah func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { @@ -21,26 +23,22 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF http.MethodTrace, } return func(w http.ResponseWriter, r *http.Request) { - theCookie := fmt.Sprintf( - "ong_form_reload_protect-%s", - strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), - ) - + // It is possible for one to send a form without having added the requiste form http header. if !slices.Contains(safeMethods, r.Method) { // This could be a http POST/DELETE/etc - defer func() { - cookie.Set( - w, - theCookie, - "YES", - domain, - 3*time.Hour, - false, - ) - }() - gotCookie, _ := r.Cookie(theCookie) + theCookie := fmt.Sprintf("%s-%s", + reloadProtectCookiePrefix, + strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), + ) + + gotCookie, err := r.Cookie(theCookie) if gotCookie != nil { + fmt.Println("\t gotCookie.MaxAge: ", gotCookie.MaxAge, " :: ", gotCookie) + } + + // TODO: && gotCookie.MaxAge > 0 + if err == nil && gotCookie != nil { // It means that the form had been submitted before. cookie.Delete( @@ -57,6 +55,16 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF http.StatusSeeOther, ) return + } else { + fmt.Println("setting cookie.") + cookie.Set( + w, + theCookie, + "YES", + domain, + 1*time.Hour, + false, + ) } } @@ -74,6 +82,7 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF // break // } + fmt.Println("\t handler called.....") wrappedHandler(w, r) } } diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go new file mode 100644 index 00000000..36b593cc --- /dev/null +++ b/middleware/reload_protect_test.go @@ -0,0 +1,152 @@ +package middleware + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/akshayjshah/attest" +) + +func someReloadProtectHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc { + // count is state that is affected by form submission. + // eg, when a form is submitted; we create a new user. + count := 0 + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + err := r.ParseForm() + if err != nil { + panic(err) + } + val := r.Form.Get(expectedFormName) + if val != expectedFormValue { + panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val)) + } + + count = count + 1 + if count > 1 { + // form re-submission happened + panic("form re-submission happened") + } + } + + fmt.Fprint(w, msg) + } +} + +func TestReloadProtect(t *testing.T) { + t.Parallel() + + t.Run("middleware succeds", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "localhost" + expectedFormName := "user_name" + expectedFormValue := "John Doe" + wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + }) + + t.Run("todo", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "localhost" + expectedFormName := "user_name" + expectedFormValue := "John Doe" + wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + + req := httptest.NewRequest(http.MethodPost, "/someUri", nil) + err := req.ParseForm() + attest.Ok(t, err) + req.Form.Add(expectedFormName, expectedFormValue) + + var addedCookie *http.Cookie + { + // first form submission + rec := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + + attest.Equal(t, len(res.Cookies()), 1) + attest.Subsequence(t, res.Cookies()[0].Name, reloadProtectCookiePrefix) + addedCookie = res.Cookies()[0] + } + + { + // second form submission + req.AddCookie(addedCookie) + rec := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusSeeOther) + attest.Equal(t, string(rb), "") + attest.Equal(t, len(res.Cookies()), 0) + } + }) + + t.Run("concurrency safe", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "localhost" + expectedFormName := "user_name" + expectedFormValue := "John Doe" + wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + + runhandler := func() { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + } + + wg := &sync.WaitGroup{} + for rN := 0; rN <= 10; rN++ { + wg.Add(1) + go func() { + defer wg.Done() + runhandler() + }() + } + wg.Wait() + }) +} diff --git a/middleware/session_test.go b/middleware/session_test.go index abc96258..a09416e5 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -23,7 +23,7 @@ func bigMap() map[string]string { return y } -func someTestHandler(msg, key, value string) http.HandlerFunc { +func someSessionHandler(msg, key, value string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sess.Set(r, key, value) sess.SetM(r, bigMap()) @@ -42,7 +42,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain) + wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -66,7 +66,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain) + wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) ts := httptest.NewServer( wrappedHandler, @@ -115,7 +115,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someTestHandler(msg, key, value), secretKey, domain) + wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) runhandler := func() { rec := httptest.NewRecorder() From a5a86d28cf985a8a61b7f1de90abcab364bffef9 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:20:57 +0300 Subject: [PATCH 04/10] f --- middleware/reload_protect_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index 36b593cc..fb12cc7a 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -63,7 +63,7 @@ func TestReloadProtect(t *testing.T) { attest.Equal(t, string(rb), msg) }) - t.Run("todo", func(t *testing.T) { + t.Run("re-submission protected", func(t *testing.T) { t.Parallel() msg := "hello" @@ -112,6 +112,7 @@ func TestReloadProtect(t *testing.T) { attest.Equal(t, res.StatusCode, http.StatusSeeOther) attest.Equal(t, string(rb), "") attest.Equal(t, len(res.Cookies()), 0) + addedCookie = nil } }) From 69a21671699e48c90c93aa08e0696efc275fe326 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:23:36 +0300 Subject: [PATCH 05/10] f --- middleware/reload_protect.go | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index 9238a9fa..3a021fc1 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -32,12 +32,8 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), ) + // todo: should we check if gotCookie.MaxAge > 0 gotCookie, err := r.Cookie(theCookie) - if gotCookie != nil { - fmt.Println("\t gotCookie.MaxAge: ", gotCookie.MaxAge, " :: ", gotCookie) - } - - // TODO: && gotCookie.MaxAge > 0 if err == nil && gotCookie != nil { // It means that the form had been submitted before. @@ -56,7 +52,6 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF ) return } else { - fmt.Println("setting cookie.") cookie.Set( w, theCookie, @@ -68,21 +63,6 @@ func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerF } } - // // TODO: check if request method is safe - - // ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) - // if err == nil && (ct == formUrlEncoded || ct == multiformData) { - // // For POST requests that; - // // - are not form data. - // // - have no cookies. - // // - are not using http authentication. - // // then it is okay to not validate csrf for them. - // // This is especially useful for REST API endpoints. - // // see: https://github.com/komuw/ong/issues/76 - // break - // } - - fmt.Println("\t handler called.....") wrappedHandler(w, r) } } From bdbab07117022218d825a77eef3c194ed376ba00 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:28:51 +0300 Subject: [PATCH 06/10] f --- middleware/middleware.go | 26 +++++++++++++++----------- middleware/reload_protect.go | 4 ++-- middleware/reload_protect_test.go | 10 +++++----- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 90c7c482..2fa0103f 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -90,9 +90,10 @@ func allDefaultMiddlewares( // 7. Cors since we might get pre-flight requests and we don't want those to go through all the middlewares for performance reasons. // 8. Csrf since this one is a bit more involved perf-wise. // 9. Gzip since it is very involved perf-wise. - // 10. Session since we want sessions to saved as soon as possible. + // 10. ReloadProtector, ideally I feel like it should come earlier but I'm yet to figure out where. + // 11. Session since we want sessions to saved as soon as possible. // - // user -> Panic -> Log -> RateLimiter -> LoadShedder -> HttpsRedirector -> SecurityHeaders -> Cors -> Csrf -> Gzip -> Session -> actual-handler + // user -> Panic -> Log -> RateLimiter -> LoadShedder -> HttpsRedirector -> SecurityHeaders -> Cors -> Csrf -> Gzip -> ReloadProtector -> Session -> actual-handler // We have disabled Gzip for now, since it is about 2.5times slower than no-gzip for a 50MB sample response. // see: https://github.com/komuw/ong/issues/85 @@ -105,9 +106,12 @@ func allDefaultMiddlewares( SecurityHeaders( Cors( Csrf( - Session( - wrappedHandler, - secretKey, + ReloadProtector( + Session( + wrappedHandler, + secretKey, + domain, + ), domain, ), secretKey, @@ -133,7 +137,7 @@ func allDefaultMiddlewares( // All is a middleware that allows all http methods. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func All(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -150,7 +154,7 @@ func all(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Get is a middleware that only allows http GET requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Get(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -181,7 +185,7 @@ func get(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Post is a middleware that only allows http POST requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Post(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -210,7 +214,7 @@ func post(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Head is a middleware that only allows http HEAD requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Head(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -239,7 +243,7 @@ func head(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Put is a middleware that only allows http PUT requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Put(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -268,7 +272,7 @@ func put(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Delete is a middleware that only allows http DELETE requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Session] & [Csrf] middleware. +// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Delete(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index 3a021fc1..49985f37 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -13,8 +13,8 @@ import ( const reloadProtectCookiePrefix = "ong_form_reload_protect" // TODO: docs. -// ReloadProtect blah against Form blah -func ReloadProtect(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { +// ReloadProtector blah against Form blah +func ReloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { safeMethods := []string{ // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 http.MethodGet, diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index fb12cc7a..9b85ae44 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -11,7 +11,7 @@ import ( "github.com/akshayjshah/attest" ) -func someReloadProtectHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc { +func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc { // count is state that is affected by form submission. // eg, when a form is submitted; we create a new user. count := 0 @@ -37,7 +37,7 @@ func someReloadProtectHandler(msg, expectedFormName, expectedFormValue string) h } } -func TestReloadProtect(t *testing.T) { +func TestReloadProtector(t *testing.T) { t.Parallel() t.Run("middleware succeds", func(t *testing.T) { @@ -47,7 +47,7 @@ func TestReloadProtect(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -70,7 +70,7 @@ func TestReloadProtect(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) req := httptest.NewRequest(http.MethodPost, "/someUri", nil) err := req.ParseForm() @@ -123,7 +123,7 @@ func TestReloadProtect(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtect(someReloadProtectHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) runhandler := func() { rec := httptest.NewRecorder() From 21934c95651f9f282330f8f6e36fd7c6b959e3ad Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:34:45 +0300 Subject: [PATCH 07/10] f --- middleware/reload_protect.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index 49985f37..d025e4bb 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -12,8 +12,9 @@ import ( const reloadProtectCookiePrefix = "ong_form_reload_protect" -// TODO: docs. -// ReloadProtector blah against Form blah +// ReloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form. +// +// If such a situation is detected; this middleware will issue a http GET redirect to the same url. func ReloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { safeMethods := []string{ // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 From 73296753cda7540fdb2e917d113dacb3e1aed9de Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:36:08 +0300 Subject: [PATCH 08/10] f --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6494a1a2..3b1591cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Most recent version is listed first. ## v0.0.24 - Set session cookie only if non-empty: https://github.com/komuw/ong/pull/170 +- Add ReloadProtector middleware: https://github.com/komuw/ong/pull/171 ## v0.0.23 - ong/client: Add log id http header: https://github.com/komuw/ong/pull/166 From 27b439400448cfaf14463a5a31353b6c6729dcbb Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:37:23 +0300 Subject: [PATCH 09/10] f --- middleware/reload_protect_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index 9b85ae44..e2c0ad05 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -86,8 +86,8 @@ func TestReloadProtector(t *testing.T) { res := rec.Result() defer res.Body.Close() - rb, err := io.ReadAll(res.Body) - attest.Ok(t, err) + rb, errR := io.ReadAll(res.Body) + attest.Ok(t, errR) attest.Equal(t, res.StatusCode, http.StatusOK) attest.Equal(t, string(rb), msg) @@ -106,8 +106,8 @@ func TestReloadProtector(t *testing.T) { res := rec.Result() defer res.Body.Close() - rb, err := io.ReadAll(res.Body) - attest.Ok(t, err) + rb, errR := io.ReadAll(res.Body) + attest.Ok(t, errR) attest.Equal(t, res.StatusCode, http.StatusSeeOther) attest.Equal(t, string(rb), "") From 8d21749b003ad705055947938be61d5df01db76d Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 26 Oct 2022 15:40:39 +0300 Subject: [PATCH 10/10] f --- middleware/reload_protect_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index e2c0ad05..b2863033 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -112,7 +112,6 @@ func TestReloadProtector(t *testing.T) { attest.Equal(t, res.StatusCode, http.StatusSeeOther) attest.Equal(t, string(rb), "") attest.Equal(t, len(res.Cookies()), 0) - addedCookie = nil } })