Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket shutdown logic #2277

Merged
merged 10 commits into from
Dec 23, 2024
19 changes: 17 additions & 2 deletions jsonrpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ type Websocket struct {
log utils.SimpleLogger
connParams *WebsocketConnParams
listener NewRequestListener

shutdown <-chan struct{}
}

func NewWebsocket(rpc *Server, log utils.SimpleLogger) *Websocket {
func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket {
ws := &Websocket{
rpc: rpc,
log: log,
connParams: DefaultWebsocketConnParams(),
listener: &SelectiveListener{},
shutdown: shutdown,
}

return ws
Expand Down Expand Up @@ -54,7 +57,19 @@ func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// TODO include connection information, such as the remote address, in the logs.

wsc := newWebsocketConn(r.Context(), conn, ws.connParams)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
select {
case <-ws.shutdown:
cancel()
case <-ctx.Done():
// in case websocket connection is closed and server is not in shutdown mode
// we need to release this goroutine from waiting
}
kirugan marked this conversation as resolved.
Show resolved Hide resolved
}()

wsc := newWebsocketConn(ctx, conn, ws.connParams)

for {
_, wsc.r, err = wsc.conn.Reader(wsc.ctx)
Expand Down
2 changes: 1 addition & 1 deletion jsonrpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testConnection(t *testing.T, ctx context.Context, method jsonrpc.Method, li
require.NoError(t, rpc.RegisterMethods(method))

// Server
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger()))
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger()))
kirugan marked this conversation as resolved.
Show resolved Hide resolved

// Client
conn, resp, err := websocket.Dial(ctx, srv.URL, nil) //nolint:bodyclose // websocket package closes resp.Body for us.
Expand Down
15 changes: 13 additions & 2 deletions node/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
}
}

func (h *httpService) registerOnShutdown(f func()) {
h.srv.RegisterOnShutdown(f)
}

func makeHTTPService(host string, port uint16, handler http.Handler) *httpService {
portStr := strconv.FormatUint(uint64(port), 10)
return &httpService{
Expand Down Expand Up @@ -108,9 +112,11 @@
listener = makeWSMetrics()
}

shutdown := make(chan struct{})

mux := http.NewServeMux()
for path, server := range servers {
wsHandler := jsonrpc.NewWebsocket(server, log)
wsHandler := jsonrpc.NewWebsocket(server, shutdown, log)
if listener != nil {
wsHandler = wsHandler.WithListener(listener)
}
Expand All @@ -124,7 +130,12 @@
if corsEnabled {
handler = cors.Default().Handler(handler)
}
return makeHTTPService(host, port, handler)

service := makeHTTPService(host, port, handler)

Check failure on line 134 in node/http.go

View workflow job for this annotation

GitHub Actions / lint

importShadow: shadow of imported from 'github.com/NethermindEth/juno/service' package 'service' (gocritic)
service.registerOnShutdown(func() {
close(shutdown)
})
return service
}

func makeMetrics(host string, port uint16) *httpService {
Expand Down
4 changes: 4 additions & 0 deletions rpc/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ type EventsChunk struct {
ContinuationToken string `json:"continuation_token,omitempty"`
}

type SubscriptionID struct {
ID uint64 `json:"subscription_id"`
}

/****************************************************
Events Handlers
*****************************************************/
Expand Down
2 changes: 1 addition & 1 deletion rpc/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func TestMultipleSubscribeNewHeadsAndUnsubscribe(t *testing.T) {
Params: []jsonrpc.Parameter{{Name: "id"}},
Handler: handler.Unsubscribe,
}))
ws := jsonrpc.NewWebsocket(server, log)
ws := jsonrpc.NewWebsocket(server, nil, log)
kirugan marked this conversation as resolved.
Show resolved Hide resolved
httpSrv := httptest.NewServer(ws)
conn1, _, err := websocket.Dial(ctx, httpSrv.URL, nil)
require.NoError(t, err)
Expand Down
7 changes: 7 additions & 0 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ var (
ErrUnsupportedTxVersion = &jsonrpc.Error{Code: 61, Message: "the transaction version is not supported"}
ErrUnsupportedContractClassVersion = &jsonrpc.Error{Code: 62, Message: "the contract class version is not supported"}
ErrUnexpectedError = &jsonrpc.Error{Code: 63, Message: "An unexpected error occurred"}
ErrTooManyBlocksBack = &jsonrpc.Error{Code: 68, Message: "Cannot go back more than 1024 blocks"}

// These errors can be only be returned by Juno-specific methods.
ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"}
)

const (
maxBlocksBack = 1024
maxEventChunkSize = 10240
maxEventFilterKeys = 1024
traceCacheSize = 128
Expand Down Expand Up @@ -311,6 +313,11 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen
Name: "starknet_specVersion",
Handler: h.SpecVersion,
},
{
Name: "starknet_subscribeEvents",
Params: []jsonrpc.Parameter{{Name: "from_address"}, {Name: "keys"}, {Name: "block", Optional: true}},
Handler: h.SubscribeEvents,
},
{
Name: "juno_subscribeNewHeads",
Handler: h.SubscribeNewHeads,
Expand Down
179 changes: 179 additions & 0 deletions rpc/subscriptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package rpc

import (
"context"
"encoding/json"
"sync"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/jsonrpc"
)

const subscribeEventsChunkSize = 1024

func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys [][]felt.Felt,
blockID *BlockID,
) (*SubscriptionID, *jsonrpc.Error) {
w, ok := jsonrpc.ConnFromContext(ctx)
if !ok {
return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil)
}

lenKeys := len(keys)
for _, k := range keys {
lenKeys += len(k)
}
if lenKeys > maxEventFilterKeys {
return nil, ErrTooManyKeysInFilter
}

var requestedHeader *core.Header
headHeader, err := h.bcReader.HeadsHeader()
if err != nil {
return nil, ErrInternal.CloneWithData(err.Error())
}

if blockID == nil {
requestedHeader = headHeader
} else {
var rpcErr *jsonrpc.Error
requestedHeader, rpcErr = h.blockHeaderByID(blockID)
if rpcErr != nil {
return nil, rpcErr
}

// Todo: should the pending block be included in the head count?
if headHeader.Number >= maxBlocksBack && requestedHeader.Number <= headHeader.Number-maxBlocksBack {
return nil, ErrTooManyBlocksBack
}
}

id := h.idgen()
subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx)
sub := &subscription{
cancel: subscriptionCtxCancel,
conn: w,
}
h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()

headerSub := h.newHeads.Subscribe()
sub.wg.Go(func() {
defer func() {
h.unsubscribe(sub, id)
headerSub.Unsubscribe()
}()

// The specification doesn't enforce ordering of events therefore events from new blocks can be sent before
// old blocks.
// Todo: see if sub's wg can be used?
wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()

for {
select {
case <-subscriptionCtx.Done():
return
case header := <-headerSub.Recv():
h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys)
}
}
}()

h.processEvents(subscriptionCtx, w, id, requestedHeader.Number, headHeader.Number, fromAddr, keys)

wg.Wait()
})

return &SubscriptionID{ID: id}, nil
}

func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, keys [][]felt.Felt) {
filter, err := h.bcReader.EventFilter(fromAddr, keys)
if err != nil {
h.log.Warnw("Error creating event filter", "err", err)
return
}
defer h.callAndLogErr(filter.Close, "Error closing event filter in events subscription")

if err = setEventFilterRange(filter, &BlockID{Number: from}, &BlockID{Number: to}, to); err != nil {
h.log.Warnw("Error setting event filter range", "err", err)
return
}

var cToken *blockchain.ContinuationToken
filteredEvents, cToken, err := filter.Events(cToken, subscribeEventsChunkSize)
if err != nil {
h.log.Warnw("Error filtering events", "err", err)
return
}

err = sendEvents(ctx, w, filteredEvents, id)
if err != nil {
h.log.Warnw("Error sending events", "err", err)
return
}

for cToken != nil {
filteredEvents, cToken, err = filter.Events(cToken, subscribeEventsChunkSize)
if err != nil {
h.log.Warnw("Error filtering events", "err", err)
return
}

err = sendEvents(ctx, w, filteredEvents, id)
if err != nil {
h.log.Warnw("Error sending events", "err", err)
return
}
}
}

func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.FilteredEvent, id uint64) error {
for _, event := range events {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Pending block doesn't have a number
var blockNumber *uint64
if event.BlockHash != nil {
blockNumber = &(event.BlockNumber)
}
emittedEvent := &EmittedEvent{
BlockNumber: blockNumber,
BlockHash: event.BlockHash,
TransactionHash: event.TransactionHash,
Event: &Event{
From: event.From,
Keys: event.Keys,
Data: event.Data,
},
}

resp, err := json.Marshal(jsonrpc.Request{
Version: "2.0",
Method: "starknet_subscriptionEvents",
Params: map[string]any{
"subscription_id": id,
"result": emittedEvent,
},
})
if err != nil {
return err
}

_, err = w.Write(resp)
if err != nil {
return err
}
}
}
return nil
}
Loading
Loading