diff --git a/src/tribler/core/components/bandwidth_accounting/db/database.py b/src/tribler/core/components/bandwidth_accounting/db/database.py index 620352f9d36..434ec5caa21 100644 --- a/src/tribler/core/components/bandwidth_accounting/db/database.py +++ b/src/tribler/core/components/bandwidth_accounting/db/database.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import List, Optional, Union -from pony.orm import Database, count, db_session, select, sum +from pony.orm import Database, count, db_session, desc, select from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData @@ -28,6 +28,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes, self.store_all_transactions = store_all_transactions self.database = Database() + # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. # pylint: disable=unused-variable @@ -77,7 +78,7 @@ def get_my_latest_transactions(self, limit: Optional[int] = None) -> List[Bandwi """ results = [] db_txs = select(tx for tx in self.BandwidthTransaction - if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key)\ + if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key) \ .limit(limit) for db_tx in db_txs: results.append(BandwidthTransactionData.from_db(db_tx)) @@ -91,8 +92,12 @@ def get_latest_transaction(self, public_key_a: bytes, public_key_b: bytes) -> Ba :param public_key_b: The public key of the party receiving the bandwidth. :return The latest transaction between the two specified parties, or None if no such transaction exists. """ - db_obj = self.BandwidthTransaction.get(public_key_a=public_key_a, public_key_b=public_key_b) - return BandwidthTransactionData.from_db(db_obj) if db_obj else None + db_tx = select(tx for tx in self.BandwidthTransaction + if public_key_a == tx.public_key_a and public_key_b == tx.public_key_b) \ + .order_by(lambda tx: desc(tx.sequence_number)) \ + .first() + + return BandwidthTransactionData.from_db(db_tx) if db_tx else None @db_session def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100) -> List[BandwidthTransactionData]: @@ -103,7 +108,7 @@ def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100) :return The latest transactions of the specified public key, or an empty list if no transactions exist. """ db_txs = select(tx for tx in self.BandwidthTransaction - if public_key in (tx.public_key_a, tx.public_key_b))\ + if public_key in (tx.public_key_a, tx.public_key_b)) \ .limit(limit) return [BandwidthTransactionData.from_db(db_txn) for db_txn in db_txs] @@ -114,8 +119,8 @@ def get_total_taken(self, public_key: bytes) -> int: :param public_key: The public key of the peer of which we want to determine the total taken. :return The total amount of bandwidth taken by the specified peer, in bytes. """ - return sum(transaction.amount for transaction in self.BandwidthTransaction - if transaction.public_key_a == public_key) + return select(transaction.amount for transaction in self.BandwidthTransaction + if transaction.public_key_a == public_key).sum() @db_session def get_total_given(self, public_key: bytes) -> int: @@ -124,8 +129,8 @@ def get_total_given(self, public_key: bytes) -> int: :param public_key: The public key of the peer of which we want to determine the total given. :return The total amount of bandwidth given by the specified peer, in bytes. """ - return sum(transaction.amount for transaction in self.BandwidthTransaction - if transaction.public_key_b == public_key) + return select(transaction.amount for transaction in self.BandwidthTransaction + if transaction.public_key_b == public_key).sum() @db_session def get_balance(self, public_key: bytes) -> int: