diff --git a/README.md b/README.md index 98ce920..6ee3c9d 100644 --- a/README.md +++ b/README.md @@ -70,3 +70,12 @@ root.addHandler(handler) # To get logging output from the QtradeAPI class for debugging: logging.getLogger('qtrade').setLevel(logging.DEBUG) ``` + +## Testing + +``` bash +pip3 install --user pytest +pytest +# Or alternatively +python3.7 -m pytest +``` diff --git a/qtrade_client/api.py b/qtrade_client/api.py index 98f6d57..98975aa 100644 --- a/qtrade_client/api.py +++ b/qtrade_client/api.py @@ -107,7 +107,7 @@ def set_hmac(self, hmac_pair): self.rs.auth = QtradeAuth(hmac_pair) def balances(self): - return {b['currency']: b['balance'] for b in self.get("/v1/user/balances")['balances']} + return {b['currency']: Decimal(b['balance']) for b in self.get("/v1/user/balances")['balances']} def get(self, endpoint, *args, **kwargs): return self._req('get', endpoint, *args, **kwargs) @@ -161,22 +161,19 @@ def order(self, order_type, price, value=None, amount=None, market_id=None, mark def balances_merged(self): """ Get total balances including order balances """ - bals = self.balances() - ords = self.orders(open=True) - for o in ords: - base_c = self.markets[o['market_id']]['base_currency']['code'] - market_c = self.markets[o['market_id']][ - 'market_currency']['code'] - if o['order_type'] == 'buy_limit': - bals[base_c] = str(Decimal(bals.setdefault( - base_c, 0)) + Decimal(o['base_amount'])) - if o['order_type'] == 'sell_limit': - bals[market_c] = str(Decimal(bals.setdefault( - market_c, 0)) + Decimal(o['market_amount_remaining'])) - return bals + bals = self.balances_all() + merged = {} + for k, v in list(bals['spendable'].items()) + list(bals['in_orders'].items()): + merged.setdefault(k, 0) + merged[k] += Decimal(v) + return merged def balances_all(self): - return {b['currency']: b for b in self.get("/v1/user/balances_all")['balances']} + all_bal = self.get("/v1/user/balances_all") + return { + "spendable": {b['currency']: Decimal(b['balance']) for b in all_bal['balances']}, + "in_orders": {b['currency']: Decimal(b['balance']) for b in all_bal['order_balances']}, + } def cancel_all_orders(self): for o in self.orders(open=True): @@ -208,23 +205,28 @@ def _refresh_tickers(self): self._tickers.update({m['id_hr']: m for m in res['markets']}) self._tickers_age = time.time() + @property + def currencies(self): + self._refresh_common() + return self._currencies_map + @property def markets(self): """ Markets may be indexed either by id or string """ - self._refresh_markets() + self._refresh_common() return self._markets_map - def _refresh_markets(self): + def _refresh_common(self): """ Lazy load and reload every market_update_interval. """ if self._markets_map is None or (time.time() - self._markets_age) > self.market_update_interval: # Index our market information by market string common = self.get("/v1/common") - self.currencies_map = {c['code']: c for c in common['currencies']} + self._currencies_map = {c['code']: c for c in common['currencies']} # Set some convenience keys so we can pass around just the dict for m in common['markets']: m['string'] = "{market_currency}_{base_currency}".format(**m) - m['base_currency'] = self.currencies_map[m['base_currency']] - m['market_currency'] = self.currencies_map[m['market_currency']] + m['base_currency'] = self._currencies_map[m['base_currency']] + m['market_currency'] = self._currencies_map[m['market_currency']] self._markets_map = {m['string']: m for m in common['markets']} self._markets_map.update({m['id']: m for m in common['markets']}) self._markets_age = time.time() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..eced74b --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,75 @@ +import pytest +import json +import requests +import unittest.mock as mock +from decimal import Decimal + +from qtrade_client.api import QtradeAPI, QtradeAuth + + +@pytest.fixture +def api(): + return QtradeAPI("http://localhost:9898/") + + +@mock.patch('time.time', mock.MagicMock(return_value=12345)) +def test_hmac(): + s = requests.Session() + s.auth = QtradeAuth("256:vwj043jtrw4o5igw4oi5jwoi45g") + r = s.prepare_request(requests.Request("GET", "http://google.com/")) + assert r.headers['Authorization'] == "HMAC-SHA256 256:iyfC4n+bE+3hLgMJns1Z67FKA7O5qm5PgDvZHGraMTQ=" + + +def test_balances(api): + api._req = mock.MagicMock(return_value=json.loads(""" + { + "balances": [ + { + "currency": "TAO", + "balance": "1" + }, + { + "currency": "ZANO", + "balance": "0.14355714" + }, + { + "currency": "VLS", + "balance": "0" + } + ] + }""")) + bal = api.balances() + assert bal == {'TAO': Decimal('1'), 'ZANO': Decimal('0.14355714'), 'VLS': Decimal('0')} + + +def test_balances_all(api): + api._req = mock.MagicMock(return_value=json.loads(""" + { + "balances": [ + { + "currency": "BIS", + "balance": "6.97936" + }, + { + "currency": "BTC", + "balance": "0.1970952" + } + ], + "order_balances": [ + { + "currency": "BAN", + "balance": "401184.76191351" + }, + { + "currency": "BTC", + "balance": "0.1708" + } + ], + "limit_used": 0, + "limit_remaining": 50000, + "limit": 50000 + }""")) + bal = api.balances_merged() + assert bal == {'BIS': Decimal('6.97936'), 'BTC': Decimal('0.3678952'), 'BAN': Decimal('401184.76191351')} + bal = api.balances_all() + assert bal == {'spendable': {'BIS': Decimal('6.97936'), 'BTC': Decimal('0.1970952')}, 'in_orders': {'BAN': Decimal('401184.76191351'), 'BTC': Decimal('0.1708')}}