Skip to content

Commit

Permalink
read errors from websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex McGrath committed Jan 30, 2024
1 parent 5b39b8a commit 3981eb7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
18 changes: 16 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3747,6 +3747,12 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle {
})
}

// WSError is used to write errors that previously occurred before a
// websocket got upgraded
type WSError struct {
Error string `json:"error"`
}

// WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated
// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as
// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster)
Expand All @@ -3762,7 +3768,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl
}

defer ws.Close()
return fn(w, r, p, sctx, site, ws)
_, err = fn(w, r, p, sctx, site, ws)
if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil {
h.log.WithError(err).Error("error writing json")
}
return nil, nil
}

sctx, site, err := h.authenticateRequestWithCluster(w, r, p)
Expand All @@ -3783,7 +3793,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl
}

defer ws.Close()
return fn(w, r, p, sctx, site, ws)
_, err = fn(w, r, p, sctx, site, ws)
if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil {
h.log.WithError(err).Error("error writing json")
}
return nil, nil
})
}

Expand Down
7 changes: 5 additions & 2 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1360,8 +1360,11 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) {
proxy: s.webServer.Listener.Addr().String(),
sessionID: "/../../../foo",
})
require.Error(t, err)
require.Nil(t, term)
require.NoError(t, err)
var wsError WSError
err = term.ws.ReadJSON(&wsError)
require.NoError(t, err)
require.Equal(t, "/../../../foo is not a valid UUID", wsError.Error)
}

func TestResolveServerHostPort(t *testing.T) {
Expand Down

0 comments on commit 3981eb7

Please sign in to comment.