Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
otherview committed Sep 13, 2024
1 parent a2569e0 commit b753aea
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 62 deletions.
11 changes: 11 additions & 0 deletions test/datagen/thor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package datagen

import (
"crypto/rand"
"github.com/vechain/thor/v2/thor"
)

func RandBytes32() (b thor.Bytes32) {
rand.Read(b[:])
return
}
20 changes: 10 additions & 10 deletions thorclient/thorclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,43 +202,43 @@ func (c *Client) ChainTag() (byte, error) {
}

// SubscribeBlocks subscribes to block updates over WebSocket.
func (c *Client) SubscribeBlocks() (*common.Subscription[*blocks.JSONCollapsedBlock], error) {
func (c *Client) SubscribeBlocks(pos string) (*common.Subscription[*blocks.JSONCollapsedBlock], error) {
if c.wsConn == nil {
return nil, fmt.Errorf("not a websocket typed client")
}
return c.wsConn.SubscribeBlocks("")
return c.wsConn.SubscribeBlocks(pos)
}

// SubscribeEvents subscribes to event updates over WebSocket.
func (c *Client) SubscribeEvents() (*common.Subscription[*subscriptions.EventMessage], error) {
func (c *Client) SubscribeEvents(pos string, filter *subscriptions.EventFilter) (*common.Subscription[*subscriptions.EventMessage], error) {
if c.wsConn == nil {
return nil, fmt.Errorf("not a websocket typed client")
}
return c.wsConn.SubscribeEvents("")
return c.wsConn.SubscribeEvents(pos, filter)
}

// SubscribeTransfers subscribes to transfer updates over WebSocket.
func (c *Client) SubscribeTransfers() (*common.Subscription[*subscriptions.TransferMessage], error) {
func (c *Client) SubscribeTransfers(pos string, filter *subscriptions.TransferFilter) (*common.Subscription[*subscriptions.TransferMessage], error) {
if c.wsConn == nil {
return nil, fmt.Errorf("not a websocket typed client")
}
return c.wsConn.SubscribeTransfers("")
return c.wsConn.SubscribeTransfers(pos, filter)
}

// SubscribeBeats2 subscribes to Beat2 message updates over WebSocket.
func (c *Client) SubscribeBeats2() (*common.Subscription[*subscriptions.Beat2Message], error) {
func (c *Client) SubscribeBeats2(pos string) (*common.Subscription[*subscriptions.Beat2Message], error) {
if c.wsConn == nil {
return nil, fmt.Errorf("not a websocket typed client")
}
return c.wsConn.SubscribeBeats2("")
return c.wsConn.SubscribeBeats2(pos)
}

// SubscribeTxPool subscribes to pending transaction updates over WebSocket.
func (c *Client) SubscribeTxPool() (*common.Subscription[*subscriptions.PendingTxIDMessage], error) {
func (c *Client) SubscribeTxPool(txID *thor.Bytes32) (*common.Subscription[*subscriptions.PendingTxIDMessage], error) {
if c.wsConn == nil {
return nil, fmt.Errorf("not a websocket typed client")
}
return c.wsConn.SubscribeTxPool("")
return c.wsConn.SubscribeTxPool(txID)
}

// convertToBatchCallData converts a transaction and sender address to batch call data format.
Expand Down
69 changes: 57 additions & 12 deletions thorclient/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package wsclient

import (
"fmt"
"github.com/vechain/thor/v2/thor"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -54,8 +55,30 @@ func NewClient(url string) (*Client, error) {

// SubscribeEvents subscribes to blockchain events based on the provided query.
// It returns a Subscription that streams event messages or an error if the connection fails.
func (c *Client) SubscribeEvents(query string) (*common.Subscription[*subscriptions.EventMessage], error) {
conn, err := c.connect("/subscriptions/event", query)
func (c *Client) SubscribeEvents(pos string, filter *subscriptions.EventFilter) (*common.Subscription[*subscriptions.EventMessage], error) {
queryValues := &url.Values{}
queryValues.Add("pos", pos)
if filter != nil {
if filter.Address != nil {
queryValues.Add("address", filter.Address.String())
}
if filter.Topic0 != nil {
queryValues.Add("topic0", filter.Topic0.String())
}
if filter.Topic1 != nil {
queryValues.Add("topic1", filter.Topic1.String())
}
if filter.Topic2 != nil {
queryValues.Add("topic2", filter.Topic2.String())
}
if filter.Topic3 != nil {
queryValues.Add("topic3", filter.Topic3.String())
}
if filter.Topic4 != nil {
queryValues.Add("topic4", filter.Topic4.String())
}
}
conn, err := c.connect("/subscriptions/event", queryValues)
if err != nil {
return nil, fmt.Errorf("unable to connect - %w", err)
}
Expand All @@ -65,8 +88,10 @@ func (c *Client) SubscribeEvents(query string) (*common.Subscription[*subscripti

// SubscribeBlocks subscribes to block updates based on the provided query.
// It returns a Subscription that streams block messages or an error if the connection fails.
func (c *Client) SubscribeBlocks(query string) (*common.Subscription[*blocks.JSONCollapsedBlock], error) {
conn, err := c.connect("/subscriptions/block", query)
func (c *Client) SubscribeBlocks(pos string) (*common.Subscription[*blocks.JSONCollapsedBlock], error) {
queryValues := &url.Values{}
queryValues.Add("pos", pos)
conn, err := c.connect("/subscriptions/block", queryValues)
if err != nil {
return nil, fmt.Errorf("unable to connect - %w", err)
}
Expand All @@ -76,8 +101,21 @@ func (c *Client) SubscribeBlocks(query string) (*common.Subscription[*blocks.JSO

// SubscribeTransfers subscribes to transfer events based on the provided query.
// It returns a Subscription that streams transfer messages or an error if the connection fails.
func (c *Client) SubscribeTransfers(query string) (*common.Subscription[*subscriptions.TransferMessage], error) {
conn, err := c.connect("/subscriptions/transfer", query)
func (c *Client) SubscribeTransfers(pos string, filter *subscriptions.TransferFilter) (*common.Subscription[*subscriptions.TransferMessage], error) {
queryValues := &url.Values{}
queryValues.Add("pos", pos)
if filter != nil {
if filter.TxOrigin != nil {
queryValues.Add("txOrigin", filter.TxOrigin.String())
}
if filter.Sender != nil {
queryValues.Add("sender", filter.Sender.String())
}
if filter.Recipient != nil {
queryValues.Add("recipient", filter.Recipient.String())
}
}
conn, err := c.connect("/subscriptions/transfer", queryValues)
if err != nil {
return nil, fmt.Errorf("unable to connect - %w", err)
}
Expand All @@ -87,8 +125,13 @@ func (c *Client) SubscribeTransfers(query string) (*common.Subscription[*subscri

// SubscribeTxPool subscribes to pending transaction pool updates based on the provided query.
// It returns a Subscription that streams pending transaction messages or an error if the connection fails.
func (c *Client) SubscribeTxPool(query string) (*common.Subscription[*subscriptions.PendingTxIDMessage], error) {
conn, err := c.connect("/subscriptions/txpool", query)
func (c *Client) SubscribeTxPool(txID *thor.Bytes32) (*common.Subscription[*subscriptions.PendingTxIDMessage], error) {
queryValues := &url.Values{}
if txID != nil {
queryValues.Add("id", txID.String())
}

conn, err := c.connect("/subscriptions/txpool", queryValues)
if err != nil {
return nil, fmt.Errorf("unable to connect - %w", err)
}
Expand All @@ -98,8 +141,10 @@ func (c *Client) SubscribeTxPool(query string) (*common.Subscription[*subscripti

// SubscribeBeats2 subscribes to Beat2 messages based on the provided query.
// It returns a Subscription that streams Beat2 messages or an error if the connection fails.
func (c *Client) SubscribeBeats2(query string) (*common.Subscription[*subscriptions.Beat2Message], error) {
conn, err := c.connect("/subscriptions/beat2", query)
func (c *Client) SubscribeBeats2(pos string) (*common.Subscription[*subscriptions.Beat2Message], error) {
queryValues := &url.Values{}
queryValues.Add("pos", pos)
conn, err := c.connect("/subscriptions/beat2", queryValues)
if err != nil {
return nil, fmt.Errorf("unable to connect - %w", err)
}
Expand Down Expand Up @@ -148,12 +193,12 @@ func subscribe[T any](conn *websocket.Conn) *common.Subscription[*T] {

// connect establishes a WebSocket connection to the specified endpoint and query.
// It returns the connection or an error if the connection fails.
func (c *Client) connect(endpoint, rawQuery string) (*websocket.Conn, error) {
func (c *Client) connect(endpoint string, queryValues *url.Values) (*websocket.Conn, error) {
u := url.URL{
Scheme: c.scheme,
Host: c.host,
Path: endpoint,
RawQuery: rawQuery,
RawQuery: queryValues.Encode(),
}

conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
Expand Down
Loading

0 comments on commit b753aea

Please sign in to comment.