From 2b337fa19e8268d65d93de911ae1529b9eb6cf78 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 22 Mar 2024 20:33:27 +0100 Subject: [PATCH] [v13] Add CORS headers to the `tsh login` callback (#39721) * Add CORS headers to the tsh login callback * Use wildcard origin, remove draft header --- lib/client/redirect.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/lib/client/redirect.go b/lib/client/redirect.go index e5f3fb82fa388..3a59d83e3e66d 100644 --- a/lib/client/redirect.go +++ b/lib/client/redirect.go @@ -236,13 +236,14 @@ func (rd *Redirector) callback(w http.ResponseWriter, r *http.Request) (*auth.SS return nil, trace.NotFound("path not found") } - if r.URL.Query().Has("err") { - err := r.URL.Query().Get("err") + r.ParseForm() + if r.Form.Has("err") { + err := r.Form.Get("err") return nil, trace.Errorf("identity provider callback failed with error: %v", err) } // Decrypt ciphertext to get login response. - plaintext, err := rd.key.Open([]byte(r.URL.Query().Get("response"))) + plaintext, err := rd.key.Open([]byte(r.Form.Get("response"))) if err != nil { return nil, trace.BadParameter("failed to decrypt response: in %v, err: %v", r.URL.String(), err) } @@ -275,6 +276,24 @@ func (rd *Redirector) wrapCallback(fn func(http.ResponseWriter, *http.Request) ( successURL := clone.String() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Allow", "GET, OPTIONS, POST") + // CORS protects the _response_, and our response is always just a + // redirect to info/login_success or error/login so it's fine to share + // with the world; we could use the proxy URL as the origin, but that + // would break setups where the proxy public address that tsh is using + // is not the "main" one that ends up being used for the redirect after + // the IdP login + w.Header().Add("Access-Control-Allow-Origin", "*") + switch r.Method { + default: + w.WriteHeader(http.StatusMethodNotAllowed) + return + case http.MethodOptions: + w.WriteHeader(http.StatusOK) + return + case http.MethodGet, http.MethodPost: + } + response, err := fn(w, r) if err != nil { if trace.IsNotFound(err) { @@ -284,18 +303,15 @@ func (rd *Redirector) wrapCallback(fn func(http.ResponseWriter, *http.Request) ( select { case rd.errorC <- err: case <-rd.context.Done(): - http.Redirect(w, r, errorURL, http.StatusFound) - return } http.Redirect(w, r, errorURL, http.StatusFound) return } select { case rd.responseC <- response: + http.Redirect(w, r, successURL, http.StatusFound) case <-rd.context.Done(): http.Redirect(w, r, errorURL, http.StatusFound) - return } - http.Redirect(w, r, successURL, http.StatusFound) }) }