From 8bcfb9e1dd53b467b5e954dfb5a9bbe16c8786db Mon Sep 17 00:00:00 2001 From: ytx1991 Date: Tue, 13 Jun 2023 12:11:49 -0700 Subject: [PATCH] Add filters for get_transaction_count --- chia/rpc/wallet_rpc_api.py | 7 ++++++- chia/rpc/wallet_rpc_client.py | 13 +++++++++---- chia/wallet/wallet_transaction_store.py | 23 +++++++++++++++++++++-- tests/wallet/rpc/test_wallet_rpc.py | 7 +++++++ tests/wallet/test_transaction_store.py | 18 ++++++++++++++++++ 5 files changed, 61 insertions(+), 7 deletions(-) diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index 209bde21f519..1a00cc9e12ab 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -923,7 +923,12 @@ async def get_transactions(self, request: Dict) -> EndpointResult: async def get_transaction_count(self, request: Dict) -> EndpointResult: wallet_id = int(request["wallet_id"]) - count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet(wallet_id) + type_filter = None + if "type_filter" in request: + type_filter = TransactionTypeFilter.from_json_dict(request["type_filter"]) + count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet( + wallet_id, confirmed=request.get("confirmed", None), type_filter=type_filter + ) return { "count": count, "wallet_id": wallet_id, diff --git a/chia/rpc/wallet_rpc_client.py b/chia/rpc/wallet_rpc_client.py index 92676f981f99..a58e10bb531c 100644 --- a/chia/rpc/wallet_rpc_client.py +++ b/chia/rpc/wallet_rpc_client.py @@ -164,11 +164,16 @@ async def get_transactions( async def get_transaction_count( self, wallet_id: int, + confirmed: Optional[bool] = None, + type_filter: Optional[TransactionTypeFilter] = None, ) -> List[TransactionRecord]: - res = await self.fetch( - "get_transaction_count", - {"wallet_id": wallet_id}, - ) + request: Dict[str, Any] = {"wallet_id": wallet_id} + if type_filter is not None: + request["type_filter"] = type_filter.to_json_dict() + + if confirmed is not None: + request["confirmed"] = confirmed + res = await self.fetch("get_transaction_count", request) return res["count"] async def get_next_address(self, wallet_id: int, new_address: bool) -> str: diff --git a/chia/wallet/wallet_transaction_store.py b/chia/wallet/wallet_transaction_store.py index d823e20de1ae..5cc705da309f 100644 --- a/chia/wallet/wallet_transaction_store.py +++ b/chia/wallet/wallet_transaction_store.py @@ -309,10 +309,29 @@ async def get_transactions_between( return [TransactionRecord.from_bytes(row[0]) for row in rows] - async def get_transaction_count_for_wallet(self, wallet_id) -> int: + async def get_transaction_count_for_wallet( + self, + wallet_id: int, + confirmed: Optional[bool] = None, + type_filter: Optional[TransactionTypeFilter] = None, + ) -> int: + confirmed_str = "" + if confirmed is not None: + confirmed_str = f"AND confirmed={int(confirmed)}" + + if type_filter is None: + type_filter_str = "" + else: + type_filter_str = ( + f"AND type {'' if type_filter.mode == FilterMode.include else 'NOT'} " + f"IN ({','.join([str(x) for x in type_filter.values])})" + ) async with self.db_wrapper.reader_no_transaction() as conn: rows = list( - await conn.execute_fetchall("SELECT COUNT(*) FROM transaction_record where wallet_id=?", (wallet_id,)) + await conn.execute_fetchall( + f"SELECT COUNT(*) FROM transaction_record where wallet_id=? {type_filter_str} {confirmed_str}", + (wallet_id,), + ) ) return 0 if len(rows) == 0 else rows[0][0] diff --git a/tests/wallet/rpc/test_wallet_rpc.py b/tests/wallet/rpc/test_wallet_rpc.py index ac9c6b004520..5149d75185bf 100644 --- a/tests/wallet/rpc/test_wallet_rpc.py +++ b/tests/wallet/rpc/test_wallet_rpc.py @@ -859,6 +859,13 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro assert len(all_transactions) > 0 transaction_count = await client.get_transaction_count(1) assert transaction_count == len(all_transactions) + assert await client.get_transaction_count(1, confirmed=False) == 0 + assert ( + await client.get_transaction_count( + 1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]) + ) + == 0 + ) @pytest.mark.asyncio diff --git a/tests/wallet/test_transaction_store.py b/tests/wallet/test_transaction_store.py index 08ce61e51e86..ca9313eeca5c 100644 --- a/tests/wallet/test_transaction_store.py +++ b/tests/wallet/test_transaction_store.py @@ -283,6 +283,24 @@ async def test_transaction_count_for_wallet() -> None: assert await store.get_transaction_count_for_wallet(1) == 5 assert await store.get_transaction_count_for_wallet(2) == 2 + assert ( + await store.get_transaction_count_for_wallet( + 1, True, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_TX]) + ) + == 0 + ) + assert ( + await store.get_transaction_count_for_wallet( + 1, False, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_CLAWBACK]) + ) + == 0 + ) + assert ( + await store.get_transaction_count_for_wallet( + 1, False, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_TX]) + ) + == 5 + ) @pytest.mark.asyncio