From 3981eb766c88009b1a478c403e21f8ad411ed8c9 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Tue, 30 Jan 2024 18:33:45 +0000 Subject: [PATCH] read errors from websocket --- lib/web/apiserver.go | 18 ++++++++++++++++-- lib/web/apiserver_test.go | 7 +++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index f4db3959d3916..43e31226c63bd 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3747,6 +3747,12 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +// WSError is used to write errors that previously occurred before a +// websocket got upgraded +type WSError struct { + Error string `json:"error"` +} + // WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) @@ -3762,7 +3768,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - return fn(w, r, p, sctx, site, ws) + _, err = fn(w, r, p, sctx, site, ws) + if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { + h.log.WithError(err).Error("error writing json") + } + return nil, nil } sctx, site, err := h.authenticateRequestWithCluster(w, r, p) @@ -3783,7 +3793,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - return fn(w, r, p, sctx, site, ws) + _, err = fn(w, r, p, sctx, site, ws) + if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { + h.log.WithError(err).Error("error writing json") + } + return nil, nil }) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 89953c94ba92a..8a7834c1fd758 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1360,8 +1360,11 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { proxy: s.webServer.Listener.Addr().String(), sessionID: "/../../../foo", }) - require.Error(t, err) - require.Nil(t, term) + require.NoError(t, err) + var wsError WSError + err = term.ws.ReadJSON(&wsError) + require.NoError(t, err) + require.Equal(t, "/../../../foo is not a valid UUID", wsError.Error) } func TestResolveServerHostPort(t *testing.T) {