diff --git a/jsonrpc/websocket.go b/jsonrpc/websocket.go index b3e436703a..2c1e207a54 100644 --- a/jsonrpc/websocket.go +++ b/jsonrpc/websocket.go @@ -18,14 +18,17 @@ type Websocket struct { log utils.SimpleLogger connParams *WebsocketConnParams listener NewRequestListener + + shutdown <-chan struct{} } -func NewWebsocket(rpc *Server, log utils.SimpleLogger) *Websocket { +func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket { ws := &Websocket{ rpc: rpc, log: log, connParams: DefaultWebsocketConnParams(), listener: &SelectiveListener{}, + shutdown: shutdown, } return ws @@ -54,7 +57,19 @@ func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) { // TODO include connection information, such as the remote address, in the logs. - wsc := newWebsocketConn(r.Context(), conn, ws.connParams) + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + go func() { + select { + case <-ws.shutdown: + cancel() + case <-ctx.Done(): + // in case websocket connection is closed and server is not in shutdown mode + // we need to release this goroutine from waiting + } + }() + + wsc := newWebsocketConn(ctx, conn, ws.connParams) for { _, wsc.r, err = wsc.conn.Reader(wsc.ctx) diff --git a/jsonrpc/websocket_test.go b/jsonrpc/websocket_test.go index 9baf704ae9..4f60377c02 100644 --- a/jsonrpc/websocket_test.go +++ b/jsonrpc/websocket_test.go @@ -20,7 +20,7 @@ func testConnection(t *testing.T, ctx context.Context, method jsonrpc.Method, li require.NoError(t, rpc.RegisterMethods(method)) // Server - srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger())) + srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger())) // Client conn, resp, err := websocket.Dial(ctx, srv.URL, nil) //nolint:bodyclose // websocket package closes resp.Body for us. diff --git a/node/http.go b/node/http.go index 4226564b0a..89a6db60fc 100644 --- a/node/http.go +++ b/node/http.go @@ -52,6 +52,10 @@ func (h *httpService) Run(ctx context.Context) error { } } +func (h *httpService) registerOnShutdown(f func()) { + h.srv.RegisterOnShutdown(f) +} + func makeHTTPService(host string, port uint16, handler http.Handler) *httpService { portStr := strconv.FormatUint(uint64(port), 10) return &httpService{ @@ -108,9 +112,11 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc. listener = makeWSMetrics() } + shutdown := make(chan struct{}) + mux := http.NewServeMux() for path, server := range servers { - wsHandler := jsonrpc.NewWebsocket(server, log) + wsHandler := jsonrpc.NewWebsocket(server, shutdown, log) if listener != nil { wsHandler = wsHandler.WithListener(listener) } @@ -124,7 +130,12 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc. if corsEnabled { handler = cors.Default().Handler(handler) } - return makeHTTPService(host, port, handler) + + httpServ := makeHTTPService(host, port, handler) + httpServ.registerOnShutdown(func() { + close(shutdown) + }) + return httpServ } func makeMetrics(host string, port uint16) *httpService { diff --git a/rpc/events_test.go b/rpc/events_test.go index c2f1417791..6655b6166e 100644 --- a/rpc/events_test.go +++ b/rpc/events_test.go @@ -351,7 +351,7 @@ func TestMultipleSubscribeNewHeadsAndUnsubscribe(t *testing.T) { Params: []jsonrpc.Parameter{{Name: "id"}}, Handler: handler.Unsubscribe, })) - ws := jsonrpc.NewWebsocket(server, log) + ws := jsonrpc.NewWebsocket(server, nil, log) httpSrv := httptest.NewServer(ws) conn1, _, err := websocket.Dial(ctx, httpSrv.URL, nil) require.NoError(t, err) diff --git a/rpc/subscriptions.go b/rpc/subscriptions.go index b049c6ce0d..5edf1bcfbb 100644 --- a/rpc/subscriptions.go +++ b/rpc/subscriptions.go @@ -90,7 +90,6 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys case <-subscriptionCtx.Done(): return case header := <-headerSub.Recv(): - h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys) } }