Skip to content

Commit

Permalink
Added query feature to bandwidth accounting
Browse files Browse the repository at this point in the history
  • Loading branch information
devos50 committed Oct 27, 2020
1 parent 6dd092a commit b8b4d75
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from asyncio import Future
from binascii import unhexlify
from pathlib import Path
from random import choice
from typing import Dict

from ipv8.community import Community
Expand All @@ -15,7 +16,9 @@
from tribler_core.modules.bandwidth_accounting import EMPTY_SIGNATURE
from tribler_core.modules.bandwidth_accounting.cache import BandwidthTransactionSignCache
from tribler_core.modules.bandwidth_accounting.database import BandwidthDatabase
from tribler_core.modules.bandwidth_accounting.payload import BandwidthTransactionPayload
from tribler_core.modules.bandwidth_accounting.payload import BandwidthTransactionPayload, \
BandwidthTransactionQueryPayload
from tribler_core.modules.bandwidth_accounting.settings import BandwidthAccountingSettings
from tribler_core.modules.bandwidth_accounting.transaction import BandwidthTransactionData
from tribler_core.utilities.unicode import hexlify

Expand All @@ -34,6 +37,7 @@ def __init__(self, *args, **kwargs) -> None:
:param persistence: The database that stores transactions, will be created if not provided.
:param database_path: The path at which the database will be created. Defaults to the current working directory.
"""
self.settings = kwargs.pop('settings', BandwidthAccountingSettings())
self.database = kwargs.pop('database', None)
self.database_path = Path(kwargs.pop('database_path', ''))

Expand All @@ -46,6 +50,9 @@ def __init__(self, *args, **kwargs) -> None:
self.database = BandwidthDatabase(self.database_path, self.my_pk)

self.add_message_handler(BandwidthTransactionPayload, self.received_transaction)
self.add_message_handler(BandwidthTransactionQueryPayload, self.received_query)

self.register_task("query_peers", self.query_random_peer, interval=self.settings.outgoing_query_interval)

self.logger.info("Started bandwidth accounting community with public key %s", hexlify(self.my_pk))

Expand Down Expand Up @@ -75,22 +82,22 @@ def do_payout(self, peer: Peer, amount: int) -> Future:
with db_session:
self.database.BandwidthTransaction.insert(tx)
cache = self.request_cache.add(BandwidthTransactionSignCache(self, tx))
self.send_transaction(tx, peer, cache.number)
self.send_transaction(tx, peer.address, cache.number)

return cache.future

def send_transaction(self, transaction: BandwidthTransactionData, peer: Peer, request_id: int) -> None:
def send_transaction(self, transaction: BandwidthTransactionData, address: Address, request_id: int) -> None:
"""
Send a provided transaction to another party.
:param transaction: The BandwidthTransaction to send to the other party.
:param peer: The peer that will receive the transaction.
:param peer: The IP address and port of the peer.
:param request_id: The identifier of the message, is usually provided by a request cache.
"""
payload = BandwidthTransactionPayload.from_transaction(transaction, request_id)
packet = self._ez_pack(self._prefix, 1, [payload], False)
self.endpoint.send(peer.address, packet)
self.endpoint.send(address, packet)

async def received_transaction(self, source_address: Address, data: bytes) -> None:
def received_transaction(self, source_address: Address, data: bytes) -> None:
"""
Callback when we receive a transaction from another peer.
:param source_address: The network address of the peer that has sent us the transaction.
Expand All @@ -103,38 +110,73 @@ async def received_transaction(self, source_address: Address, data: bytes) -> No
self.logger.info("Transaction %s not valid, ignoring it", tx)
return

latest_tx = self.database.get_latest_transaction(tx.public_key_a, tx.public_key_b)

if payload.public_key_b == self.my_peer.public_key.key_to_bin():
from_peer = Peer(payload.public_key_a, source_address)
if latest_tx:
# Check if the amount in the received transaction is higher than the amount of the latest one
# in the database.
if payload.amount > latest_tx.amount:
# Sign it, store it, and send it back
if payload.public_key_a == self.my_pk or payload.public_key_b == self.my_pk:
# This transaction involves this peer.
latest_tx = self.database.get_latest_transaction(tx.public_key_a, tx.public_key_b)
if payload.public_key_b == self.my_peer.public_key.key_to_bin():
from_peer = Peer(payload.public_key_a, source_address)
if latest_tx:
# Check if the amount in the received transaction is higher than the amount of the latest one
# in the database.
if payload.amount > latest_tx.amount:
# Sign it, store it, and send it back
tx.sign(self.my_peer.key, as_a=False)
self.database.BandwidthTransaction.insert(tx)
self.send_transaction(tx, from_peer.address, payload.request_id)
else:
self.logger.info("Received older bandwidth transaction - sending back the latest one")
self.send_transaction(latest_tx, from_peer.address, payload.request_id)
else:
# This transaction is the first one with party A. Sign it, store it, and send it back.
tx.sign(self.my_peer.key, as_a=False)
self.database.BandwidthTransaction.insert(tx)
self.send_transaction(tx, from_peer, payload.request_id)
else:
self.logger.info("Received older bandwidth transaction - sending back the latest one")
self.send_transaction(latest_tx, from_peer, payload.request_id)
else:
# This transaction is the first one with party A. Sign it, store it, and send it back.
tx.sign(self.my_peer.key, as_a=False)
self.database.BandwidthTransaction.insert(tx)
from_peer = Peer(payload.public_key_a, source_address)
self.send_transaction(tx, from_peer, payload.request_id)
elif payload.public_key_a == self.my_peer.public_key.key_to_bin():
# It seems that we initiated this transaction. Check if we are waiting for it.
cache = self.request_cache.get("bandwidth-tx-sign", payload.request_id)
if not cache:
self.logger.info("Received bandwidth transaction %s without associated cache entry, ignoring it", tx)
return

if not latest_tx or (latest_tx and latest_tx.amount >= tx.amount):
self.database.BandwidthTransaction.insert(tx)

cache.future.set_result(tx)
from_peer = Peer(payload.public_key_a, source_address)
self.send_transaction(tx, from_peer.address, payload.request_id)
elif payload.public_key_a == self.my_peer.public_key.key_to_bin():
# It seems that we initiated this transaction. Check if we are waiting for it.
cache = self.request_cache.get("bandwidth-tx-sign", payload.request_id)
if not cache:
self.logger.info("Received bandwidth transaction %s without associated cache entry, ignoring it",
tx)
return

if not latest_tx or (latest_tx and latest_tx.amount >= tx.amount):
self.database.BandwidthTransaction.insert(tx)

cache.future.set_result(tx)
else:
# This transaction involves two unknown peers. We can add it to our database.
self.database.BandwidthTransaction.insert(tx)

def query_random_peer(self) -> None:
"""
Query a random peer neighbouring peer and ask their bandwidth transactions.
"""
peers = list(self.network.verified_peers)
random_peer = choice(peers)
if random_peer:
self.query_transactions(random_peer)

def query_transactions(self, peer: Peer) -> None:
"""
Query the transactions of a specific peer and ask for their bandwidth transactions.
:param peer: The peer to send the query to.
"""
self.logger.info("Querying the transactions of peer %s:%d", *peer.address)
payload = BandwidthTransactionQueryPayload()
packet = self._ez_pack(self._prefix, 2, [payload], False)
self.endpoint.send(peer.address, packet)

def received_query(self, source_address: Address, data: bytes) -> None:
"""
We received a query from another peer.
:param source_address: The network address of the peer that has sent us the query.
:param data: The serialized, raw data in the packet.
"""
my_txs = self.database.get_latest_transactions(limit=self.settings.max_tx_returned_in_query)
self.logger.debug("Sending %d bandwidth transaction(s) to peer %s:%d", len(my_txs), *source_address)
for tx in my_txs:
self.send_transaction(tx, source_address, 0)

def get_statistics(self) -> Dict:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def __init__(self, db_path: Path, my_pub_key: bytes) -> None:
with db_session:
self.MiscData(name="db_version", value=str(self.CURRENT_DB_VERSION))

@db_session
def get_latest_transactions(self, limit=None) -> List[BandwidthTransactionData]:
"""
Return all latest transactions involving you.
:param limit: An optional integer, to limit the number of results returned. Pass None to get all results.
:return A list containing all latest transactions involving you.
"""
results = []
db_txs = select(tx for tx in self.BandwidthTransaction
if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key)\
.limit(limit)
for db_tx in db_txs:
results.append(BandwidthTransactionData.from_db(db_tx))
return results

@db_session
def get_latest_transaction(self, public_key_a: bytes, public_key_b: bytes) -> BandwidthTransactionData:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ def from_transaction(cls, transaction: BandwidthTransaction, request_id: int) ->
transaction.timestamp,
request_id
)


@vp_compile
class BandwidthTransactionQueryPayload(VariablePayload):
msg_id = 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class BandwidthAccountingSettings:
"""
This class contains several settings related to the bandwidth accounting mechanism.
"""
outgoing_query_interval: int = 30 # The interval at which we send out queries to other peers, in seconds.
max_tx_returned_in_query: int = 10 # The maximum number of bandwidth transactions to return in response to a query.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_invalid_transaction(self):
tx.signature_a = b"invalid"
self.nodes[0].overlay.database.BandwidthTransaction.insert(tx)
cache = self.nodes[0].overlay.request_cache.add(BandwidthTransactionSignCache(self.nodes[0].overlay, tx))
self.nodes[0].overlay.send_transaction(tx, other_peer, cache.number)
self.nodes[0].overlay.send_transaction(tx, other_peer.address, cache.number)

await self.deliver_messages()

Expand All @@ -75,7 +75,7 @@ async def test_ignore_unknown_transaction(self):

tx = BandwidthTransactionData(1, pk1, pk2, EMPTY_SIGNATURE, EMPTY_SIGNATURE, 1000)
tx.sign(self.nodes[0].my_peer.key, as_a=True)
self.nodes[0].overlay.send_transaction(tx, self.nodes[1].my_peer, 1234)
self.nodes[0].overlay.send_transaction(tx, self.nodes[1].my_peer.address, 1234)
await self.deliver_messages()
assert not self.nodes[0].overlay.database.get_latest_transaction(pk1, pk2)

Expand All @@ -91,14 +91,44 @@ async def test_concurrent_transaction_out_of_order(self):

# Send them in reverse order
cache = self.nodes[0].overlay.request_cache.add(BandwidthTransactionSignCache(self.nodes[0].overlay, tx1))
self.nodes[0].overlay.send_transaction(tx2, self.nodes[1].my_peer, cache.number)
self.nodes[0].overlay.send_transaction(tx2, self.nodes[1].my_peer.address, cache.number)
await self.deliver_messages()

# This one should be ignored by node 1
cache = self.nodes[0].overlay.request_cache.add(BandwidthTransactionSignCache(self.nodes[0].overlay, tx1))
self.nodes[0].overlay.send_transaction(tx1, self.nodes[1].my_peer, cache.number)
self.nodes[0].overlay.send_transaction(tx1, self.nodes[1].my_peer.address, cache.number)
await self.deliver_messages()

# Both parties should have the transaction with amount 2000 in their database
assert self.nodes[0].overlay.database.get_total_taken(pk1) == 2000
assert self.nodes[1].overlay.database.get_total_taken(pk1) == 2000

async def test_querying_peer(self):
"""
Test whether node C can query node B to get the transaction between A and B.
"""
await self.nodes[0].overlay.do_payout(self.nodes[1].overlay.my_peer, 500)

# Add an additional node to the experiment
self.add_node_to_experiment(self.create_node())
self.nodes[2].overlay.query_transactions(self.nodes[1].my_peer)

await self.deliver_messages()

pk1 = self.nodes[0].my_peer.public_key.key_to_bin()
assert self.nodes[2].overlay.database.get_total_taken(pk1) == 500

async def test_query_random_peer(self):
"""
Test whether node C can query node B to get the transaction between A and B.
"""
await self.nodes[0].overlay.do_payout(self.nodes[1].overlay.my_peer, 500)

# Add an additional node to the experiment
self.add_node_to_experiment(self.create_node())
self.nodes[2].overlay.query_random_peer()

await self.deliver_messages()

pk1 = self.nodes[0].my_peer.public_key.key_to_bin()
assert self.nodes[2].overlay.database.get_total_taken(pk1) == 500
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def test_add_transaction(bandwidth_db):
assert latest_tx.amount == 4000


@db_session
def test_get_latest_transactions(bandwidth_db):
assert not bandwidth_db.get_latest_transactions()

tx1 = BandwidthTransactionData(1, b"a", bandwidth_db.my_pub_key, EMPTY_SIGNATURE, EMPTY_SIGNATURE, 3000)
bandwidth_db.BandwidthTransaction.insert(tx1)
tx2 = BandwidthTransactionData(1, bandwidth_db.my_pub_key, b"c", EMPTY_SIGNATURE, EMPTY_SIGNATURE, 3000)
bandwidth_db.BandwidthTransaction.insert(tx2)
tx3 = BandwidthTransactionData(1, b"c", b"d", EMPTY_SIGNATURE, EMPTY_SIGNATURE, 3000)
bandwidth_db.BandwidthTransaction.insert(tx3)

assert len(bandwidth_db.get_latest_transactions()) == 2
assert len(bandwidth_db.get_latest_transactions(limit=1)) == 1


@db_session
def test_get_latest_transaction(bandwidth_db):
assert not bandwidth_db.get_latest_transaction(b"a", b"b")
Expand Down

0 comments on commit b8b4d75

Please sign in to comment.