Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get by select.limit in BandwidthDatabase.get_latest_transaction #6821

Merged
merged 3 commits into from
Mar 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/tribler/core/components/bandwidth_accounting/db/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Optional, Union

from pony.orm import Database, count, db_session, select, sum
from pony.orm import Database, count, db_session, desc, select

from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction
from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData
Expand All @@ -28,6 +28,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes,
self.store_all_transactions = store_all_transactions

self.database = Database()

# This attribute is internally called by Pony on startup, though pylint cannot detect it
# with the static analysis.
# pylint: disable=unused-variable
Expand Down Expand Up @@ -77,7 +78,7 @@ def get_my_latest_transactions(self, limit: Optional[int] = None) -> List[Bandwi
"""
results = []
db_txs = select(tx for tx in self.BandwidthTransaction
if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key)\
if tx.public_key_a == self.my_pub_key or tx.public_key_b == self.my_pub_key) \
.limit(limit)
for db_tx in db_txs:
results.append(BandwidthTransactionData.from_db(db_tx))
Expand All @@ -91,8 +92,12 @@ def get_latest_transaction(self, public_key_a: bytes, public_key_b: bytes) -> Ba
:param public_key_b: The public key of the party receiving the bandwidth.
:return The latest transaction between the two specified parties, or None if no such transaction exists.
"""
db_obj = self.BandwidthTransaction.get(public_key_a=public_key_a, public_key_b=public_key_b)
return BandwidthTransactionData.from_db(db_obj) if db_obj else None
db_tx = select(tx for tx in self.BandwidthTransaction
if public_key_a == tx.public_key_a and public_key_b == tx.public_key_b) \
.order_by(lambda tx: desc(tx.sequence_number)) \
.first()

return BandwidthTransactionData.from_db(db_tx) if db_tx else None

@db_session
def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100) -> List[BandwidthTransactionData]:
Expand All @@ -103,7 +108,7 @@ def get_latest_transactions(self, public_key: bytes, limit: Optional[int] = 100)
:return The latest transactions of the specified public key, or an empty list if no transactions exist.
"""
db_txs = select(tx for tx in self.BandwidthTransaction
if public_key in (tx.public_key_a, tx.public_key_b))\
if public_key in (tx.public_key_a, tx.public_key_b)) \
.limit(limit)
return [BandwidthTransactionData.from_db(db_txn) for db_txn in db_txs]

Expand All @@ -114,8 +119,8 @@ def get_total_taken(self, public_key: bytes) -> int:
:param public_key: The public key of the peer of which we want to determine the total taken.
:return The total amount of bandwidth taken by the specified peer, in bytes.
"""
return sum(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_a == public_key)
return select(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_a == public_key).sum()

@db_session
def get_total_given(self, public_key: bytes) -> int:
Expand All @@ -124,8 +129,8 @@ def get_total_given(self, public_key: bytes) -> int:
:param public_key: The public key of the peer of which we want to determine the total given.
:return The total amount of bandwidth given by the specified peer, in bytes.
"""
return sum(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_b == public_key)
return select(transaction.amount for transaction in self.BandwidthTransaction
if transaction.public_key_b == public_key).sum()

@db_session
def get_balance(self, public_key: bytes) -> int:
Expand Down