diff --git a/rpc/events.go b/rpc/events.go index 7a322fde1e..943315c3b5 100644 --- a/rpc/events.go +++ b/rpc/events.go @@ -44,6 +44,10 @@ type EventsChunk struct { ContinuationToken string `json:"continuation_token,omitempty"` } +type SubscriptionID struct { + ID uint64 `json:"subscription_id"` +} + /**************************************************** Events Handlers *****************************************************/ diff --git a/rpc/handlers.go b/rpc/handlers.go index 4d4d35b508..a69f3b5edc 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -65,12 +65,15 @@ var ( ErrUnsupportedTxVersion = &jsonrpc.Error{Code: 61, Message: "the transaction version is not supported"} ErrUnsupportedContractClassVersion = &jsonrpc.Error{Code: 62, Message: "the contract class version is not supported"} ErrUnexpectedError = &jsonrpc.Error{Code: 63, Message: "An unexpected error occurred"} + ErrTooManyBlocksBack = &jsonrpc.Error{Code: 68, Message: "Cannot go back more than 1024 blocks"} + ErrCallOnPending = &jsonrpc.Error{Code: 69, Message: "This method does not support being called on the pending block"} // These errors can be only be returned by Juno-specific methods. ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"} ) const ( + maxBlocksBack = 1024 maxEventChunkSize = 10240 maxEventFilterKeys = 1024 traceCacheSize = 128 @@ -334,6 +337,11 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen Name: "starknet_specVersion", Handler: h.SpecVersion, }, + { + Name: "starknet_subscribeEvents", + Params: []jsonrpc.Parameter{{Name: "from_address"}, {Name: "keys"}, {Name: "block", Optional: true}}, + Handler: h.SubscribeEvents, + }, { Name: "juno_subscribeNewHeads", Handler: h.SubscribeNewHeads, diff --git a/rpc/subscriptions.go b/rpc/subscriptions.go new file mode 100644 index 0000000000..40586400db --- /dev/null +++ b/rpc/subscriptions.go @@ -0,0 +1,180 @@ +package rpc + +import ( + "context" + "encoding/json" + "sync" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/jsonrpc" +) + +const subscribeEventsChunkSize = 1024 + +func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys [][]felt.Felt, + blockID *BlockID, +) (*SubscriptionID, *jsonrpc.Error) { + w, ok := jsonrpc.ConnFromContext(ctx) + if !ok { + return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) + } + + lenKeys := len(keys) + for _, k := range keys { + lenKeys += len(k) + } + if lenKeys > maxEventFilterKeys { + return nil, ErrTooManyKeysInFilter + } + + var requestedHeader *core.Header + headHeader, err := h.bcReader.HeadsHeader() + if err != nil { + return nil, ErrInternal.CloneWithData(err.Error()) + } + + if blockID == nil { + requestedHeader = headHeader + } else { + if blockID.Pending { + return nil, ErrCallOnPending + } + + var rpcErr *jsonrpc.Error + requestedHeader, rpcErr = h.blockHeaderByID(blockID) + if rpcErr != nil { + return nil, rpcErr + } + + if headHeader.Number >= maxBlocksBack && requestedHeader.Number <= headHeader.Number-maxBlocksBack { + return nil, ErrTooManyBlocksBack + } + } + + id := h.idgen() + subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx) + sub := &subscription{ + cancel: subscriptionCtxCancel, + conn: w, + } + h.mu.Lock() + h.subscriptions[id] = sub + h.mu.Unlock() + + headerSub := h.newHeads.Subscribe() + sub.wg.Go(func() { + defer func() { + h.unsubscribe(sub, id) + headerSub.Unsubscribe() + }() + + // The specification doesn't enforce ordering of events therefore events from new blocks can be sent before + // old blocks. + // Todo: see if sub's wg can be used? + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case <-subscriptionCtx.Done(): + return + case header := <-headerSub.Recv(): + + h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys) + } + } + }() + + h.processEvents(subscriptionCtx, w, id, requestedHeader.Number, headHeader.Number, fromAddr, keys) + + wg.Wait() + }) + + return &SubscriptionID{ID: id}, nil +} + +func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, keys [][]felt.Felt) { + filter, err := h.bcReader.EventFilter(fromAddr, keys) + if err != nil { + h.log.Warnw("Error creating event filter", "err", err) + return + } + + defer func() { + h.callAndLogErr(filter.Close, "Error closing event filter in events subscription") + }() + + if err = setEventFilterRange(filter, &BlockID{Number: from}, &BlockID{Number: to}, to); err != nil { + h.log.Warnw("Error setting event filter range", "err", err) + return + } + + filteredEvents, cToken, err := filter.Events(nil, subscribeEventsChunkSize) + if err != nil { + h.log.Warnw("Error filtering events", "err", err) + return + } + + err = sendEvents(ctx, w, filteredEvents, id) + if err != nil { + h.log.Warnw("Error sending events", "err", err) + return + } + + for cToken != nil { + filteredEvents, cToken, err = filter.Events(cToken, subscribeEventsChunkSize) + if err != nil { + h.log.Warnw("Error filtering events", "err", err) + return + } + + err = sendEvents(ctx, w, filteredEvents, id) + if err != nil { + h.log.Warnw("Error sending events", "err", err) + return + } + } +} + +func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.FilteredEvent, id uint64) error { + for _, event := range events { + select { + case <-ctx.Done(): + return ctx.Err() + default: + emittedEvent := &EmittedEvent{ + BlockNumber: &event.BlockNumber, // This always be filled as subscribeEvents cannot be called on pending block + BlockHash: event.BlockHash, + TransactionHash: event.TransactionHash, + Event: &Event{ + From: event.From, + Keys: event.Keys, + Data: event.Data, + }, + } + + resp, err := json.Marshal(jsonrpc.Request{ + Version: "2.0", + Method: "starknet_subscriptionEvents", + Params: map[string]any{ + "subscription_id": id, + "result": emittedEvent, + }, + }) + if err != nil { + return err + } + + _, err = w.Write(resp) + if err != nil { + return err + } + } + } + return nil +} diff --git a/rpc/subscriptions_test.go b/rpc/subscriptions_test.go new file mode 100644 index 0000000000..9271b01ace --- /dev/null +++ b/rpc/subscriptions_test.go @@ -0,0 +1,338 @@ +package rpc + +import ( + "context" + "encoding/json" + "io" + "net" + "testing" + "time" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/clients/feeder" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/feed" + "github.com/NethermindEth/juno/jsonrpc" + "github.com/NethermindEth/juno/mocks" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +// Due to the difference in how some test files in rpc use "package rpc" vs "package rpc_test" it was easiest to copy +// the fakeConn here. +// Todo: move all the subscription related test here +type fakeConn struct { + w io.Writer +} + +func (fc *fakeConn) Write(p []byte) (int, error) { + return fc.w.Write(p) +} + +func (fc *fakeConn) Equal(other jsonrpc.Conn) bool { + fc2, ok := other.(*fakeConn) + if !ok { + return false + } + return fc.w == fc2.w +} + +func TestSubscribeEvents(t *testing.T) { + log := utils.NewNopZapLogger() + + t.Run("Return error if too many keys in filter", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) + + keys := make([][]felt.Felt, 1024+1) + fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) + assert.Zero(t, id) + assert.Equal(t, ErrTooManyKeysInFilter, rpcErr) + }) + + t.Run("Return error if called on pending block", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) + + keys := make([][]felt.Felt, 1) + fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) + blockID := &BlockID{Pending: true} + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1}, nil) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + assert.Zero(t, id) + assert.Equal(t, ErrCallOnPending, rpcErr) + }) + + t.Run("Return error if block is too far back", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) + + keys := make([][]felt.Felt, 1) + fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) + blockID := &BlockID{Number: 0} + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + + // Note the end of the window doesn't need to be tested because if requested block number is more than the + // head, a block not found error will be returned. This behaviour has been tested in various other test, and we + // don't need to test it here again. + t.Run("head is 1024", func(t *testing.T) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + assert.Zero(t, id) + assert.Equal(t, ErrTooManyBlocksBack, rpcErr) + }) + + t.Run("head is more than 1024", func(t *testing.T) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 2024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + assert.Zero(t, id) + assert.Equal(t, ErrTooManyBlocksBack, rpcErr) + }) + }) + + n := utils.Ptr(utils.Sepolia) + client := feeder.NewTestClient(t, n) + gw := adaptfeeder.New(client) + + b1, err := gw.BlockByNumber(context.Background(), 56377) + require.NoError(t, err) + + fromAddr := new(felt.Felt).SetBytes([]byte("some address")) + keys := [][]felt.Felt{{*new(felt.Felt).SetBytes([]byte("key1"))}} + + filteredEvents := []*blockchain.FilteredEvent{ + { + Event: b1.Receipts[0].Events[0], + BlockNumber: b1.Number, + BlockHash: new(felt.Felt).SetBytes([]byte("b1")), + TransactionHash: b1.Transactions[0].Hash(), + }, + { + Event: b1.Receipts[1].Events[0], + BlockNumber: b1.Number + 1, + BlockHash: new(felt.Felt).SetBytes([]byte("b2")), + TransactionHash: b1.Transactions[1].Hash(), + }, + } + + var emittedEvents []*EmittedEvent + for _, e := range filteredEvents { + emittedEvents = append(emittedEvents, &EmittedEvent{ + Event: &Event{ + From: e.From, + Keys: e.Keys, + Data: e.Data, + }, + BlockHash: e.BlockHash, + BlockNumber: &e.BlockNumber, + TransactionHash: e.TransactionHash, + }) + } + + t.Run("Events from old blocks", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) + mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(filteredEvents, nil, nil) + mockEventFilterer.EXPECT().Close().AnyTimes() + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + ctx, cancel := context.WithCancel(context.Background()) + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) + require.Nil(t, rpcErr) + + var marshalledResponses [][]byte + for _, e := range emittedEvents { + resp, err := marshalSubscriptionResponse(e, id.ID) + require.NoError(t, err) + marshalledResponses = append(marshalledResponses, resp) + } + + for _, m := range marshalledResponses { + got := make([]byte, len(m)) + _, err := clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(m), string(got)) + } + cancel() + }) + + t.Run("Events when continuation token is not nil", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) + mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + + cToken := new(blockchain.ContinuationToken) + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return( + []*blockchain.FilteredEvent{filteredEvents[0]}, cToken, nil) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return( + []*blockchain.FilteredEvent{filteredEvents[1]}, nil, nil) + mockEventFilterer.EXPECT().Close().AnyTimes() + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + ctx, cancel := context.WithCancel(context.Background()) + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) + require.Nil(t, rpcErr) + + var marshalledResponses [][]byte + for _, e := range emittedEvents { + resp, err := marshalSubscriptionResponse(e, id.ID) + require.NoError(t, err) + marshalledResponses = append(marshalledResponses, resp) + } + + for _, m := range marshalledResponses { + got := make([]byte, len(m)) + _, err := clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(m), string(got)) + } + cancel() + }) + + t.Run("Events from new blocks", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) + + handler := New(mockChain, mockSyncer, nil, "", log) + headerFeed := feed.New[*core.Header]() + handler.newHeads = headerFeed + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[0]}, nil, nil) + mockEventFilterer.EXPECT().Close().AnyTimes() + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + ctx, cancel := context.WithCancel(context.Background()) + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) + require.Nil(t, rpcErr) + + resp, err := marshalSubscriptionResponse(emittedEvents[0], id.ID) + require.NoError(t, err) + + got := make([]byte, len(resp)) + _, err = clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(resp), string(got)) + + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[1]}, nil, nil) + + headerFeed.Send(&core.Header{Number: b1.Number + 1}) + + resp, err = marshalSubscriptionResponse(emittedEvents[1], id.ID) + require.NoError(t, err) + + got = make([]byte, len(resp)) + _, err = clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(resp), string(got)) + + cancel() + time.Sleep(100 * time.Millisecond) + }) +} + +func marshalSubscriptionResponse(e *EmittedEvent, id uint64) ([]byte, error) { + return json.Marshal(jsonrpc.Request{ + Version: "2.0", + Method: "starknet_subscriptionEvents", + Params: map[string]any{ + "subscription_id": id, + "result": e, + }, + }) +}