diff --git a/node/client.go b/node/client.go index 3e8f0bc..990a8aa 100644 --- a/node/client.go +++ b/node/client.go @@ -540,3 +540,28 @@ func (c *client) GetAccounts(ctx context.Context) ([]eth.Address, error) { return accountList, nil } + +func (c *client) GetBalance(ctx context.Context, addr eth.Address, numberOrTag eth.BlockNumberOrTag) (uint64, error) { + request := jsonrpc.Request{ + ID: jsonrpc.ID{Num: 1}, + Method: "eth_getBalance", + Params: jsonrpc.MustParams(addr, &numberOrTag), + } + + applyContext(ctx, &request) + response, err := c.Request(ctx, &request) + if err != nil { + return 0, errors.Wrap(err, "could not make request") + } + if response.Error != nil { + return 0, errors.New(string(*response.Error)) + } + + q := eth.Quantity{} + err = json.Unmarshal(response.Result, &q) + if err != nil { + return 0, errors.Wrap(err, "could not decode result") + } + + return q.UInt64(), nil +} diff --git a/node/interfaces.go b/node/interfaces.go index 11d1c65..d30e0fe 100644 --- a/node/interfaces.go +++ b/node/interfaces.go @@ -38,9 +38,12 @@ type Client interface { // is used to call read-only functions of a smart contract Call(ctx context.Context, msg eth.Transaction, numberOrTag eth.BlockNumberOrTag) (string, error) - // executes get_accounts and retrieves address array + // executes get_accounts and retrieves address array GetAccounts(ctx context.Context) ([]eth.Address, error) + // GetBalance returns available balance + GetBalance(ctx context.Context, addr eth.Address, numberOrTag eth.BlockNumberOrTag) (uint64, error) + // ChainId returns the chain id ChainId(ctx context.Context) (string, error) diff --git a/node/mocks/node.go b/node/mocks/node.go index fa717d6..70d7dc0 100644 --- a/node/mocks/node.go +++ b/node/mocks/node.go @@ -185,7 +185,7 @@ func (m *MockClient) Call(ctx context.Context, msg eth.Transaction) (string, err // Call indicates an expected call of Call. func (mr *MockClientMockRecorder) Call(ctx, msg interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockClient)(nil).EstimateGas), ctx, msg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockClient)(nil).Call), ctx, msg) } // GetAccounts mocks base method. @@ -200,7 +200,22 @@ func (m *MockClient) GetAccounts(ctx context.Context) ([]eth.Address, error) { // Call indicates an expected call of GetAccounts. func (mr *MockClientMockRecorder) GetAccounts(ctx, msg interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccounts", reflect.TypeOf((*MockClient)(nil).EstimateGas), ctx, msg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccounts", reflect.TypeOf((*MockClient)(nil).GetAccounts), ctx, msg) +} + +// GetBalance mocks base method. +func (m *MockClient) GetBalance(ctx context.Context) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalance", ctx) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Call indicates an expected call of GetBalance. +func (mr *MockClientMockRecorder) GetBalance(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockClient)(nil).GetBalance), ctx, msg) } // ChainId mocks base method. diff --git a/node/rpc_test.go b/node/rpc_test.go index 36dcd2f..1c24cb4 100644 --- a/node/rpc_test.go +++ b/node/rpc_test.go @@ -17,7 +17,7 @@ import ( func getClient(t *testing.T, ctx context.Context) node.Client { base_url := os.Getenv("ETHLIBS_TEST_URL") if base_url == "" { - t.Skip("ETHLIBS_TEST_URL not set, skipping test. Set to a valid websocket URL to execute this test.") + t.Skip("ETHLIBS_TEST_URL not set, skipping test. Set to a valid http/ws URL to execute this test.") } auth_id := os.Getenv("AUTH_ID") if auth_id == "" { @@ -66,6 +66,15 @@ func TestConnection_Get_Accounts(t *testing.T) { require.Empty(t, accountList) } +func TestConnection_Get_Balance(t *testing.T) { + ctx := context.Background() + conn := getClient(t, ctx) + + bal, err := conn.GetBalance(ctx, *eth.MustAddress("0x148772F29058DcC772613260b078dCa8C14afF6c"), *eth.MustBlockNumberOrTag("latest")) + require.NoError(t, err) + require.NotNil(t, bal) +} + func TestConnection_GetTransactionCount(t *testing.T) { ctx := context.Background() conn := getClient(t, ctx)