Skip to content

Commit

Permalink
Merge pull request #6821 from Tribler/fix/6802
Browse files Browse the repository at this point in the history
Replace `get` by `select.limit` in BandwidthDatabase.get_latest_transaction
  • Loading branch information
drew2a authored Mar 23, 2022
2 parents 4055d9f + 94cdc67 commit f68dc90
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/tribler/core/components/bandwidth_accounting/db/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Optional, Union

from pony.orm import Database, count, db_session, select, sum
from pony.orm import Database, count, db_session, desc, select

from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction
from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData
Expand All @@ -28,6 +28,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes,
self.store_all_transactions = store_all_transactions

self.database = Database()

# This attribute is internally called by Pony on startup, though pylint cannot detect it
# with the static analysis.
# pylint: disable=unused-variable
Expand Down Expand Up @@ -77,7 +78,7 @@ def get_my_latest_transactions(self, limit: Optional[int] = None) -> List[Bandwi
"""
results = []
db_txs = select(tx for tx in self.BandwidthTransaction
if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key)\
if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key) \
.limit(limit)
for db_tx in db_txs:
results.append(BandwidthTransactionData.from_db(db_tx))
Expand All @@ -91,8 +92,12 @@ def get_latest_transaction(self, public_key_a: bytes, public_key_b: bytes) -> Ba
:param public_key_b: The public key of the party receiving the bandwidth.
:return The latest transaction between the two specified parties, or None if no such transaction exists.
"""
db_obj = self.BandwidthTransaction.get(public_key_a=public_key_a, public_key_b=public_key_b)
return BandwidthTransactionData.from_db(db_obj) if db_obj else None
db_tx = select(tx for tx in self.BandwidthTransaction
if public_key_a == tx.public_key_a and public_key_b == tx.public_key_b) \
.order_by(lambda tx: desc(tx.sequence_number)) \
.first()

return BandwidthTransactionData.from_db(db_tx) if db_tx else None

@db_session
def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100) -> List[BandwidthTransactionData]:
Expand All @@ -103,7 +108,7 @@ def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100)
:return The latest transactions of the specified public key, or an empty list if no transactions exist.
"""
db_txs = select(tx for tx in self.BandwidthTransaction
if public_key in (tx.public_key_a, tx.public_key_b))\
if public_key in (tx.public_key_a, tx.public_key_b)) \
.limit(limit)
return [BandwidthTransactionData.from_db(db_txn) for db_txn in db_txs]

Expand All @@ -114,8 +119,8 @@ def get_total_taken(self, public_key: bytes) -> int:
:param public_key: The public key of the peer of which we want to determine the total taken.
:return The total amount of bandwidth taken by the specified peer, in bytes.
"""
return sum(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_a == public_key)
return select(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_a == public_key).sum()

@db_session
def get_total_given(self, public_key: bytes) -> int:
Expand All @@ -124,8 +129,8 @@ def get_total_given(self, public_key: bytes) -> int:
:param public_key: The public key of the peer of which we want to determine the total given.
:return The total amount of bandwidth given by the specified peer, in bytes.
"""
return sum(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_b == public_key)
return select(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_b == public_key).sum()

@db_session
def get_balance(self, public_key: bytes) -> int:
Expand Down

0 comments on commit f68dc90

Please sign in to comment.