Skip to content

Commit

Permalink
[v13] Add CORS headers to the tsh login callback (#39721)
Browse files Browse the repository at this point in the history
* Add CORS headers to the tsh login callback

* Use wildcard origin, remove draft header
  • Loading branch information
espadolini authored Mar 22, 2024
1 parent 584cede commit 2b337fa
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions lib/client/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,14 @@ func (rd *Redirector) callback(w http.ResponseWriter, r *http.Request) (*auth.SS
return nil, trace.NotFound("path not found")
}

if r.URL.Query().Has("err") {
err := r.URL.Query().Get("err")
r.ParseForm()
if r.Form.Has("err") {
err := r.Form.Get("err")
return nil, trace.Errorf("identity provider callback failed with error: %v", err)
}

// Decrypt ciphertext to get login response.
plaintext, err := rd.key.Open([]byte(r.URL.Query().Get("response")))
plaintext, err := rd.key.Open([]byte(r.Form.Get("response")))
if err != nil {
return nil, trace.BadParameter("failed to decrypt response: in %v, err: %v", r.URL.String(), err)
}
Expand Down Expand Up @@ -275,6 +276,24 @@ func (rd *Redirector) wrapCallback(fn func(http.ResponseWriter, *http.Request) (
successURL := clone.String()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Allow", "GET, OPTIONS, POST")
// CORS protects the _response_, and our response is always just a
// redirect to info/login_success or error/login so it's fine to share
// with the world; we could use the proxy URL as the origin, but that
// would break setups where the proxy public address that tsh is using
// is not the "main" one that ends up being used for the redirect after
// the IdP login
w.Header().Add("Access-Control-Allow-Origin", "*")
switch r.Method {
default:
w.WriteHeader(http.StatusMethodNotAllowed)
return
case http.MethodOptions:
w.WriteHeader(http.StatusOK)
return
case http.MethodGet, http.MethodPost:
}

response, err := fn(w, r)
if err != nil {
if trace.IsNotFound(err) {
Expand All @@ -284,18 +303,15 @@ func (rd *Redirector) wrapCallback(fn func(http.ResponseWriter, *http.Request) (
select {
case rd.errorC <- err:
case <-rd.context.Done():
http.Redirect(w, r, errorURL, http.StatusFound)
return
}
http.Redirect(w, r, errorURL, http.StatusFound)
return
}
select {
case rd.responseC <- response:
http.Redirect(w, r, successURL, http.StatusFound)
case <-rd.context.Done():
http.Redirect(w, r, errorURL, http.StatusFound)
return
}
http.Redirect(w, r, successURL, http.StatusFound)
})
}

0 comments on commit 2b337fa

Please sign in to comment.