Skip to content

Commit

Permalink
Websocket shutdown logic (#2277)
Browse files Browse the repository at this point in the history
  • Loading branch information
kirugan authored Dec 23, 2024
1 parent fce5078 commit 81262fd
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 7 deletions.
19 changes: 17 additions & 2 deletions jsonrpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ type Websocket struct {
log utils.SimpleLogger
connParams *WebsocketConnParams
listener NewRequestListener

shutdown <-chan struct{}
}

func NewWebsocket(rpc *Server, log utils.SimpleLogger) *Websocket {
func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket {
ws := &Websocket{
rpc: rpc,
log: log,
connParams: DefaultWebsocketConnParams(),
listener: &SelectiveListener{},
shutdown: shutdown,
}

return ws
Expand Down Expand Up @@ -54,7 +57,19 @@ func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// TODO include connection information, such as the remote address, in the logs.

wsc := newWebsocketConn(r.Context(), conn, ws.connParams)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
select {
case <-ws.shutdown:
cancel()
case <-ctx.Done():
// in case websocket connection is closed and server is not in shutdown mode
// we need to release this goroutine from waiting
}
}()

wsc := newWebsocketConn(ctx, conn, ws.connParams)

for {
_, wsc.r, err = wsc.conn.Reader(wsc.ctx)
Expand Down
2 changes: 1 addition & 1 deletion jsonrpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testConnection(t *testing.T, ctx context.Context, method jsonrpc.Method, li
require.NoError(t, rpc.RegisterMethods(method))

// Server
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger()))
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger()))

// Client
conn, resp, err := websocket.Dial(ctx, srv.URL, nil) //nolint:bodyclose // websocket package closes resp.Body for us.
Expand Down
15 changes: 13 additions & 2 deletions node/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func (h *httpService) Run(ctx context.Context) error {
}
}

func (h *httpService) registerOnShutdown(f func()) {
h.srv.RegisterOnShutdown(f)
}

func makeHTTPService(host string, port uint16, handler http.Handler) *httpService {
portStr := strconv.FormatUint(uint64(port), 10)
return &httpService{
Expand Down Expand Up @@ -108,9 +112,11 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc.
listener = makeWSMetrics()
}

shutdown := make(chan struct{})

mux := http.NewServeMux()
for path, server := range servers {
wsHandler := jsonrpc.NewWebsocket(server, log)
wsHandler := jsonrpc.NewWebsocket(server, shutdown, log)
if listener != nil {
wsHandler = wsHandler.WithListener(listener)
}
Expand All @@ -124,7 +130,12 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc.
if corsEnabled {
handler = cors.Default().Handler(handler)
}
return makeHTTPService(host, port, handler)

httpServ := makeHTTPService(host, port, handler)
httpServ.registerOnShutdown(func() {
close(shutdown)
})
return httpServ
}

func makeMetrics(host string, port uint16) *httpService {
Expand Down
2 changes: 1 addition & 1 deletion rpc/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func TestMultipleSubscribeNewHeadsAndUnsubscribe(t *testing.T) {
Params: []jsonrpc.Parameter{{Name: "id"}},
Handler: handler.Unsubscribe,
}))
ws := jsonrpc.NewWebsocket(server, log)
ws := jsonrpc.NewWebsocket(server, nil, log)
httpSrv := httptest.NewServer(ws)
conn1, _, err := websocket.Dial(ctx, httpSrv.URL, nil)
require.NoError(t, err)
Expand Down
1 change: 0 additions & 1 deletion rpc/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys
case <-subscriptionCtx.Done():
return
case header := <-headerSub.Recv():

h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys)
}
}
Expand Down

0 comments on commit 81262fd

Please sign in to comment.