From 3a9d624f44b4ed3b1634d0fd597e6c8a076ec144 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Fri, 9 Feb 2024 11:22:01 -0800 Subject: [PATCH] [v15] backport #37520 and #37981 (#38032) * Read the bearer token over websocket endpoints instead of query parameter (#37520) * Read the bearer token over WS endpoints use the request context, not session Dont pass websocket by context lint resolve some comments Add TestWSAuthenticateRequest Close ws in handler deprecation notices, doc resolve comments resolve comments give a longer read/write deadline dont set write deadline, ws endpoints never did before and it breaks things convert frontend to use ws access token Resolove comments, move to using an explicit state fix ci reset read deadline prettier * update connectToHost * linter * read errors from websocket * missing /ws on ttyWsAddr and fix wrong onmessage * fix race in test * lint * skip TestTerminal as it takes 11 seconds to run * dont skip the test * resolve apiserver comments * Add an AuthenticatedWebSocket class * convert other clients to use AuthenticatedWebSocket * Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` (#37699) * Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` that automatically goes through Teleport's custom authentication process before facilitating any caller-defined communication. This also reverts previous-`WebSocket` users to their original state (sans the code for passing the bearer token in the query string), swapping in `AuthenticatedWebSocket` in place of `WebSocket`. * Create a single authnWsUpgrader with a comment justifying why we turn off CORS * recieving to receiving * resolve comments --------- Co-authored-by: Isaiah Becker-Mayer * Updates `desktopPlaybackHandle` to new ws paradigm (#37981) * Updates `desktopPlaybackHandle` to new ws paradigm This was mistakenly left out of https://github.com/gravitational/teleport/pull/37520. This commit also refactors `WithClusterAuthWebSocket` slightly for easier comprehension, and updates the vite config to facilitate the new websocket endpoints in development mode. * Update lib/web/apiserver.go Co-authored-by: Zac Bergquist --------- Co-authored-by: Zac Bergquist --------- Co-authored-by: Alex McGrath Co-authored-by: Zac Bergquist --- lib/web/apiserver.go | 218 +++++++++++++- lib/web/apiserver_test.go | 151 +++++++++- lib/web/assistant.go | 20 +- lib/web/command.go | 15 +- lib/web/desktop.go | 12 +- lib/web/desktop_playback.go | 11 +- lib/web/terminal.go | 22 +- lib/web/terminal_test.go | 8 +- .../src/Assist/context/AssistContext.tsx | 18 +- .../TerminalAssist/TerminalAssistContext.tsx | 11 +- .../teleport/src/Console/consoleContext.tsx | 3 +- .../src/DesktopSession/useTdpClientCanvas.tsx | 3 +- .../teleport/src/Player/DesktopPlayer.tsx | 5 +- web/packages/teleport/src/config.ts | 16 +- .../src/lib/AuthenticatedWebSocket.ts | 279 ++++++++++++++++++ web/packages/teleport/src/lib/tdp/client.ts | 7 +- web/packages/teleport/src/lib/term/tty.ts | 3 +- web/packages/teleport/src/types.ts | 8 + 18 files changed, 693 insertions(+), 117 deletions(-) create mode 100644 web/packages/teleport/src/lib/AuthenticatedWebSocket.ts diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0977e13cfc1d4..50f0c41aa0d9a 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -40,6 +40,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -151,6 +152,11 @@ type Handler struct { // tracer is used to create spans. tracer oteltrace.Tracer + + // wsIODeadline is used to set a deadline for receiving a message from + // an authenticated websocket so unauthenticated sockets dont get left + // open. + wsIODeadline time.Duration } // HandlerOption is a functional argument - an option that can be passed @@ -365,6 +371,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb), + wsIODeadline: wsIODeadline, } // Check for self-hosted vs Cloud. @@ -720,7 +727,10 @@ func (h *Handler) bindDefaultEndpoints() { h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock)) // active sessions handlers - h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -828,9 +838,17 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet)) h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle)) + // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= - h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) + // Deprecated: The desktopplayback/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle)) + // GET /webapi/sites/:site/desktopplayback/:sid/ws + h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) // GET a Connection Diagnostics by its name @@ -889,7 +907,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) // WebSocket endpoint for the chat conversation - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) + // WebSocket endpoint for the chat conversation, websocket auth + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -908,7 +930,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) // Allows executing an arbitrary command on multiple nodes. - h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand)) + // Deprecated: The execute/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) + // Allows executing an arbitrary command on multiple nodes, websocket auth. + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -2941,6 +2967,7 @@ func (h *Handler) siteNodeConnect( p httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -3033,6 +3060,7 @@ func (h *Handler) siteNodeConnect( PROXYSigner: h.cfg.PROXYSigner, Tracker: tracker, PresenceChecker: h.cfg.PresenceChecker, + WebsocketConn: ws, } term, err := NewTerminal(ctx, terminalConfig) @@ -3731,6 +3759,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa // ClusterHandler is a authenticated handler that is called for some existing remote cluster type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) +// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster +type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error) + // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster // or a remote trusted cluster) as specified by the ":site" url parameter. @@ -3745,12 +3776,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { + if err == nil { + return + } + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: trace.UserMessage(err), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return + } +} + +// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket. +// This makes our lives easier in our automated tests. While ordinarily this would be +// used to enforce the same-origin policy, we don't need to worry about that for authenticated +// websockets, which also require a valid bearer token sent over the websocket after upgrade. +// Therefore even if an attacker were to connect to the websocket and trick the browser into +// sending the session cookie, they would still fail to send the bearer token needed to authenticate. +var authnWsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as +// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) +// as specified by the ":site" url parameter. +// +// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed +func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { + var sctx *SessionContext + var ws *websocket.Conn + var site reversetunnelclient.RemoteSite + var err error + + if websocketAuth { + sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p) + } else { + sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p) + } + + if err != nil { + return nil, trace.Wrap(err) + } + // WS protocol requires the server send a close message + // which should be done by downstream users + defer ws.Close() + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + h.writeErrToWebSocket(ws, err) + } + return nil, nil + }) +} + +// authenticateWSRequestWithCluster ensures that a request is +// authenticated to this proxy via websocket, returning the +// *SessionContext (same as AuthenticateRequest), and also grabs the +// remoteSite (which can represent this local cluster or a remote +// trusted cluster) as specified by the ":site" url parameter. +func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, ws, err := h.AuthenticateRequestWS(w, r) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + site, err := h.getSiteByParams(sctx, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sctx, ws, site, nil +} + +// TODO(lxea): remove once the deprecated websocket handlers are removed +func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) + if err != nil { return nil, nil, trace.Wrap(err) } @@ -4068,9 +4195,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { return trace.Wrap(err) } -// AuthenticateRequest authenticates request using combination of a session cookie -// and bearer token -func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { +func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" cookie, err := r.Cookie(websession.CookieName) if err != nil || (cookie != nil && cookie.Value == "") { @@ -4085,6 +4210,17 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch clearSessionCookies((w)) return nil, trace.AccessDenied("need auth") } + + return sctx, nil +} + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token +func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, trace.Wrap(err) + } if checkBearerToken { creds, err := roundtrip.ParseAuthHeaders(r) if err != nil { @@ -4137,6 +4273,72 @@ func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader return ctx, nil } +type wsBearerToken struct { + Token string `json:"token"` +} + +type wsStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// wsIODeadline is used to set a deadline for receiving a message from +// an authenticated websocket so unauthenticated sockets dont get left +// open. +const wsIODeadline = time.Second * 4 + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token retrieved from a websocket +func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + var t wsBearerToken + if err := ws.ReadJSON(&t); err != nil { + return nil, nil, trace.Wrap(err) + } + if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { + writeErr := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }) + if writeErr != nil { + log.Errorf("Error while writing invalid token error to websocket: %s", writeErr) + } + + return nil, nil, trace.Wrap(err) + } + + if err := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "ok", + }); err != nil { + return nil, nil, trace.Wrap(err) + } + + // unset the deadline as downstream consumers should handle this themselves. + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + if err := parseMFAResponseFromRequest(r); err != nil { + return nil, nil, trace.Wrap(err) + } + + return sctx, ws, nil +} + // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions // of the given user. func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 680ec4b727493..d47abc5c0a6b3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1355,14 +1355,25 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { ctx, cancel := context.WithCancel(s.ctx) t.Cleanup(cancel) - term, err := connectToHost(ctx, connectConfig{ + result := make(chan error) + + _, err := connectToHost(ctx, connectConfig{ pack: s.authPack(t, "foo"), host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), sessionID: "/../../../foo", + handlers: map[string]WSHandlerFunc{ + defaults.WebsocketError: func(ctx context.Context, e Envelope) { + if e.Payload == "/../../../foo is not a valid UUID" { + result <- errors.New(e.Payload) + } + close(result) + }, + }, }) - require.Error(t, err) - require.Nil(t, term) + require.NoError(t, err) + res := <-result + require.Error(t, res) } func TestResolveServerHostPort(t *testing.T) { @@ -1897,6 +1908,7 @@ func TestTerminal(t *testing.T) { host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), }) + require.NoError(t, err) t.Cleanup(func() { require.True(t, utils.IsOKNetworkError(term.Close())) }) @@ -8118,18 +8130,38 @@ func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *Test return &TestWebClient{clt, t} } +func makeAuthReqOverWS(ws *websocket.Conn, token string) error { + authReq, err := json.Marshal(struct { + Token string `json:"token"` + }{Token: token}) + if err != nil { + return trace.Wrap(err) + } + + if err := ws.WriteMessage(websocket.TextMessage, authReq); err != nil { + return trace.Wrap(err) + } + _, authRes, err := ws.ReadMessage() + if err != nil { + return trace.Wrap(err) + } + if !strings.Contains(string(authRes), `"status":"ok"`) { + return trace.AccessDenied("unexpected response") + } + return nil +} + func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), + Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect/ws", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -8144,6 +8176,10 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) + + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) @@ -9135,6 +9171,111 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } +func TestWebSocketAuthenticateRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + proxy := env.proxies[0] + proxy.handler.handler.wsIODeadline = time.Second + pack := proxy.authPack(t, "test-user@example.com", nil) + for _, tc := range []struct { + name string + serverExpectError string + expectResponse wsStatus + token string + writeTimeout func() + readTimeout func() + }{ + { + name: "valid token", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "ok", + }, + token: pack.session.Token, + }, + { + name: "invalid token", + serverExpectError: "not found", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }, + token: "honk", + }, + { + name: "server read timeout", + serverExpectError: "i/o timeout", + token: pack.session.Token, + readTimeout: func() { + <-time.After(wsIODeadline * 3) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) + if err != nil { + if tc.serverExpectError == "" { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tc.serverExpectError) { + t.Errorf("unexpected error: %v", err) + return + } + return + } + t.Cleanup(func() { ws.Close() }) + if err == nil && tc.serverExpectError != "" { + t.Errorf("expected error, got nil") + return + } + + clt, err := sctx.GetClient() + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, err = clt.GetDomainName(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + })) + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + u := strings.Replace(server.URL, "http:", "ws:", 1) + conn, resp, err := websocket.DefaultDialer.Dial(u, header) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { resp.Body.Close() }) + + if tc.readTimeout != nil { + tc.readTimeout() + } + err = conn.WriteJSON(wsBearerToken{ + Token: tc.token, + }) + require.NoError(t, err) + if tc.readTimeout != nil { + return // Reading will fail as the server will have closed the connection + } + + var status wsStatus + err = conn.ReadJSON(&status) + require.NoError(t, err) + require.Equal(t, tc.expectResponse, status) + }) + } +} + // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy diff --git a/lib/web/assistant.go b/lib/web/assistant.go index fca5fde9c4e19..da1caae195a17 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -332,9 +332,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, // This handler covers the main chat conversation as well as the // SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { - if err := runAssistant(h, w, r, sctx, site); err != nil { + if err := runAssistant(h, w, r, sctx, site, ws); err != nil { h.log.Warn(trace.DebugReport(err)) return nil, trace.Wrap(err) } @@ -420,7 +420,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -455,20 +455,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil - } - // Note: This time should be longer than OpenAI response time. keepAliveInterval := netConfig.GetKeepAliveInterval() err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) diff --git a/lib/web/command.go b/lib/web/command.go index 9017b63ed1b34..bae2a66247b04 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -128,6 +128,7 @@ func (h *Handler) executeCommand( _ httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + rawWS *websocket.Conn, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -171,20 +172,6 @@ func (h *Handler) executeCommand( clusterName := site.GetName() - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - rawWS, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } - defer func() { rawWS.WriteMessage(websocket.CloseMessage, nil) rawWS.Close() diff --git a/lib/web/desktop.go b/lib/web/desktop.go index a89bcf01b17c3..abd663d897c1c 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -66,6 +66,7 @@ func (h *Handler) desktopConnectHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -75,7 +76,7 @@ func (h *Handler) desktopConnectHandle( log := sctx.cfg.Log.WithField("desktop-name", desktopName).WithField("cluster-name", site.GetName()) log.Debug("New desktop access websocket connection") - if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site); err != nil { + if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site, ws); err != nil { // createDesktopConnection makes a best effort attempt to send an error to the user // (via websocket) before terminating the connection. We log the error here, but // return nil because our HTTP middleware will try to write the returned error in JSON @@ -94,15 +95,8 @@ func (h *Handler) createDesktopConnection( log *logrus.Entry, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return trace.Wrap(err) - } defer ws.Close() sendTDPError := func(err error) error { diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index df04c330eed7f..9c50cdcc153c7 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -38,6 +38,7 @@ func (h *Handler) desktopPlaybackHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { @@ -49,16 +50,6 @@ func (h *Handler) desktopPlaybackHandle( return nil, trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return nil, trace.Wrap(err) - } - defer ws.Close() - player, err := player.New(&player.Config{ Clock: h.clock, Log: h.log, diff --git a/lib/web/terminal.go b/lib/web/terminal.go index f8c6ebb17882f..3c40c7a18c15c 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -142,6 +142,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl participantMode: cfg.ParticipantMode, tracker: cfg.Tracker, presenceChecker: cfg.PresenceChecker, + websocketConn: cfg.WebsocketConn, }, nil } @@ -191,6 +192,8 @@ type TerminalHandlerConfig struct { PresenceChecker PresenceChecker // Clock allows interaction with time. Clock clockwork.Clock + // WebsocketConn is the active websocket connection + WebsocketConn *websocket.Conn } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { @@ -317,6 +320,9 @@ type TerminalHandler struct { // clock used to interact with time. clock clockwork.Clock + + // websocketConn is the active websocket connection + websocketConn *websocket.Conn } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -328,21 +334,9 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.ctx.AddClosers(t) defer t.ctx.RemoveCloser(t) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - t.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return - } + ws := t.websocketConn - err = ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) + err := ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) if err != nil { t.log.WithError(err).Error("Error setting websocket readline") return diff --git a/lib/web/terminal_test.go b/lib/web/terminal_test.go index ce8996308ed53..2e7edf5150f6f 100644 --- a/lib/web/terminal_test.go +++ b/lib/web/terminal_test.go @@ -33,7 +33,6 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/websocket" - "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -127,12 +126,11 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { u := url.URL{ Host: cfg.proxy, Scheme: client.WSS, - Path: "/v1/webapi/sites/-current-/connect", + Path: "/v1/webapi/sites/-current-/connect/ws", } q := u.Query() q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, cfg.pack.session.Token) u.RawQuery = q.Encode() header := http.Header{} @@ -162,6 +160,10 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { return nil, trace.Wrap(err) } + if err := makeAuthReqOverWS(ws, cfg.pack.session.Token); err != nil { + return nil, trace.Wrap(err) + } + if cfg.pingHandler != nil { ws.SetPingHandler(func(message string) error { return cfg.pingHandler(ws, message) diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index e38ab3577936c..edce10821c056 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -31,7 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state'; import { convertServerMessages } from 'teleport/Assist/context/utils'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import { AccessRequestClientMessage, @@ -48,6 +48,7 @@ import { makeMfaAuthenticateChallenge, WebauthnAssertionResponse, } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import * as service from '../service'; import { @@ -84,9 +85,9 @@ let lastCommandExecutionResultId = 0; const TEN_MINUTES = 10 * 60 * 1000; export function AssistContextProvider(props: PropsWithChildren) { - const activeWebSocket = useRef(null); + const activeWebSocket = useRef(null); // TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented - const executeCommandWebSocket = useRef(null); + const executeCommandWebSocket = useRef(null); const refreshWebSocketTimeout = useRef(null); const { clusterId } = useStickyClusterId(); @@ -124,11 +125,10 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - activeWebSocket.current = new WebSocket( + activeWebSocket.current = new AuthenticatedWebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), clusterId, - getAccessToken(), conversationId ) ); @@ -348,7 +348,7 @@ export function AssistContextProvider(props: PropsWithChildren) { if ( !activeWebSocket.current || - activeWebSocket.current.readyState === WebSocket.CLOSED + activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED ) { setupWebSocket(state.conversations.selectedId, data); } else { @@ -378,7 +378,8 @@ export function AssistContextProvider(props: PropsWithChildren) { function sendMfaChallenge(data: WebauthnAssertionResponse) { if ( !executeCommandWebSocket.current || - executeCommandWebSocket.current.readyState !== WebSocket.OPEN || + executeCommandWebSocket.current.readyState !== + AuthenticatedWebSocket.OPEN || !data ) { console.warn( @@ -446,12 +447,11 @@ export function AssistContextProvider(props: PropsWithChildren) { const url = cfg.getAssistExecuteCommandUrl( getHostName(), clusterId, - getAccessToken(), execParams ); const proto = new Protobuf(); - executeCommandWebSocket.current = new WebSocket(url); + executeCommandWebSocket.current = new AuthenticatedWebSocket(url); executeCommandWebSocket.current.binaryType = 'arraybuffer'; executeCommandWebSocket.current.onmessage = event => { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9bcd78ce43027..ebdec9bfcf4dc 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -26,7 +26,7 @@ import React, { } from 'react'; import { Author, ServerMessage } from 'teleport/Assist/types'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { @@ -36,6 +36,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; interface TerminalAssistContextValue { close: () => void; @@ -57,11 +58,10 @@ export function TerminalAssistContextProvider( const [visible, setVisible] = useState(false); - const socketRef = useRef(null); + const socketRef = useRef(null); const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-cmdgen' ); @@ -72,7 +72,7 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - socketRef.current = new WebSocket(socketUrl); + socketRef.current = new AuthenticatedWebSocket(socketUrl); socketRef.current.onmessage = e => { const data = JSON.parse(e.data) as ServerMessage; @@ -117,11 +117,10 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-explain' ); - const ws = new WebSocket(socketUrl); + const ws = new AuthenticatedWebSocket(socketUrl); ws.onopen = () => { ws.send(encodedOutput); diff --git a/web/packages/teleport/src/Console/consoleContext.tsx b/web/packages/teleport/src/Console/consoleContext.tsx index e696f18602e0e..76ba3013d7cac 100644 --- a/web/packages/teleport/src/Console/consoleContext.tsx +++ b/web/packages/teleport/src/Console/consoleContext.tsx @@ -24,7 +24,7 @@ import { W3CTraceContextPropagator } from '@opentelemetry/core'; import webSession from 'teleport/services/websession'; import history from 'teleport/services/history'; import cfg, { UrlResourcesParams, UrlSshParams } from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import Tty from 'teleport/lib/term/tty'; import TtyAddressResolver from 'teleport/lib/term/ttyAddressResolver'; import serviceSession, { @@ -197,7 +197,6 @@ export default class ConsoleContext { const ttyUrl = cfg.api.ttyWsAddr .replace(':fqdn', getHostName()) - .replace(':token', getAccessToken()) .replace(':clusterId', clusterId) .replace(':traceparent', carrier['traceparent']); diff --git a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx index 74153145af2b9..e5f3c986a3f21 100644 --- a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx +++ b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx @@ -30,7 +30,7 @@ import { PngFrame, SyncKeys, } from 'teleport/lib/tdp/codec'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import cfg from 'teleport/config'; import { Sha256Digest } from 'teleport/lib/util'; @@ -85,7 +85,6 @@ export default function useTdpClientCanvas(props: Props) { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':desktopName', desktopName) - .replace(':token', getAccessToken()) .replace(':username', username); setTdpClient(new TdpClient(addr)); diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index c819d7e6581c8..8262835db060c 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -23,7 +23,7 @@ import { Indicator, Box, Alert, Flex } from 'design'; import cfg from 'teleport/config'; import { StatusEnum, formatDisplayTime } from 'teleport/lib/player'; import { PlayerClient, TdpClient } from 'teleport/lib/tdp'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import TdpClientCanvas from 'teleport/components/TdpClientCanvas'; import ProgressBar from './ProgressBar'; @@ -157,8 +157,7 @@ const useDesktopPlayer = ({ clusterId, sid }) => { const url = cfg.api.desktopPlaybackWsAddr .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) - .replace(':sid', sid) - .replace(':token', getAccessToken()); + .replace(':sid', sid); return new PlayerClient({ url, setTime, setPlayerStatus, setStatusText }); }, [clusterId, sid, setTime, setPlayerStatus]); diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index ac3521c95bb86..c8d0748193fd5 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -196,12 +196,12 @@ const cfg = { desktopServicesPath: `/v1/webapi/sites/:clusterId/desktopservices?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?`, desktopPath: `/v1/webapi/sites/:clusterId/desktops/:desktopName`, desktopWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect?access_token=:token&username=:username', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect/ws?username=:username', desktopPlaybackWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid?access_token=:token', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?access_token=:token¶ms=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect/ws?params=:params&traceparent=:traceparent', ttyPlaybackWsAddr: 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', @@ -310,11 +310,11 @@ const cfg = { '/v1/webapi/assistant/conversations/:conversationId/title', assistGenerateSummaryPath: '/v1/webapi/assistant/title/summary', assistConversationWebSocketPath: - 'wss://:hostname/v1/webapi/sites/:clusterId/assistant', + 'wss://:hostname/v1/webapi/sites/:clusterId/assistant/ws', assistConversationHistoryPath: '/v1/webapi/assistant/conversations/:conversationId', assistExecuteCommandWebSocketPath: - 'wss://:hostname/v1/webapi/command/:clusterId/execute', + 'wss://:hostname/v1/webapi/command/:clusterId/execute/ws', userPreferencesPath: '/v1/webapi/user/preferences', userClusterPreferencesPath: '/v1/webapi/user/preferences/:clusterId', @@ -857,12 +857,10 @@ const cfg = { getAssistConversationWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, conversationId: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('conversation_id', conversationId); return ( @@ -876,12 +874,10 @@ const cfg = { getAssistActionWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, action: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('action', action); return ( @@ -901,12 +897,10 @@ const cfg = { getAssistExecuteCommandUrl( hostname: string, clusterId: string, - accessToken: string, params: Record ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('params', JSON.stringify(params)); return ( diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts new file mode 100644 index 0000000000000..4c1d0c4e5e281 --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -0,0 +1,279 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +/** + * `AuthenticatedWebSocket` is a drop-in replacement for + * the `WebSocket` class that handles Teleport's websocket + * authentication process. + */ +export class AuthenticatedWebSocket extends WebSocket { + private authenticated: boolean = false; + private openListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onopenInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private messageListeners: ((this: WebSocket, ev: MessageEvent) => any)[] = []; + private onmessageInternal: + | ((this: WebSocket, ev: MessageEvent) => any) + | null = null; + private oncloseListeners: ((this: WebSocket, ev: CloseEvent) => any)[] = []; + private oncloseInternal: ((this: WebSocket, ev: CloseEvent) => any) | null = + null; + private onerrorListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onerrorInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private binaryTypeInternal: BinaryType = 'blob'; // Default binaryType + private onopenEvent: Event | null = null; + + constructor(url: string | URL, protocols?: string | string[]) { + super(url, protocols); + // Set the binaryType to 'arraybuffer' to handle the authentication process. + super.binaryType = 'arraybuffer'; + + // The open event listener should immediately send the authentication token + super.onopen = (onopenEvent: Event) => { + super.send(JSON.stringify({ token: getAccessToken() })); + // Don't call the user defined onopen messages yet, wait for the authentication response. + this.onopenEvent = onopenEvent; + }; + + // The message event listener should handle the authentication response, + // and if it succeeds, set the binaryType to the user-defined value and + // trigger any user-added open listeners. + super.onmessage = (ev: MessageEvent) => { + // If not yet authenticated, handle the authentication response. + if (!this.authenticated) { + // Parse the message as a WebsocketStatus. + let authResponse: WebsocketStatus; + try { + authResponse = JSON.parse(ev.data) as WebsocketStatus; + } catch (e) { + this.triggerError('Error parsing JSON from websocket message: ' + e); + return; + } + + // Validate the WebsocketStatus. + if ( + !authResponse.type || + !authResponse.status || + !(authResponse.type === 'create_session_response') || + !(authResponse.status === 'ok' || authResponse.status === 'error') + ) { + this.triggerError( + 'Invalid auth response: ' + JSON.stringify(authResponse) + ); + return; + } + + // Authentication succeeded. + if (authResponse.status === 'ok') { + this.authenticated = true; + // Set the binaryType to the value set by the user (or back to the default 'blob'). + super.binaryType = this.binaryTypeInternal; + // Now that authentication is complete, trigger any user-added open listeners + // with the original onopen event. + this.openListeners.forEach(listener => + listener.call(this, this.onopenEvent) + ); + this.onopenInternal?.call(this, this.onopenEvent); + return; + } else { + // Authentication failed, authResponse.status === 'error'. + this.triggerError( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + } else { + // If authenticated, pass messages to user-added listeners. + this.messageListeners.forEach(listener => { + listener.call(this, ev); + }); + this.onmessageInternal?.call(this, ev); + } + }; + + // Set the 'close' event for cleanup. + super.onclose = (ev: CloseEvent) => { + // Trigger any user-added close listeners + this.oncloseListeners.forEach(listener => listener.call(this, ev)); + this.oncloseInternal?.call(this, ev); + this.authenticated = false; + }; + + // Set the 'error' event for cleanup. + super.onerror = (ev: Event) => { + // Trigger any user-added error listeners + this.onerrorListeners.forEach(listener => listener.call(this, ev)); + this.onerrorInternal?.call(this, ev); + this.authenticated = false; + }; + } + + // Authenticated send + override send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + if (!this.authenticated) { + // This should be unreachable, but just in case. + this.triggerError( + 'Cannot send data before authentication is complete. Data: ' + data + ); + return; + } + super.send(data); + } + + // Override addEventListener to intercept these listeners and store them in + // our appropriate arrays. They are called in the appropriate places in the + // `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + override addEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + this.openListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + } else if (type === 'message') { + this.messageListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + } else if (type === 'close') { + this.oncloseListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + } else if (type === 'error') { + this.onerrorListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + } else { + // This should be unreachable, but just in case. + super.addEventListener(type, listener); + } + } + + // Override the onopen, onmessage, onclose, and onerror properties to store the user-defined + // listeners in the appropriate internal properties. These are called in the appropriate places + // in the `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + + override set onopen(listener: (this: WebSocket, ev: Event) => any | null) { + this.onopenInternal = listener; + } + + override get onopen(): ((this: WebSocket, ev: Event) => any) | null { + return this.onopenInternal; + } + + override set onmessage( + listener: ((this: WebSocket, ev: MessageEvent) => any) | null + ) { + this.onmessageInternal = listener; + } + + override get onmessage(): + | ((this: WebSocket, ev: MessageEvent) => any) + | null { + return this.onmessageInternal; + } + + override set onclose( + listener: ((this: WebSocket, ev: CloseEvent) => any) | null + ) { + this.oncloseInternal = listener; + } + + override get onclose(): ((this: WebSocket, ev: CloseEvent) => any) | null { + return this.oncloseInternal; + } + + override set onerror(listener: ((this: WebSocket, ev: Event) => any) | null) { + this.onerrorInternal = listener; + } + + override get onerror(): ((this: WebSocket, ev: Event) => any) | null { + return this.onerrorInternal; + } + + // Override the binaryType property to store the user-defined binaryType in the appropriate internal property. + // This is because we need to set the binaryType to 'arraybuffer' for the authentication process (see constructor), + // and only then can we set it to the user-defined value. + override set binaryType(binaryType: BinaryType) { + if (this.authenticated) { + super.binaryType = binaryType; + return; + } + + this.binaryTypeInternal = binaryType; + } + + override get binaryType(): BinaryType { + return this.binaryTypeInternal; + } + + // Override removeEventListener to support listeners removal for 'open', 'message', and 'close' events + override removeEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + const index = this.openListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + if (index !== -1) { + this.openListeners.splice(index, 1); + } + } else if (type === 'message') { + const index = this.messageListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + if (index !== -1) { + this.messageListeners.splice(index, 1); + } + } else if (type === 'close') { + const index = this.oncloseListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + if (index !== -1) { + this.oncloseListeners.splice(index, 1); + } + } else if (type === 'error') { + const index = this.onerrorListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + if (index !== -1) { + this.onerrorListeners.splice(index, 1); + } + } else { + // This should be unreachable, but just in case. + super.removeEventListener( + type, + listener as EventListenerOrEventListenerObject + ); + } + } + + // Method to manually trigger an error event. + private triggerError(errorMessage: string): void { + const errorEvent = new ErrorEvent('error', { + error: new Error(errorMessage), + message: errorMessage, + }); + + // Dispatch the event to trigger all listeners attached for 'error' events. + this.dispatchEvent(errorEvent); + } +} diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index f2584c95eadea..64b66dde7dffd 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,6 +25,7 @@ import init, { import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { MessageType, @@ -85,12 +86,12 @@ export enum LogType { } // Client is the TDP client. It is responsible for connecting to a websocket serving the tdp server, -// sending client commands, and recieving and processing server messages. Its creator is responsible for +// sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { protected codec: Codec; - protected socket: WebSocket | undefined; + protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; private sdManager: SharedDirectoryManager; private fastPathProcessor: FastPathProcessor | undefined; @@ -114,7 +115,7 @@ export default class Client extends EventEmitterWebAuthnSender { async connect(spec?: ClientScreenSpec) { await this.initWasm(); - this.socket = new WebSocket(this.socketAddr); + this.socket = new AuthenticatedWebSocket(this.socketAddr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = () => { diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 88a18bcc5d246..fe45eb930d65a 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -20,6 +20,7 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -62,7 +63,7 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new WebSocket(connStr); + this.socket = new AuthenticatedWebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; this.socket.onmessage = this._onMessage; diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index 3aabdcec6a52a..144db28946953 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -197,3 +197,11 @@ export enum RecommendationStatus { Notify = 'NOTIFY', Done = 'DONE', } + +// WebsocketStatus is used to indicate the auth status from a +// websocket connection +export type WebsocketStatus = { + type: string; + status: string; + message?: string; +};