Skip to content

Commit

Permalink
[v15] backport #37520 and #37981 (#38032)
Browse files Browse the repository at this point in the history
* Read the bearer token over websocket endpoints instead of query parameter (#37520)

* Read the bearer token over WS endpoints

use the request context, not session

Dont pass websocket by context

lint

resolve some comments

Add TestWSAuthenticateRequest

Close ws in handler

deprecation notices, doc

resolve comments

resolve comments

give a longer read/write deadline

dont set write deadline, ws endpoints never did before and it breaks things

convert frontend to use ws access token

Resolove comments, move to using an explicit state

fix ci

reset read deadline

prettier

* update connectToHost

* linter

* read errors from websocket

* missing /ws on ttyWsAddr and fix wrong onmessage

* fix race in test

* lint

* skip TestTerminal as it takes 11 seconds to run

* dont skip the test

* resolve apiserver comments

* Add an AuthenticatedWebSocket class

* convert other clients to use AuthenticatedWebSocket

* Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` (#37699)

* Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket`
that automatically goes through Teleport's custom authentication process
before facilitating any caller-defined communication.

This also reverts previous-`WebSocket` users to their original state
(sans the code for passing the bearer token in the query string),
swapping in `AuthenticatedWebSocket` in place of `WebSocket`.

* Create a single authnWsUpgrader with a comment justifying why we turn off CORS

* recieving to receiving

* resolve comments

---------

Co-authored-by: Isaiah Becker-Mayer <[email protected]>

* Updates `desktopPlaybackHandle` to new ws paradigm (#37981)

* Updates `desktopPlaybackHandle` to new ws paradigm

This was mistakenly left out of #37520.
This commit also refactors `WithClusterAuthWebSocket` slightly for easier
comprehension, and updates the vite config to facilitate the new websocket
endpoints in development mode.

* Update lib/web/apiserver.go

Co-authored-by: Zac Bergquist <[email protected]>

---------

Co-authored-by: Zac Bergquist <[email protected]>

---------

Co-authored-by: Alex McGrath <[email protected]>
Co-authored-by: Zac Bergquist <[email protected]>
  • Loading branch information
3 people authored Feb 9, 2024
1 parent 5b149a2 commit 3a9d624
Show file tree
Hide file tree
Showing 18 changed files with 693 additions and 117 deletions.
218 changes: 210 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -151,6 +152,11 @@ type Handler struct {

// tracer is used to create spans.
tracer oteltrace.Tracer

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
wsIODeadline time.Duration
}

// HandlerOption is a functional argument - an option that can be passed
Expand Down Expand Up @@ -365,6 +371,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
ClusterFeatures: cfg.ClusterFeatures,
healthCheckAppServer: cfg.HealthCheckAppServer,
tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb),
wsIODeadline: wsIODeadline,
}

// Check for self-hosted vs Cloud.
Expand Down Expand Up @@ -720,7 +727,10 @@ func (h *Handler) bindDefaultEndpoints() {
h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock))

// active sessions handlers
h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket)
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket)
h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket)
h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions

// Audit events handlers.
Expand Down Expand Up @@ -828,9 +838,17 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet))
h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=<bearer_token>&username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktops/:desktopName/connect?username=<username>&width=<width>&height=<height>
h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle))
// GET /webapi/sites/:site/desktopplayback/:sid?access_token=<bearer_token>
h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle))
// Deprecated: The desktopplayback/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle))
// GET /webapi/sites/:site/desktopplayback/:sid/ws
h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle))
h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive))

// GET a Connection Diagnostics by its name
Expand Down Expand Up @@ -889,7 +907,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups))

// WebSocket endpoint for the chat conversation
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant))
// Deprecated: The connect/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant))
// WebSocket endpoint for the chat conversation, websocket auth
h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant))

// Sets the title for the conversation.
h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle))
Expand All @@ -908,7 +930,11 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID))

// Allows executing an arbitrary command on multiple nodes.
h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand))
// Deprecated: The execute/ws variant should be used instead.
// TODO(lxea): DELETE in v16
h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand))
// Allows executing an arbitrary command on multiple nodes, websocket auth.
h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand))

// Fetches the user's preferences
h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences))
Expand Down Expand Up @@ -2941,6 +2967,7 @@ func (h *Handler) siteNodeConnect(
p httprouter.Params,
sessionCtx *SessionContext,
site reversetunnelclient.RemoteSite,
ws *websocket.Conn,
) (interface{}, error) {
q := r.URL.Query()
params := q.Get("params")
Expand Down Expand Up @@ -3033,6 +3060,7 @@ func (h *Handler) siteNodeConnect(
PROXYSigner: h.cfg.PROXYSigner,
Tracker: tracker,
PresenceChecker: h.cfg.PresenceChecker,
WebsocketConn: ws,
}

term, err := NewTerminal(ctx, terminalConfig)
Expand Down Expand Up @@ -3731,6 +3759,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa
// ClusterHandler is a authenticated handler that is called for some existing remote cluster
type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error)

// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster
type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error)

// WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy
// (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster
// or a remote trusted cluster) as specified by the ":site" url parameter.
Expand All @@ -3745,12 +3776,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle {
})
}

func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) {
if err == nil {
return
}
errEnvelope := Envelope{
Type: defaults.WebsocketError,
Payload: trace.UserMessage(err),
}
env, err := errEnvelope.Marshal()
if err != nil {
h.log.WithError(err).Error("error marshaling proto")
return
}
if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil {
h.log.WithError(err).Error("error writing proto")
return
}
}

// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket.
// This makes our lives easier in our automated tests. While ordinarily this would be
// used to enforce the same-origin policy, we don't need to worry about that for authenticated
// websockets, which also require a valid bearer token sent over the websocket after upgrade.
// Therefore even if an attacker were to connect to the websocket and trick the browser into
// sending the session cookie, they would still fail to send the bearer token needed to authenticate.
var authnWsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}

// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated
// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as
// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster)
// as specified by the ":site" url parameter.
//
// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed
func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) {
var sctx *SessionContext
var ws *websocket.Conn
var site reversetunnelclient.RemoteSite
var err error

if websocketAuth {
sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p)
} else {
sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p)
}

if err != nil {
return nil, trace.Wrap(err)
}
// WS protocol requires the server send a close message
// which should be done by downstream users
defer ws.Close()
if _, err := fn(w, r, p, sctx, site, ws); err != nil {
h.writeErrToWebSocket(ws, err)
}
return nil, nil
})
}

// authenticateWSRequestWithCluster ensures that a request is
// authenticated to this proxy via websocket, returning the
// *SessionContext (same as AuthenticateRequest), and also grabs the
// remoteSite (which can represent this local cluster or a remote
// trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) {
sctx, ws, err := h.AuthenticateRequestWS(w, r)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

site, err := h.getSiteByParams(sctx, p)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}

return sctx, ws, site, nil
}

// TODO(lxea): remove once the deprecated websocket handlers are removed
func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) {
sctx, site, err := h.authenticateRequestWithCluster(w, r, p)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
return sctx, ws, site, nil
}

// authenticateRequestWithCluster ensures that a request is authenticated
// to this proxy, returning the *SessionContext (same as AuthenticateRequest),
// and also grabs the remoteSite (which can represent this local cluster or a
// remote trusted cluster) as specified by the ":site" url parameter.
func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) {
sctx, err := h.AuthenticateRequest(w, r, true)

if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -4068,9 +4195,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error {
return trace.Wrap(err)
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) {
const missingCookieMsg = "missing session cookie"
cookie, err := r.Cookie(websession.CookieName)
if err != nil || (cookie != nil && cookie.Value == "") {
Expand All @@ -4085,6 +4210,17 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
clearSessionCookies((w))
return nil, trace.AccessDenied("need auth")
}

return sctx, nil
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
sctx, err := h.validateCookie(w, r)
if err != nil {
return nil, trace.Wrap(err)
}
if checkBearerToken {
creds, err := roundtrip.ParseAuthHeaders(r)
if err != nil {
Expand Down Expand Up @@ -4137,6 +4273,72 @@ func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader
return ctx, nil
}

type wsBearerToken struct {
Token string `json:"token"`
}

type wsStatus struct {
Type string `json:"type"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}

// wsIODeadline is used to set a deadline for receiving a message from
// an authenticated websocket so unauthenticated sockets dont get left
// open.
const wsIODeadline = time.Second * 4

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token retrieved from a websocket
func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) {
sctx, err := h.validateCookie(w, r)
if err != nil {
return nil, nil, trace.Wrap(err)
}
ws, err := authnWsUpgrader.Upgrade(w, r, nil)
if err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

var t wsBearerToken
if err := ws.ReadJSON(&t); err != nil {
return nil, nil, trace.Wrap(err)
}
if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil {
writeErr := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "error",
Message: "invalid token",
})
if writeErr != nil {
log.Errorf("Error while writing invalid token error to websocket: %s", writeErr)
}

return nil, nil, trace.Wrap(err)
}

if err := ws.WriteJSON(wsStatus{
Type: "create_session_response",
Status: "ok",
}); err != nil {
return nil, nil, trace.Wrap(err)
}

// unset the deadline as downstream consumers should handle this themselves.
if err := ws.SetReadDeadline(time.Time{}); err != nil {
return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err)
}

if err := parseMFAResponseFromRequest(r); err != nil {
return nil, nil, trace.Wrap(err)
}

return sctx, ws, nil
}

// ProxyWithRoles returns a reverse tunnel proxy verifying the permissions
// of the given user.
func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) {
Expand Down
Loading

0 comments on commit 3a9d624

Please sign in to comment.