From a0da5e0e25e7c97dd349af725a739b4104392225 Mon Sep 17 00:00:00 2001 From: David Wolinsky Date: Sat, 29 Apr 2023 20:10:53 -0700 Subject: [PATCH] [python] add a transaction management layer This provides a framework for managing as many transactions from a single account at once * The AccountSequenceNumber allocates up to 100 outstanding sequence numbers to maximize the number of concurrent transactions in the happy path. * The transaction manager provides async workers that push a transaction from submission through to validating completion Together they provide the basic harness for scaling transaction submission on the Aptos blockchain from a single account. --- .../sdk/aptos_sdk/account_sequence_number.py | 178 ++++++++++++++ .../sdk/aptos_sdk/transaction_worker.py | 225 ++++++++++++++++++ 2 files changed, 403 insertions(+) create mode 100644 ecosystem/python/sdk/aptos_sdk/account_sequence_number.py create mode 100644 ecosystem/python/sdk/aptos_sdk/transaction_worker.py diff --git a/ecosystem/python/sdk/aptos_sdk/account_sequence_number.py b/ecosystem/python/sdk/aptos_sdk/account_sequence_number.py new file mode 100644 index 0000000000000..266a25047d708 --- /dev/null +++ b/ecosystem/python/sdk/aptos_sdk/account_sequence_number.py @@ -0,0 +1,178 @@ +# Copyright © Aptos Foundation +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import logging +import time +from typing import Optional + +from aptos_sdk.account_address import AccountAddress +from aptos_sdk.async_client import RestClient + + +class AccountSequenceNumber: + """ + A managed wrapper around sequence numbers that implements the trivial flow control used by the + Aptos faucet: + * Submit up to 50 transactions per account in parallel with a timeout of 20 seconds + * If local assumes 50 are in flight, determine the actual committed state from the network + * If there are less than 50 due to some being committed, adjust the window + * If 50 are in flight Wait .1 seconds before re-evaluating + * If ever waiting more than 30 seconds restart the sequence number to the current on-chain state + Assumptions: + * Accounts are expected to be managed by a single AccountSequenceNumber and not used otherwise. + * They are initialized to the current on-chain state, so if there are already transactions in + flight, they make take some time to reset. + * Accounts are automatically initialized if not explicitly + + Notes: + * This is co-routine safe, that is many async tasks can be reading from this concurrently. + * The synchronize method will create a barrier that prevents additional next_sequence_number + calls until it is complete. + * This only manages the distribution of sequence numbers it does not help handle transaction + failures. + """ + + client: RestClient + account: AccountAddress + lock = asyncio.Lock + + maximum_in_flight: int = 100 + maximum_wait_time = 30 + sleep_time = 0.01 + + last_committed_number: Optional[int] + current_number: Optional[int] + + def __init__(self, client: RestClient, account: AccountAddress): + self.client = client + self.account = account + self.lock = asyncio.Lock() + + self.last_uncommitted_number = None + self.current_number = None + + async def next_sequence_number(self, block: bool = True) -> Optional[int]: + """ + Returns the next sequence number available on this account. This leverages a lock to + guarantee first-in, first-out ordering of requests. + """ + await self.lock.acquire() + try: + if self.last_uncommitted_number is None or self.current_number is None: + await self.initialize() + if ( + self.current_number - self.last_uncommitted_number + >= self.maximum_in_flight + ): + await self.__update() + start_time = time.time() + while ( + self.current_number - self.last_uncommitted_number + >= self.maximum_in_flight + ): + if not block: + return None + await asyncio.sleep(self.sleep_time) + if time.time() - start_time > self.maximum_wait_time: + logging.warn( + f"Waited over 30 seconds for a transaction to commit, resyncing {self.account.address().hex()}" + ) + await self.initialize() + else: + await self.__update() + next_number = self.current_number + self.current_number += 1 + finally: + self.lock.release() + return next_number + + async def initialize(self): + """Optional initializer. called by next_sequence_number if not called prior.""" + self.current_number = await self.__current_sequence_number() + self.last_uncommitted_number = self.current_number + + async def synchronize(self): + """ + Poll the network until all submitted transactions have either been committed or until + the maximum wait time has elapsed. This will prevent any calls to next_sequence_number + until this called has returned. + """ + if self.last_uncommitted_number == self.current_number: + return + + await self.lock.acquire() + try: + await self.__update() + start_time = time.time() + while self.last_uncommitted_number != self.current_number: + print(f"{self.last_uncommitted_number} {self.current_number}") + if time.time() - start_time > self.maximum_wait_time: + logging.warn( + f"Waited over 30 seconds for a transaction to commit, resyncing {self.account.address}" + ) + await self.initialize() + else: + await asyncio.sleep(self.sleep_time) + await self.__update() + finally: + self.lock.release() + + async def __update(self): + self.last_uncommitted_number = await self.__current_sequence_number() + return self.last_uncommitted_number + + async def __current_sequence_number(self) -> int: + return await self.client.account_sequence_number(self.account) + + +import unittest +import unittest.mock + + +class Test(unittest.IsolatedAsyncioTestCase): + async def test_common_path(self): + """ + Verifies that: + * AccountSequenceNumber returns sequential numbers starting from 0 + * When the account has been updated on-chain include that in computations 100 -> 105 + * Ensure that none is returned if the call for next_sequence_number would block + * Ensure that synchronize completes if the value matches on-chain + """ + patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.account_sequence_number", return_value=0 + ) + patcher.start() + + rest_client = RestClient("https://fullnode.devnet.aptoslabs.com/v1") + account_sequence_number = AccountSequenceNumber( + rest_client, AccountAddress.from_hex("b0b") + ) + last_seq_num = 0 + for seq_num in range(5): + last_seq_num = await account_sequence_number.next_sequence_number() + self.assertEqual(last_seq_num, seq_num) + + patcher.stop() + patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.account_sequence_number", return_value=5 + ) + patcher.start() + + for seq_num in range(AccountSequenceNumber.maximum_in_flight): + last_seq_num = await account_sequence_number.next_sequence_number() + self.assertEqual(last_seq_num, seq_num + 5) + + self.assertEqual( + await account_sequence_number.next_sequence_number(block=False), None + ) + next_sequence_number = last_seq_num + 1 + patcher.stop() + patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.account_sequence_number", + return_value=next_sequence_number, + ) + patcher.start() + + self.assertNotEqual(account_sequence_number.current_number, last_seq_num) + await account_sequence_number.synchronize() + self.assertEqual(account_sequence_number.current_number, next_sequence_number) diff --git a/ecosystem/python/sdk/aptos_sdk/transaction_worker.py b/ecosystem/python/sdk/aptos_sdk/transaction_worker.py new file mode 100644 index 0000000000000..ce0a7648a22d2 --- /dev/null +++ b/ecosystem/python/sdk/aptos_sdk/transaction_worker.py @@ -0,0 +1,225 @@ +# Copyright © Aptos Foundation +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import logging +import typing + +from aptos_sdk.account import Account +from aptos_sdk.account_address import AccountAddress +from aptos_sdk.account_sequence_number import AccountSequenceNumber +from aptos_sdk.async_client import RestClient +from aptos_sdk.transactions import SignedTransaction, TransactionPayload + + +class TransactionWorker: + """ + The TransactionWorker provides a simple framework for receiving payloads to be processed. It + acquires new sequence numbers and calls into the callback to produce a signed transaction, and + then submits the transaction. In another task, it waits for resolution of the submission + process or get pre-execution validation error. + + Note: This is not a particularly robust solution, as it lacks any framework to handle failed + transactions with functionality like retries or checking whether the framework is online. + This is the responsibility of a higher-level framework. + """ + + _account: Account + _account_sequence_number: AccountSequenceNumber + _rest_client: RestClient + _transaction_generator: typing.Callable[ + [Account, int], typing.Awaitable[SignedTransaction] + ] + _started: bool + _stopped: bool + _outstanding_transactions: asyncio.Queue + _outstanding_transactions_task: typing.Optional[asyncio.Task] + _processed_transactions: asyncio.Queue + _process_transactions_task: typing.Optional[asyncio.Task] + + def __init__( + self, + account: Account, + rest_client: RestClient, + transaction_generator: typing.Callable[ + [Account, int], typing.Awaitable[SignedTransaction] + ], + ): + self._account = account + self._account_sequence_number = AccountSequenceNumber( + rest_client, account.address() + ) + self._rest_client = rest_client + self._transaction_generator = transaction_generator + + self._started = False + self._stopped = False + self._outstanding_transactions = asyncio.Queue() + self._processed_transactions = asyncio.Queue() + + def account(self) -> AccountAddress: + return self._account.address() + + async def _submit_transactions_task(self): + try: + while True: + sequence_number = ( + await self._account_sequence_number.next_sequence_number() + ) + transaction = await self._transaction_generator( + self._account, sequence_number + ) + txn_hash_awaitable = self._rest_client.submit_bcs_transaction( + transaction + ) + await self._outstanding_transactions.put( + (txn_hash_awaitable, sequence_number) + ) + except asyncio.CancelledError: + return + except Exception as e: + # This is insufficient, if we hit this we either need to bail or resolve the potential errors + logging.error(e, exc_info=True) + + async def _process_transactions_task(self): + try: + while True: + # Always start waiting for one + ( + txn_awaitable, + sequence_number, + ) = await self._outstanding_transactions.get() + awaitables = [txn_awaitable] + sequence_numbers = [sequence_number] + + # Only acquire if there are more + while not self._outstanding_transactions.empty(): + ( + txn_awaitable, + sequence_number, + ) = await self._outstanding_transactions.get() + awaitables.append(txn_awaitable) + sequence_numbers.append(sequence_number) + outputs = await asyncio.gather(*awaitables, return_exceptions=True) + + for (output, sequence_number) in zip(outputs, sequence_numbers): + if isinstance(output, BaseException): + await self._processed_transactions.put( + (sequence_number, None, output) + ) + else: + await self._processed_transactions.put( + (sequence_number, output, None) + ) + except asyncio.CancelledError: + return + except Exception as e: + # This is insufficient, if we hit this we either need to bail or resolve the potential errors + logging.error(e, exc_info=True) + + async def next_processed_transaction( + self, + ) -> (int, typing.Optional[str], typing.Optional[Exception]): + return await self._processed_transactions.get() + + def stop(self): + """Stop the tasks for managing transactions""" + if not self._started: + raise Exception("Start not yet called") + if self._stopped: + raise Exception("Already stopped") + self._stopped = True + + self._submit_transactions_task.cancel() + self._process_transactions_task.cancel() + + def start(self): + """Begin the tasks for managing transactions""" + if self._started: + raise Exception("Already started") + self._started = True + + self._submit_transactions_task = asyncio.create_task( + self._submit_transactions_task() + ) + self._process_transactions_task = asyncio.create_task( + self._process_transactions_task() + ) + + +class TransactionQueue: + """Provides a queue model for pushing transactions into the TransactionWorker.""" + + _client: RestClient + _outstanding_transactions: asyncio.Queue + + def __init__(self, client: RestClient): + self._client = client + self._outstanding_transactions = asyncio.Queue() + + async def push(self, payload: TransactionPayload): + await self._outstanding_transactions.put(payload) + + async def next(self, sender: Account, sequence_number: int) -> SignedTransaction: + payload = await self._outstanding_transactions.get() + return await self._client.create_bcs_signed_transaction( + sender, payload, sequence_number=sequence_number + ) + + +import unittest +import unittest.mock + +from aptos_sdk.bcs import Serializer +from aptos_sdk.transactions import EntryFunction, TransactionArgument + + +class Test(unittest.IsolatedAsyncioTestCase): + async def test_common_path(self): + transaction_arguments = [ + TransactionArgument(AccountAddress.from_hex("b0b"), Serializer.struct), + TransactionArgument(100, Serializer.u64), + ] + payload = EntryFunction.natural( + "0x1::aptos_accounts", + "transfer", + [], + transaction_arguments, + ) + + seq_num_patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.account_sequence_number", return_value=0 + ) + seq_num_patcher.start() + submit_txn_patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.submit_bcs_transaction", + return_value="0xff", + ) + submit_txn_patcher.start() + + rest_client = RestClient("https://fullnode.devnet.aptoslabs.com/v1") + txn_queue = TransactionQueue(rest_client) + txn_worker = TransactionWorker(Account.generate(), rest_client, txn_queue.next) + txn_worker.start() + + await txn_queue.push(payload) + processed_txn = await txn_worker.next_processed_transaction() + self.assertEqual(processed_txn[0], 0) + self.assertEqual(processed_txn[1], "0xff") + self.assertEqual(processed_txn[2], None) + + submit_txn_patcher.stop() + exception = Exception("Power overwhelming") + submit_txn_patcher = unittest.mock.patch( + "aptos_sdk.async_client.RestClient.submit_bcs_transaction", + side_effect=exception, + ) + submit_txn_patcher.start() + + await txn_queue.push(payload) + processed_txn = await txn_worker.next_processed_transaction() + self.assertEqual(processed_txn[0], 1) + self.assertEqual(processed_txn[1], None) + self.assertEqual(processed_txn[2], exception) + + txn_worker.stop()