From 278a515dfd599319745e111a7f296ff6373a5383 Mon Sep 17 00:00:00 2001 From: Byron Hambly Date: Wed, 1 Sep 2021 15:33:26 +0200 Subject: [PATCH] chore: merge development into installer branch (#3279) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update LibWallet recovery task event handling The recovery task broke out of its monitoring loop before getting the `UtxoScannerEvent::Completed` event. This PR just moves that break statement so that the final completed callback is made. Also Ignore `test_store_and_forward_send_tx` due to it being flakey on CI and the functionality is covered by Cucumber tests. * Update LibWallet `wallet_import_utxo` function to include valid TariScript The `wallet_import_utxo` FFI function in LibWallet just used defaults for a number of the new UTXO fields when importing a faucet UTXO. The Faucet UTXO provided by the client is just the spending key and amount. The `metadata_signature` and `sender_offset_public_key` can both remain as default values as they are not used in spending an UTXO. A Nop script is assumed and the spending key is used as the `script_private_key`. The final update is that the `input_data` it set as the public key of the spending key. To test that the base node is happy for an UTXO imported in this way to be spent a Cucumber test is provided which imports a UTXO into a wallet and zeroes out the `metadata_signature` and`sender_offset_public_key` and updates the `input_data` and `script_private_key` in the same way as described above. This imported Faucet utxo is then successfully spent to another wallet. * Move not found to status, and not boolean flag * Introduce cache update cool down to console wallet The current architecture of the wallet is that the AppState contains a cache of the current state that the UI uses to draw from and a second instance of the data that is updated in the background when events are received from the wallet services without interfering with drawing. When the background data has been updated by the event monitoring thread it flips a flag telling the UI that the cache has been invalidated so when the drawing is done on a Tick event the UI thread will clone the background data into the cache for future drawing calls. A problem was found where when a large number of transactions were being processed by the wallet the UI would become unresponsive. The reason for this is that with a large amount of transactions there is quite a lot of AppState that is copied from the background into the UI cache which could take 300-400ms and this cache was being invalidated very often as the transactions are being handled by the wallet services. This would mean that a Cache copy occurred after every single draw cycle and would block the processing of Key events for 300-400ms. This PR proposes a solution of introducing a cache update cooldown period, initially set to 2 seconds, so that if a cache update has occurred the soonest that the next update can occur is 2 seconds later giving the UI thread time to handle key events. In normal wallet operation state update events do not occur that often so this approach will allow the cache update to occur as soon as the cache is invalidated but will force a fast subsequent update to wait at least 2 seconds. In the mean time the background data can be worked on in the background thread. * v0.9.2 * Improve cucumber tx reliability * Miningcore Transcoder Update proxy.rs WIP Added stratum configuration to miner. Mostly implemented stratum. Mostly implemented stratum controller. Partially implemented stratum miner. Rebased to latest dev Import PR#3006 Rebased to latest dev and updated version to 0.9.0 Fixed tari_stratum_ffi tests Clippy and cargo-fmt Bug fixes Return blockheader as json object. Retrieve recipients from params instead of directly from body of request. Fix bug in GetHeaderByHeight Update stratum miner to receive blockheader instead of block clippy update Update Implemented keepalive Bug fix for transfer results Implemented stratum error code response handling in tari_mining_node Rebase fix Update stratum.rs Update stratum.rs Review Comments Update and Fixes Added ResumeJob to MinerMessage. Fixed disconnection bug where miner would not resume solves on reconnect. Added transcoder_host_address config variable to stop using proxy_host_address as it is already used by mm_proxy, this enables them both to be run simultaneously. Update cucumber config variables * fix: add timeout to protocol notifications + log improvements - protocol notifications now have a set "safety" timeout. - add log for inbound comms pipeline concurrency usage * Fix: Correctly deal with new coinbase transactions for the same height A bug was observed that now and then a Coinbase Transaction would not conclude its monitoring with this error: `ValueNotFound(CompletedTransaction(15645503741743694020))` This occurred when due to a small network reorg the miner would request a block for the same height. The new request would often have a different amount of fees which would mean this transaction needs to be cancelled in favour of the new transaction. However, with the transaction being cancelled the coinbase monitoring task will come around to checking the DB if the transaction is still there so it can continue to poll the base node and will get the ValueNotFound error. This is correct behaviour so should not have been logged as an error. The error case also means that the TransactionCancelled event is never fired which resulted in the console wallet UI not updating the state of that transaction which left it in the transaction list erroneous. So this PR handles that “Error” properly and sends the event before ending the coinbase monitoring task. * Improve co mining cucumber test increase key length * Added links to mobile wallets' repos * Fix search_kernel command * Stratum Transcoder Config Cleanup Moved stratum_transcoder config variables to its' own section. Fixed bug in defaults for stratum_transcoder. Updated example config and cucumber variable to reflect changes in configuration. Added stratum mode configuration variables for tari_mining_node in sample config, commented out by default. * test: cucumber for all wallet commands * fix: in wallet block certain keys during popup * chore(deps): bump tar from 6.1.0 to 6.1.6 in /integration_tests Bumps [tar](https://github.com/npm/node-tar) from 6.1.0 to 6.1.6. - [Release notes](https://github.com/npm/node-tar/releases) - [Changelog](https://github.com/npm/node-tar/blob/main/CHANGELOG.md) - [Commits](https://github.com/npm/node-tar/compare/v6.1.0...v6.1.6) --- updated-dependencies: - dependency-name: tar dependency-type: indirect ... Signed-off-by: dependabot[bot] * simplify and improve test * fix: ban peer when merkle roots mismatch * Fix two cucumber tests to be less flaky - An "FundsPending" error returned from the coin split step will now be handled gracefully and not stop the test - Wallet A will monitor transactions to be at least broadcast when sending one-sided transactions to wallet B before mining will commence * Fix UTXO scan edge case when PC awakes from sleep This PR fixes an issue where bottled up tokio interval events are fired successively for the edge case where a computer wakes up from sleep. This was evident in the UTXO scanning service where many many UTXO scanning tasks would be started in parallel afterwards instead of only one. * update to use sydney more address * Fix flaky cucumber tests involving transactions Added exception handling to gRPC methods submitting transactions in cucumber integration test steps. Async timing is not always favourable to conclude a transaction; retrying a transaction if it returns an error is important for test robustness. * Washing Machine Update Fixed network bug in walletProcess. Fixed grpcPort bug in walletProcess for custom grpc address. Updated network in washing machine. Added routing mechanism. Added excludeTestEnvrs. Added notifications via web hook if url is supplied Update washing_machine.js Reduced output noisiness * feat: shared p2p rpc client session pool Adds an RPC client session pool to efficiently maintain multiple shared RPC sessions on the p2p network. * fix: better method for getting an open port in cucumber tests * fix(cucumber): update to @grpc/grpc-js library Potentially fixes GRPC connection issues we've been having * misc: update package lock * fix(cucumber): retry grpc connect before giving up * fix(cucumber): retry wallet grpc connections * fix wallet resize * feat: wallet connectivity service Adds a service responsible for wallet connectivity. This service is responsible for and abstracts any complexity in the management of the base node connections and RPC session management. This PR makes use of this service in the base node montoring service but does not "plumb" the WalletConenctivityService into the protocols. This is left as a TODO, but we should expect this to remove many lines of code and greaty simplify these protocols by removing the budren of connection management in the various wallet components. A number of simplifications on the peer connection and substream code, debatably reducing places that bugs could hide. * Improve prune mode to stop panics. * fix build binaries * Update handling of SAF message propagation and deletion This PR adds two changes to the way SAF messages are handled to fix two subtle bugs spotted while developing cucumber tests. The first issue was that when a Node propagates a SAF message it was storing to other nodes in its neighbourhood the broadcast strategy it was using only chose from currently connected base nodes. This meant that if the Node had an active connection to a Communication Client (wallet) it would not just directly send the SAF message to that client but to other base nodes in the network region. This meant that the wallet would only receive new SAF message when it actively requested them on connection even though it was directly connected to the node. This PR adds a new broadcast strategy called `DirectOrClosestNodes` which will first check if the node has a direct active connection and if it does just send the SAF message directly to its destination. The second issue was a subtle problem where when a node starts to send SAF messages to a destination it would remove the messages from the database based only on whether the outbound messages were put onto the outbound message pipeline. The problem occurs when the TCP connection to that peer is actually broken the sending of those messages would fail at the end of the pipeline but the SAF messages were already deleted from the database. This PR changes the way SAF messages are deleted. When a client asks a node for SAF message it will also provide a timestamp of the most recent SAF message it has received. The Node will then send all SAF messages since that timestamp that it has for the node and will delete all SAF messages from before the specified Timestamp. This serves as a form of Ack that the client has received the older messages at some point and they are no longer needed. * fix unit flaky test * fix(dailys): use non-zero exit status when wallet recovery fails Non-zero error code on failure and return more information about the error. Adds logfile output for processes * Fix wallet CLI cucumber tests (1) When using the wallet CLI mode, the wallet is not always in the correct state to send transactions, depending on the last known metadata status from the previously connected base node. This PR introduces exception handling for wallet CLI commands so that it will automatically retry to execute the command if it fails. A follow-up PR should let the wallet wait for a base node connection before it executes certain commands. (2) Added a new test step to explicitly test transaction statuses in the wallet. (3) Fixed the UTXO scanning service so it would not go into an infinite loop trying to scan UTXOs if the blockchain does not have any UTXOs yet. (4) Fixed an erroneous error return in the output manager service. * Update log4rs for tari_stratum_transcoder Fix typos * mark cucumber test as flaky * ci: remove flaky tests from ci * Mining worker name for tari_mining_node Adds optional field to config for tari_mining_node to allow the miner to be named when using the stratum configuration. Updated successful server connection status from warn to info. * chore(deps): bump jszip from 3.6.0 to 3.7.0 in /integration_tests Bumps [jszip](https://github.com/Stuk/jszip) from 3.6.0 to 3.7.0. - [Release notes](https://github.com/Stuk/jszip/releases) - [Changelog](https://github.com/Stuk/jszip/blob/master/CHANGES.md) - [Commits](https://github.com/Stuk/jszip/compare/v3.6.0...v3.7.0) --- updated-dependencies: - dependency-name: jszip dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Set robust limits for busy a blockchain Updated limits for the base node and wallet that would be robust for a busy blockchain - this was simulated with two sizeable stress tests of 15,000 transactions each. * test: add wallet ffi testing * Add wallet reorg cucumber tests Added wallet reorg cucumber tests for coinbase and normal transactions. These will be failing until wallets can handle reorgs properly - failing steps have been commented out. * v0.9.3 * libwallet-0.17.3 * Add network selection to wallet_ffi Update lib.rs Review comments * re_add validation test * Cleanup stratum config terminal output in tari_mining_node * Cleanup stratum config terminal output in tari_mining_node * Fix stratum miner speed for tari_mining_node Removed call to thread::sleep accidentally left in while debugging a prior issue. Added nonce to display for solution. Co-Authored-By: Hansie Odendaal <39146854+hansieodendaal@users.noreply.github.com> Review comments Co-Authored-By: Hansie Odendaal <39146854+hansieodendaal@users.noreply.github.com> * Remove old unused integration tests * fix: ensure peers are added to peer list before recovery starts * fix: enforce unique commitments in utxo set Adds a unique commitment db index for the UTXO set as well as unique commitment check in the block validator. * Update docs with pooled SHA3 mining * Handle receiver cancelling an inbound transaction that is later received This PR addresses the following scenario spotted by @stanimal: - NodeA sends to nodeB(offline) - NodeA goes offline - NodeB receives tx, and cancels it (weird I know) - NodeA comes online and broadcasts the transaction - NodeB is not aware of the transaction, transaction complete for NodeA This is handled by adding logic that if a FinalizedTransaction is received with no active Receive Protocols that the database is checked if there is a matching cancelled inbound transaction from the same pubkey. If there is the receiver might as well restart that protocol and accept the finalized transaction. A cucumber test is provided to test this case. This required adding in functionality to the Transaction and Output Manager service to reinstate a cancelled inbound transaction, unit tests provided for that. * add rfc docs to inclode unique kernels * Fix console wallet buffer size bug * feat: use nodejs cron for dailies, improve washingmachine reporting - implement cron using `cron` nodejs library - move "utils" to "daily_tests" - daily_tests has its own dependencies and package.json - improve washing machine MM reporting - refactor dailies to allow them to be exported as modules * feat: add sync rpc client pool to wallet connectivity - add sync pool and `obtain_base_node_sync_rpc_client` - add `get_header` to base node rpc * Change RPC connection issues log status Not all pertinent RPC connection issues were logged as warning or error * fix dev after test merges * [skip ci] auto deploy tags to s3 * v0.9.4 * ci: Add libwallet iOS build * ci: Fix libwallet android github action * Expose `get_mempool_stats` via gRPC and add cucumber test This PR exposes the `get_mempool_stats` method via the base node gRPC interface. It also adds a cucumber test to add 5 transactions to the mempool and tests that the base node reports the correct stats. * fix: Fix or remove ignored tests in pow_data.rs This PR aimed to remove the ignore from the tests in pow_data.rs. These tests are failing tests so two of them that panicked with updated with the `should_panic` flag but the out of memory test aborts and cannot be handled in a test so it was removed. * Remove `test_harness` feature from wallet In the early days of wallet development we didn’t have a working base nodes or cucumber infrastructure for the Mobile Clients to be able to test transacting. The `test_harness` functionality in the wallet was created to allow wallet clients to generate test data. This is no longer used so this PR removes this code. This code was also used to generate test data in the wallet_ffi integration test. There will soon be Cucumber infrastructure to test the FFI wrapper which is a far better way to perform this integration test so that huge test is removed from the FFI library. * Make `send_transaction` request handling async in transaction service It was noted that under stress test the transaction service select! loop was blocking for longer than 500ms at times. This turned out to be during the send_transaction calls which did the initial transaction setup synchronously, this involves selecting UTXOs and building the initial transaction which with a large UTXO database can take time. In order to reduce this impact the `handle_request(…)` function is changed in the transaction service. Instead of calling that function, waiting for a synchronous response and then sending the response down the reply one shot channel the one-shot channel is passed into the `handle_request(…)` function. For the send_transaction case the intensive work is then moved into the asynchronous `transaction_send_protocol` task and the reply channel is also sent into that task. The task is spawned and runs asynchronously when the `handle_request` method can end and return to the select! loop. Once the task has completed the initial tranasction setup it can send the response to the service API caller via the reply channel. This is only implemented in the asynchronous way for `send_transaction` in this PR, all the other API requests will still do their work synchronously and send the response over the reply channel but the infrastructure is now there to convert any one of those API calls to reply asynchronously if needed. At the moment the other API requests don’t appear to take long enough to require this effort just yet. `transaction_sender_protocols` are indexed by their TxId so in order to start the task with the ID before building the initial transaction the tx_id is now generated in the Transaction Service. The Transaction Protocol builder is updated to allow an optional manual specification of the TxId. * fix: correct regexp for recovery and sync tests * fix: improve p2p RPC robustness - Correctly handles edge case where the client can receive a response after the deadline + grace period has expired due to extreme latency. - Minor code simplifications - Adds integration stress tests to comms - Increase client-side deadline grace period. This is needed because at any point we could be transferring a lot of data, causing delays, which the client must tollerate. - Minor performance optimisations (e.g removed usage of `.split()` which uses a BiLock) * separate ffi tests * test: add more wallet ffi testing * test: Add integration test for ListHeaders * wallet: Add NodeId to console wallet Who Am I tab Just adds the console wallet’s own NodeId to the Who am I tab in the console wallet. * fix: division by zero * [wallet_ffi] Add null check for `transport_type` in FFI The `transport_type` argument is not checked is it null before it is used in this method. This PR adds the check and appropriate error response. * wip * fix: bug in wallet base node peer switching - Connectivity retries connecting indefinitely, however previously if this continuously fails and the user updates their peer, the new peer will not be read until the previous peer was connected to. This PR fixes this. - Add protocl information to comms RPC logs - Cleanup peer state, making the wallet connectivity service the source of truth for the base node peer - Adds an `Ack` flag to RPC protocol, this is not currently used but could be implemented in the client side in future if required without breaking the network (server supports it, client support may or may not be needed). * chore: better logging for lmdb_delete * test: cucumber check block heights in order test (#3219) ## Description Cucumber check block heights in order test ## Motivation and Context Cucumber tests ## How Has This Been Tested? npm test -- --name "Base node lists heights" ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * test: almost complete ffi wallet test (#3220) ## Description Almost complete ffi wallet test. - The async base node connection is not there. - The SAF test is broken (the ffi wallet doesn't change the status, but the receiver node has correct status) ## How Has This Been Tested? npm test -- .\features\WalletFFI.feature ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * cucumber: Add mempool test for unconfirmed tx to mined tx (#3222) ## Description Add mempool test for unconfirmed tx to mined tx ## Motivation and Context cucumber tests ## How Has This Been Tested? npm test -- --name "Mempool unconfirmed transaction to mined transaction" ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * fix: chain error caused by zero-conf transactions and reorgs (#3223) ## Description If a zero-conf transaction was in a block and the block is rewound. The block_chain backend will try to delete in input which was never marked as unspent as it was immediately spent. When we rewound this block we should check this and not try and delete an output that was never an unspent output on the chain. ## Motivation and Context ## How Has This Been Tested? ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * ci: prebuild mining node on cucumber tests (#3221) Co-authored-by: mergequeue[bot] <48659329+mergequeue[bot]@users.noreply.github.com> * Add extra logging detail for wiremode warnings (#3216) ## Description Minor log warning improvements for the case where a wire format byte is not received ## Motivation and Context Adding more log info for some observed warnings ## How Has This Been Tested? Replaced existing base node and started syncing a new node successfully ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * fix: edge-case fixes for wallet peer switching in console wallet (#3226) ## Description - set peer using the watch to allow the connectivity service to immediately be aware of the new peer - aborted the dial early if necessary, should the user set a different peer - slightly reduce busy-ness of the wallet monitor by monitoring for less comms connectivity events - monitor for wallet connectivity peer status changes to improve the responsiveness of the status ui update. ## Motivation and Context When a peer is selected, and the previous peer is offline, it appears as if the new peer is offline. This allows the state to be immediately be updated (though there is still a delay where the frontend gets refreshed - probably waiting for a tick) The wallet event monitor was kept very busy with all the comms connectivity events incl. events that have nothing ## How Has This Been Tested? Tested on existing wallet ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * chore: simpler pull request template (#3231) Simpler PR template * feat: add `ping()` to all comms RPC clients (#3227) ## Description Adds a `ping()` function to all comms RPC clients. `ping()` sends an RPC request with the `ACK` flag set. The server will immediately reply with an `ACK` response. This accurately measures RPC latency without a potentially slow backend. `ping()` is now used in the wallet monitor. ## Motivation and Context Previously, `get_last_request_latency` would refer to the latency of `get_tip_info` which will increase whenever the blockchain db is busy, for e.g: - the base node is syncing - another node(s) syncing from the base node - one or many wallets scanning UTXOs from the base node - one or many wallets running a recovery from the base node - lots of lmdb writes e.g large reorg A client wallet would, of course, have no idea that this is occurring and simply display poor latency. This could be perceived to be a poor RPC performance, when in fact, there are a number or non-network/RPC related reasons why a ping > 1/2 seconds is displayed. `get_last_request_latency` is a better measure when determining current base node performance (caveats: depending on the RPC method impl, current performance does not predict future performance). However, it is misleading to use this as a user-facing value presented as network latency. ## How Has This Been Tested? Unit test, console wallet test ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * fix: show warnings on console (#3225) Show warnings and errors on apps as well as logs Also moved a struct that was in the middle of a method into its own file * test: improve comms limits to combat network timeouts (#3224) ## Description This PR: - Improved various comms configuration settings to combat RPC connection and response timeouts - Improved wire mode logging The philosophy here is to rather wait for a connection or response than to abandon all and try again. These settings were tested on two separate systems performing system-level tests where RPC timeouts and connection problems were previously prevalent: - while base node/console wallet pairs were only monitoring the network or were linked to SHA3 or RandomX miners - while performing a stress test and compiling Rust at the same time. The former proved to run virtually without any errors while the latter registered some timouts, especially when performing Rust compilations (~ two orders of magnitude less). **Edit:** - Fixed flaky cucumber test `Node should not sync from pruned node` - Fixed magic numbers in unit test `test_txo_validation_rpc_timeout` ## Motivation and Context See above ## How Has This Been Tested? See above ## Checklist: * [X] I'm merging against the `development` branch. * [X] I have squashed my commits into a single commit. * chore: add extra seed node (#3234) Description --- Add additional seed node Motivation and Context --- N/A How Has This Been Tested? --- Tested with new temporary console wallet * v0.9.5 * ci: add pr title check * test: Add rebuild-db integration test (#3232) Description: Adds a cucumber test to cover the `--rebuild-db` blockchain recovery functionality on the base node Motivation and Context: Improved test coverage How Has This Been Tested? `npm test -- --name "Blockchain database recovery"` * Remove OpenSSL from Windows runtime - Removed OpenSSL installers and dependencies from Windows runtime - Removed stdout as appender from console wallet's log4rs logger * fix: exit command and free up tokio thread (#3235) Description --- Exit command didn't exit the cli loop. Rustyline was holding up a tokio thread - used a blocking thread instead Motivation and Context --- Bug How Has This Been Tested? --- Manually on base node * fix: add status output to logs in non-interactive mode (#3244) Description --- Add status output to logs in non-interactive mode Motivation and Context --- Base node in non-interactive mode was logging status How Has This Been Tested? --- Manually run base node in non-interactive mode * feat: add tab for error log to the wallet * feat: add support for forcing sync from seeds (#3228) ## Description Add support for forcing sync to a seed node. Specify index for the node from peer_seeds list. ## How Has This Been Tested? Manually. ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * test: cucumber forced sync (#3230) ## Description Add cucumber test that tests forced sync to single node. ## How Has This Been Tested? npm test -- --name "Force sync many nodes agains one peer" ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * test: add multiple wallet recovery from peer (#3240) Description Added multiple wallet recovery from peer Motivation and Context Additional tests How Has This Been Tested? Manually (npm test -- --name "Multiple Wallet recovery from seed node") * feat: base_node prompt user to create id if not found (#3245) Description Prompt user to create id if not found Motivation and Context Improvement to base node startup, specifically on first run. How Has This Been Tested? Manually * fix: daily wallet recovery fixes (#3229) ## Description - Use the GRPC client connection given `WalletProcess` - Outputs error details in webhook in recovery cron job - Adds very basic mocha tests ## Motivation and Context Recovery daily is failing even though it succeeds. ## How Has This Been Tested? ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * feat!: tell an FFI client that a recovery is in progress on `wallet_create` (#3249) Description --- ### Note: This is a breaking change to LibWallet FFI Currently if a wallet recovery was in progress and the wallet was shutdown the next time that wallet is start by an FFI client using the ‘wallet_create’ method there is no way for the FFI client to know that the recovery should be continued. The wallet is able to resume the recovery from where it left off and it should so as not to lose funds but the FFI client must restart the recovery process with the same seed words. The FFI client has to do the restarting so that it can provide the callback through which the process is monitored. Furthermore, the wallet does not respond to P2P transaction negotiation message if a recovery process is in progress so it is important that an FFI client completes any outstanding recoveries ASAP. How Has This Been Tested? --- untested in the backend. * fix: fix base_node_service_config not read (#3251) Description --- Fixed the base_node_service_config not being initialized with values from the config file. Motivation and Context --- See above How Has This Been Tested? --- System level testing * test: add flag to have Cucumber exit when tests are complete (#3252) Description --- Add the —exit flag to the Cucumber CI commands to force Cucumber to end when the tests are completed. This doesn’t solve the issue where something is keeping the Cucumber process running due to a poor shutdown though. How Has This Been Tested? --- N/A * docs: rfc staged security (#3246) Description --- This Request for Comment (RFC) aims to describe Tari's ergonomic approach to securing funds in a hot wallet. The focus is on mobile wallets, but the strategy described here is equally applicable to console or desktop wallets. Motivation and Context --- This philosophy has been partially implemented in Aurora already but has not been captured in community documentation before. How Has This Been Tested? --- N/A * test: add tracing to comms via --tracing-enabled (#3238) Description --- Add tracing to comms to debug timings via the `--tracing-enabled` flag Motivation and Context --- It's currently difficult to understand the timings of network calls and errors in the application. How Has This Been Tested? --- Tested manually * refactor: additional DB audit of methods (#2864) ## Description This provides a audit removing and increasing security over the following db methods from the WriteOperation enum: ```rust InsertChainOrphanBlock(Arc), InsertInput { header_hash: HashOutput, input: Box, mmr_position: u32, }, InsertKernel { header_hash: HashOutput, kernel: Box, mmr_position: u32, }, InsertOutput { header_hash: HashOutput, output: Box, mmr_position: u32, }, DeleteHeader(u64), DeleteOrphanChainTip(HashOutput), InsertOrphanChainTip(HashOutput), SetBestBlock { height: u64, hash: HashOutput, accumulated_difficulty: u128, }, SetPruningHorizonConfig(u64), SetPrunedHeight { height: u64, kernel_sum: Commitment, utxo_sum: Commitment, }, ``` ## Motivation and Context ## How Has This Been Tested? This synced to tip, and passed all unit tests ## Checklist: * [x] I'm merging against the `development` branch. * [ ] I have squashed my commits into a single commit. * feat: allow DHT to be configured to repropagate messages for a number of rounds (#3211) ## Description Use the dedup cache hit count to allow certain duplicate messages through a configurable number of times. ## Motivation and Context ~~This improves mempool synchronization.~~ Implements gossip repropagation that could be used for some message types in future. ## How Has This Been Tested? New unit test. More manual system tests need to be done ## Checklist: * [x] I'm merging against the `development` branch. * [x] I have squashed my commits into a single commit. * refactor: refactor wallet ffi cucumber tests (#3259) Description Refactored WalletFFI.feature into a working state, tested locally. Further refactoring and dead code removal would be beneficial. Motivation and Context Necessary to get WalletFFI.feature working. How Has This Been Tested? Tested locally, each scenario tested with: `./node_modules/.bin/cucumber-js --name "${scenario_name}"` * fix: add periodic connection check to wallet connectivity service (#3237) Description --- - Adds a periodic check of the connection status and attempts a reconnect if no longer connected. Previously it was assumed that this can be done lazily because some caller will always call `obtain_base_node_wallet_rpc_client`, but this may not be the case. A periodic check is added. - Clean up some state checking to use the wallet connectivity service. Motivation and Context --- Improves snappiness of the connectivity and chain state updates in the wallet How Has This Been Tested? --- Manually on the console wallet + existing tests * fix: send transactions to all connected peers (#3239) Description --- Send transactions to all connected peers as we do with block propagation Motivation and Context --- Alternative to #3211. How Has This Been Tested? --- Existing tests / single line code change * test: add random transactions to empty cucumber blocks (#3253) Description --- This adds in random transactions spending each other and unused coin bases to fill in blocks. Motivation and Context --- This is to allow us to test more thoroughly with all blocks having transactions and not just blocks that were explicitly created with transactions. These are limited to 10 transactions per block to not make it too slow at the current validation speeds. This might be revisited in a later stage. How Has This Been Tested? --- Manually confirmed that the blocks do have the transactions in them. Ran all cucumber tests with the flags: critical and not broken and not flaky * feat: improve basenode switch from listening to lagging mode (#3255) Description --- This PR changes the peer metadata push to listing mode speed to push every time it receives a chain metadata ping or pong message. Motivation and Context --- This is introduced to allow a node to switch faster and not wait till it received it all the pings and pongs from a node. How Has This Been Tested? --- Run all unit tests and manually ran node. * fix: small display bug (#3257) Description --- The escape sequence was eating up the string "Starting recovery at height: ". How Has This Been Tested? --- Manually/visually. * feat: add Igor testnet (#3256) Description --- This PR adds support for the Igor testnet to the repo. This involves adding Igor to the Network Enum, adding a Igor generic block and adding a config file with the details of 4 Igor seed nodes (still to be rolled out) Motivation and Context --- We need a second testnet to test network switching How Has This Been Tested? --- Manually ran the network to generate seed nodes details. * chore: tokio 1 and other crate upgrades (#3258) Description --- - upgrades to tokio 1.10 - upgrade multiaddr to 1.13 - updates select loops to use tokio::select! - updates to use tokio mpsc and oneshot channels - remove max_threads config - removed tari_wallet dependency from tari base node - moved emoji id library out of tari wallet into tari core (in order to remove dependency on `tari_wallet` for tari base node) - Wait for bootstrap with mempool sync moved to the initializer - Unit and integration test fixup - Upgraded following crates that use or are required by tokio 1: `bytes`, `prost`, `tonic`, `reqwest`, `hyper`, `trust-dns-client` ~~Include changes from https://github.com/tari-project/tari/pull/3237~~ merged Motivation and Context --- Tokio runtime is perhaps the most critical dependency we have and was very out of date (was 0.2.x). This PR takes advantage of bug fixes and optimisations of tokio 1.10. How Has This Been Tested? --- - Existing unit and integration tests run and pass - Existing cucumber tests pass - Ran all tari applications (base node, console wallet, miner, mm proxy, stratum transcoder) - Ran a washing machine test on two upgraded wallets connected to an upgraded base node * fix: off-by-one causing "no further headers to download" bug (#3264) Description --- When entering the `synchonize_headers` function, a chain of headers has been downloaded and validated but not committed. If there less than 1000 (not equal to as before), the function can exit without streaming more as there are no more to send. This PR correctly handles the case where the node is exactly 1000 headers behind by: (1) correcting the off-by-one "no further headers to download" conditional and (2) commiting headers before starting streaming if the PoW is stronger, in case no further headers would be streamed. Motivation and Context --- Header sync ends prematurely when receiving exactly 1000 "pre-sync" headers. How Has This Been Tested? --- Manually - Sync from scratch. * fix: revert mining_node default logging config (#3262) Description --- The mining_node relies on its stdout logging for output for the binary and a recent global update to the logging filtered out the debug and info messages to the std out. This PR updates the default logging config for the mining node so that debug and info messages are logged to stdout. How Has This Been Tested? --- Manually * feat: allow network to be selected at application start (#3247) Description Network selection for applications Motivation and Context Allows network to be selected at application start How Has This Been Tested? Manually * feat: add ability to bypass rangeproof (#3265) Description --- Adds the ability to bypass rangeproof verification. Motivation and Context --- Warning: This should not be done by default as it can cause a fork. By default this should always be set to verify rangeproofs, but in some scenarios, you want to disable it to quickly download a chain or run on a slim device. The rangeproof verification also takes the majority of time when profiling, so by disabling it, we can monitor other performance bottlenecks How Has This Been Tested? --- Manually > Note that I disabled checking of rangeproofs during wallet sending because it adds little value to validate a rangeproof that you created * test: add trace tag to liveness data (#3269) Description --- Added trace tag info into the liveness log messages for improved tracing of ping-pong messages Motivation and Context --- This will help to investigate why ping-pong messages are not robust when using a single forced sync peer. How Has This Been Tested? --- System-level testing * fix: auto update continuously checks auto_update_check_interval is disabled (#3270) Description --- Continuously checks for updates when auto_update_check_interval is disabled. Thanks @mikethetike. for finding it and for the fix Add check to if no auto update URIs are configured Motivation and Context --- Bug fix When check_interval is disabled, stream::empty() is used to disable the update checking, however it returns None continuously when polled, causing the update to continuously be checked. Also sets MissedTickBehaviour::Skip - which will prevent bursts of checks if intervals are missed How Has This Been Tested? --- Ran base node with auto_update_check_interval = 0 (or equivalently without this setting set) * fix: remove cucumber walletffi.js file that got re-included in rebase (#3271) Description Code was moved to ffiInterface.js and updated. Mistakenly got re-included when fixing a conflict in a rebase. Motivation and Context --- How Has This Been Tested? --- * test: early subscription to connectivity events for mempool sync (#3272) Description --- Subscribe to connectivity events before waiting for the state machine to bootstrap Motivation and Context --- Causes cucumber ` Scenario: Transactions are synced` to fail. Could cause mempool sync not to happen in some fairly unlikely but possible cases in base node. How Has This Been Tested? --- Cucumber Scenario: Transactions are synced passes * refactor: enable compile without sqlite, move emoji id and common types to tari_common_types (#3266) Description --- It moves emoji id and common types to tari_common_types Motivation and Context --- The main problem here was a dependency on tari_base_node -> tari_wallet for `EmojiId`. Then EmojiId references PublicKey, so ended up moving a whole bunch around. > Note: Hidden in all of this is feature to compile SQLite without having it installed as a lib How Has This Been Tested? --- Manually * fix: make logging less noisy (#3267) Description --- Remove logging of errors from tracing instrument macros. Motivation and Context --- Was reported as making the base node unusable. Hopefully we are not swallowing important information, but probably the right choice How Has This Been Tested? --- Manually > ~~Note: This PR is based on #3266 to enable compilation without SQLite installed~~ * chore: add network to Base Node status line (#3278) Description --- We want to see the network in the base node status line. How Has This Been Tested? --- manually Co-authored-by: Philip Robinson Co-authored-by: mergequeue[bot] <48659329+mergequeue[bot]@users.noreply.github.com> Co-authored-by: SW van Heerden Co-authored-by: Mike the Tike Co-authored-by: striderDM <51991544+StriderDM@users.noreply.github.com> Co-authored-by: Stanimal Co-authored-by: mongolsteppe <75075420+mongolsteppe@users.noreply.github.com> Co-authored-by: Martin Stefcek <35243812+Cifko@users.noreply.github.com> Co-authored-by: Martin Stefcek Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Hansie Odendaal Co-authored-by: Hansie Odendaal <39146854+hansieodendaal@users.noreply.github.com> Co-authored-by: Cayle Sharrock --- .circleci/config.yml | 4 +- Cargo.lock | 860 +++++----- Cargo.toml | 3 + applications/ffi_client/index.js | 2 + applications/tari_app_grpc/Cargo.toml | 13 +- .../src/conversions/block_header.rs | 3 +- .../src/conversions/com_signature.rs | 4 +- .../tari_app_grpc/src/conversions/mod.rs | 2 +- .../src/conversions/new_block_template.rs | 4 +- .../tari_app_grpc/src/conversions/peer.rs | 2 +- .../src/conversions/signature.rs | 4 +- .../src/conversions/transaction.rs | 65 +- .../src/conversions/transaction_input.rs | 6 +- .../src/conversions/transaction_kernel.rs | 2 +- .../src/conversions/transaction_output.rs | 15 +- .../src/conversions/unblinded_output.rs | 15 +- applications/tari_app_utilities/Cargo.toml | 15 +- .../src/identity_management.rs | 13 +- .../tari_app_utilities/src/initialization.rs | 31 +- .../tari_app_utilities/src/utilities.rs | 91 +- applications/tari_base_node/Cargo.toml | 22 +- applications/tari_base_node/src/bootstrap.rs | 10 +- applications/tari_base_node/src/builder.rs | 28 +- .../tari_base_node/src/command_handler.rs | 28 +- .../src/grpc/base_node_grpc_server.rs | 52 +- applications/tari_base_node/src/main.rs | 34 +- applications/tari_base_node/src/parser.rs | 8 +- applications/tari_base_node/src/recovery.rs | 22 +- applications/tari_console_wallet/Cargo.toml | 16 +- .../src/automation/command_parser.rs | 6 +- .../src/automation/commands.rs | 227 +-- .../src/grpc/wallet_grpc_server.rs | 7 +- .../tari_console_wallet/src/init/mod.rs | 26 +- applications/tari_console_wallet/src/main.rs | 12 +- .../tari_console_wallet/src/recovery.rs | 17 +- .../src/ui/components/base_node.rs | 9 +- .../src/ui/state/app_state.rs | 74 +- .../tari_console_wallet/src/ui/state/tasks.rs | 47 +- .../src/ui/state/wallet_event_monitor.rs | 47 +- .../tari_console_wallet/src/ui/ui_contact.rs | 3 +- .../tari_console_wallet/src/utils/db.rs | 3 +- .../tari_console_wallet/src/wallet_modes.rs | 5 +- .../tari_merge_mining_proxy/Cargo.toml | 21 +- .../tari_merge_mining_proxy/src/main.rs | 2 +- applications/tari_mining_node/Cargo.toml | 10 +- applications/tari_mining_node/src/main.rs | 20 +- .../tari_stratum_transcoder/Cargo.toml | 22 +- .../tari_stratum_transcoder/src/main.rs | 2 +- applications/test_faucet/Cargo.toml | 10 +- applications/test_faucet/src/main.rs | 26 +- base_layer/common_types/Cargo.toml | 4 +- .../src/util => common_types/src}/emoji.rs | 11 +- base_layer/common_types/src/lib.rs | 5 + .../src/util => common_types/src}/luhn.rs | 2 +- base_layer/common_types/src/types.rs | 24 - .../src/types/bullet_rangeproofs.rs | 110 ++ base_layer/common_types/src/types/mod.rs | 81 + .../common_types/src/waiting_requests.rs | 3 +- base_layer/core/Cargo.toml | 14 +- .../chain_metadata_service/initializer.rs | 11 +- .../chain_metadata_service/service.rs | 99 +- .../comms_interface/comms_request.rs | 8 +- .../comms_interface/comms_response.rs | 7 +- .../comms_interface/inbound_handlers.rs | 4 +- .../comms_interface/local_interface.rs | 6 +- .../comms_interface/outbound_interface.rs | 17 +- .../core/src/base_node/proto/request.rs | 2 +- .../core/src/base_node/proto/wallet_rpc.rs | 3 +- base_layer/core/src/base_node/rpc/service.rs | 5 +- .../core/src/base_node/service/initializer.rs | 6 +- .../core/src/base_node/service/service.rs | 53 +- .../state_machine_service/initializer.rs | 18 +- .../state_machine_service/state_machine.rs | 3 +- .../states/block_sync.rs | 4 +- .../states/events_and_states.rs | 60 +- .../states/header_sync.rs | 15 +- .../states/horizon_state_sync.rs | 23 +- .../horizon_state_synchronization.rs | 6 +- .../state_machine_service/states/listening.rs | 12 +- .../state_machine_service/states/waiting.rs | 4 +- .../sync/header_sync/synchronizer.rs | 62 +- .../base_node/sync/header_sync/validator.rs | 14 +- base_layer/core/src/base_node/sync/hooks.rs | 8 +- .../core/src/base_node/sync/rpc/service.rs | 48 +- .../src/base_node/sync/rpc/sync_utxos_task.rs | 13 +- .../core/src/base_node/sync/rpc/tests.rs | 6 +- .../core/src/base_node/sync/validators.rs | 13 +- base_layer/core/src/blocks/block.rs | 23 +- base_layer/core/src/blocks/block_header.rs | 8 +- base_layer/core/src/blocks/genesis_block.rs | 92 +- .../src/blocks/new_blockheader_template.rs | 3 +- .../src/chain_storage/accumulated_data.rs | 6 +- base_layer/core/src/chain_storage/async_db.rs | 10 +- .../src/chain_storage/blockchain_backend.rs | 10 +- .../src/chain_storage/blockchain_database.rs | 10 +- .../core/src/chain_storage/db_transaction.rs | 7 +- .../src/chain_storage/historical_block.rs | 2 +- .../core/src/chain_storage/horizon_data.rs | 3 +- .../core/src/chain_storage/lmdb_db/lmdb_db.rs | 3 +- .../core/src/chain_storage/lmdb_db/mod.rs | 6 +- .../core/src/chain_storage/pruned_output.rs | 3 +- .../core/src/consensus/consensus_constants.rs | 32 + .../core/src/consensus/consensus_manager.rs | 2 + base_layer/core/src/consensus/network.rs | 1 + base_layer/core/src/lib.rs | 2 +- base_layer/core/src/mempool/async_mempool.rs | 3 +- base_layer/core/src/mempool/mempool.rs | 3 +- .../core/src/mempool/mempool_storage.rs | 3 +- base_layer/core/src/mempool/mod.rs | 3 +- .../priority/prioritized_transaction.rs | 6 +- .../core/src/mempool/proto/state_response.rs | 6 +- .../core/src/mempool/reorg_pool/reorg_pool.rs | 3 +- .../mempool/reorg_pool/reorg_pool_storage.rs | 7 +- base_layer/core/src/mempool/rpc/test.rs | 12 +- base_layer/core/src/mempool/service/handle.rs | 3 +- .../src/mempool/service/inbound_handlers.rs | 2 +- .../core/src/mempool/service/initializer.rs | 14 +- .../core/src/mempool/service/local_service.rs | 7 +- .../src/mempool/service/outbound_interface.rs | 21 +- .../core/src/mempool/service/request.rs | 4 +- .../core/src/mempool/service/service.rs | 49 +- .../src/mempool/sync_protocol/initializer.rs | 38 +- .../core/src/mempool/sync_protocol/mod.rs | 47 +- .../core/src/mempool/sync_protocol/test.rs | 32 +- .../unconfirmed_pool/unconfirmed_pool.rs | 31 +- base_layer/core/src/proto/block.rs | 3 +- base_layer/core/src/proto/block_header.rs | 2 +- base_layer/core/src/proto/transaction.rs | 3 +- base_layer/core/src/proto/types_impls.rs | 4 +- .../core/src/test_helpers/blockchain.rs | 28 +- base_layer/core/src/test_helpers/mod.rs | 17 +- .../core/src/transactions/aggregated_body.rs | 38 +- .../src/transactions/bullet_rangeproofs.rs | 2 +- .../core/src/transactions/coinbase_builder.rs | 34 +- .../core/src/transactions/crypto_factories.rs | 45 + base_layer/core/src/transactions/helpers.rs | 29 +- base_layer/core/src/transactions/mod.rs | 13 +- .../core/src/transactions/transaction.rs | 92 +- .../transactions/transaction_protocol/mod.rs | 8 +- .../proto/recipient_signed_message.rs | 3 +- .../proto/transaction_sender.rs | 2 +- .../transaction_protocol/recipient.rs | 23 +- .../transaction_protocol/sender.rs | 56 +- .../transaction_protocol/single_receiver.rs | 32 +- .../transaction_initializer.rs | 52 +- base_layer/core/src/transactions/types.rs | 96 -- .../core/src/validation/block_validators.rs | 46 +- .../core/src/validation/chain_balance.rs | 14 +- base_layer/core/src/validation/error.rs | 3 +- base_layer/core/src/validation/helpers.rs | 16 +- base_layer/core/src/validation/mocks.rs | 4 +- base_layer/core/src/validation/test.rs | 12 +- base_layer/core/src/validation/traits.rs | 4 +- .../src/validation/transaction_validators.rs | 15 +- base_layer/core/tests/async_db.rs | 13 +- base_layer/core/tests/base_node_rpc.rs | 102 +- base_layer/core/tests/block_validation.rs | 15 +- .../chain_storage_tests/chain_storage.rs | 40 +- .../core/tests/helpers/block_builders.rs | 17 +- base_layer/core/tests/helpers/database.rs | 5 +- base_layer/core/tests/helpers/event_stream.rs | 17 +- .../core/tests/helpers/mock_state_machine.rs | 2 +- base_layer/core/tests/helpers/nodes.rs | 91 +- .../core/tests/helpers/sample_blockchains.rs | 7 +- .../core/tests/helpers/test_blockchain.rs | 21 +- base_layer/core/tests/mempool.rs | 439 +++-- base_layer/core/tests/node_comms_interface.rs | 63 +- base_layer/core/tests/node_service.rs | 588 ++++--- base_layer/core/tests/node_state_machine.rs | 109 +- base_layer/key_manager/Cargo.toml | 2 +- base_layer/mmr/Cargo.toml | 2 +- base_layer/p2p/Cargo.toml | 36 +- base_layer/p2p/examples/gen_tor_identity.rs | 2 +- base_layer/p2p/src/auto_update/dns.rs | 42 +- base_layer/p2p/src/auto_update/mod.rs | 8 +- base_layer/p2p/src/auto_update/service.rs | 42 +- .../src/comms_connector/inbound_connector.rs | 74 +- base_layer/p2p/src/comms_connector/pubsub.rs | 78 +- base_layer/p2p/src/dns/client.rs | 139 +- base_layer/p2p/src/dns/mock.rs | 105 ++ base_layer/p2p/src/dns/mod.rs | 3 + base_layer/p2p/src/initialization.rs | 33 +- base_layer/p2p/src/lib.rs | 2 - base_layer/p2p/src/peer_seeds.rs | 72 +- base_layer/p2p/src/services/liveness/mock.rs | 7 +- .../p2p/src/services/liveness/service.rs | 77 +- base_layer/p2p/tests/services/liveness.rs | 19 +- .../p2p/tests/support/comms_and_services.rs | 18 +- base_layer/service_framework/Cargo.toml | 8 +- .../examples/services/service_a.rs | 4 +- .../examples/services/service_b.rs | 8 +- .../examples/stack_builder_example.rs | 8 +- .../service_framework/src/reply_channel.rs | 57 +- base_layer/service_framework/src/stack.rs | 4 +- base_layer/tari_stratum_ffi/Cargo.toml | 2 +- base_layer/wallet/Cargo.toml | 38 +- .../wallet/src/base_node_service/handle.rs | 5 +- .../mock_base_node_service.rs | 21 +- .../wallet/src/base_node_service/monitor.rs | 59 +- .../wallet/src/base_node_service/service.rs | 5 +- base_layer/wallet/src/config.rs | 8 +- .../wallet/src/connectivity_service/handle.rs | 10 +- .../src/connectivity_service/initializer.rs | 11 +- .../src/connectivity_service/service.rs | 97 +- .../wallet/src/connectivity_service/test.rs | 22 +- .../wallet/src/connectivity_service/watch.rs | 15 +- .../wallet/src/contacts_service/service.rs | 10 +- .../src/contacts_service/storage/sqlite_db.rs | 4 +- base_layer/wallet/src/lib.rs | 2 - .../src/output_manager_service/handle.rs | 7 +- .../master_key_manager.rs | 6 +- .../wallet/src/output_manager_service/mod.rs | 31 +- .../recovery/standard_outputs_recoverer.rs | 18 +- .../src/output_manager_service/resources.rs | 10 +- .../src/output_manager_service/service.rs | 82 +- .../storage/database.rs | 7 +- .../output_manager_service/storage/models.rs | 14 +- .../storage/sqlite_db.rs | 96 +- .../tasks/txo_validation_task.rs | 49 +- base_layer/wallet/src/storage/database.rs | 2 +- .../wallet/src/transaction_service/error.rs | 2 +- .../wallet/src/transaction_service/handle.rs | 5 +- .../wallet/src/transaction_service/mod.rs | 41 +- .../transaction_broadcast_protocol.rs | 53 +- ...ransaction_coinbase_monitoring_protocol.rs | 58 +- .../protocols/transaction_receive_protocol.rs | 40 +- .../protocols/transaction_send_protocol.rs | 31 +- .../transaction_validation_protocol.rs | 34 +- .../wallet/src/transaction_service/service.rs | 108 +- .../transaction_service/storage/database.rs | 3 +- .../src/transaction_service/storage/models.rs | 2 +- .../transaction_service/storage/sqlite_db.rs | 88 +- ...tion_validation_and_broadcast_protocols.rs | 35 +- base_layer/wallet/src/util/mod.rs | 2 - .../wallet/src/utxo_scanner_service/mod.rs | 2 +- .../src/utxo_scanner_service/utxo_scanning.rs | 89 +- base_layer/wallet/src/wallet.rs | 74 +- .../wallet/tests/contacts_service/mod.rs | 2 +- .../tests/output_manager_service/service.rs | 1366 ++++++++------- .../tests/output_manager_service/storage.rs | 21 +- .../tests/support/comms_and_services.rs | 15 +- base_layer/wallet/tests/support/rpc.rs | 31 +- base_layer/wallet/tests/support/utils.rs | 2 +- .../tests/transaction_service/service.rs | 549 +++--- .../tests/transaction_service/storage.rs | 21 +- .../transaction_protocols.rs | 131 +- base_layer/wallet/tests/wallet/mod.rs | 97 +- base_layer/wallet_ffi/Cargo.toml | 6 +- base_layer/wallet_ffi/src/callback_handler.rs | 38 +- base_layer/wallet_ffi/src/lib.rs | 113 +- base_layer/wallet_ffi/src/tasks.rs | 8 +- common/Cargo.toml | 4 +- common/config/presets/tari_igor_config.toml | 535 ++++++ common/logging/log4rs_sample_mining_node.yml | 21 +- common/src/configuration/bootstrap.rs | 4 + common/src/configuration/global.rs | 10 +- common/src/configuration/network.rs | 3 + common/src/configuration/utils.rs | 49 +- common/src/dns/tests.rs | 4 +- common/src/lib.rs | 2 +- comms/Cargo.toml | 37 +- comms/dht/Cargo.toml | 39 +- .../examples/graphing_utilities/utilities.rs | 5 +- comms/dht/examples/memory_net/drain_burst.rs | 7 +- comms/dht/examples/memory_net/utilities.rs | 99 +- comms/dht/examples/memorynet.rs | 11 +- ...rynet_graph_network_join_multiple_seeds.rs | 6 +- .../memorynet_graph_network_track_join.rs | 6 +- ...morynet_graph_network_track_propagation.rs | 6 +- comms/dht/src/actor.rs | 225 +-- comms/dht/src/builder.rs | 7 +- comms/dht/src/config.rs | 6 + comms/dht/src/connectivity/metrics.rs | 33 +- comms/dht/src/connectivity/mod.rs | 37 +- comms/dht/src/connectivity/test.rs | 15 +- comms/dht/src/dedup/dedup_cache.rs | 114 +- comms/dht/src/dedup/mod.rs | 41 +- comms/dht/src/dht.rs | 96 +- comms/dht/src/discovery/error.rs | 16 +- comms/dht/src/discovery/requester.rs | 9 +- comms/dht/src/discovery/service.rs | 54 +- comms/dht/src/domain_message.rs | 2 +- comms/dht/src/envelope.rs | 13 +- .../src/{tower_filter => filter}/future.rs | 31 +- .../dht/src/{tower_filter => filter}/layer.rs | 0 comms/dht/src/{tower_filter => filter}/mod.rs | 4 +- comms/dht/src/filter/predicate.rs | 13 + comms/dht/src/inbound/decryption.rs | 9 +- comms/dht/src/inbound/deserialize.rs | 7 +- comms/dht/src/inbound/dht_handler/task.rs | 14 + comms/dht/src/inbound/message.rs | 15 +- comms/dht/src/lib.rs | 4 +- comms/dht/src/network_discovery/on_connect.rs | 11 +- comms/dht/src/network_discovery/test.rs | 18 +- comms/dht/src/network_discovery/waiting.rs | 2 +- comms/dht/src/outbound/broadcast.rs | 36 +- comms/dht/src/outbound/error.rs | 12 +- comms/dht/src/outbound/message.rs | 2 +- comms/dht/src/outbound/message_params.rs | 10 +- comms/dht/src/outbound/message_send_state.rs | 12 +- comms/dht/src/outbound/mock.rs | 20 +- comms/dht/src/outbound/requester.rs | 5 +- comms/dht/src/outbound/serialize.rs | 4 +- comms/dht/src/rpc/service.rs | 10 +- comms/dht/src/rpc/test.rs | 36 +- comms/dht/src/storage/connection.rs | 5 +- comms/dht/src/storage/error.rs | 2 + comms/dht/src/store_forward/database/mod.rs | 7 +- comms/dht/src/store_forward/forward.rs | 28 +- .../src/store_forward/saf_handler/layer.rs | 2 +- .../store_forward/saf_handler/middleware.rs | 3 +- .../dht/src/store_forward/saf_handler/task.rs | 37 +- comms/dht/src/store_forward/service.rs | 47 +- comms/dht/src/store_forward/store.rs | 43 +- comms/dht/src/test_utils/dht_actor_mock.rs | 34 +- .../dht/src/test_utils/dht_discovery_mock.rs | 13 +- .../src/test_utils/store_and_forward_mock.rs | 14 +- comms/dht/src/tower_filter/predicate.rs | 25 - comms/dht/tests/dht.rs | 304 +++- comms/examples/stress/error.rs | 16 +- comms/examples/stress/node.rs | 3 +- comms/examples/stress/service.rs | 69 +- comms/examples/stress_test.rs | 8 +- comms/examples/tor.rs | 14 +- comms/rpc_macros/Cargo.toml | 8 +- comms/rpc_macros/src/generator.rs | 4 +- comms/rpc_macros/tests/macro.rs | 17 +- comms/src/bounded_executor.rs | 16 +- comms/src/builder/comms_node.rs | 6 +- comms/src/builder/mod.rs | 3 +- comms/src/builder/tests.rs | 53 +- comms/src/compat.rs | 11 +- comms/src/connection_manager/dial_state.rs | 2 +- comms/src/connection_manager/dialer.rs | 61 +- comms/src/connection_manager/error.rs | 8 +- comms/src/connection_manager/listener.rs | 53 +- comms/src/connection_manager/liveness.rs | 14 +- comms/src/connection_manager/manager.rs | 33 +- .../src/connection_manager/peer_connection.rs | 47 +- comms/src/connection_manager/requester.rs | 12 +- .../tests/listener_dialer.rs | 39 +- comms/src/connection_manager/tests/manager.rs | 56 +- comms/src/connection_manager/types.rs | 31 - comms/src/connectivity/manager.rs | 31 +- comms/src/connectivity/requester.rs | 31 +- comms/src/connectivity/selection.rs | 2 +- comms/src/connectivity/test.rs | 68 +- comms/src/framing.rs | 7 +- comms/src/lib.rs | 6 +- comms/src/memsocket/mod.rs | 223 +-- comms/src/message/outbound.rs | 2 +- comms/src/multiplexing/yamux.rs | 150 +- comms/src/noise/config.rs | 67 +- comms/src/noise/socket.rs | 112 +- comms/src/peer_manager/manager.rs | 8 +- comms/src/pipeline/builder.rs | 12 +- comms/src/pipeline/inbound.rs | 44 +- comms/src/pipeline/mod.rs | 3 - comms/src/pipeline/outbound.rs | 52 +- comms/src/pipeline/sink.rs | 24 +- comms/src/pipeline/translate_sink.rs | 5 +- comms/src/protocol/identity.rs | 20 +- comms/src/protocol/messaging/error.rs | 11 +- comms/src/protocol/messaging/extension.rs | 2 +- comms/src/protocol/messaging/forward.rs | 110 ++ comms/src/protocol/messaging/inbound.rs | 43 +- comms/src/protocol/messaging/mod.rs | 2 +- comms/src/protocol/messaging/outbound.rs | 71 +- comms/src/protocol/messaging/protocol.rs | 49 +- comms/src/protocol/messaging/test.rs | 57 +- comms/src/protocol/negotiation.rs | 12 +- comms/src/protocol/protocols.rs | 9 +- comms/src/protocol/rpc/body.rs | 16 +- comms/src/protocol/rpc/client.rs | 153 +- comms/src/protocol/rpc/client_pool.rs | 9 + comms/src/protocol/rpc/handshake.rs | 9 +- comms/src/protocol/rpc/mod.rs | 3 +- comms/src/protocol/rpc/server/error.rs | 6 +- comms/src/protocol/rpc/server/handle.rs | 5 +- comms/src/protocol/rpc/server/mock.rs | 17 +- comms/src/protocol/rpc/server/mod.rs | 40 +- comms/src/protocol/rpc/server/router.rs | 9 +- comms/src/protocol/rpc/test/client_pool.rs | 14 +- .../protocol/rpc/test/comms_integration.rs | 2 +- .../src/protocol/rpc/test/greeting_service.rs | 29 +- comms/src/protocol/rpc/test/handshake.rs | 4 +- comms/src/protocol/rpc/test/smoke.rs | 44 +- comms/src/{common => }/rate_limit.rs | 62 +- comms/src/runtime.rs | 5 +- comms/src/socks/client.rs | 8 +- .../test_utils/mocks/connection_manager.rs | 11 +- .../test_utils/mocks/connectivity_manager.rs | 40 +- comms/src/test_utils/mocks/peer_connection.rs | 13 +- comms/src/test_utils/test_node.rs | 6 +- comms/src/tor/control_client/client.rs | 20 +- comms/src/tor/control_client/monitor.rs | 28 +- comms/src/tor/control_client/test_server.rs | 10 +- comms/src/tor/hidden_service/controller.rs | 2 +- comms/src/transports/dns/tor.rs | 2 +- comms/src/transports/memory.rs | 3 +- comms/src/transports/mod.rs | 4 +- comms/src/transports/socks.rs | 13 +- comms/src/transports/tcp.rs | 130 +- comms/src/transports/tcp_with_tor.rs | 12 +- comms/src/utils/mod.rs | 1 + comms/src/{common/mod.rs => utils/mpsc.rs} | 14 +- comms/tests/greeting_service.rs | 18 +- comms/tests/rpc_stress.rs | 3 +- comms/tests/substream_stress.rs | 8 +- infrastructure/shutdown/Cargo.toml | 4 +- infrastructure/shutdown/src/lib.rs | 171 +- .../shutdown/src/oneshot_trigger.rs | 106 ++ infrastructure/storage/Cargo.toml | 4 +- infrastructure/test_utils/Cargo.toml | 3 +- .../src/futures/async_assert_eventually.rs | 4 +- infrastructure/test_utils/src/runtime.rs | 10 +- infrastructure/test_utils/src/streams/mod.rs | 159 +- integration_tests/features/Mempool.feature | 3 +- integration_tests/features/Reorgs.feature | 2 + integration_tests/features/StressTest.feature | 6 +- integration_tests/features/WalletFFI.feature | 165 +- integration_tests/features/support/steps.js | 681 +++++--- integration_tests/features/support/world.js | 52 +- integration_tests/helpers/ffi/byteVector.js | 41 +- integration_tests/helpers/ffi/commsConfig.js | 43 + .../helpers/ffi/completedTransaction.js | 93 +- .../helpers/ffi/completedTransactions.js | 31 +- integration_tests/helpers/ffi/contact.js | 40 +- integration_tests/helpers/ffi/contacts.js | 25 +- integration_tests/helpers/ffi/emojiSet.js | 36 + integration_tests/helpers/ffi/ffiInterface.js | 1473 +++++++++++++++++ .../helpers/ffi/pendingInboundTransaction.js | 60 +- .../helpers/ffi/pendingInboundTransactions.js | 32 +- .../helpers/ffi/pendingOutboundTransaction.js | 64 +- .../ffi/pendingOutboundTransactions.js | 32 +- integration_tests/helpers/ffi/privateKey.js | 67 + integration_tests/helpers/ffi/publicKey.js | 72 +- integration_tests/helpers/ffi/seedWords.js | 42 +- .../helpers/ffi/transportType.js | 85 + integration_tests/helpers/ffi/wallet.js | 449 +++++ integration_tests/helpers/walletFFIClient.js | 475 ++---- 441 files changed, 11777 insertions(+), 7742 deletions(-) rename base_layer/{wallet/src/util => common_types/src}/emoji.rs (97%) rename base_layer/{wallet/src/util => common_types/src}/luhn.rs (98%) delete mode 100644 base_layer/common_types/src/types.rs create mode 100644 base_layer/common_types/src/types/bullet_rangeproofs.rs create mode 100644 base_layer/common_types/src/types/mod.rs create mode 100644 base_layer/core/src/transactions/crypto_factories.rs create mode 100644 base_layer/p2p/src/dns/mock.rs create mode 100644 common/config/presets/tari_igor_config.toml rename comms/dht/src/{tower_filter => filter}/future.rs (66%) rename comms/dht/src/{tower_filter => filter}/layer.rs (100%) rename comms/dht/src/{tower_filter => filter}/mod.rs (92%) create mode 100644 comms/dht/src/filter/predicate.rs delete mode 100644 comms/dht/src/tower_filter/predicate.rs create mode 100644 comms/src/protocol/messaging/forward.rs rename comms/src/{common => }/rate_limit.rs (79%) rename comms/src/{common/mod.rs => utils/mpsc.rs} (84%) create mode 100644 infrastructure/shutdown/src/oneshot_trigger.rs create mode 100644 integration_tests/helpers/ffi/commsConfig.js create mode 100644 integration_tests/helpers/ffi/emojiSet.js create mode 100644 integration_tests/helpers/ffi/ffiInterface.js create mode 100644 integration_tests/helpers/ffi/privateKey.js create mode 100644 integration_tests/helpers/ffi/transportType.js create mode 100644 integration_tests/helpers/ffi/wallet.js diff --git a/.circleci/config.yml b/.circleci/config.yml index 0478f5aaae..f665baf02c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -129,10 +129,10 @@ commands: when: always - run: name: Run ffi cucumber scenarios - command: cd integration_tests && mkdir -p cucumber_output && node_modules/.bin/cucumber-js --tags "not @long-running and not @broken and not @flaky and @wallet-ffi" --format json:cucumber_output/tests-ffi.cucumber --exit + command: cd integration_tests && mkdir -p cucumber_output && node_modules/.bin/cucumber-js --tags "not @long-running and not @broken and not @flaky and @wallet-ffi" --format json:cucumber_output/tests_ffi.cucumber --exit - run: name: Generate report (ffi) - command: cd integration_tests && touch cucumber_output/tests-ffi.cucumber && node ./generate_report.js cucumber_output/tests-ffi.cucumber temp/reports/cucumber_ffi_report.html + command: cd integration_tests && node ./generate_report.js "cucumber_output/tests_ffi.cucumber" "temp/reports/cucumber_ffi_report.html" when: always # - run: # name: Run flaky/broken cucumber scenarios (Always pass) diff --git a/Cargo.lock b/Cargo.lock index f3a00b0139..b93e14637e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e61f2b7f93d2c7d2b08263acaa4a363b3e276806c68af6134c44f523bf1aacd" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -28,40 +19,29 @@ dependencies = [ [[package]] name = "aead" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e3e798aa0c8239776f54415bc06f3d74b1850f3f830b45c35cfc80556973f70" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" dependencies = [ "generic-array", ] -[[package]] -name = "aes" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2bc6d3f370b5666245ff421e231cba4353df936e26986d2918e61a8fd6aef6" -dependencies = [ - "aes-soft 0.5.0", - "aesni 0.8.0", - "block-cipher", -] - [[package]] name = "aes" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "884391ef1066acaa41e766ba8f596341b96e93ce34f9a43e7d24bf0a0eaf0561" dependencies = [ - "aes-soft 0.6.4", - "aesni 0.10.0", + "aes-soft", + "aesni", "cipher 0.2.5", ] [[package]] name = "aes" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495ee669413bfbe9e8cace80f4d3d78e6d8c8d99579f97fb93bde351b185f2d4" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -85,29 +65,18 @@ dependencies = [ [[package]] name = "aes-gcm" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a930fd487faaa92a30afa92cc9dd1526a5cff67124abbbb1c617ce070f4dcf" +checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" dependencies = [ - "aead 0.4.2", - "aes 0.7.4", + "aead 0.4.3", + "aes 0.7.5", "cipher 0.3.0", "ctr 0.8.0", - "ghash 0.4.3", + "ghash 0.4.4", "subtle", ] -[[package]] -name = "aes-soft" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63dd91889c49327ad7ef3b500fd1109dbd3c509a03db0d4a9ce413b79f575cb6" -dependencies = [ - "block-cipher", - "byteorder", - "opaque-debug", -] - [[package]] name = "aes-soft" version = "0.6.4" @@ -118,16 +87,6 @@ dependencies = [ "opaque-debug", ] -[[package]] -name = "aesni" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6fe808308bb07d393e2ea47780043ec47683fcf19cf5efc8ca51c50cc8c68a" -dependencies = [ - "block-cipher", - "opaque-debug", -] - [[package]] name = "aesni" version = "0.10.0" @@ -206,9 +165,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" dependencies = [ "async-stream-impl", "futures-core", @@ -216,9 +175,9 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -259,21 +218,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" -[[package]] -name = "backtrace" -version = "0.3.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a905d892734eea339e896738c14b9afce22b5318f64b951e70bf3844419b01" -dependencies = [ - "addr2line", - "cc", - "cfg-if 1.0.0", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base58-monero" version = "0.3.0" @@ -302,12 +246,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "base64" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" - [[package]] name = "base64" version = "0.12.3" @@ -337,7 +275,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -360,7 +298,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "which", + "which 3.1.1", ] [[package]] @@ -389,9 +327,9 @@ checksum = "3e54f7b7a46d7b183eb41e2d82965261fa8a1597c68b50aced268ee1fc70272d" [[package]] name = "blake2" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a5720225ef5daecf08657f23791354e1685a8c91a4c60c7f3d3b2892f978f4" +checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" dependencies = [ "crypto-mac", "digest", @@ -435,12 +373,12 @@ checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" [[package]] name = "blowfish" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f06850ba969bc59388b2cc0a4f186fc6d9d37208863b15b84ae3866ac90ac06" +checksum = "32fa6a061124e37baba002e496d203e23ba3d7b73750be82dbfbc92913048a5b" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -459,7 +397,7 @@ dependencies = [ "lazy_static 1.4.0", "memchr", "regex-automata", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -496,30 +434,20 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" -[[package]] -name = "bytes" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" -dependencies = [ - "byteorder", - "iovec", -] - [[package]] name = "bytes" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" -dependencies = [ - "serde 1.0.129", -] [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +dependencies = [ + "serde 1.0.130", +] [[package]] name = "c_linked_list" @@ -550,12 +478,12 @@ dependencies = [ [[package]] name = "cast5" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ed1e6b53a3de8bafcce4b88867893c234e57f91686a4726d8e803771f0b55b" +checksum = "1285caf81ea1f1ece6b24414c521e625ad0ec94d880625c20f2e65d8d3f78823" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -571,7 +499,7 @@ dependencies = [ "log 0.4.14", "proc-macro2 1.0.28", "quote 1.0.9", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "syn 1.0.75", "tempfile", @@ -595,11 +523,11 @@ dependencies = [ [[package]] name = "cfb-mode" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fa76b7293f89734378d27057d169dc68077ad34b21dbcabf1c0a646a9462592" +checksum = "1d6975e91054798d325f85f50115056d7deccf6817fe7f947c438ee45b119632" dependencies = [ - "stream-cipher", + "cipher 0.2.5", ] [[package]] @@ -616,9 +544,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chacha20" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8756167ea0aca10e066cdbe7813bd71d2f24e69b0bc7b50509590cef2ce0b9" +checksum = "f08493fa7707effc63254c66c6ea908675912493cd67952eda23c09fae2610b1" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -628,11 +556,11 @@ dependencies = [ [[package]] name = "chacha20poly1305" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "175a11316f33592cf2b71416ee65283730b5b7849813c4891d02a12906ed9acc" +checksum = "b6547abe025f4027edacd9edaa357aded014eecec42a5070d9b885c3c334aba2" dependencies = [ - "aead 0.4.2", + "aead 0.4.3", "chacha20", "cipher 0.3.0", "poly1305", @@ -654,7 +582,7 @@ dependencies = [ "libc", "num-integer", "num-traits 0.2.14", - "serde 1.0.129", + "serde 1.0.130", "time", "winapi 0.3.9", ] @@ -677,7 +605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6316c62053228eddd526a5e6deb6344c80bf2bc1e9786e7f90b3083e73197c1" dependencies = [ "bitstring", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -763,7 +691,7 @@ dependencies = [ "lazy_static 1.4.0", "nom 4.2.3", "rust-ini", - "serde 1.0.129", + "serde 1.0.130", "serde-hjson", "serde_json", "toml 0.4.10", @@ -788,9 +716,9 @@ checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" [[package]] name = "cpufeatures" -version = "0.1.5" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" dependencies = [ "libc", ] @@ -827,7 +755,7 @@ dependencies = [ "clap", "criterion-plot", "csv", - "itertools", + "itertools 0.8.2", "lazy_static 1.4.0", "libc", "num-traits 0.2.14", @@ -836,7 +764,7 @@ dependencies = [ "rand_xoshiro", "rayon", "rayon-core", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "tinytemplate", @@ -851,7 +779,7 @@ checksum = "76f9212ddf2f4a9eb2d401635190600656a1f88a932ef53d06e7fa4c7e02fb8e" dependencies = [ "byteorder", "cast", - "itertools", + "itertools 0.8.2", ] [[package]] @@ -1014,7 +942,7 @@ dependencies = [ "csv-core", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -1068,7 +996,7 @@ dependencies = [ "byteorder", "digest", "rand_core 0.5.1", - "serde 1.0.129", + "serde 1.0.130", "subtle", "zeroize", ] @@ -1083,7 +1011,7 @@ dependencies = [ "digest", "packed_simd_2", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "subtle-ng", "zeroize", ] @@ -1178,12 +1106,12 @@ dependencies = [ [[package]] name = "des" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e084b5048dec677e6c9f27d7abc551dde7d127cf4127fea82323c98a30d7fa0d" +checksum = "b24e7c748888aa2fa8bce21d8c64a52efc810663285315ac7476f7197a982fae" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -1279,7 +1207,7 @@ dependencies = [ "curve25519-dalek", "ed25519", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "sha2", "zeroize", ] @@ -1599,19 +1527,6 @@ dependencies = [ "pin-utils", ] -[[package]] -name = "futures-test-preview" -version = "0.3.0-alpha.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0813d833a213f893d1f07ccd9d49d3de100181d146053b2097b8b934d7945eb" -dependencies = [ - "futures-core-preview", - "futures-executor-preview", - "futures-io-preview", - "futures-util-preview", - "pin-utils", -] - [[package]] name = "futures-timer" version = "0.3.0" @@ -1731,20 +1646,14 @@ dependencies = [ [[package]] name = "ghash" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b442c439366184de619215247d24e908912b175e824a530253845ac4c251a5c1" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" dependencies = [ "opaque-debug", - "polyval 0.5.2", + "polyval 0.5.3", ] -[[package]] -name = "gimli" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" - [[package]] name = "git2" version = "0.8.0" @@ -1792,7 +1701,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7f3675cfef6a30c8031cf9e6493ebdc3bb3272a3fea3923c4210d1830e6a472" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -1847,7 +1756,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -1868,7 +1777,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] @@ -1955,7 +1864,7 @@ version = "0.14.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13f67199e765030fa08fe0bd581af683f0d5bc04ea09c2b1102012c5fb90e7fd" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -1973,6 +1882,18 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.12", + "pin-project-lite 0.2.7", + "tokio 1.10.1", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.4.3" @@ -1992,7 +1913,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "hyper 0.14.12", "native-tls", "tokio 1.10.1", @@ -2037,7 +1958,7 @@ dependencies = [ "byteorder", "color_quant", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits 0.2.14", ] @@ -2090,6 +2011,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -2112,7 +2042,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "436f3455a8a4e9c7b14de9f1206198ee5d0bdc2db1b560339d2141093d7dd389" dependencies = [ "hyper 0.10.16", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", ] @@ -2162,9 +2092,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.100" +version = "0.2.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1fa8cddc8fbbee11227ef194b5317ed014b8acbf15139bd716a18ad3fe99ec5" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libgit2-sys" @@ -2289,9 +2219,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" dependencies = [ "scopeguard", ] @@ -2312,7 +2242,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -2335,7 +2265,7 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.5.3", "serde_derive", "serde_yaml", @@ -2359,9 +2289,9 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.7.0", "serde_json", "serde_yaml", @@ -2512,17 +2442,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -2555,21 +2474,39 @@ dependencies = [ "fixed-hash", "hex", "hex-literal", - "serde 1.0.129", + "serde 1.0.130", "serde-big-array", "thiserror", "tiny-keccak", ] +[[package]] +name = "multiaddr" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48ee4ea82141951ac6379f964f71b20876d43712bea8faf6dd1a375e08a46499" +dependencies = [ + "arrayref", + "bs58", + "byteorder", + "data-encoding", + "multihash", + "percent-encoding 2.1.0", + "serde 1.0.130", + "static_assertions", + "unsigned-varint", + "url 2.2.2", +] + [[package]] name = "multihash" -version = "0.13.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac63698b887d2d929306ea48b63760431ff8a24fac40ddb22f9c7f49fb7cab" +checksum = "752a61cd890ff691b4411423d23816d5866dd5621e4d1c5687a53b94b5a979d8" dependencies = [ "generic-array", "multihash-derive", - "unsigned-varint 0.5.1", + "unsigned-varint", ] [[package]] @@ -2629,9 +2566,12 @@ checksum = "d36047f46c69ef97b60e7b069a26ce9a15cd8a7852eddb6991ea94a83ba36a78" [[package]] name = "nibble_vec" -version = "0.0.4" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d77f3db4bce033f4d04db08079b2ef1c3d02b44e86f25d08886fafa7756ffa" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] [[package]] name = "nix" @@ -2693,10 +2633,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" dependencies = [ "num-bigint 0.3.2", - "num-complex", + "num-complex 0.3.1", "num-integer", "num-iter", - "num-rational", + "num-rational 0.3.2", + "num-traits 0.2.14", +] + +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint 0.4.1", + "num-complex 0.4.0", + "num-integer", + "num-iter", + "num-rational 0.4.0", "num-traits 0.2.14", ] @@ -2722,6 +2676,17 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-bigint" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76e97c412795abf6c24ba30055a8f20642ea57ca12875220b854cfa501bf1e48" +dependencies = [ + "autocfg 1.0.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-bigint-dig" version = "0.6.1" @@ -2736,7 +2701,7 @@ dependencies = [ "num-iter", "num-traits 0.2.14", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "smallvec", "zeroize", ] @@ -2750,6 +2715,15 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -2804,6 +2778,18 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg 1.0.1", + "num-bigint 0.4.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-traits" version = "0.1.43" @@ -2832,15 +2818,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee2766204889d09937d00bfbb7fec56bb2a199e2ade963cab19185d8a6104c7c" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.8.0" @@ -2967,24 +2944,6 @@ dependencies = [ "libm 0.1.4", ] -[[package]] -name = "parity-multiaddr" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bfda2e46fc5e14122649e2645645a81ee5844e0fb2e727ef560cc71a8b2d801" -dependencies = [ - "arrayref", - "bs58", - "byteorder", - "data-encoding", - "multihash", - "percent-encoding 2.1.0", - "serde 1.0.129", - "static_assertions", - "unsigned-varint 0.6.0", - "url 2.2.2", -] - [[package]] name = "parking_lot" version = "0.10.2" @@ -2997,13 +2956,13 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", - "lock_api 0.4.4", - "parking_lot_core 0.8.3", + "lock_api 0.4.5", + "parking_lot_core 0.8.5", ] [[package]] @@ -3022,9 +2981,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" +checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" dependencies = [ "cfg-if 1.0.0", "instant", @@ -3090,11 +3049,11 @@ dependencies = [ [[package]] name = "pgp" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501f8c2834bc16a23ae40932b9f924c6c5fc1d7cd1cc3536a532f37e81f603ed" +checksum = "856124b4d0a95badd3e1ad353edd7157fc6c6995767b78ef62848f3b296405ff" dependencies = [ - "aes 0.5.0", + "aes 0.6.0", "base64 0.12.3", "bitfield", "block-modes", @@ -3105,6 +3064,7 @@ dependencies = [ "cast5", "cfb-mode", "chrono", + "cipher 0.2.5", "circular", "clear_on_drop", "crc24", @@ -3203,9 +3163,9 @@ checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "poly1305" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fcffab1f78ebbdf4b93b68c1ffebc24037eedf271edaca795732b24e5e4e349" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" dependencies = [ "cpufeatures", "opaque-debug", @@ -3225,9 +3185,9 @@ dependencies = [ [[package]] name = "polyval" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6ba6a405ef63530d6cb12802014b22f9c5751bd17cdcddbe9e46d5c8ae83287" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" dependencies = [ "cfg-if 1.0.0", "cpufeatures", @@ -3307,40 +3267,40 @@ dependencies = [ [[package]] name = "prost" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce49aefe0a6144a45de32927c77bd2859a5f7677b55f220ae5b744e87389c212" +checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost-derive", ] [[package]] name = "prost-build" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b10678c913ecbd69350e8535c3aef91a8676c0773fc1d7b95cdd196d7f2f26" +checksum = "355f634b43cdd80724ee7848f95770e7e70eefa6dcf14fea676216573b8fd603" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "heck", - "itertools", + "itertools 0.10.1", "log 0.4.14", "multimap", "petgraph", "prost", "prost-types", "tempfile", - "which", + "which 4.2.2", ] [[package]] name = "prost-derive" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537aa19b95acde10a12fec4301466386f757403de4cd4e5b4fa78fb5ecb18f72" +checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.1", "proc-macro2 1.0.28", "quote 1.0.9", "syn 1.0.75", @@ -3348,11 +3308,11 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1834f67c0697c001304b75be76f67add9c89742eda3a085ad8ee0bb38c3417aa" +checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost", ] @@ -3398,9 +3358,9 @@ dependencies = [ [[package]] name = "radix_trie" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3681b28cd95acfb0560ea9441f82d6a4504fa3b15b97bd7b6e952131820e95" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" dependencies = [ "endian-type", "nibble_vec", @@ -3417,7 +3377,6 @@ dependencies = [ "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc 0.2.0", - "rand_pcg", ] [[package]] @@ -3517,15 +3476,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "rand_pcg" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "rand_xoshiro" version = "0.1.0" @@ -3666,8 +3616,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", - "serde_json", + "serde 1.0.130", "serde_urlencoded", "tokio 0.2.25", "tokio-tls", @@ -3685,7 +3634,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246e9f61b9bb77df069a947682be06e31ac43ea37862e244a69f177694ea6d22" dependencies = [ "base64 0.13.0", - "bytes 1.0.1", + "bytes 1.1.0", "encoding_rs", "futures-core", "futures-util", @@ -3701,7 +3650,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "serde_urlencoded", "tokio 1.10.1", @@ -3757,7 +3706,7 @@ checksum = "011e1d58446e9fa3af7cdc1fb91295b10621d3ac4cb3a85cc86385ee9ca50cd3" dependencies = [ "byteorder", "rmp", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -3798,12 +3747,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e52c148ef37f8c375d49d5a73aa70713125b7f19095948a923f80afdeb22ec2" -[[package]] -name = "rustc-demangle" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -3836,11 +3779,11 @@ dependencies = [ [[package]] name = "rustls" -version = "0.17.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d4a31f5d68413404705d6982529b0e11a9aacd4839d1d6222ee3b8cb4015e1" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ - "base64 0.11.0", + "base64 0.13.0", "log 0.4.14", "ring", "sct", @@ -3931,22 +3874,23 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467" +checksum = "5b9bd29cdffb8875b04f71c51058f940cf4e390bbfd2ce669c4f22cd70b492a5" dependencies = [ "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", + "num 0.4.0", "security-framework-sys", ] [[package]] name = "security-framework-sys" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4effb91b4b8b6fb7732e670b6cee160278ff8e6bf485c7805d9e319d76e284" +checksum = "19133a286e494cc3311c165c4676ccb1fd47bed45b55f9d71fbd784ad4cea6f8" dependencies = [ "core-foundation-sys", "libc", @@ -3984,9 +3928,9 @@ checksum = "9dad3f759919b92c3068c696c15c3d17238234498bbdcc80f2c469606f948ac8" [[package]] name = "serde" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1f72836d2aa753853178eda473a3b9d8e4eefdaf20523b919677e6de489f8f1" +checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" dependencies = [ "serde_derive", ] @@ -3997,7 +3941,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_derive", ] @@ -4021,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a663f873dedc4eac1a559d4c6bc0d0b2c34dc5ac4702e105014b8281489e44f" dependencies = [ "ordered-float 1.1.1", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4031,14 +3975,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" dependencies = [ "ordered-float 2.7.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_derive" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e57ae87ad533d9a56427558b516d0adac283614e347abf85b0dc0cbbf0a249f3" +checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -4047,13 +3991,13 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" +checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950" dependencies = [ "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4085,26 +4029,26 @@ dependencies = [ "form_urlencoded", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_yaml" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" +checksum = "ad104641f3c958dab30eb3010e834c2622d1f3f4c530fef1dee20ad9485f3c09" dependencies = [ "dtoa", "indexmap", - "serde 1.0.129", + "serde 1.0.130", "yaml-rust", ] [[package]] name = "sha-1" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4115,9 +4059,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" +checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4208,7 +4152,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6142f7c25e94f6fd25a32c3348ec230df9109b463f59c8c7acc4bd34936babb7" dependencies = [ - "aes-gcm 0.9.3", + "aes-gcm 0.9.4", "blake2", "chacha20poly1305", "rand 0.8.4", @@ -4261,16 +4205,6 @@ dependencies = [ "futures 0.1.31", ] -[[package]] -name = "stream-cipher" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c80e15f898d8d8f25db24c253ea615cc14acf418ff307822995814e7d42cfa89" -dependencies = [ - "block-cipher", - "generic-array", -] - [[package]] name = "strsim" version = "0.8.0" @@ -4452,13 +4386,14 @@ dependencies = [ "strum", "strum_macros 0.19.4", "tari_common", + "tari_common_types", "tari_comms", "tari_core", "tari_crypto", "tari_p2p", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4482,6 +4417,7 @@ dependencies = [ "tari_app_grpc", "tari_app_utilities", "tari_common", + "tari_common_types", "tari_comms", "tari_comms_dht", "tari_core", @@ -4490,9 +4426,8 @@ dependencies = [ "tari_p2p", "tari_service_framework", "tari_shutdown", - "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4512,7 +4447,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "sha3", "subtle-ng", @@ -4530,12 +4465,12 @@ dependencies = [ "git2", "log 0.4.14", "log4rs 1.0.0", + "multiaddr", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", "path-clean", "prost-build", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "structopt", @@ -4552,11 +4487,13 @@ dependencies = [ name = "tari_common_types" version = "0.9.5" dependencies = [ + "digest", "futures 0.3.16", + "lazy_static 1.4.0", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "tari_crypto", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4567,7 +4504,7 @@ dependencies = [ "async-trait", "bitflags 1.3.2", "blake2", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "cidr", "clear_on_drop", @@ -4578,15 +4515,15 @@ dependencies = [ "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", + "multiaddr", "nom 5.1.2", "openssl", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "snow", @@ -4598,10 +4535,10 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-util 0.3.1", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.3.1", "tower-make", "tracing", "tracing-futures", @@ -4614,7 +4551,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bitflags 1.3.2", - "bytes 0.4.12", + "bytes 0.5.6", "chacha20", "chrono", "clap", @@ -4623,7 +4560,7 @@ dependencies = [ "digest", "env_logger 0.7.1", "futures 0.3.16", - "futures-test-preview", + "futures-test", "futures-util", "lazy_static 1.4.0", "libsqlite3-sys", @@ -4634,7 +4571,7 @@ dependencies = [ "prost", "prost-types", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_repr", "tari_common", @@ -4647,10 +4584,10 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-test", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-test 0.4.2", + "tower 0.3.1", "tower-test", "ttl_cache", ] @@ -4666,8 +4603,7 @@ dependencies = [ "syn 1.0.75", "tari_comms", "tari_test_utils", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tower-service", ] @@ -4693,6 +4629,7 @@ dependencies = [ "tari_app_grpc", "tari_app_utilities", "tari_common", + "tari_common_types", "tari_comms", "tari_comms_dht", "tari_core", @@ -4702,7 +4639,7 @@ dependencies = [ "tari_shutdown", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4719,7 +4656,7 @@ dependencies = [ "bincode", "bitflags 1.3.2", "blake2", - "bytes 0.4.12", + "bytes 0.5.6", "chrono", "config", "croaring", @@ -4728,17 +4665,18 @@ dependencies = [ "fs2", "futures 0.3.16", "hex", + "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", "monero", "newtype-ops", - "num", + "num 0.3.1", "num-format", "prost", "prost-types", "rand 0.8.4", "randomx-rs", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "strum_macros 0.17.1", @@ -4756,8 +4694,7 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tracing", "tracing-attributes", "tracing-futures", @@ -4781,7 +4718,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "sha3", @@ -4806,7 +4743,7 @@ version = "0.9.5" dependencies = [ "digest", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "sha2", @@ -4820,7 +4757,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bincode", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "config", "derive-error", @@ -4828,12 +4765,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.8.4", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -4843,8 +4780,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tracing", "tracing-futures", @@ -4868,7 +4804,7 @@ dependencies = [ "prost-types", "rand 0.8.4", "reqwest 0.11.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "tari_app_grpc", @@ -4878,7 +4814,7 @@ dependencies = [ "tari_crypto", "thiserror", "time", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4893,7 +4829,7 @@ dependencies = [ "digest", "log 0.4.14", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_crypto", "tari_infra_derive", @@ -4922,7 +4858,7 @@ dependencies = [ "rand 0.8.4", "reqwest 0.10.10", "semver 1.0.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "stream-cancel", "tari_common", @@ -4936,9 +4872,9 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tower 0.3.1", "tower-service", "trust-dns-client", ] @@ -4955,9 +4891,8 @@ dependencies = [ "tari_shutdown", "tari_test_utils", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", "tower-service", ] @@ -4966,7 +4901,7 @@ name = "tari_shutdown" version = "0.9.5" dependencies = [ "futures 0.3.16", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4974,14 +4909,14 @@ name = "tari_storage" version = "0.9.5" dependencies = [ "bincode", - "bytes 0.4.12", + "bytes 0.5.6", "env_logger 0.6.2", "lmdb-zero", "log 0.4.14", "rand 0.8.4", "rmp", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "tari_utilities", "thiserror", @@ -4993,7 +4928,7 @@ version = "0.0.1" dependencies = [ "hex", "libc", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_app_grpc", "tari_common", @@ -5017,12 +4952,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.7.3", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -5031,8 +4966,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tonic-build", "tracing", @@ -5051,7 +4985,7 @@ dependencies = [ "rand 0.8.4", "tari_shutdown", "tempfile", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5067,7 +5001,7 @@ dependencies = [ "clear_on_drop", "newtype-ops", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "thiserror", ] @@ -5087,14 +5021,13 @@ dependencies = [ "env_logger 0.7.1", "fs2", "futures 0.3.16", - "lazy_static 1.4.0", "libsqlite3-sys", "lmdb-zero", "log 0.4.14", "log4rs 1.0.0", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_common_types", "tari_comms", @@ -5110,9 +5043,8 @@ dependencies = [ "tempfile", "thiserror", "time", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", ] [[package]] @@ -5140,7 +5072,7 @@ dependencies = [ "tari_wallet", "tempfile", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5171,12 +5103,13 @@ name = "test_faucet" version = "0.9.5" dependencies = [ "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", + "tari_common_types", "tari_core", "tari_crypto", "tari_utilities", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5190,18 +5123,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "283d5230e63df9608ac7d9691adc1dfb6e701225436eb64d0b9a7f0a5a04f6ec" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "fa3884228611f5cd3608e2d409bf7dce832e4eb3135e3f11addbd7e41bd68e71" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5276,7 +5209,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_json", ] @@ -5306,16 +5239,10 @@ dependencies = [ "futures-core", "iovec", "lazy_static 1.4.0", - "libc", "memchr", "mio 0.6.23", - "mio-uds", - "num_cpus", "pin-project-lite 0.1.12", - "signal-hook-registry", "slab", - "tokio-macros", - "winapi 0.3.9", ] [[package]] @@ -5325,20 +5252,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92036be488bb6594459f2e03b60e42df6f937fe6ca5c5ffdcb539c6b84dc40f5" dependencies = [ "autocfg 1.0.1", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", "mio 0.7.13", "num_cpus", + "once_cell", "pin-project-lite 0.2.7", + "signal-hook-registry", + "tokio-macros", "winapi 0.3.9", ] +[[package]] +name = "tokio-io-timeout" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90c49f106be240de154571dd31fbe48acb10ba6c6dd6f6517ad603abffa42de9" +dependencies = [ + "pin-project-lite 0.2.7", + "tokio 1.10.1", +] + [[package]] name = "tokio-macros" -version = "0.2.6" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5355,6 +5295,17 @@ dependencies = [ "tokio 1.10.1", ] +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio 1.10.1", + "webpki", +] + [[package]] name = "tokio-stream" version = "0.1.7" @@ -5364,6 +5315,7 @@ dependencies = [ "futures-core", "pin-project-lite 0.2.7", "tokio 1.10.1", + "tokio-util 0.6.7", ] [[package]] @@ -5377,6 +5329,19 @@ dependencies = [ "tokio 0.2.25", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes 1.1.0", + "futures-core", + "tokio 1.10.1", + "tokio-stream", +] + [[package]] name = "tokio-tls" version = "0.3.1" @@ -5407,8 +5372,9 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", + "futures-io", "futures-sink", "log 0.4.14", "pin-project-lite 0.2.7", @@ -5421,7 +5387,7 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -5430,34 +5396,35 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "tonic" -version = "0.2.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4afef9ce97ea39593992cf3fa00ff33b1ad5eb07665b31355df63a690e38c736" +checksum = "796c5e1cd49905e65dd8e700d4cb1dffcbfdb4fc9d017de08c1a537afd83627c" dependencies = [ "async-stream", "async-trait", - "base64 0.11.0", - "bytes 0.5.6", + "base64 0.13.0", + "bytes 1.1.0", "futures-core", "futures-util", + "h2 0.3.4", "http", - "http-body 0.3.1", - "hyper 0.13.10", + "http-body 0.4.3", + "hyper 0.14.12", + "hyper-timeout", "percent-encoding 2.1.0", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "prost-derive", - "tokio 0.2.25", - "tokio-util 0.3.1", - "tower", - "tower-balance", - "tower-load", - "tower-make", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.4.8", + "tower-layer", "tower-service", "tracing", "tracing-futures", @@ -5465,9 +5432,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d8d21cb568e802d77055ab7fcd43f0992206de5028de95c8d3a41118d32e8e" +checksum = "12b52d07035516c2b74337d2ac7746075e7dcae7643816c1b12c5ff8a7484c08" dependencies = [ "proc-macro2 1.0.28", "prost-build", @@ -5494,23 +5461,21 @@ dependencies = [ ] [[package]] -name = "tower-balance" -version = "0.3.0" +name = "tower" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a792277613b7052448851efcf98a2c433e6f1d01460832dc60bef676bc275d4c" +checksum = "f60422bc7fefa2f3ec70359b8ff1caff59d785877eb70595904605bcc412470f" dependencies = [ "futures-core", "futures-util", "indexmap", - "pin-project 0.4.28", - "rand 0.7.3", + "pin-project 1.0.8", + "rand 0.8.4", "slab", - "tokio 0.2.25", - "tower-discover", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", "tower-layer", - "tower-load", - "tower-make", - "tower-ready-cache", "tower-service", "tracing", ] @@ -5592,21 +5557,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" dependencies = [ - "tokio 0.2.25", - "tower-service", -] - -[[package]] -name = "tower-ready-cache" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eabb6620e5481267e2ec832c780b31cad0c15dcb14ed825df5076b26b591e1f" -dependencies = [ - "futures-core", - "futures-util", - "indexmap", - "log 0.4.14", - "tokio 0.2.25", "tower-service", ] @@ -5638,7 +5588,7 @@ dependencies = [ "futures-util", "pin-project 0.4.28", "tokio 0.2.25", - "tokio-test", + "tokio-test 0.2.1", "tower-layer", "tower-service", ] @@ -5740,7 +5690,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb65ea441fbb84f9f6748fd496cf7f63ec9af5bca94dd86456978d055e8eb28b" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "tracing-core", ] @@ -5755,7 +5705,7 @@ dependencies = [ "lazy_static 1.4.0", "matchers", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sharded-slab", "smallvec", @@ -5774,47 +5724,54 @@ checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" [[package]] name = "trust-dns-client" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e935ae5a26a2745fb5a6b95f0e206e1cfb7f00066892d2cf78a8fee87bc2e0c6" +checksum = "37532ce92c75c6174b1d51ed612e26c5fde66ef3f29aa10dbd84e7c5d9a0c27b" dependencies = [ "cfg-if 1.0.0", "chrono", "data-encoding", - "futures 0.3.16", + "futures-channel", + "futures-util", "lazy_static 1.4.0", "log 0.4.14", "radix_trie", - "rand 0.7.3", + "rand 0.8.4", "ring", "rustls", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "trust-dns-proto", "webpki", ] [[package]] name = "trust-dns-proto" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cad71a0c0d68ab9941d2fb6e82f8fb2e86d9945b94e1661dd0aaea2b88215a9" +checksum = "4cd23117e93ea0e776abfd8a07c9e389d7ecd3377827858f21bd795ebdfefa36" dependencies = [ "async-trait", - "backtrace", "cfg-if 1.0.0", "data-encoding", "enum-as-inner", - "futures 0.3.16", + "futures-channel", + "futures-io", + "futures-util", "idna 0.2.3", + "ipnet", "lazy_static 1.4.0", "log 0.4.14", - "rand 0.7.3", + "rand 0.8.4", "ring", + "rustls", "smallvec", "thiserror", - "tokio 0.2.25", + "tinyvec", + "tokio 1.10.1", + "tokio-rustls", "url 2.2.2", + "webpki", ] [[package]] @@ -5856,12 +5813,12 @@ dependencies = [ [[package]] name = "twofish" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a30db256d7388f6e08efa0a8e9e62ee34dd1af59706c76c9e8c97c2a500f12" +checksum = "0028f5982f23ecc9a1bc3008ead4c664f843ed5d78acd3d213b99ff50c441bc2" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -5988,15 +5945,9 @@ dependencies = [ [[package]] name = "unsigned-varint" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fdeedbf205afadfe39ae559b75c3240f24e257d0ca27e85f85cb82aa19ac35" - -[[package]] -name = "unsigned-varint" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35581ff83d4101e58b582e607120c7f5ffb17e632a980b1f38334d76b36908b2" +checksum = "5f8d425fafb8cd76bc3f22aace4af471d3156301d7508f2107e98fbeae10bc7f" [[package]] name = "untrusted" @@ -6097,7 +6048,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce9b1b516211d33767048e5d47fa2a381ed8b76fc48d2ce4aa39877f9f183e0" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "wasm-bindgen-macro", ] @@ -6187,6 +6138,17 @@ dependencies = [ "libc", ] +[[package]] +name = "which" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea187a8ef279bc014ec368c27a920da2024d2a711109bfbe3440585d5cf27ad9" +dependencies = [ + "either", + "lazy_static 1.4.0", + "libc", +] + [[package]] name = "winapi" version = "0.2.8" @@ -6278,7 +6240,7 @@ dependencies = [ "futures 0.3.16", "log 0.4.14", "nohash-hasher", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "rand 0.8.4", "static_assertions", ] diff --git a/Cargo.toml b/Cargo.toml index 2394f2eb9c..ac268c2585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,6 @@ members = [ "applications/tari_stratum_transcoder", "applications/tari_mining_node", ] +# +#[profile.release] +#debug = true diff --git a/applications/ffi_client/index.js b/applications/ffi_client/index.js index bf1a47d3c2..aa9473740b 100644 --- a/applications/ffi_client/index.js +++ b/applications/ffi_client/index.js @@ -1,6 +1,8 @@ // this is nasty // ¯\_(ツ)_/¯ +// TODO: Use implementation in cucumber tests instead (see helpers/ffi). + const lib = require("./lib"); const ref = require("ref-napi"); const ffi = require("ffi-napi"); diff --git a/applications/tari_app_grpc/Cargo.toml b/applications/tari_app_grpc/Cargo.toml index bec7e05720..8a40dba4ee 100644 --- a/applications/tari_app_grpc/Cargo.toml +++ b/applications/tari_app_grpc/Cargo.toml @@ -10,14 +10,17 @@ edition = "2018" [dependencies] tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_core = { path = "../../base_layer/core"} -tari_wallet = { path = "../../base_layer/wallet"} +tari_wallet = { path = "../../base_layer/wallet", optional = true} tari_crypto = "0.11.1" tari_comms = { path = "../../comms"} chrono = "0.4.6" -prost = "0.6" -prost-types = "0.6.1" -tonic = "0.2" +prost = "0.8" +prost-types = "0.8" +tonic = "0.5.2" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" + +[features] +wallet = ["tari_wallet"] \ No newline at end of file diff --git a/applications/tari_app_grpc/src/conversions/block_header.rs b/applications/tari_app_grpc/src/conversions/block_header.rs index 3b660f21e1..5bb37e87e3 100644 --- a/applications/tari_app_grpc/src/conversions/block_header.rs +++ b/applications/tari_app_grpc/src/conversions/block_header.rs @@ -25,7 +25,8 @@ use crate::{ tari_rpc as grpc, }; use std::convert::TryFrom; -use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork, transactions::types::BlindingFactor}; +use tari_common_types::types::BlindingFactor; +use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork}; use tari_crypto::tari_utilities::{ByteArray, Hashable}; impl From for grpc::BlockHeader { diff --git a/applications/tari_app_grpc/src/conversions/com_signature.rs b/applications/tari_app_grpc/src/conversions/com_signature.rs index e10e48ffe8..1924e1c054 100644 --- a/applications/tari_app_grpc/src/conversions/com_signature.rs +++ b/applications/tari_app_grpc/src/conversions/com_signature.rs @@ -21,10 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; -use tari_core::transactions::types::{ComSignature, Commitment, PrivateKey}; +use tari_common_types::types::{ComSignature, Commitment, PrivateKey}; impl TryFrom for ComSignature { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/mod.rs b/applications/tari_app_grpc/src/conversions/mod.rs index e404d52d1a..f48a1b876d 100644 --- a/applications/tari_app_grpc/src/conversions/mod.rs +++ b/applications/tari_app_grpc/src/conversions/mod.rs @@ -59,7 +59,7 @@ pub use self::{ use crate::{tari_rpc as grpc, tari_rpc::BlockGroupRequest}; use prost_types::Timestamp; -use tari_crypto::tari_utilities::epoch_time::EpochTime; +use tari_core::crypto::tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `EpochTime` to a `prost::Timestamp` pub fn datetime_to_timestamp(datetime: EpochTime) -> Timestamp { diff --git a/applications/tari_app_grpc/src/conversions/new_block_template.rs b/applications/tari_app_grpc/src/conversions/new_block_template.rs index 7d87900cd1..15c41c499b 100644 --- a/applications/tari_app_grpc/src/conversions/new_block_template.rs +++ b/applications/tari_app_grpc/src/conversions/new_block_template.rs @@ -22,12 +22,12 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::BlindingFactor; use tari_core::{ blocks::{NewBlockHeaderTemplate, NewBlockTemplate}, + crypto::tari_utilities::ByteArray, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; -use tari_crypto::tari_utilities::ByteArray; impl From for grpc::NewBlockTemplate { fn from(block: NewBlockTemplate) -> Self { let header = grpc::NewBlockHeaderTemplate { diff --git a/applications/tari_app_grpc/src/conversions/peer.rs b/applications/tari_app_grpc/src/conversions/peer.rs index f04d3fd8dd..f3bd151d0d 100644 --- a/applications/tari_app_grpc/src/conversions/peer.rs +++ b/applications/tari_app_grpc/src/conversions/peer.rs @@ -22,7 +22,7 @@ use crate::{conversions::datetime_to_timestamp, tari_rpc as grpc}; use tari_comms::{connectivity::ConnectivityStatus, net_address::MutliaddrWithStats, peer_manager::Peer}; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; impl From for grpc::Peer { fn from(peer: Peer) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/signature.rs b/applications/tari_app_grpc/src/conversions/signature.rs index d9883a338e..2f0fe10cd3 100644 --- a/applications/tari_app_grpc/src/conversions/signature.rs +++ b/applications/tari_app_grpc/src/conversions/signature.rs @@ -21,10 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; -use tari_core::transactions::types::{PrivateKey, PublicKey, Signature}; +use tari_common_types::types::{PrivateKey, PublicKey, Signature}; impl TryFrom for Signature { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/transaction.rs b/applications/tari_app_grpc/src/conversions/transaction.rs index 97dcaf8795..cb9e91f63a 100644 --- a/applications/tari_app_grpc/src/conversions/transaction.rs +++ b/applications/tari_app_grpc/src/conversions/transaction.rs @@ -22,9 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::transaction::Transaction; -use tari_crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}; -use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; +use tari_core::{ + crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}, + transactions::transaction::Transaction, +}; impl From for grpc::Transaction { fn from(source: Transaction) -> Self { @@ -53,38 +54,44 @@ impl TryFrom for Transaction { } } -impl From for grpc::TransactionStatus { - fn from(status: models::TransactionStatus) -> Self { - use models::TransactionStatus::*; - match status { - Completed => grpc::TransactionStatus::Completed, - Broadcast => grpc::TransactionStatus::Broadcast, - MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, - MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, - Imported => grpc::TransactionStatus::Imported, - Pending => grpc::TransactionStatus::Pending, - Coinbase => grpc::TransactionStatus::Coinbase, +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; + + impl From for grpc::TransactionStatus { + fn from(status: models::TransactionStatus) -> Self { + use models::TransactionStatus::*; + match status { + Completed => grpc::TransactionStatus::Completed, + Broadcast => grpc::TransactionStatus::Broadcast, + MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, + MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, + Imported => grpc::TransactionStatus::Imported, + Pending => grpc::TransactionStatus::Pending, + Coinbase => grpc::TransactionStatus::Coinbase, + } } } -} -impl From for grpc::TransactionDirection { - fn from(status: models::TransactionDirection) -> Self { - use models::TransactionDirection::*; - match status { - Unknown => grpc::TransactionDirection::Unknown, - Inbound => grpc::TransactionDirection::Inbound, - Outbound => grpc::TransactionDirection::Outbound, + impl From for grpc::TransactionDirection { + fn from(status: models::TransactionDirection) -> Self { + use models::TransactionDirection::*; + match status { + Unknown => grpc::TransactionDirection::Unknown, + Inbound => grpc::TransactionDirection::Inbound, + Outbound => grpc::TransactionDirection::Outbound, + } } } -} -impl grpc::TransactionInfo { - pub fn not_found(tx_id: TxId) -> Self { - Self { - tx_id, - status: grpc::TransactionStatus::NotFound as i32, - ..Default::default() + impl grpc::TransactionInfo { + pub fn not_found(tx_id: TxId) -> Self { + Self { + tx_id, + status: grpc::TransactionStatus::NotFound as i32, + ..Default::default() + } } } } diff --git a/applications/tari_app_grpc/src/conversions/transaction_input.rs b/applications/tari_app_grpc/src/conversions/transaction_input.rs index ed5793a2f2..48eebe04ad 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_input.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_input.rs @@ -22,10 +22,8 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - transaction::TransactionInput, - types::{Commitment, PublicKey}, -}; +use tari_common_types::types::{Commitment, PublicKey}; +use tari_core::transactions::transaction::TransactionInput; use tari_crypto::{ script::{ExecutionStack, TariScript}, tari_utilities::{ByteArray, Hashable}, diff --git a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs index e394a6bce5..7bf8664487 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs @@ -22,10 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::Commitment; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{KernelFeatures, TransactionKernel}, - types::Commitment, }; use tari_crypto::tari_utilities::{ByteArray, Hashable}; diff --git a/applications/tari_app_grpc/src/conversions/transaction_output.rs b/applications/tari_app_grpc/src/conversions/transaction_output.rs index b9556b2940..7b783e3498 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_output.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_output.rs @@ -22,14 +22,13 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - bullet_rangeproofs::BulletRangeProof, - transaction::TransactionOutput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::TariScript, - tari_utilities::{ByteArray, Hashable}, +use tari_common_types::types::{BulletRangeProof, Commitment, PublicKey}; +use tari_core::{ + crypto::{ + script::TariScript, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::transaction::TransactionOutput, }; impl TryFrom for TransactionOutput { diff --git a/applications/tari_app_grpc/src/conversions/unblinded_output.rs b/applications/tari_app_grpc/src/conversions/unblinded_output.rs index 94ac4c178d..bf9efa58bc 100644 --- a/applications/tari_app_grpc/src/conversions/unblinded_output.rs +++ b/applications/tari_app_grpc/src/conversions/unblinded_output.rs @@ -22,14 +22,13 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::UnblindedOutput, - types::{PrivateKey, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::ByteArray, +use tari_common_types::types::{PrivateKey, PublicKey}; +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::ByteArray, + }, + transactions::{tari_amount::MicroTari, transaction::UnblindedOutput}, }; impl From for grpc::UnblindedOutput { diff --git a/applications/tari_app_utilities/Cargo.toml b/applications/tari_app_utilities/Cargo.toml index 333af5f959..5fdae3d097 100644 --- a/applications/tari_app_utilities/Cargo.toml +++ b/applications/tari_app_utilities/Cargo.toml @@ -8,22 +8,23 @@ edition = "2018" tari_comms = { path = "../../comms"} tari_crypto = "0.11.1" tari_common = { path = "../../common" } +tari_common_types ={ path ="../../base_layer/common_types"} tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", optional = true } config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} qrcode = { version = "0.12" } dirs-next = "1.0.2" serde_json = "1.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version="^1.10", features = ["signal"] } structopt = { version = "0.3.13", default_features = false } strum = "^0.19" strum_macros = "^0.19" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" [dependencies.tari_core] path = "../../base_layer/core" @@ -33,3 +34,7 @@ features = ["transactions"] [build-dependencies] tari_common = { path = "../../common", features = ["build", "static-application-info"] } + +[features] +# TODO: This crate is supposed to hold common logic. Move code from this feature into the crate that is more specific to the wallet +wallet = ["tari_wallet"] diff --git a/applications/tari_app_utilities/src/identity_management.rs b/applications/tari_app_utilities/src/identity_management.rs index 40bfcf5cee..013a0e8fc8 100644 --- a/applications/tari_app_utilities/src/identity_management.rs +++ b/applications/tari_app_utilities/src/identity_management.rs @@ -25,8 +25,8 @@ use log::*; use rand::rngs::OsRng; use std::{clone::Clone, fs, path::Path, string::ToString, sync::Arc}; use tari_common::configuration::bootstrap::prompt; +use tari_common_types::types::PrivateKey; use tari_comms::{multiaddr::Multiaddr, peer_manager::PeerFeatures, NodeIdentity}; -use tari_core::transactions::types::PrivateKey; use tari_crypto::{ keys::SecretKey, tari_utilities::{hex::Hex, message_format::MessageFormat}, @@ -55,14 +55,19 @@ pub fn setup_node_identity>( if !create_id { let prompt = prompt("Node identity does not exist.\nWould you like to to create one (Y/n)?"); if !prompt { - let msg = format!( + error!( + target: LOG_TARGET, "Node identity information not found. {}. You can update the configuration file to point to a \ valid node identity file, or re-run the node with the --create-id flag to create a new \ identity.", e ); - error!(target: LOG_TARGET, "{}", msg); - return Err(ExitCodes::ConfigError(msg)); + return Err(ExitCodes::ConfigError(format!( + "Node identity information not found. {}. You can update the configuration file to point to a \ + valid node identity file, or re-run the node with the --create-id flag to create a new \ + identity.", + e + ))); }; } diff --git a/applications/tari_app_utilities/src/initialization.rs b/applications/tari_app_utilities/src/initialization.rs index ad66210437..2497788307 100644 --- a/applications/tari_app_utilities/src/initialization.rs +++ b/applications/tari_app_utilities/src/initialization.rs @@ -1,8 +1,13 @@ use crate::{consts, utilities::ExitCodes}; use config::Config; -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; use structopt::StructOpt; -use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DatabaseType, GlobalConfig}; +use tari_common::{ + configuration::{bootstrap::ApplicationType, Network}, + ConfigBootstrap, + DatabaseType, + GlobalConfig, +}; pub const LOG_TARGET: &str = "tari::application"; @@ -27,6 +32,28 @@ pub fn init_configuration( let mut global_config = GlobalConfig::convert_from(application_type, cfg.clone()) .map_err(|err| ExitCodes::ConfigError(err.to_string()))?; check_file_paths(&mut global_config, &bootstrap); + + if let Some(str) = bootstrap.network.clone() { + log::info!(target: LOG_TARGET, "Network selection requested"); + let network = Network::from_str(&str); + match network { + Ok(network) => { + log::info!( + target: LOG_TARGET, + "Network selection successful, current network is: {}", + network + ); + global_config.network = network; + }, + Err(_) => { + log::warn!( + target: LOG_TARGET, + "Network selection was invalid, continuing with default network." + ); + }, + } + } + Ok((bootstrap, global_config, cfg)) } diff --git a/applications/tari_app_utilities/src/utilities.rs b/applications/tari_app_utilities/src/utilities.rs index d963f212a7..23a1ebf9ab 100644 --- a/applications/tari_app_utilities/src/utilities.rs +++ b/applications/tari_app_utilities/src/utilities.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::identity_management::load_from_json; use futures::future::Either; use log::*; +use thiserror::Error; +use tokio::{runtime, runtime::Runtime}; + use tari_common::{CommsTransport, GlobalConfig, SocksAuthentication, TorControlAuthentication}; use tari_comms::{ connectivity::ConnectivityError, @@ -37,13 +39,9 @@ use tari_comms::{ }; use tari_core::tari_utilities::hex::Hex; use tari_p2p::transport::{TorConfig, TransportType}; -use tari_wallet::{ - error::{WalletError, WalletStorageError}, - output_manager_service::error::OutputManagerError, - util::emoji::EmojiId, -}; -use thiserror::Error; -use tokio::{runtime, runtime::Runtime}; + +use crate::identity_management::load_from_json; +use tari_common_types::emoji::EmojiId; pub const LOG_TARGET: &str = "tari::application"; @@ -107,20 +105,6 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - -impl From for ExitCodes { - fn from(err: OutputManagerError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - impl From for ExitCodes { fn from(err: ConnectivityError) -> Self { error!(target: LOG_TARGET, "{}", err); @@ -135,13 +119,36 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletStorageError) -> Self { - use WalletStorageError::*; - match err { - NoPasswordError => ExitCodes::NoPassword, - IncorrectPassword => ExitCodes::IncorrectPassword, - e => ExitCodes::WalletError(e.to_string()), +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{ + error::{WalletError, WalletStorageError}, + output_manager_service::error::OutputManagerError, + }; + + impl From for ExitCodes { + fn from(err: WalletError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: OutputManagerError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: WalletStorageError) -> Self { + use WalletStorageError::*; + match err { + NoPasswordError => ExitCodes::NoPassword, + IncorrectPassword => ExitCodes::IncorrectPassword, + e => ExitCodes::WalletError(e.to_string()), + } } } } @@ -259,26 +266,22 @@ pub fn convert_socks_authentication(auth: SocksAuthentication) -> socks::Authent /// ## Returns /// A result containing the runtime on success, string indicating the error on failure pub fn setup_runtime(config: &GlobalConfig) -> Result { - info!( - target: LOG_TARGET, - "Configuring the node to run on up to {} core threads and {} mining threads.", - config.max_threads.unwrap_or(512), - config.num_mining_threads - ); - - let mut builder = runtime::Builder::new(); + let mut builder = runtime::Builder::new_multi_thread(); - if let Some(max_threads) = config.max_threads { - // Ensure that there are always enough threads for mining. - // e.g if the user sets max_threads = 2, mining_threads = 5 then 7 threads are available in total - builder.max_threads(max_threads + config.num_mining_threads); - } if let Some(core_threads) = config.core_threads { - builder.core_threads(core_threads); + info!( + target: LOG_TARGET, + "Configuring the node to run on up to {} core threads.", + config + .core_threads + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "".to_string()), + ); + builder.worker_threads(core_threads); } builder - .threaded_scheduler() .enable_all() .build() .map_err(|e| format!("There was an error while building the node runtime. {}", e.to_string())) diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index db832f8a9a..b41db079ef 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -11,30 +11,30 @@ edition = "2018" tari_app_grpc = { path = "../tari_app_grpc" } tari_app_utilities = { path = "../tari_app_utilities" } tari_common = { path = "../../common" } -tari_comms = { path = "../../comms", features = ["rpc"]} -tari_comms_dht = { path = "../../comms/dht"} -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_comms = { path = "../../comms", features = ["rpc"] } +tari_common_types = {path = "../../base_layer/common_types"} +tari_comms_dht = { path = "../../comms/dht" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_mmr = { path = "../../base_layer/mmr" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_service_framework = { path = "../../base_layer/service_framework"} -tari_shutdown = { path = "../../infrastructure/shutdown"} -tari_wallet = { path = "../../base_layer/wallet" } +tari_service_framework = { path = "../../base_layer/service_framework" } +tari_shutdown = { path = "../../infrastructure/shutdown" } anyhow = "1.0.32" bincode = "1.3.1" chrono = "0.4" config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"] } log = { version = "0.4.8", features = ["std"] } regex = "1" rustyline = "6.0" rustyline-derive = "0.3" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version = "^1.10", features = ["signal"] } strum = "^0.19" strum_macros = "0.18.0" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" tracing-subscriber = "0.2.20" @@ -44,7 +44,7 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} [features] -avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_wallet/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] +avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] safe = [] diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index cd10870e6d..7cb00ce527 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{cmp, fs, str::FromStr, sync::Arc, time::Duration}; + use anyhow::anyhow; use log::*; -use std::{cmp, fs, str::FromStr, sync::Arc, time::Duration}; + use tari_app_utilities::{consts, identity_management, utilities::create_transport_type}; use tari_common::{configuration::bootstrap::ApplicationType, GlobalConfig}; use tari_comms::{peer_manager::Peer, protocol::rpc::RpcServer, NodeIdentity, UnspawnedCommsNode}; @@ -47,7 +49,7 @@ use tari_core::{ MempoolServiceInitializer, MempoolSyncInitializer, }, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; use tari_p2p::{ auto_update::{AutoUpdateConfig, SoftwareUpdaterService}, @@ -59,7 +61,6 @@ use tari_p2p::{ }; use tari_service_framework::{ServiceHandles, StackBuilder}; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "c::bn::initialization"; /// The minimum buffer size for the base node pubsub_connector channel @@ -84,8 +85,7 @@ where B: BlockchainBackend + 'static fs::create_dir_all(&config.peer_db_path)?; let buf_size = cmp::max(BASE_NODE_BUFFER_MIN_SIZE, config.buffer_size_base_node); - let (publisher, peer_message_subscriptions) = - pubsub_connector(runtime::Handle::current(), buf_size, config.buffer_rate_limit_base_node); + let (publisher, peer_message_subscriptions) = pubsub_connector(buf_size, config.buffer_rate_limit_base_node); let peer_message_subscriptions = Arc::new(peer_message_subscriptions); let node_config = BaseNodeServiceConfig { diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index dc36f64f59..ee374b339e 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::bootstrap::BaseNodeBootstrapper; -use log::*; use std::sync::Arc; + +use log::*; +use tokio::sync::watch; + use tari_common::{configuration::Network, DatabaseType, GlobalConfig}; use tari_comms::{peer_manager::NodeIdentity, protocol::rpc::RpcServerHandle, CommsNode}; use tari_comms_dht::Dht; @@ -32,7 +34,7 @@ use tari_core::{ consensus::ConsensusManager, mempool::{service::LocalMempoolService, Mempool, MempoolConfig}, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, header_validator::HeaderValidator, @@ -48,7 +50,8 @@ use tari_core::{ use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; use tari_service_framework::ServiceHandles; use tari_shutdown::ShutdownSignal; -use tokio::sync::watch; + +use crate::bootstrap::BaseNodeBootstrapper; const LOG_TARGET: &str = "c::bn::initialization"; @@ -71,10 +74,8 @@ impl BaseNodeContext { pub async fn run(self) { info!(target: LOG_TARGET, "Tari base node has STARTED"); - if let Err(e) = self.state_machine().shutdown_signal().await { - warn!(target: LOG_TARGET, "Error shutting down Base Node State Machine: {}", e); - } - info!(target: LOG_TARGET, "Initiating communications stack shutdown"); + self.state_machine().shutdown_signal().wait().await; + info!(target: LOG_TARGET, "Waiting for communications stack shutdown"); self.base_node_comms.wait_until_shutdown().await; info!(target: LOG_TARGET, "Communications stack has shutdown"); @@ -222,7 +223,11 @@ async fn build_node_context( let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories.clone()), + OrphanBlockValidator::new( + rules.clone(), + config.base_node_bypass_range_proof_verification, + factories.clone(), + ), ); let db_config = BlockchainDatabaseConfig { orphan_storage_capacity: config.orphan_storage_capacity, @@ -238,7 +243,10 @@ async fn build_node_context( cleanup_orphans_at_startup, )?; let mempool_validator = MempoolValidator::new(vec![ - Box::new(TxInternalConsistencyValidator::new(factories.clone())), + Box::new(TxInternalConsistencyValidator::new( + factories.clone(), + config.base_node_bypass_range_proof_verification, + )), Box::new(TxInputAndMaturityValidator::new(blockchain_db.clone())), Box::new(TxConsensusValidator::new(blockchain_db.clone())), ]); diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 114cee7ad2..d626c4fa0a 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -34,6 +34,10 @@ use std::{ }; use tari_app_utilities::consts; use tari_common::GlobalConfig; +use tari_common_types::{ + emoji::EmojiId, + types::{Commitment, HashOutput, Signature}, +}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer, PeerFeatures, PeerManager, PeerManagerError, PeerQuery}, @@ -53,11 +57,9 @@ use tari_core::{ mempool::service::LocalMempoolService, proof_of_work::PowAlgorithm, tari_utilities::{hex::Hex, message_format::MessageFormat}, - transactions::types::{Commitment, HashOutput, Signature}, }; use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::Hashable}; use tari_p2p::auto_update::SoftwareUpdaterHandle; -use tari_wallet::util::emoji::EmojiId; use tokio::{runtime, sync::watch}; pub enum StatusOutput { @@ -101,7 +103,7 @@ impl CommandHandler { } pub fn status(&self, output: StatusOutput) { - let mut state_info = self.state_machine_info.clone(); + let state_info = self.state_machine_info.clone(); let mut node = self.node_service.clone(); let mut mempool = self.mempool_service.clone(); let peer_manager = self.peer_manager.clone(); @@ -114,9 +116,9 @@ impl CommandHandler { let mut status_line = StatusLine::new(); let version = format!("v{}", consts::APP_VERSION_NUMBER); status_line.add_field("", version); - - let state = state_info.recv().await.unwrap(); - status_line.add_field("State", state.state_info.short_desc()); + let network = format!("{}", config.network); + status_line.add_field("", network); + status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); @@ -189,18 +191,8 @@ impl CommandHandler { /// Function to process the get-state-info command pub fn state_info(&self) { - let mut channel = self.state_machine_info.clone(); - self.executor.spawn(async move { - match channel.recv().await { - None => { - info!( - target: LOG_TARGET, - "Error communicating with state machine, channel could have been closed" - ); - }, - Some(data) => println!("Current state machine state:\n{}", data), - }; - }); + let watch = self.state_machine_info.clone(); + println!("Current state machine state:\n{}", *watch.borrow()); } /// Check for updates diff --git a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs index 00474a5cb4..e3a19ef185 100644 --- a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs +++ b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs @@ -26,6 +26,7 @@ use crate::{ helpers::{mean, median}, }, }; +use futures::{channel::mpsc, SinkExt}; use log::*; use std::{ cmp, @@ -36,11 +37,11 @@ use tari_app_grpc::{ tari_rpc::{CalcType, Sorting}, }; use tari_app_utilities::consts; +use tari_common_types::types::Signature; use tari_comms::{Bytes, CommsNode}; use tari_core::{ base_node::{ comms_interface::{Broadcast, CommsInterfaceError}, - state_machine_service::states::BlockSyncInfo, LocalNodeCommsInterface, StateMachineHandle, }, @@ -50,11 +51,11 @@ use tari_core::{ crypto::tari_utilities::{hex::Hex, ByteArray}, mempool::{service::LocalMempoolService, TxStorageResponse}, proof_of_work::PowAlgorithm, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use tari_crypto::tari_utilities::{message_format::MessageFormat, Hashable}; use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "tari::base_node::grpc"; @@ -995,32 +996,25 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer { ) -> Result, Status> { debug!(target: LOG_TARGET, "Incoming GRPC request for BN sync data"); - let mut channel = self.state_machine_handle.get_status_info_watch(); - - let mut sync_info: Option = None; - - if let Some(info) = channel.recv().await { - sync_info = info.state_info.get_block_sync_info(); - } - - let mut response = tari_rpc::SyncInfoResponse { - tip_height: 0, - local_height: 0, - peer_node_id: vec![], - }; - - if let Some(info) = sync_info { - let node_ids = info - .sync_peers - .iter() - .map(|x| x.to_string().as_bytes().to_vec()) - .collect(); - response = tari_rpc::SyncInfoResponse { - tip_height: info.tip_height, - local_height: info.local_height, - peer_node_id: node_ids, - }; - } + let response = self + .state_machine_handle + .get_status_info_watch() + .borrow() + .state_info + .get_block_sync_info() + .map(|info| { + let node_ids = info + .sync_peers + .iter() + .map(|x| x.to_string().as_bytes().to_vec()) + .collect(); + tari_rpc::SyncInfoResponse { + tip_height: info.tip_height, + local_height: info.local_height, + peer_node_id: node_ids, + } + }) + .unwrap_or_default(); debug!(target: LOG_TARGET, "Sending SyncData response to client"); Ok(Response::new(response)) diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index 8ef2c9b68a..cb9daf1904 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -96,7 +96,7 @@ mod status_line; mod utils; use crate::command_handler::{CommandHandler, StatusOutput}; -use futures::{future::Fuse, pin_mut, FutureExt}; +use futures::{pin_mut, FutureExt}; use log::*; use opentelemetry::{self, global, KeyValue}; use parser::Parser; @@ -119,7 +119,7 @@ use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ runtime, task, - time::{self, Delay}, + time::{self}, }; use tonic::transport::Server; use tracing_subscriber::{layer::SubscriberExt, Registry}; @@ -145,7 +145,7 @@ fn main_inner() -> Result<(), ExitCodes> { debug!(target: LOG_TARGET, "Using configuration: {:?}", node_config); // Set up the Tokio runtime - let mut rt = setup_runtime(&node_config).map_err(|e| { + let rt = setup_runtime(&node_config).map_err(|e| { error!(target: LOG_TARGET, "{}", e); ExitCodes::UnknownError })?; @@ -320,26 +320,28 @@ async fn read_command(mut rustyline: Editor) -> Result<(String, Editor

Fuse { +fn status_interval(start_time: Instant) -> time::Sleep { let duration = match start_time.elapsed().as_secs() { 0..=120 => Duration::from_secs(5), _ => Duration::from_secs(30), }; - time::delay_for(duration).fuse() + time::sleep(duration) } async fn status_loop(command_handler: Arc, shutdown: Shutdown) { let start_time = Instant::now(); let mut shutdown_signal = shutdown.to_signal(); loop { - let mut interval = status_interval(start_time); - futures::select! { + let interval = status_interval(start_time); + tokio::select! { + biased; + _ = shutdown_signal.wait() => { + break; + } + _ = interval => { command_handler.status(StatusOutput::Log); }, - _ = shutdown_signal => { - break; - } } } } @@ -368,9 +370,9 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { let start_time = Instant::now(); let mut software_update_notif = command_handler.get_software_updater().new_update_notifier().clone(); loop { - let mut interval = status_interval(start_time); - futures::select! { - res = read_command_fut => { + let interval = status_interval(start_time); + tokio::select! { + res = &mut read_command_fut => { match res { Ok((line, mut rustyline)) => { if let Some(p) = rustyline.helper_mut().as_deref_mut() { @@ -387,8 +389,8 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { } } }, - resp = software_update_notif.recv().fuse() => { - if let Some(Some(update)) = resp { + Ok(_) = software_update_notif.changed() => { + if let Some(ref update) = *software_update_notif.borrow() { println!( "Version {} of the {} is available: {} (sha: {})", update.version(), @@ -401,7 +403,7 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { _ = interval => { command_handler.status(StatusOutput::Full); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { break; } } diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index f280019ee0..79fdf27efc 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -40,12 +40,8 @@ use tari_app_utilities::utilities::{ parse_emoji_id_or_public_key, parse_emoji_id_or_public_key_or_node_id, }; -use tari_core::{ - crypto::tari_utilities::hex::from_hex, - proof_of_work::PowAlgorithm, - tari_utilities::hex::Hex, - transactions::types::{Commitment, PrivateKey, PublicKey, Signature}, -}; +use tari_common_types::types::{Commitment, PrivateKey, PublicKey, Signature}; +use tari_core::{crypto::tari_utilities::hex::from_hex, proof_of_work::PowAlgorithm, tari_utilities::hex::Hex}; use tari_shutdown::Shutdown; /// Enum representing commands used by the basenode diff --git a/applications/tari_base_node/src/recovery.rs b/applications/tari_base_node/src/recovery.rs index ce62fe36a3..8e0e0786c7 100644 --- a/applications/tari_base_node/src/recovery.rs +++ b/applications/tari_base_node/src/recovery.rs @@ -21,14 +21,16 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use anyhow::anyhow; -use log::*; use std::{ fs, io::{self, Write}, path::Path, sync::Arc, }; + +use anyhow::anyhow; +use log::*; + use tari_app_utilities::utilities::ExitCodes; use tari_common::{configuration::Network, DatabaseType, GlobalConfig}; use tari_core::{ @@ -43,7 +45,7 @@ use tari_core::{ }, consensus::ConsensusManager, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, header_validator::HeaderValidator, @@ -98,7 +100,11 @@ pub async fn run_recovery(node_config: &GlobalConfig) -> Result<(), anyhow::Erro let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories.clone()), + OrphanBlockValidator::new( + rules.clone(), + node_config.base_node_bypass_range_proof_verification, + factories.clone(), + ), ); let db_config = BlockchainDatabaseConfig { orphan_storage_capacity: node_config.orphan_storage_capacity, @@ -173,12 +179,12 @@ async fn do_recovery( db.add_block(Arc::new(block)) .await .map_err(|e| anyhow!("Stopped recovery at height {}, reason: {}", counter, e))?; - counter += 1; - if counter > max_height { - info!(target: LOG_TARGET, "Done with recovery, chain height {}", counter - 1); + if counter >= max_height { + info!(target: LOG_TARGET, "Done with recovery, chain height {}", counter); break; } - print!("\x1B[{}D\x1B[K", (counter + 1).to_string().chars().count()); + print!("\x1B[{}D\x1B[K", counter.to_string().len()); + counter += 1; } Ok(()) } diff --git a/applications/tari_console_wallet/Cargo.toml b/applications/tari_console_wallet/Cargo.toml index 5163910d37..17bddfe336 100644 --- a/applications/tari_console_wallet/Cargo.toml +++ b/applications/tari_console_wallet/Cargo.toml @@ -5,21 +5,22 @@ authors = ["The Tari Development Community"] edition = "2018" [dependencies] -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", features=["bundled_sqlite"] } tari_crypto = "0.11.1" tari_common = { path = "../../common" } -tari_app_utilities = { path = "../tari_app_utilities"} +tari_app_utilities = { path = "../tari_app_utilities", features = ["wallet"]} tari_comms = { path = "../../comms"} tari_comms_dht = { path = "../../comms/dht"} +tari_common_types = {path = "../../base_layer/common_types"} tari_p2p = { path = "../../base_layer/p2p" } -tari_app_grpc = { path = "../tari_app_grpc" } +tari_app_grpc = { path = "../tari_app_grpc", features = ["wallet"] } tari_shutdown = { path = "../../infrastructure/shutdown" } tari_key_manager = { path = "../../base_layer/key_manager" } bitflags = "1.2.1" chrono = { version = "0.4.6", features = ["serde"]} chrono-english = "0.1" -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} crossterm = { version = "0.17"} rand = "0.8" unicode-width = "0.1" @@ -31,9 +32,9 @@ rpassword = "5.0" rustyline = "6.0" strum = "^0.19" strum_macros = "^0.19" -tokio = { version="0.2.10", features = ["signal"] } -thiserror = "1.0.20" -tonic = "0.2" +tokio = { version="^1.10", features = ["signal"] } +thiserror = "1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" @@ -43,7 +44,6 @@ tracing-subscriber = "0.2.20" opentelemetry = { version = "0.16", default-features = false, features = ["trace","rt-tokio"] } opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} - [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" diff --git a/applications/tari_console_wallet/src/automation/command_parser.rs b/applications/tari_console_wallet/src/automation/command_parser.rs index 9555ea7fba..90f0347918 100644 --- a/applications/tari_console_wallet/src/automation/command_parser.rs +++ b/applications/tari_console_wallet/src/automation/command_parser.rs @@ -32,7 +32,8 @@ use std::{ use tari_app_utilities::utilities::parse_emoji_id_or_public_key; use tari_comms::multiaddr::Multiaddr; -use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; +use tari_common_types::types::PublicKey; +use tari_core::transactions::tari_amount::MicroTari; #[derive(Debug)] pub struct ParsedCommand { @@ -348,7 +349,8 @@ mod test { }; use rand::rngs::OsRng; use std::str::FromStr; - use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; + use tari_common_types::types::PublicKey; + use tari_core::transactions::tari_amount::MicroTari; use tari_crypto::keys::PublicKey as PublicKeyTrait; #[test] diff --git a/applications/tari_console_wallet/src/automation/commands.rs b/applications/tari_console_wallet/src/automation/commands.rs index 608cc8a675..7bbdf8da44 100644 --- a/applications/tari_console_wallet/src/automation/commands.rs +++ b/applications/tari_console_wallet/src/automation/commands.rs @@ -21,12 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::error::CommandError; -use crate::{ - automation::command_parser::{ParsedArgument, ParsedCommand}, - utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, -}; -use chrono::{DateTime, Utc}; -use futures::{FutureExt, StreamExt}; use log::*; use std::{ fs::File, @@ -34,8 +28,18 @@ use std::{ str::FromStr, time::{Duration, Instant}, }; + +use chrono::{DateTime, Utc}; +use futures::FutureExt; use strum_macros::{Display, EnumIter, EnumString}; +use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; + +use crate::{ + automation::command_parser::{ParsedArgument, ParsedCommand}, + utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, +}; use tari_common::GlobalConfig; +use tari_common_types::{emoji::EmojiId, types::PublicKey}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityRequester}, multiaddr::Multiaddr, @@ -47,19 +51,16 @@ use tari_core::{ transactions::{ tari_amount::{uT, MicroTari, Tari}, transaction::UnblindedOutput, - types::PublicKey, }, }; -use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; use tari_wallet::{ output_manager_service::{handle::OutputManagerHandle, TxId}, transaction_service::handle::{TransactionEvent, TransactionServiceHandle}, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ - sync::mpsc, - time::{delay_for, timeout}, + sync::{broadcast, mpsc}, + time::{sleep, timeout}, }; pub const LOG_TARGET: &str = "wallet::automation::commands"; @@ -175,21 +176,22 @@ pub async fn coin_split( Ok(tx_id) } -async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result { - let mut connectivity = connectivity_requester.get_event_subscription().fuse(); +async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result<(), CommandError> { + let mut connectivity = connectivity_requester.get_event_subscription(); print!("Waiting for connectivity... "); - let mut timeout = delay_for(Duration::from_secs(30)).fuse(); + let timeout = sleep(Duration::from_secs(30)); + tokio::pin!(timeout); + let mut timeout = timeout.fuse(); loop { - futures::select! { - result = connectivity.select_next_some() => { - if let Ok(msg) = result { - if let ConnectivityEvent::PeerConnected(_) = (*msg).clone() { - println!("✅"); - return Ok(true); - } + tokio::select! { + // Wait for the first base node connection + Ok(ConnectivityEvent::PeerConnected(conn)) = connectivity.recv() => { + if conn.peer_features().is_node() { + println!("✅"); + return Ok(()); } }, - () = timeout => { + () = &mut timeout => { println!("❌"); return Err(CommandError::Comms("Timed out".to_string())); } @@ -311,7 +313,7 @@ pub async fn make_it_rain( target: LOG_TARGET, "make-it-rain delaying for {:?} ms - scheduled to start at {}", delay_ms, start_time ); - delay_for(Duration::from_millis(delay_ms)).await; + sleep(Duration::from_millis(delay_ms)).await; let num_txs = (txps * duration as f64) as usize; let started_at = Utc::now(); @@ -352,10 +354,10 @@ pub async fn make_it_rain( let target_ms = (i as f64 / (txps / 1000.0)) as i64; if target_ms - actual_ms > 0 { // Maximum delay between Txs set to 120 s - delay_for(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; + sleep(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; } let delayed_for = Instant::now(); - let mut sender_clone = sender.clone(); + let sender_clone = sender.clone(); tokio::task::spawn(async move { let spawn_start = Instant::now(); // Send transaction @@ -432,7 +434,7 @@ pub async fn monitor_transactions( tx_ids: Vec, wait_stage: TransactionStage, ) -> Vec { - let mut event_stream = transaction_service.get_event_stream_fused(); + let mut event_stream = transaction_service.get_event_stream(); let mut results = Vec::new(); debug!(target: LOG_TARGET, "monitor transactions wait_stage: {:?}", wait_stage); println!( @@ -442,104 +444,102 @@ pub async fn monitor_transactions( ); loop { - match event_stream.next().await { - Some(event_result) => match event_result { - Ok(event) => match &*event { - TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx direct send event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } - } - }, - TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx store and forward event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } + match event_stream.recv().await { + Ok(event) => match &*event { + TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx direct send event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); - if wait_stage == TransactionStage::Negotiated { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Negotiated, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx store and forward event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); - if wait_stage == TransactionStage::Broadcast { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Broadcast, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); + if wait_stage == TransactionStage::Negotiated { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Negotiated, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations - ); - if wait_stage == TransactionStage::MinedUnconfirmed { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::MinedUnconfirmed, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); + if wait_stage == TransactionStage::Broadcast { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Broadcast, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); - if wait_stage == TransactionStage::Mined { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Mined, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations + ); + if wait_stage == TransactionStage::MinedUnconfirmed { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::MinedUnconfirmed, + }); + if results.len() == tx_ids.len() { + break; } - }, - _ => {}, + } }, - Err(e) => { - eprintln!("RecvError in monitor_transactions: {:?}", e); - break; + TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); + if wait_stage == TransactionStage::Mined { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Mined, + }); + if results.len() == tx_ids.len() { + break; + } + } }, + _ => {}, }, - None => { - warn!( + // All event senders have gone (i.e. we take it that the node is shutting down) + Err(broadcast::error::RecvError::Closed) => { + debug!( target: LOG_TARGET, - "`None` result in event in monitor_transactions loop" + "All Transaction event senders have gone. Exiting `monitor_transactions` loop." ); break; }, + Err(err) => { + warn!(target: LOG_TARGET, "monitor_transactions: {}", err); + }, } } @@ -578,7 +578,8 @@ pub async fn command_runner( }, DiscoverPeer => { if !online { - online = wait_for_comms(&connectivity_requester).await?; + wait_for_comms(&connectivity_requester).await?; + online = true; } discover_peer(dht_service.clone(), parsed.args).await? }, diff --git a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs index b7e53ae6a2..b0c7779431 100644 --- a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -1,4 +1,4 @@ -use futures::future; +use futures::{channel::mpsc, future, SinkExt}; use log::*; use std::convert::TryFrom; use tari_app_grpc::{ @@ -31,17 +31,18 @@ use tari_app_grpc::{ TransferResult, }, }; +use tari_common_types::types::Signature; use tari_comms::{types::CommsPublicKey, CommsNode}; use tari_core::{ tari_utilities::{hex::Hex, ByteArray}, - transactions::{tari_amount::MicroTari, transaction::UnblindedOutput, types::Signature}, + transactions::{tari_amount::MicroTari, transaction::UnblindedOutput}, }; use tari_wallet::{ output_manager_service::handle::OutputManagerHandle, transaction_service::{handle::TransactionServiceHandle, storage::models}, WalletSqlite, }; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "wallet::ui::grpc"; diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index b5c0c9d805..b877c6b729 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -20,23 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - utils::db::get_custom_base_node_peer_from_db, - wallet_modes::{PeerConfig, WalletMode}, -}; +use std::{fs, path::PathBuf, str::FromStr, sync::Arc}; + use log::*; use rpassword::prompt_password_stdout; use rustyline::Editor; -use std::{fs, path::PathBuf, str::FromStr, sync::Arc}; + use tari_app_utilities::utilities::{create_transport_type, ExitCodes}; use tari_common::{ConfigBootstrap, GlobalConfig}; +use tari_common_types::types::PrivateKey; use tari_comms::{ peer_manager::{Peer, PeerFeatures}, types::CommsSecretKey, NodeIdentity, }; use tari_comms_dht::{DbConnectionUrl, DhtConfig}; -use tari_core::transactions::types::{CryptoFactories, PrivateKey}; +use tari_core::transactions::CryptoFactories; use tari_p2p::{ initialization::CommsConfig, peer_seeds::SeedPeer, @@ -59,6 +58,11 @@ use tari_wallet::{ WalletSqlite, }; +use crate::{ + utils::db::get_custom_base_node_peer_from_db, + wallet_modes::{PeerConfig, WalletMode}, +}; + pub const LOG_TARGET: &str = "wallet::console_wallet::init"; /// The minimum buffer size for a tari application pubsub_connector channel const BASE_NODE_BUFFER_MIN_SIZE: usize = 30; @@ -128,9 +132,15 @@ pub async fn change_password( return Err(ExitCodes::InputError("Passwords don't match!".to_string())); } - wallet.remove_encryption().await?; + wallet + .remove_encryption() + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; - wallet.apply_encryption(passphrase).await?; + wallet + .apply_encryption(passphrase) + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; println!("Wallet password changed successfully."); diff --git a/applications/tari_console_wallet/src/main.rs b/applications/tari_console_wallet/src/main.rs index 4042eb27eb..23b917fb16 100644 --- a/applications/tari_console_wallet/src/main.rs +++ b/applications/tari_console_wallet/src/main.rs @@ -24,7 +24,7 @@ use recovery::prompt_private_key_from_seed_words; use std::{env, process}; use tari_app_utilities::{consts, initialization::init_configuration, utilities::ExitCodes}; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap}; -use tari_core::transactions::types::PrivateKey; +use tari_common_types::types::PrivateKey; use tari_shutdown::Shutdown; use tracing_subscriber::{layer::SubscriberExt, Registry}; use wallet_modes::{command_mode, grpc_mode, recovery_mode, script_mode, tui_mode, WalletMode}; @@ -58,8 +58,7 @@ fn main() { } fn main_inner() -> Result<(), ExitCodes> { - let mut runtime = tokio::runtime::Builder::new() - .threaded_scheduler() + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Failed to build a runtime!"); @@ -156,11 +155,8 @@ fn main_inner() -> Result<(), ExitCodes> { }; print!("\nShutting down wallet... "); - if shutdown.trigger().is_ok() { - runtime.block_on(wallet.wait_until_shutdown()); - } else { - error!(target: LOG_TARGET, "No listeners for the shutdown signal!"); - } + shutdown.trigger(); + runtime.block_on(wallet.wait_until_shutdown()); println!("Done."); result diff --git a/applications/tari_console_wallet/src/recovery.rs b/applications/tari_console_wallet/src/recovery.rs index 887995f0a5..297b84e983 100644 --- a/applications/tari_console_wallet/src/recovery.rs +++ b/applications/tari_console_wallet/src/recovery.rs @@ -21,11 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use chrono::offset::Local; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use rustyline::Editor; use tari_app_utilities::utilities::ExitCodes; -use tari_core::transactions::types::PrivateKey; +use tari_common_types::types::PrivateKey; use tari_key_manager::mnemonic::to_secretkey; use tari_shutdown::Shutdown; use tari_wallet::{ @@ -35,6 +35,7 @@ use tari_wallet::{ }; use crate::wallet_modes::PeerConfig; +use tokio::sync::broadcast; pub const LOG_TARGET: &str = "wallet::recovery"; @@ -97,13 +98,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi .with_retry_limit(10) .build_with_wallet(wallet, shutdown_signal); - let mut event_stream = recovery_task.get_event_receiver().fuse(); + let mut event_stream = recovery_task.get_event_receiver(); let recovery_join_handle = tokio::spawn(recovery_task.run()).fuse(); // Read recovery task events. The event stream will end once recovery has completed. - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { print!("Connecting to base node {}... ", peer); }, @@ -170,11 +171,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi info!(target: LOG_TARGET, "{}", stats); println!("{}", stats); }, - Err(e) => { - // Can occur if we read events too slowly (lagging/slow subscriber) + Err(e @ broadcast::error::RecvError::Lagged(_)) => { debug!(target: LOG_TARGET, "Error receiving Wallet recovery events: {}", e); continue; }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Ok(UtxoScannerEvent::ScanningFailed) => { error!(target: LOG_TARGET, "Wallet Recovery process failed and is exiting"); }, diff --git a/applications/tari_console_wallet/src/ui/components/base_node.rs b/applications/tari_console_wallet/src/ui/components/base_node.rs index d9a271e291..ade233b90c 100644 --- a/applications/tari_console_wallet/src/ui/components/base_node.rs +++ b/applications/tari_console_wallet/src/ui/components/base_node.rs @@ -42,9 +42,9 @@ impl BaseNode { impl Component for BaseNode { fn draw(&mut self, f: &mut Frame, area: Rect, app_state: &AppState) where B: Backend { - let base_node_state = app_state.get_base_node_state(); + let current_online_status = app_state.get_wallet_connectivity().get_connectivity_status(); - let chain_info = match base_node_state.online { + let chain_info = match current_online_status { OnlineStatus::Connecting => Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), @@ -56,7 +56,8 @@ impl Component for BaseNode { Span::styled("Offline", Style::default().fg(Color::Red)), ]), OnlineStatus::Online => { - if let Some(metadata) = base_node_state.clone().chain_metadata { + let base_node_state = app_state.get_base_node_state(); + if let Some(ref metadata) = base_node_state.chain_metadata { let tip = metadata.height_of_longest_chain(); let synced = base_node_state.is_synced.unwrap_or_default(); @@ -92,7 +93,7 @@ impl Component for BaseNode { Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), - Span::styled("Error", Style::default().fg(Color::Red)), + Span::styled("Waiting for data...", Style::default().fg(Color::DarkGray)), ]) } }, diff --git a/applications/tari_console_wallet/src/ui/state/app_state.rs b/applications/tari_console_wallet/src/ui/state/app_state.rs index 6d9293ef47..b4a45df7b4 100644 --- a/applications/tari_console_wallet/src/ui/state/app_state.rs +++ b/applications/tari_console_wallet/src/ui/state/app_state.rs @@ -20,29 +20,23 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - notifier::Notifier, - ui::{ - state::{ - tasks::{send_one_sided_transaction_task, send_transaction_task}, - wallet_event_monitor::WalletEventMonitor, - }, - UiContact, - UiError, - }, - utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, - wallet_modes::PeerConfig, -}; -use bitflags::bitflags; -use futures::{stream::Fuse, StreamExt}; -use log::*; -use qrcode::{render::unicode, QrCode}; use std::{ collections::HashMap, sync::Arc, time::{Duration, Instant}, }; + +use bitflags::bitflags; +use log::*; +use qrcode::{render::unicode, QrCode}; +use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex}; +use tokio::{ + sync::{watch, RwLock}, + task, +}; + use tari_common::{configuration::Network, GlobalConfig}; +use tari_common_types::{emoji::EmojiId, types::PublicKey}; use tari_comms::{ connectivity::ConnectivityEventRx, multiaddr::Multiaddr, @@ -50,11 +44,7 @@ use tari_comms::{ types::CommsPublicKey, NodeIdentity, }; -use tari_core::transactions::{ - tari_amount::{uT, MicroTari}, - types::PublicKey, -}; -use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex}; +use tari_core::transactions::tari_amount::{uT, MicroTari}; use tari_shutdown::ShutdownSignal; use tari_wallet::{ base_node_service::{handle::BaseNodeEventReceiver, service::BaseNodeState}, @@ -66,12 +56,21 @@ use tari_wallet::{ storage::models::{CompletedTransaction, TransactionStatus}, }, types::ValidationRetryStrategy, - util::emoji::EmojiId, WalletSqlite, }; -use tokio::{ - sync::{watch, RwLock}, - task, + +use crate::{ + notifier::Notifier, + ui::{ + state::{ + tasks::{send_one_sided_transaction_task, send_transaction_task}, + wallet_event_monitor::WalletEventMonitor, + }, + UiContact, + UiError, + }, + utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, + wallet_modes::PeerConfig, }; const LOG_TARGET: &str = "wallet::console_wallet::app_state"; @@ -84,6 +83,7 @@ pub struct AppState { completed_tx_filter: TransactionFilter, node_config: GlobalConfig, config: AppStateConfig, + wallet_connectivity: WalletConnectivityHandle, } impl AppState { @@ -95,6 +95,7 @@ impl AppState { base_node_config: PeerConfig, node_config: GlobalConfig, ) -> Self { + let wallet_connectivity = wallet.wallet_connectivity.clone(); let inner = AppStateInner::new(node_identity, network, wallet, base_node_selected, base_node_config); let cached_data = inner.data.clone(); @@ -105,6 +106,7 @@ impl AppState { completed_tx_filter: TransactionFilter::ABANDONED_COINBASES, node_config, config: AppStateConfig::default(), + wallet_connectivity, } } @@ -352,6 +354,10 @@ impl AppState { &self.cached_data.base_node_state } + pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { + self.wallet_connectivity.clone() + } + pub fn get_selected_base_node(&self) -> &Peer { &self.cached_data.base_node_selected } @@ -641,24 +647,24 @@ impl AppStateInner { self.wallet.comms.shutdown_signal() } - pub fn get_transaction_service_event_stream(&self) -> Fuse { - self.wallet.transaction_service.get_event_stream_fused() + pub fn get_transaction_service_event_stream(&self) -> TransactionEventReceiver { + self.wallet.transaction_service.get_event_stream() } - pub fn get_output_manager_service_event_stream(&self) -> Fuse { - self.wallet.output_manager_service.get_event_stream_fused() + pub fn get_output_manager_service_event_stream(&self) -> OutputManagerEventReceiver { + self.wallet.output_manager_service.get_event_stream() } - pub fn get_connectivity_event_stream(&self) -> Fuse { - self.wallet.comms.connectivity().get_event_subscription().fuse() + pub fn get_connectivity_event_stream(&self) -> ConnectivityEventRx { + self.wallet.comms.connectivity().get_event_subscription() } pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { self.wallet.wallet_connectivity.clone() } - pub fn get_base_node_event_stream(&self) -> Fuse { - self.wallet.base_node_service.clone().get_event_stream_fused() + pub fn get_base_node_event_stream(&self) -> BaseNodeEventReceiver { + self.wallet.base_node_service.get_event_stream() } pub async fn set_base_node_peer(&mut self, peer: Peer) -> Result<(), UiError> { diff --git a/applications/tari_console_wallet/src/ui/state/tasks.rs b/applications/tari_console_wallet/src/ui/state/tasks.rs index caf8073f56..85243660e6 100644 --- a/applications/tari_console_wallet/src/ui/state/tasks.rs +++ b/applications/tari_console_wallet/src/ui/state/tasks.rs @@ -21,11 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::ui::{state::UiTransactionSendStatus, UiError}; -use futures::StreamExt; use tari_comms::types::CommsPublicKey; use tari_core::transactions::tari_amount::MicroTari; use tari_wallet::transaction_service::handle::{TransactionEvent, TransactionServiceHandle}; -use tokio::sync::watch; +use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "wallet::console_wallet::tasks "; @@ -37,8 +36,8 @@ pub async fn send_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); let mut send_direct_received_result = (false, false); let mut send_saf_received_result = (false, false); match transaction_service_handle @@ -46,15 +45,15 @@ pub async fn send_transaction_task( .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => match &*event { TransactionEvent::TransactionDiscoveryInProgress(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::DiscoveryInProgress); + let _ = result_tx.send(UiTransactionSendStatus::DiscoveryInProgress); } }, TransactionEvent::TransactionDirectSendResult(tx_id, result) => { @@ -75,25 +74,28 @@ pub async fn send_transaction_task( }, TransactionEvent::TransactionCompletedImmediately(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } }, _ => (), }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } if send_direct_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentDirect); + let _ = result_tx.send(UiTransactionSendStatus::SentDirect); } else if send_saf_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentViaSaf); + let _ = result_tx.send(UiTransactionSendStatus::SentViaSaf); } else { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "Transaction could not be sent".to_string(), )); } @@ -109,34 +111,37 @@ pub async fn send_one_sided_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); match transaction_service_handle .send_one_sided_transaction(public_key, amount, fee_per_gram, message) .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => { if let TransactionEvent::TransactionCompletedImmediately(tx_id) = &*event { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } } }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "One-sided transaction could not be sent".to_string(), )); }, diff --git a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs index 2e20999667..e7df30b653 100644 --- a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs +++ b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{notifier::Notifier, ui::state::AppStateInner}; -use futures::stream::StreamExt; use log::*; use std::sync::Arc; use tari_comms::{connectivity::ConnectivityEvent, peer_manager::Peer}; @@ -30,7 +29,7 @@ use tari_wallet::{ output_manager_service::{handle::OutputManagerEvent, TxId}, transaction_service::handle::TransactionEvent, }; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, RwLock}; const LOG_TARGET: &str = "wallet::console_wallet::wallet_event_monitor"; @@ -55,14 +54,14 @@ impl WalletEventMonitor { let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream(); let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity(); - let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse(); + let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch(); let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream(); info!(target: LOG_TARGET, "Wallet Event Monitor starting"); loop { - futures::select! { - result = transaction_service_events.select_next_some() => { + tokio::select! { + result = transaction_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet transaction service event {:?}", msg); @@ -104,18 +103,21 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Transaction events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - status = connectivity_status.select_next_some() => { - trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status); + Ok(_) = connectivity_status.changed() => { + trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status changed"); self.trigger_peer_state_refresh().await; }, - result = connectivity_events.select_next_some() => { + result = connectivity_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg); - match &*msg { + match msg { ConnectivityEvent::PeerDisconnected(_) | ConnectivityEvent::ManagedPeerDisconnected(_) | ConnectivityEvent::PeerConnected(_) => { @@ -125,10 +127,13 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Connectivity event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Connectivity events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = base_node_events.select_next_some() => { + result = base_node_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received base node event {:?}", msg); @@ -141,10 +146,13 @@ impl WalletEventMonitor { } } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on base node event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Base node Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = output_manager_service_events.select_next_some() => { + result = output_manager_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -152,14 +160,13 @@ impl WalletEventMonitor { self.trigger_balance_refresh().await; } }, - Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Output Manager Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } - }, - complete => { - info!(target: LOG_TARGET, "Wallet Event Monitor is exiting because all tasks have completed"); - break; }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Wallet Event Monitor shutting down because the shutdown signal was received"); break; }, diff --git a/applications/tari_console_wallet/src/ui/ui_contact.rs b/applications/tari_console_wallet/src/ui/ui_contact.rs index 2d8be4a182..49f83e8284 100644 --- a/applications/tari_console_wallet/src/ui/ui_contact.rs +++ b/applications/tari_console_wallet/src/ui/ui_contact.rs @@ -1,4 +1,5 @@ -use tari_wallet::{contacts_service::storage::database::Contact, util::emoji::EmojiId}; +use tari_common_types::emoji::EmojiId; +use tari_wallet::contacts_service::storage::database::Contact; #[derive(Debug, Clone)] pub struct UiContact { diff --git a/applications/tari_console_wallet/src/utils/db.rs b/applications/tari_console_wallet/src/utils/db.rs index 9dbf43cfd2..aee50c4e40 100644 --- a/applications/tari_console_wallet/src/utils/db.rs +++ b/applications/tari_console_wallet/src/utils/db.rs @@ -21,11 +21,12 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use log::*; +use tari_common_types::types::PublicKey; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, }; -use tari_core::transactions::types::PublicKey; + use tari_crypto::tari_utilities::hex::Hex; use tari_wallet::WalletSqlite; diff --git a/applications/tari_console_wallet/src/wallet_modes.rs b/applications/tari_console_wallet/src/wallet_modes.rs index a205523a00..23b1ee6330 100644 --- a/applications/tari_console_wallet/src/wallet_modes.rs +++ b/applications/tari_console_wallet/src/wallet_modes.rs @@ -239,7 +239,10 @@ pub fn tui_mode(config: WalletModeConfig, mut wallet: WalletSqlite) -> Result<() info!(target: LOG_TARGET, "Starting app"); - handle.enter(|| ui::run(app))?; + { + let _enter = handle.enter(); + ui::run(app)?; + } info!( target: LOG_TARGET, diff --git a/applications/tari_merge_mining_proxy/Cargo.toml b/applications/tari_merge_mining_proxy/Cargo.toml index 626d0502b6..2cd31b8b3d 100644 --- a/applications/tari_merge_mining_proxy/Cargo.toml +++ b/applications/tari_merge_mining_proxy/Cargo.toml @@ -13,33 +13,32 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} -tari_app_utilities = { path = "../tari_app_utilities"} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } +tari_app_utilities = { path = "../tari_app_utilities" } tari_crypto = "0.11.1" tari_utilities = "^0.3" anyhow = "1.0.40" bincode = "1.3.1" -bytes = "0.5.6" +bytes = "1.1" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11.4", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" diff --git a/applications/tari_merge_mining_proxy/src/main.rs b/applications/tari_merge_mining_proxy/src/main.rs index 7e15777977..0df27490bb 100644 --- a/applications/tari_merge_mining_proxy/src/main.rs +++ b/applications/tari_merge_mining_proxy/src/main.rs @@ -46,7 +46,7 @@ use tari_app_utilities::initialization::init_configuration; use tari_common::configuration::bootstrap::ApplicationType; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), anyhow::Error> { let (_, config, _) = init_configuration(ApplicationType::MergeMiningProxy)?; diff --git a/applications/tari_mining_node/Cargo.toml b/applications/tari_mining_node/Cargo.toml index a04938c048..f893c0bc06 100644 --- a/applications/tari_mining_node/Cargo.toml +++ b/applications/tari_mining_node/Cargo.toml @@ -17,15 +17,15 @@ crossbeam = "0.8" futures = "0.3" log = { version = "0.4", features = ["std"] } num_cpus = "1.13" -prost-types = "0.6" +prost-types = "0.8" rand = "0.8" sha3 = "0.9" serde = { version = "1.0", default_features = false, features = ["derive"] } -tonic = { version = "0.2", features = ["transport"] } -tokio = { version = "0.2", default_features = false, features = ["rt-core"] } +tonic = { version = "0.5.2", features = ["transport"] } +tokio = { version = "1.10", default_features = false, features = ["rt-multi-thread"] } thiserror = "1.0" jsonrpc = "0.11.0" -reqwest = { version = "0.11", features = ["blocking", "json"] } +reqwest = { version = "0.11", features = [ "json"] } serde_json = "1.0.57" native-tls = "0.2" bufstream = "0.1" @@ -35,5 +35,5 @@ hex = "0.4.2" [dev-dependencies] tari_crypto = "0.11.1" -prost-types = "0.6.1" +prost-types = "0.8" chrono = "0.4" diff --git a/applications/tari_mining_node/src/main.rs b/applications/tari_mining_node/src/main.rs index b1f538c1e4..0a0bf31557 100644 --- a/applications/tari_mining_node/src/main.rs +++ b/applications/tari_mining_node/src/main.rs @@ -23,13 +23,6 @@ use config::MinerConfig; use futures::stream::StreamExt; use log::*; -use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; -use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; -use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; -use tari_core::blocks::BlockHeader; -use tokio::{runtime::Runtime, time::delay_for}; -use tonic::transport::Channel; -use utils::{coinbase_request, extract_outputs_and_kernels}; mod config; mod difficulty; @@ -53,10 +46,17 @@ use std::{ thread, time::Instant, }; +use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; +use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; +use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; +use tari_core::blocks::BlockHeader; +use tokio::{runtime::Runtime, time::sleep}; +use tonic::transport::Channel; +use utils::{coinbase_request, extract_outputs_and_kernels}; /// Application entry point fn main() { - let mut rt = Runtime::new().expect("Failed to start tokio runtime"); + let rt = Runtime::new().expect("Failed to start tokio runtime"); match rt.block_on(main_inner()) { Ok(_) => std::process::exit(0), Err(exit_code) => { @@ -144,7 +144,7 @@ async fn main_inner() -> Result<(), ExitCodes> { error!("Connection error: {:?}", err); loop { debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; match connect(&config, &global).await { Ok((nc, wc)) => { node_conn = nc; @@ -168,7 +168,7 @@ async fn main_inner() -> Result<(), ExitCodes> { Err(err) => { error!("Error: {:?}", err); debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; }, Ok(submitted) => { if submitted { diff --git a/applications/tari_stratum_transcoder/Cargo.toml b/applications/tari_stratum_transcoder/Cargo.toml index 29f95c82da..d29b91ead7 100644 --- a/applications/tari_stratum_transcoder/Cargo.toml +++ b/applications/tari_stratum_transcoder/Cargo.toml @@ -13,37 +13,37 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_utilities = "^0.3" + bincode = "1.3.1" -bytes = "0.5.6" +bytes = "0.5" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.7.2" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "^1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" url = "2.1.1" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" [dev-dependencies] futures-test = "0.3.5" diff --git a/applications/tari_stratum_transcoder/src/main.rs b/applications/tari_stratum_transcoder/src/main.rs index d55d551b5b..f742c92d6d 100644 --- a/applications/tari_stratum_transcoder/src/main.rs +++ b/applications/tari_stratum_transcoder/src/main.rs @@ -41,7 +41,7 @@ use tari_app_grpc::tari_rpc as grpc; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, GlobalConfig}; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), StratumTranscoderProxyError> { let config = initialize()?; diff --git a/applications/test_faucet/Cargo.toml b/applications/test_faucet/Cargo.toml index 3ef3c8a4c1..f686ece098 100644 --- a/applications/test_faucet/Cargo.toml +++ b/applications/test_faucet/Cargo.toml @@ -7,11 +7,13 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tari_crypto = "0.11.1" tari_utilities = "^0.3" +tari_common_types ={path="../../base_layer/common_types"} + +rand = "0.8" serde = { version = "1.0.97", features = ["derive"] } serde_json = "1.0" -rand = "0.8" -tari_crypto = "0.11.1" [dependencies.tari_core] version = "^0.9" @@ -20,6 +22,6 @@ default-features = false features = ["transactions", "avx2"] [dependencies.tokio] -version = "^0.2.10" +version = "^1.10" default-features = false -features = ["fs", "blocking", "stream", "rt-threaded", "macros", "io-util", "sync"] +features = ["fs", "rt-multi-thread", "macros", "io-util", "sync"] diff --git a/applications/test_faucet/src/main.rs b/applications/test_faucet/src/main.rs index 0aadee4d5e..65f75a4b5c 100644 --- a/applications/test_faucet/src/main.rs +++ b/applications/test_faucet/src/main.rs @@ -5,20 +5,22 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use serde::Serialize; + use std::{fs::File, io::Write}; -use tari_core::{ - tari_utilities::hex::Hex, - transactions::{ - helpers, - tari_amount::{MicroTari, T}, - transaction::{KernelFeatures, OutputFeatures, TransactionKernel, TransactionOutput}, - types::{Commitment, CryptoFactories, PrivateKey}, - }, -}; + +use serde::Serialize; use tari_crypto::script; use tokio::{sync::mpsc, task}; +use tari_common_types::types::{Commitment, PrivateKey}; +use tari_core::transactions::{ + helpers, + tari_amount::{MicroTari, T}, + transaction::{KernelFeatures, OutputFeatures, TransactionKernel, TransactionOutput}, + CryptoFactories, +}; +use tari_crypto::tari_utilities::hex::Hex; + const NUM_KEYS: usize = 4000; #[derive(Serialize)] @@ -32,7 +34,7 @@ struct Key { /// UTXO generation is pretty slow (esp range proofs), so we'll use async threads to speed things up. /// We'll use blocking thread tasks to do the CPU intensive utxo generation, and then push the results /// through a channel where a file-writer is waiting to persist the results to disk. -#[tokio::main(core_threads = 2, max_threads = 10)] +#[tokio::main(worker_threads = 2)] async fn main() -> Result<(), Box> { let num_keys: usize = std::env::args() .skip(1) @@ -52,7 +54,7 @@ async fn main() -> Result<(), Box> { // Use Rust's awesome Iterator trait to produce a sequence of values and output features. for (value, feature) in values.take(num_keys).zip(features.take(num_keys)) { let fc = factories.clone(); - let mut txc = tx.clone(); + let txc = tx.clone(); // Notice the `spawn(.. spawn_blocking)` nested call here. If we don't do this, we're basically queuing up // blocking tasks, `await`ing them to finish, and then queueing up the next one. In effect we're running things // synchronously. diff --git a/base_layer/common_types/Cargo.toml b/base_layer/common_types/Cargo.toml index 2b85d4e9a9..b70379c97f 100644 --- a/base_layer/common_types/Cargo.toml +++ b/base_layer/common_types/Cargo.toml @@ -11,4 +11,6 @@ futures = {version = "^0.3.1", features = ["async-await"] } rand = "0.8" tari_crypto = "0.11.1" serde = { version = "1.0.106", features = ["derive"] } -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +tokio = { version="^1.10", features = [ "time", "sync"] } +lazy_static = "1.4.0" +digest = "0.9.0" \ No newline at end of file diff --git a/base_layer/wallet/src/util/emoji.rs b/base_layer/common_types/src/emoji.rs similarity index 97% rename from base_layer/wallet/src/util/emoji.rs rename to base_layer/common_types/src/emoji.rs index 18ecdc174c..6d5b42aea8 100644 --- a/base_layer/wallet/src/util/emoji.rs +++ b/base_layer/common_types/src/emoji.rs @@ -20,12 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::util::luhn::{checksum, is_valid}; +use crate::{ + luhn::{checksum, is_valid}, + types::PublicKey, +}; use std::{ collections::HashMap, fmt::{Display, Error, Formatter}, }; -use tari_core::transactions::types::PublicKey; use tari_crypto::tari_utilities::{ hex::{Hex, HexError}, ByteArray, @@ -70,7 +72,7 @@ lazy_static! { /// # Example /// /// ``` -/// use tari_wallet::util::emoji::EmojiId; +/// use tari_common_types::emoji::EmojiId; /// /// assert!(EmojiId::is_valid("🐎🍴🌷🌟💻🐖🐩🐾🌟🐬🎧🐌🏦🐳🐎🐝🐢🔋👕🎸👿🍒🐓🎉💔🌹🏆🐬💡🎳🚦🍹🎒")); /// let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); @@ -170,8 +172,7 @@ pub struct EmojiIdError; #[cfg(test)] mod test { - use crate::util::emoji::EmojiId; - use tari_core::transactions::types::PublicKey; + use crate::{emoji::EmojiId, types::PublicKey}; use tari_crypto::tari_utilities::hex::Hex; #[test] diff --git a/base_layer/common_types/src/lib.rs b/base_layer/common_types/src/lib.rs index 03d3d25a62..df01ae7302 100644 --- a/base_layer/common_types/src/lib.rs +++ b/base_layer/common_types/src/lib.rs @@ -21,5 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. pub mod chain_metadata; +pub mod emoji; +pub mod luhn; pub mod types; pub mod waiting_requests; + +#[macro_use] +extern crate lazy_static; diff --git a/base_layer/wallet/src/util/luhn.rs b/base_layer/common_types/src/luhn.rs similarity index 98% rename from base_layer/wallet/src/util/luhn.rs rename to base_layer/common_types/src/luhn.rs index 9a9996ef72..3225b42ebe 100644 --- a/base_layer/wallet/src/util/luhn.rs +++ b/base_layer/common_types/src/luhn.rs @@ -45,7 +45,7 @@ pub fn is_valid(arr: &[usize], dict_len: usize) -> bool { #[cfg(test)] mod test { - use crate::util::luhn::*; + use crate::luhn::{checksum, is_valid}; #[test] fn luhn_6() { diff --git a/base_layer/common_types/src/types.rs b/base_layer/common_types/src/types.rs deleted file mode 100644 index 99c2789cbc..0000000000 --- a/base_layer/common_types/src/types.rs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2020. The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -pub const BLOCK_HASH_LENGTH: usize = 32; -pub type BlockHash = Vec; diff --git a/base_layer/common_types/src/types/bullet_rangeproofs.rs b/base_layer/common_types/src/types/bullet_rangeproofs.rs new file mode 100644 index 0000000000..a62dcd1228 --- /dev/null +++ b/base_layer/common_types/src/types/bullet_rangeproofs.rs @@ -0,0 +1,110 @@ +// Copyright 2019 The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::types::HashDigest; +use digest::Digest; +use serde::{ + de::{self, Visitor}, + Deserialize, + Deserializer, + Serialize, + Serializer, +}; +use std::fmt; +use tari_crypto::tari_utilities::{hex::*, ByteArray, ByteArrayError, Hashable}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct BulletRangeProof(pub Vec); +/// Implement the hashing function for RangeProof for use in the MMR +impl Hashable for BulletRangeProof { + fn hash(&self) -> Vec { + HashDigest::new().chain(&self.0).finalize().to_vec() + } +} + +impl ByteArray for BulletRangeProof { + fn to_vec(&self) -> Vec { + self.0.clone() + } + + fn from_vec(v: &Vec) -> Result { + Ok(BulletRangeProof { 0: v.clone() }) + } + + fn from_bytes(bytes: &[u8]) -> Result { + Ok(BulletRangeProof { 0: bytes.to_vec() }) + } + + fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +impl From> for BulletRangeProof { + fn from(v: Vec) -> Self { + BulletRangeProof(v) + } +} + +impl fmt::Display for BulletRangeProof { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.to_hex()) + } +} + +impl Serialize for BulletRangeProof { + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + if serializer.is_human_readable() { + self.to_hex().serialize(serializer) + } else { + serializer.serialize_bytes(self.as_bytes()) + } + } +} + +impl<'de> Deserialize<'de> for BulletRangeProof { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + struct RangeProofVisitor; + + impl<'de> Visitor<'de> for RangeProofVisitor { + type Value = BulletRangeProof; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a bulletproof range proof in binary format") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where E: de::Error { + BulletRangeProof::from_bytes(v).map_err(E::custom) + } + } + + if deserializer.is_human_readable() { + let s = String::deserialize(deserializer)?; + BulletRangeProof::from_hex(&s).map_err(de::Error::custom) + } else { + deserializer.deserialize_bytes(RangeProofVisitor) + } + } +} diff --git a/base_layer/common_types/src/types/mod.rs b/base_layer/common_types/src/types/mod.rs new file mode 100644 index 0000000000..e379d2bbac --- /dev/null +++ b/base_layer/common_types/src/types/mod.rs @@ -0,0 +1,81 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use tari_crypto::{ + common::Blake256, + ristretto::{ + pedersen::{PedersenCommitment, PedersenCommitmentFactory}, + RistrettoComSig, + RistrettoPublicKey, + RistrettoSchnorr, + RistrettoSecretKey, + }, +}; + +use tari_crypto::ristretto::dalek_range_proof::DalekRangeProofService; + +mod bullet_rangeproofs; + +pub use bullet_rangeproofs::BulletRangeProof; + +pub const BLOCK_HASH_LENGTH: usize = 32; +pub type BlockHash = Vec; + +/// Define the explicit Signature implementation for the Tari base layer. A different signature scheme can be +/// employed by redefining this type. +pub type Signature = RistrettoSchnorr; +/// Define the explicit Commitment Signature implementation for the Tari base layer. +pub type ComSignature = RistrettoComSig; + +/// Define the explicit Commitment implementation for the Tari base layer. +pub type Commitment = PedersenCommitment; +pub type CommitmentFactory = PedersenCommitmentFactory; + +/// Define the explicit Public key implementation for the Tari base layer +pub type PublicKey = RistrettoPublicKey; + +/// Define the explicit Secret key implementation for the Tari base layer. +pub type PrivateKey = RistrettoSecretKey; +pub type BlindingFactor = RistrettoSecretKey; + +/// Define the hash function that will be used to produce a signature challenge +pub type SignatureHasher = Blake256; + +/// Specify the Hash function for general hashing +pub type HashDigest = Blake256; + +/// Specify the digest type for signature challenges +pub type Challenge = Blake256; + +/// The type of output that `Challenge` produces +pub type MessageHash = Vec; + +/// Define the data type that is used to store results of `HashDigest` +pub type HashOutput = Vec; + +pub const MAX_RANGE_PROOF_RANGE: usize = 64; // 2^64 + +/// Specify the range proof type +pub type RangeProofService = DalekRangeProofService; + +/// Specify the range proof +pub type RangeProof = BulletRangeProof; diff --git a/base_layer/common_types/src/waiting_requests.rs b/base_layer/common_types/src/waiting_requests.rs index a26119a5cb..67e6eed6ef 100644 --- a/base_layer/common_types/src/waiting_requests.rs +++ b/base_layer/common_types/src/waiting_requests.rs @@ -20,10 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::oneshot::Sender as OneshotSender; use rand::RngCore; use std::{collections::HashMap, sync::Arc, time::Instant}; -use tokio::sync::RwLock; +use tokio::sync::{oneshot::Sender as OneshotSender, RwLock}; pub type RequestKey = u64; diff --git a/base_layer/core/Cargo.toml b/base_layer/core/Cargo.toml index 1212b2a40c..25e8ceb197 100644 --- a/base_layer/core/Cargo.toml +++ b/base_layer/core/Cargo.toml @@ -35,27 +35,28 @@ bincode = "1.1.4" bitflags = "1.0.4" blake2 = "^0.9.0" sha3 = "0.9" -bytes = "0.4.12" +bytes = "0.5" chrono = { version = "0.4.6", features = ["serde"]} croaring = { version = "=0.4.5", optional = true } digest = "0.9.0" -futures = {version = "^0.3.1", features = ["async-await"] } +futures = {version = "^0.3.16", features = ["async-await"] } fs2 = "0.3.0" hex = "0.4.2" +lazy_static = "1.4.0" lmdb-zero = "0.4.4" log = "0.4" monero = { version = "^0.13.0", features= ["serde_support"], optional = true } newtype-ops = "0.1.4" num = "0.3" -prost = "0.6.1" -prost-types = "0.6.1" +prost = "0.8.0" +prost-types = "0.8.0" rand = "0.8" randomx-rs = { version = "0.5.0", optional = true } serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0" strum_macros = "0.17.1" -thiserror = "1.0.20" -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +thiserror = "1.0.26" +tokio = { version="^1.10", features = [ "time", "sync", "macros"] } ttl_cache = "0.5.1" uint = { version = "0.9", default-features = false } num-format = "0.4.0" @@ -70,7 +71,6 @@ tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } config = { version = "0.9.3" } env_logger = "0.7.0" tempfile = "3.1.0" -tokio-macros = "0.2.4" [build-dependencies] tari_common = { version = "^0.9", path="../../common", features = ["build"]} diff --git a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs index 1310f22702..2700dc9d01 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs @@ -20,10 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{service::ChainMetadataService, LOG_TARGET}; +use super::service::ChainMetadataService; use crate::base_node::{chain_metadata_service::handle::ChainMetadataHandle, comms_interface::LocalNodeCommsInterface}; -use futures::{future, pin_mut}; -use log::*; use tari_comms::connectivity::ConnectivityRequester; use tari_p2p::services::liveness::LivenessHandle; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; @@ -40,15 +38,12 @@ impl ServiceInitializer for ChainMetadataServiceInitializer { let handle = ChainMetadataHandle::new(publisher.clone()); context.register_handle(handle); - context.spawn_when_ready(|handles| async move { + context.spawn_until_shutdown(|handles| { let liveness = handles.expect_handle::(); let base_node = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - let service_run = ChainMetadataService::new(liveness, base_node, connectivity, publisher).run(); - pin_mut!(service_run); - future::select(service_run, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "ChainMetadataService has shut down"); + ChainMetadataService::new(liveness, base_node, connectivity, publisher).run() }); Ok(()) diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index 6c7da96719..61110bc147 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -29,7 +29,6 @@ use crate::{ chain_storage::BlockAddResult, proto::base_node as proto, }; -use futures::stream::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use prost::Message; @@ -75,9 +74,9 @@ impl ChainMetadataService { /// Run the service pub async fn run(mut self) { - let mut liveness_event_stream = self.liveness.get_event_stream().fuse(); - let mut block_event_stream = self.base_node.get_block_event_stream().fuse(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut liveness_event_stream = self.liveness.get_event_stream(); + let mut block_event_stream = self.base_node.get_block_event_stream(); + let mut connectivity_events = self.connectivity.get_event_subscription(); log_if_error!( target: LOG_TARGET, @@ -86,47 +85,36 @@ impl ChainMetadataService { ); loop { - futures::select! { - block_event = block_event_stream.select_next_some() => { - if let Ok(block_event) = block_event { - log_if_error!( - level: debug, - target: LOG_TARGET, - "Failed to handle block event because '{}'", - self.handle_block_event(&block_event).await - ); - } + tokio::select! { + Ok(block_event) = block_event_stream.recv() => { + log_if_error!( + level: debug, + target: LOG_TARGET, + "Failed to handle block event because '{}'", + self.handle_block_event(&block_event).await + ); }, - liveness_event = liveness_event_stream.select_next_some() => { - if let Ok(event) = liveness_event { - log_if_error!( - target: LOG_TARGET, - "Failed to handle liveness event because '{}'", - self.handle_liveness_event(&*event).await - ); - } + Ok(event) = liveness_event_stream.recv() => { + log_if_error!( + target: LOG_TARGET, + "Failed to handle liveness event because '{}'", + self.handle_liveness_event(&*event).await + ); }, - event = connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event); - } - } - - complete => { - info!(target: LOG_TARGET, "ChainStateSyncService is exiting because all tasks have completed"); - break; + Ok(event) = connectivity_events.recv() => { + self.handle_connectivity_event(event); } } } } - fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { use ConnectivityEvent::*; match event { PeerDisconnected(node_id) | ManagedPeerDisconnected(node_id) | PeerBanned(node_id) => { - if let Some(pos) = self.peer_chain_metadata.iter().position(|p| &p.node_id == node_id) { + if let Some(pos) = self.peer_chain_metadata.iter().position(|p| p.node_id == node_id) { debug!( target: LOG_TARGET, "Removing disconnected/banned peer `{}` from chain metadata list ", node_id @@ -164,7 +152,7 @@ impl ChainMetadataService { async fn handle_liveness_event(&mut self, event: &LivenessEvent) -> Result<(), ChainMetadataSyncError> { match event { - // Received a ping, check if our neighbour sent it and it contains ChainMetadata + // Received a ping, check if it contains ChainMetadata LivenessEvent::ReceivedPing(event) => { trace!( target: LOG_TARGET, @@ -172,6 +160,7 @@ impl ChainMetadataService { event.node_id ); self.collect_chain_state_from_ping(&event.node_id, &event.metadata)?; + self.send_chain_metadata_to_event_publisher().await?; }, // Received a pong, check if our neighbour sent it and it contains ChainMetadata LivenessEvent::ReceivedPong(event) => { @@ -181,11 +170,7 @@ impl ChainMetadataService { event.node_id ); self.collect_chain_state_from_pong(&event.node_id, &event.metadata)?; - - // All peers have responded in this round, send the chain metadata to the base node service - if self.peer_chain_metadata.len() >= self.peer_chain_metadata.capacity() { - self.flush_chain_metadata_to_event_publisher().await?; - } + self.send_chain_metadata_to_event_publisher().await?; }, // New ping round has begun LivenessEvent::PingRoundBroadcast(num_peers) => { @@ -193,11 +178,9 @@ impl ChainMetadataService { target: LOG_TARGET, "New chain metadata round sent to {} peer(s)", num_peers ); - // If we have chain metadata to send to the base node service, send them now - // because the next round of pings is happening. - self.flush_chain_metadata_to_event_publisher().await?; // Ensure that we're waiting for the correct amount of peers to respond // and have allocated space for their replies + self.resize_chainstate_buffer(*num_peers); }, } @@ -205,13 +188,13 @@ impl ChainMetadataService { Ok(()) } - async fn flush_chain_metadata_to_event_publisher(&mut self) -> Result<(), ChainMetadataSyncError> { - let chain_metadata = self.peer_chain_metadata.drain(..).collect::>(); - + async fn send_chain_metadata_to_event_publisher(&mut self) -> Result<(), ChainMetadataSyncError> { // send only fails if there are no subscribers. let _ = self .event_publisher - .send(Arc::new(ChainMetadataEvent::PeerChainMetadataReceived(chain_metadata))); + .send(Arc::new(ChainMetadataEvent::PeerChainMetadataReceived( + self.peer_chain_metadata.clone(), + ))); Ok(()) } @@ -289,7 +272,6 @@ impl ChainMetadataService { self.peer_chain_metadata .push(PeerChainMetadata::new(node_id.clone(), chain_metadata)); - Ok(()) } } @@ -298,6 +280,7 @@ impl ChainMetadataService { mod test { use super::*; use crate::base_node::comms_interface::{CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}; + use futures::StreamExt; use std::convert::TryInto; use tari_comms::test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, @@ -361,7 +344,7 @@ mod test { ) } - #[tokio_macros::test] + #[tokio::test] async fn update_liveness_chain_metadata() { let (mut service, liveness_mock_state, _, mut base_node_receiver) = setup(); @@ -370,11 +353,11 @@ mod test { let chain_metadata = proto_chain_metadata.clone().try_into().unwrap(); task::spawn(async move { - let base_node_req = base_node_receiver.select_next_some().await; - let (_req, reply_tx) = base_node_req.split(); - reply_tx - .send(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) - .unwrap(); + if let Some(base_node_req) = base_node_receiver.next().await { + base_node_req + .reply(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) + .unwrap(); + } }); service.update_liveness_chain_metadata().await.unwrap(); @@ -387,7 +370,7 @@ mod test { let chain_metadata = proto::ChainMetadata::decode(data.as_slice()).unwrap(); assert_eq!(chain_metadata.height_of_longest_chain, Some(123)); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_ok() { let (mut service, _, _, _) = setup(); @@ -416,7 +399,7 @@ mod test { ); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_banned_peer() { let (mut service, _, _, _) = setup(); @@ -442,7 +425,7 @@ mod test { .peer_chain_metadata .iter() .any(|p| &p.node_id == nodes[0].node_id())); - service.handle_connectivity_event(&ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); + service.handle_connectivity_event(ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); // Check that banned peer was removed assert!(service .peer_chain_metadata @@ -450,7 +433,7 @@ mod test { .all(|p| &p.node_id != nodes[0].node_id())); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_no_metadata() { let (mut service, _, _, _) = setup(); @@ -468,7 +451,7 @@ mod test { assert_eq!(service.peer_chain_metadata.len(), 0); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_bad_metadata() { let (mut service, _, _, _) = setup(); diff --git a/base_layer/core/src/base_node/comms_interface/comms_request.rs b/base_layer/core/src/base_node/comms_interface/comms_request.rs index 2eec332b58..eef287d8f1 100644 --- a/base_layer/core/src/base_node/comms_interface/comms_request.rs +++ b/base_layer/core/src/base_node/comms_interface/comms_request.rs @@ -20,14 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - blocks::NewBlockTemplate, - chain_storage::MmrTree, - proof_of_work::PowAlgorithm, - transactions::types::{Commitment, HashOutput, Signature}, -}; +use crate::{blocks::NewBlockTemplate, chain_storage::MmrTree, proof_of_work::PowAlgorithm}; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Error, Formatter}; +use tari_common_types::types::{Commitment, HashOutput, Signature}; use tari_crypto::tari_utilities::hex::Hex; /// A container for the parameters required for a FetchMmrState request. diff --git a/base_layer/core/src/base_node/comms_interface/comms_response.rs b/base_layer/core/src/base_node/comms_interface/comms_response.rs index e275dc9c5a..8f7ec1b9e5 100644 --- a/base_layer/core/src/base_node/comms_interface/comms_response.rs +++ b/base_layer/core/src/base_node/comms_interface/comms_response.rs @@ -24,14 +24,11 @@ use crate::{ blocks::{block_header::BlockHeader, Block, NewBlockTemplate}, chain_storage::HistoricalBlock, proof_of_work::Difficulty, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::HashOutput, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use serde::{Deserialize, Serialize}; use std::fmt::{self, Display, Formatter}; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::HashOutput}; /// API Response enum #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs index 21ebc49201..e0d2c1268a 100644 --- a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs +++ b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs @@ -34,7 +34,7 @@ use crate::{ consensus::{ConsensusConstants, ConsensusManager}, mempool::{async_mempool, Mempool}, proof_of_work::{Difficulty, PowAlgorithm}, - transactions::{transaction::TransactionKernel, types::HashOutput}, + transactions::transaction::TransactionKernel, }; use log::*; use std::{ @@ -42,7 +42,7 @@ use std::{ sync::Arc, }; use strum_macros::Display; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, HashOutput}; use tari_comms::peer_manager::NodeId; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; use tokio::sync::Semaphore; diff --git a/base_layer/core/src/base_node/comms_interface/local_interface.rs b/base_layer/core/src/base_node/comms_interface/local_interface.rs index a0f5bcf2c3..0a270a78e2 100644 --- a/base_layer/core/src/base_node/comms_interface/local_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/local_interface.rs @@ -31,10 +31,7 @@ use crate::{ blocks::{Block, BlockHeader, NewBlockTemplate}, chain_storage::HistoricalBlock, proof_of_work::PowAlgorithm, - transactions::{ - transaction::TransactionKernel, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::TransactionKernel, }; use std::sync::Arc; use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; @@ -47,6 +44,7 @@ use crate::{ base_node::comms_interface::comms_request::GetNewBlockTemplateRequest, transactions::transaction::TransactionOutput, }; +use tari_common_types::types::{Commitment, HashOutput, Signature}; /// The InboundNodeCommsInterface provides an interface to request information from the current local node by other /// internal services. diff --git a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs index 753fe802d8..433b5325de 100644 --- a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs @@ -24,13 +24,16 @@ use crate::{ base_node::comms_interface::{error::CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}, blocks::{block_header::BlockHeader, NewBlock}, chain_storage::HistoricalBlock, - transactions::{transaction::TransactionOutput, types::HashOutput}, + transactions::transaction::TransactionOutput, }; -use futures::channel::mpsc::UnboundedSender; use log::*; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, HashOutput}, +}; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::bn::comms_interface::outbound_interface"; @@ -234,10 +237,8 @@ impl OutboundNodeCommsInterface { new_block: NewBlock, exclude_peers: Vec, ) -> Result<(), CommsInterfaceError> { - self.block_sender - .unbounded_send((new_block, exclude_peers)) - .map_err(|err| { - CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) - }) + self.block_sender.send((new_block, exclude_peers)).map_err(|err| { + CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) + }) } } diff --git a/base_layer/core/src/base_node/proto/request.rs b/base_layer/core/src/base_node/proto/request.rs index 45195f20b8..bf766bffc2 100644 --- a/base_layer/core/src/base_node/proto/request.rs +++ b/base_layer/core/src/base_node/proto/request.rs @@ -32,9 +32,9 @@ use crate::{ HashOutputs, }, }, - transactions::types::{Commitment, HashOutput, Signature}, }; use std::convert::{From, TryFrom, TryInto}; +use tari_common_types::types::{Commitment, HashOutput, Signature}; use tari_crypto::tari_utilities::ByteArrayError; //---------------------------------- BaseNodeRequest --------------------------------------------// diff --git a/base_layer/core/src/base_node/proto/wallet_rpc.rs b/base_layer/core/src/base_node/proto/wallet_rpc.rs index 94f2f2d7f6..3183128b55 100644 --- a/base_layer/core/src/base_node/proto/wallet_rpc.rs +++ b/base_layer/core/src/base_node/proto/wallet_rpc.rs @@ -23,7 +23,6 @@ use crate::{ crypto::tari_utilities::ByteArrayError, proto::{base_node as proto, types}, - transactions::types::Signature, }; use serde::{Deserialize, Serialize}; @@ -31,7 +30,7 @@ use std::{ convert::TryFrom, fmt::{Display, Error, Formatter}, }; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, Signature}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TxSubmissionResponse { diff --git a/base_layer/core/src/base_node/rpc/service.rs b/base_layer/core/src/base_node/rpc/service.rs index c50600ea9c..dbd1b141e4 100644 --- a/base_layer/core/src/base_node/rpc/service.rs +++ b/base_layer/core/src/base_node/rpc/service.rs @@ -40,9 +40,10 @@ use crate::{ }, types::{Signature as SignatureProto, Transaction as TransactionProto}, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use std::convert::TryFrom; +use tari_common_types::types::Signature; use tari_comms::protocol::rpc::{Request, Response, RpcStatus}; const LOG_TARGET: &str = "c::base_node::rpc"; @@ -230,7 +231,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpc // Determine if we are synced let status_watch = state_machine.get_status_info_watch(); - let is_synced = match (*status_watch.borrow()).state_info { + let is_synced = match status_watch.borrow().state_info { StateInfo::Listening(li) => li.is_synced(), _ => false, }; diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index ae6be0b519..11132db138 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -33,7 +33,7 @@ use crate::{ proto as shared_protos, proto::base_node as proto, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -50,7 +50,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::service::initializer"; const SUBSCRIPTION_LABEL: &str = "Base Node"; @@ -151,7 +151,7 @@ where T: BlockchainBackend + 'static let inbound_block_stream = self.inbound_block_stream(); // Connect InboundNodeCommsInterface and OutboundNodeCommsInterface to BaseNodeService let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); - let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded(); + let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded_channel(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (local_block_sender_service, local_block_stream) = reply_channel::unbounded(); let outbound_nci = diff --git a/base_layer/core/src/base_node/service/service.rs b/base_layer/core/src/base_node/service/service.rs index db96b6ed9e..1d66cbf1b1 100644 --- a/base_layer/core/src/base_node/service/service.rs +++ b/base_layer/core/src/base_node/service/service.rs @@ -38,16 +38,7 @@ use crate::{ proto as shared_protos, proto::{base_node as proto, base_node::base_node_service_request::Request}, }; -use futures::{ - channel::{ - mpsc::{channel, Receiver, Sender, UnboundedReceiver}, - oneshot::Sender as OneshotSender, - }, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -64,7 +55,14 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::reply_channel::RequestContext; -use tokio::task; +use tokio::{ + sync::{ + mpsc, + mpsc::{Receiver, Sender, UnboundedReceiver}, + oneshot::Sender as OneshotSender, + }, + task, +}; const LOG_TARGET: &str = "c::bn::base_node_service::service"; @@ -134,7 +132,7 @@ where B: BlockchainBackend + 'static config: BaseNodeServiceConfig, state_machine_handle: StateMachineHandle, ) -> Self { - let (timeout_sender, timeout_receiver) = channel(100); + let (timeout_sender, timeout_receiver) = mpsc::channel(100); Self { outbound_message_service, inbound_nch, @@ -162,7 +160,7 @@ where B: BlockchainBackend + 'static { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let outbound_block_stream = streams.outbound_block_stream.fuse(); + let outbound_block_stream = streams.outbound_block_stream; pin_mut!(outbound_block_stream); let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); @@ -177,53 +175,52 @@ where B: BlockchainBackend + 'static let timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Base Node Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Base Node Service initialized without timeout_receiver_stream"); pin_mut!(timeout_receiver_stream); loop { - futures::select! { + tokio::select! { // Outbound request messages from the OutboundNodeCommsInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound block messages from the OutboundNodeCommsInterface - (block, excluded_peers) = outbound_block_stream.select_next_some() => { + Some((block, excluded_peers)) = outbound_block_stream.recv() => { self.spawn_handle_outbound_block(block, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, // Incoming block messages from the Comms layer - block_msg = inbound_block_stream.select_next_some() => { + Some(block_msg) = inbound_block_stream.next() => { self.spawn_handle_incoming_block(block_msg).await; } // Incoming local request messages from the LocalNodeCommsInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Incoming local block messages from the LocalNodeCommsInterface e.g. miner and block sync - local_block_context = local_block_stream.select_next_some() => { + Some(local_block_context) = local_block_stream.next() => { self.spawn_handle_local_block(local_block_context); }, - complete => { - info!(target: LOG_TARGET, "Base Node service shutting down"); + else => { + info!(target: LOG_TARGET, "Base Node service shutting down because all streams ended"); break; } } @@ -646,9 +643,9 @@ async fn handle_request_timeout( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/base_node/state_machine_service/initializer.rs b/base_layer/core/src/base_node/state_machine_service/initializer.rs index a6d4c73a0c..c58d62000f 100644 --- a/base_layer/core/src/base_node/state_machine_service/initializer.rs +++ b/base_layer/core/src/base_node/state_machine_service/initializer.rs @@ -20,6 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use log::*; +use tokio::sync::{broadcast, watch}; + +use tari_comms::{connectivity::ConnectivityRequester, PeerManager}; +use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; + use crate::{ base_node::{ chain_metadata_service::ChainMetadataHandle, @@ -35,13 +43,8 @@ use crate::{ chain_storage::{async_db::AsyncBlockchainDb, BlockchainBackend}, consensus::ConsensusManager, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; -use log::*; -use std::sync::Arc; -use tari_comms::{connectivity::ConnectivityRequester, PeerManager}; -use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "c::bn::state_machine_service::initializer"; @@ -98,7 +101,8 @@ where B: BlockchainBackend + 'static let connectivity = handles.expect_handle::(); let peer_manager = handles.expect_handle::>(); - let sync_validators = SyncValidators::full_consensus(rules.clone(), factories); + let sync_validators = + SyncValidators::full_consensus(rules.clone(), factories, config.bypass_range_proof_verification); let max_randomx_vms = config.max_randomx_vms; let node = BaseNodeStateMachine::new( diff --git a/base_layer/core/src/base_node/state_machine_service/state_machine.rs b/base_layer/core/src/base_node/state_machine_service/state_machine.rs index 3c8e33ee06..ec99ccc989 100644 --- a/base_layer/core/src/base_node/state_machine_service/state_machine.rs +++ b/base_layer/core/src/base_node/state_machine_service/state_machine.rs @@ -52,6 +52,7 @@ pub struct BaseNodeStateMachineConfig { pub pruning_horizon: u64, pub max_randomx_vms: usize, pub blocks_behind_before_considered_lagging: u64, + pub bypass_range_proof_verification: bool, } /// A Tari full node, aka Base Node. @@ -158,7 +159,7 @@ impl BaseNodeStateMachine { state_info: self.info.clone(), }; - if let Err(e) = self.status_event_sender.broadcast(status) { + if let Err(e) = self.status_event_sender.send(status) { debug!(target: LOG_TARGET, "Error broadcasting a StatusEvent update: {}", e); } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs index 689c9a8316..5c7710c371 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs @@ -67,7 +67,7 @@ impl BlockSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSyncStarting, }); @@ -80,7 +80,7 @@ impl BlockSync { false.into(), )); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSync(BlockSyncInfo { tip_height: remote_tip_height, diff --git a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs index 4e864624bc..0ef1f6e99b 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs @@ -160,7 +160,7 @@ impl Display for BaseNodeState { #[derive(Debug, Clone, PartialEq)] pub enum StateInfo { StartUp, - HeaderSync(BlockSyncInfo), + HeaderSync(Option), HorizonSync(HorizonSyncInfo), BlockSyncStarting, BlockSync(BlockSyncInfo), @@ -169,15 +169,12 @@ pub enum StateInfo { impl StateInfo { pub fn short_desc(&self) -> String { + use StateInfo::*; match self { - Self::StartUp => "Starting up".to_string(), - Self::HeaderSync(info) => format!( - "Syncing headers: {}/{} ({:.0}%)", - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 - ), - Self::HorizonSync(info) => match info.status { + StartUp => "Starting up".to_string(), + HeaderSync(None) => "Starting header sync".to_string(), + HeaderSync(Some(info)) => format!("Syncing headers: {}", info.sync_progress_string()), + HorizonSync(info) => match info.status { HorizonSyncStatus::Starting => "Starting horizon sync".to_string(), HorizonSyncStatus::Kernels(current, total) => format!( "Syncing kernels: {}/{} ({:.0}%)", @@ -193,18 +190,16 @@ impl StateInfo { ), HorizonSyncStatus::Finalizing => "Finalizing horizon sync".to_string(), }, - Self::BlockSync(info) => format!( - "Syncing blocks with {}: {}/{} ({:.0}%) ", + BlockSync(info) => format!( + "Syncing blocks: ({}) {}", info.sync_peers .first() - .map(|s| s.short_str()) + .map(|n| n.short_str()) .unwrap_or_else(|| "".to_string()), - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 + info.sync_progress_string() ), - Self::Listening(_) => "Listening".to_string(), - Self::BlockSyncStarting => "Starting block sync".to_string(), + Listening(_) => "Listening".to_string(), + BlockSyncStarting => "Starting block sync".to_string(), } } @@ -226,13 +221,15 @@ impl StateInfo { impl Display for StateInfo { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + use StateInfo::*; match self { - Self::StartUp => write!(f, "Node starting up"), - Self::HeaderSync(info) => write!(f, "Synchronizing block headers: {}", info), - Self::HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), - Self::BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), - Self::Listening(info) => write!(f, "Listening: {}", info), - Self::BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), + StartUp => write!(f, "Node starting up"), + HeaderSync(Some(info)) => write!(f, "Synchronizing block headers: {}", info), + HeaderSync(None) => write!(f, "Synchronizing block headers: Starting"), + HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), + BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), + Listening(info) => write!(f, "Listening: {}", info), + BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), } } } @@ -282,15 +279,24 @@ impl BlockSyncInfo { sync_peers, } } + + pub fn sync_progress_string(&self) -> String { + format!( + "{}/{} ({:.0}%)", + self.local_height, + self.tip_height, + (self.local_height as f64 / self.tip_height as f64 * 100.0) + ) + } } impl Display for BlockSyncInfo { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - fmt.write_str("Syncing from the following peers: \n")?; + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Syncing from the following peers:")?; for peer in &self.sync_peers { - fmt.write_str(&format!("{}\n", peer))?; + writeln!(f, "{}", peer)?; } - fmt.write_str(&format!("Syncing {}/{}\n", self.local_height, self.tip_height)) + writeln!(f, "Syncing {}", self.sync_progress_string()) } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs index 2acfd53206..68f663d71f 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs @@ -74,14 +74,15 @@ impl HeaderSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - synchronizer.on_progress(move |current_height, remote_tip_height, sync_peers| { - let _ = status_event_sender.broadcast(StatusInfo { + synchronizer.on_progress(move |details, sync_peers| { + let details = details.map(|(current_height, remote_tip_height)| BlockSyncInfo { + tip_height: remote_tip_height, + local_height: current_height, + sync_peers: sync_peers.to_vec(), + }); + let _ = status_event_sender.send(StatusInfo { bootstrapped, - state_info: StateInfo::HeaderSync(BlockSyncInfo { - tip_height: remote_tip_height, - local_height: current_height, - sync_peers: sync_peers.to_vec(), - }), + state_info: StateInfo::HeaderSync(details), }); }); diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs index 2fe036b19e..c1b8412028 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs @@ -25,26 +25,27 @@ // TODO: Move the horizon synchronizer to the `sync` module -mod config; - -pub use self::config::HorizonSyncConfig; - -mod error; +use log::*; pub use error::HorizonSyncError; - -mod horizon_state_synchronization; - use horizon_state_synchronization::HorizonStateSynchronization; +use tari_comms::PeerConnection; + +use crate::{base_node::BaseNodeStateMachine, chain_storage::BlockchainBackend, transactions::CryptoFactories}; use super::{ events_and_states::{HorizonSyncInfo, HorizonSyncStatus}, StateEvent, StateInfo, }; -use crate::{base_node::BaseNodeStateMachine, chain_storage::BlockchainBackend, transactions::types::CryptoFactories}; -use log::*; -use tari_comms::PeerConnection; + +pub use self::config::HorizonSyncConfig; + +mod config; + +mod error; + +mod horizon_state_synchronization; const LOG_TARGET: &str = "c::bn::state_machine_service::states::horizon_state_sync"; diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs index 8347bcf521..3511ade206 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs @@ -39,10 +39,7 @@ use crate::{ SyncUtxosRequest, SyncUtxosResponse, }, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{HashDigest, RangeProofService}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use futures::StreamExt; @@ -51,6 +48,7 @@ use std::{ convert::{TryFrom, TryInto}, sync::Arc, }; +use tari_common_types::types::{HashDigest, RangeProofService}; use tari_comms::PeerConnection; use tari_crypto::{ commitment::HomomorphicCommitment, diff --git a/base_layer/core/src/base_node/state_machine_service/states/listening.rs b/base_layer/core/src/base_node/state_machine_service/states/listening.rs index 0ea8157568..92349c269a 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/listening.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/listening.rs @@ -31,7 +31,6 @@ use crate::{ }, chain_storage::BlockchainBackend, }; -use futures::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use serde::{Deserialize, Serialize}; @@ -118,7 +117,8 @@ impl Listening { info!(target: LOG_TARGET, "Listening for chain metadata updates"); shared.set_state_info(StateInfo::Listening(ListeningInfo::new(self.is_synced))); - while let Some(metadata_event) = shared.metadata_event_stream.next().await { + loop { + let metadata_event = shared.metadata_event_stream.recv().await; match metadata_event.as_ref().map(|v| v.deref()) { Ok(ChainMetadataEvent::PeerChainMetadataReceived(peer_metadata_list)) => { let mut peer_metadata_list = peer_metadata_list.clone(); @@ -199,16 +199,16 @@ impl Listening { if !self.is_synced { self.is_synced = true; + shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); debug!(target: LOG_TARGET, "Initial sync achieved"); } - shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { debug!(target: LOG_TARGET, "Metadata event subscriber lagged by {} item(s)", n); }, - Err(broadcast::RecvError::Closed) => { - // This should never happen because the while loop exits when the stream ends + Err(broadcast::error::RecvError::Closed) => { debug!(target: LOG_TARGET, "Metadata event subscriber closed"); + break; }, } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs index 7ea2e7e2b0..aeaa5ab430 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs @@ -23,7 +23,7 @@ use crate::base_node::state_machine_service::states::{BlockSync, HeaderSync, HorizonStateSync, StateEvent}; use log::info; use std::time::Duration; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "c::bn::state_machine_service::states::waiting"; @@ -41,7 +41,7 @@ impl Waiting { "The base node has started a WAITING state for {} seconds", self.timeout.as_secs() ); - delay_for(self.timeout).await; + sleep(self.timeout).await; info!( target: LOG_TARGET, "The base node waiting state has completed. Resuming normal operations" diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index b7483ff70e..fdd1ab6d64 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -32,7 +32,6 @@ use crate::{ base_node::{FindChainSplitRequest, SyncHeadersRequest}, }, tari_utilities::{hex::Hex, Hashable}, - transactions::types::HashOutput, validation::ValidationError, }; use futures::{future, stream::FuturesUnordered, StreamExt}; @@ -42,6 +41,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use tari_common_types::types::HashOutput; use tari_comms::{ connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, peer_manager::NodeId, @@ -83,7 +83,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } pub fn on_progress(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.hooks.add_on_progress_header_hook(hook); } @@ -94,6 +94,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { pub async fn synchronize(&mut self) -> Result { debug!(target: LOG_TARGET, "Starting header sync.",); + self.hooks.call_on_progress_header_hooks(None, self.sync_peers); let sync_peers = self.select_sync_peers().await?; info!( target: LOG_TARGET, @@ -261,7 +262,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { let latency = client.get_last_request_latency().await?; debug!( target: LOG_TARGET, - "Initiating header sync with peer `{}` (latency = {}ms)", + "Initiating header sync with peer `{}` (sync latency = {}ms)", conn.peer_node_id(), latency.unwrap_or_default().as_millis() ); @@ -272,6 +273,10 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { // We're ahead of this peer, try another peer if possible SyncStatus::Ahead => Err(BlockHeaderSyncError::NotInSync), SyncStatus::Lagging(split_info) => { + self.hooks.call_on_progress_header_hooks( + Some((split_info.local_tip_header.height(), split_info.remote_tip_height)), + self.sync_peers, + ); self.synchronize_headers(&peer, &mut client, *split_info).await?; Ok(()) }, @@ -482,46 +487,53 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { ) -> Result<(), BlockHeaderSyncError> { const COMMIT_EVERY_N_HEADERS: usize = 1000; - // Peer returned no more than the max headers. This indicates that there are no further headers to request. - if self.header_validator.valid_headers().len() <= NUM_INITIAL_HEADERS_TO_REQUEST as usize { - debug!(target: LOG_TARGET, "No further headers to download"); - if !self.pending_chain_has_higher_pow(&split_info.local_tip_header)? { - return Err(BlockHeaderSyncError::WeakerChain); - } + let mut has_switched_to_new_chain = false; + let pending_len = self.header_validator.valid_headers().len(); + // Find the hash to start syncing the rest of the headers. + // The expectation cannot fail because there has been at least one valid header returned (checked in + // determine_sync_status) + let (start_header_height, start_header_hash) = self + .header_validator + .current_valid_chain_tip_header() + .map(|h| (h.height(), h.hash().clone())) + .expect("synchronize_headers: expected there to be a valid tip header but it was None"); + + // If we already have a stronger chain at this point, switch over to it. + // just in case we happen to be exactly NUM_INITIAL_HEADERS_TO_REQUEST headers behind. + let has_better_pow = self.pending_chain_has_higher_pow(&split_info.local_tip_header)?; + if has_better_pow { debug!( target: LOG_TARGET, "Remote chain from peer {} has higher PoW. Switching", peer ); - // PoW is higher, switching over to the new chain self.switch_to_pending_chain(&split_info).await?; + has_switched_to_new_chain = true; + } + + if pending_len < NUM_INITIAL_HEADERS_TO_REQUEST as usize { + // Peer returned less than the number of requested headers. This indicates that we have all the available + // headers. + debug!(target: LOG_TARGET, "No further headers to download"); + if !has_better_pow { + return Err(BlockHeaderSyncError::WeakerChain); + } return Ok(()); } - // Find the hash to start syncing the rest of the headers. - // The expectation cannot fail because the number of headers has been checked in determine_sync_status - let start_header = - self.header_validator.valid_headers().last().expect( - "synchronize_headers: expected there to be at least one valid pending header but there were none", - ); - debug!( target: LOG_TARGET, - "Download remaining headers starting from header #{} from peer `{}`", - start_header.height(), - peer + "Download remaining headers starting from header #{} from peer `{}`", start_header_height, peer ); let request = SyncHeadersRequest { - start_hash: start_header.hash().clone(), + start_hash: start_header_hash, // To the tip! count: 0, }; let mut header_stream = client.sync_headers(request).await?; - debug!(target: LOG_TARGET, "Reading headers from peer `{}`", peer); - - let mut has_switched_to_new_chain = false; + debug!(target: LOG_TARGET, "Reading headers from peer `{}`", peer,); while let Some(header) = header_stream.next().await { let header = BlockHeader::try_from(header?).map_err(BlockHeaderSyncError::ReceivedInvalidHeader)?; @@ -563,7 +575,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } self.hooks - .call_on_progress_header_hooks(current_height, split_info.remote_tip_height, self.sync_peers); + .call_on_progress_header_hooks(Some((current_height, split_info.remote_tip_height)), self.sync_peers); } if !has_switched_to_new_chain { diff --git a/base_layer/core/src/base_node/sync/header_sync/validator.rs b/base_layer/core/src/base_node/sync/header_sync/validator.rs index aff52b80fd..3c11c9a2d5 100644 --- a/base_layer/core/src/base_node/sync/header_sync/validator.rs +++ b/base_layer/core/src/base_node/sync/header_sync/validator.rs @@ -35,7 +35,6 @@ use crate::{ consensus::ConsensusManager, proof_of_work::{randomx_factory::RandomXFactory, PowAlgorithm}, tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}, - transactions::types::HashOutput, validation::helpers::{ check_header_timestamp_greater_than_median, check_pow_data, @@ -45,6 +44,7 @@ use crate::{ }; use log::*; use std::cmp::Ordering; +use tari_common_types::types::HashOutput; const LOG_TARGET: &str = "c::bn::header_sync"; @@ -115,6 +115,10 @@ impl BlockHeaderSyncValidator { Ok(()) } + pub fn current_valid_chain_tip_header(&self) -> Option<&ChainHeader> { + self.valid_headers().last() + } + pub fn validate(&mut self, header: BlockHeader) -> Result<(), BlockHeaderSyncError> { let state = self.state(); let expected_height = state.current_height + 1; @@ -283,7 +287,7 @@ mod test { mod initialize_state { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_initializes_state_to_given_header() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(&tip.header().hash()).await.unwrap(); @@ -295,7 +299,7 @@ mod test { assert_eq!(state.current_height, 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_if_hash_does_not_exist() { let (mut validator, _) = setup(); let start_hash = vec![0; 32]; @@ -308,7 +312,7 @@ mod test { mod validate { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_passes_if_headers_are_valid() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(tip.hash()).await.unwrap(); @@ -322,7 +326,7 @@ mod test { assert_eq!(validator.valid_headers().len(), 2); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_fails_if_height_is_not_serial() { let (mut validator, _, tip) = setup_with_headers(2).await; validator.initialize_state(tip.hash()).await.unwrap(); diff --git a/base_layer/core/src/base_node/sync/hooks.rs b/base_layer/core/src/base_node/sync/hooks.rs index 71f0802926..d1e2628822 100644 --- a/base_layer/core/src/base_node/sync/hooks.rs +++ b/base_layer/core/src/base_node/sync/hooks.rs @@ -28,7 +28,7 @@ use tari_comms::peer_manager::NodeId; #[derive(Default)] pub(super) struct Hooks { - on_progress_header: Vec>, + on_progress_header: Vec, &[NodeId]) + Send + Sync>>, on_progress_block: Vec, u64, &[NodeId]) + Send + Sync>>, on_complete: Vec) + Send + Sync>>, on_rewind: Vec>) + Send + Sync>>, @@ -36,14 +36,14 @@ pub(super) struct Hooks { impl Hooks { pub fn add_on_progress_header_hook(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.on_progress_header.push(Box::new(hook)); } - pub fn call_on_progress_header_hooks(&mut self, height: u64, remote_tip_height: u64, sync_peers: &[NodeId]) { + pub fn call_on_progress_header_hooks(&mut self, height_vs_remote: Option<(u64, u64)>, sync_peers: &[NodeId]) { self.on_progress_header .iter_mut() - .for_each(|f| (*f)(height, remote_tip_height, sync_peers)); + .for_each(|f| (*f)(height_vs_remote, sync_peers)); } pub fn add_on_progress_block_hook(&mut self, hook: H) diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index 776c2b7e01..e9df7073a2 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -35,12 +35,14 @@ use crate::{ SyncUtxosResponse, }, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::cmp; -use tari_comms::protocol::rpc::{Request, Response, RpcStatus, Streaming}; +use tari_comms::{ + protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, +}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::task; +use tokio::{sync::mpsc, task}; use tracing::{instrument, span, Instrument, Level}; const LOG_TARGET: &str = "c::base_node::sync_rpc"; @@ -116,7 +118,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ // Number of blocks to load and push to the stream before loading the next batch const BATCH_SIZE: usize = 4; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let span = span!(Level::TRACE, "sync_rpc::block_sync::inner_worker"); task::spawn( @@ -138,19 +140,16 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(blocks) => { - let mut blocks = stream::iter( - blocks - .into_iter() - .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) - .map(|block| match block { - Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), - Err(err) => Err(err), - }) - .map(Ok), - ); + let blocks = blocks + .into_iter() + .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) + .map(|block| match block { + Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), + Err(err) => Err(err), + }); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut blocks).await.is_err() { + if utils::mpsc::send_all(&tx, blocks).await.is_err() { break; } }, @@ -209,7 +208,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ chunk_size ); - let (mut tx, rx) = mpsc::channel(chunk_size); + let (tx, rx) = mpsc::channel(chunk_size); let span = span!(Level::TRACE, "sync_rpc::sync_headers::inner_worker"); task::spawn( async move { @@ -233,10 +232,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(headers) => { - let mut headers = - stream::iter(headers.into_iter().map(proto::core::BlockHeader::from).map(Ok).map(Ok)); + let headers = headers.into_iter().map(proto::core::BlockHeader::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut headers).await.is_err() { + if utils::mpsc::send_all(&tx, headers).await.is_err() { break; } }, @@ -354,7 +352,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ ) -> Result, RpcStatus> { let req = request.into_message(); const BATCH_SIZE: usize = 1000; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let db = self.db(); task::spawn(async move { @@ -394,15 +392,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(kernels) => { - let mut kernels = stream::iter( - kernels - .into_iter() - .map(proto::types::TransactionKernel::from) - .map(Ok) - .map(Ok), - ); + let kernels = kernels.into_iter().map(proto::types::TransactionKernel::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut kernels).await.is_err() { + if utils::mpsc::send_all(&tx, kernels).await.is_err() { break; } }, diff --git a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs index ef10b41c2f..8064aaf458 100644 --- a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs +++ b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs @@ -25,11 +25,11 @@ use crate::{ proto, proto::base_node::{SyncUtxo, SyncUtxosRequest, SyncUtxosResponse}, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc, time::Instant}; -use tari_comms::protocol::rpc::RpcStatus; +use tari_comms::{protocol::rpc::RpcStatus, utils}; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tokio::sync::mpsc; const LOG_TARGET: &str = "c::base_node::sync_rpc::sync_utxo_task"; @@ -147,8 +147,7 @@ where B: BlockchainBackend + 'static utxos.len(), deleted_diff.cardinality(), ); - let mut utxos = stream::iter( - utxos + let utxos = utxos .into_iter() .enumerate() // Only include pruned UTXOs if include_pruned_utxos is true @@ -161,12 +160,10 @@ where B: BlockchainBackend + 'static mmr_index: start + i as u64, } }) - .map(Ok) - .map(Ok), - ); + .map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut utxos).await.is_err() { + if utils::mpsc::send_all(&tx, utxos).await.is_err() { break; } diff --git a/base_layer/core/src/base_node/sync/rpc/tests.rs b/base_layer/core/src/base_node/sync/rpc/tests.rs index 35611a9dda..61adc1aa3c 100644 --- a/base_layer/core/src/base_node/sync/rpc/tests.rs +++ b/base_layer/core/src/base_node/sync/rpc/tests.rs @@ -89,7 +89,7 @@ mod sync_blocks { use tari_comms::protocol::rpc::RpcStatusCode; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_not_found_if_unknown_hash() { let mut backend = create_mock_backend(); backend.expect_fetch().times(1).returning(|_| Ok(None)); @@ -103,7 +103,7 @@ mod sync_blocks { unpack_enum!(RpcStatusCode::NotFound = err.status_code()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_sends_an_empty_response() { let mut backend = create_mock_backend(); @@ -136,7 +136,7 @@ mod sync_blocks { assert!(streaming.next().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_streams_blocks_until_end() { let mut backend = create_mock_backend(); diff --git a/base_layer/core/src/base_node/sync/validators.rs b/base_layer/core/src/base_node/sync/validators.rs index d65af9a972..e5282cc604 100644 --- a/base_layer/core/src/base_node/sync/validators.rs +++ b/base_layer/core/src/base_node/sync/validators.rs @@ -20,10 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{fmt, sync::Arc}; + use crate::{ chain_storage::BlockchainBackend, consensus::ConsensusManager, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::BlockValidator, CandidateBlockBodyValidation, @@ -31,7 +33,6 @@ use crate::{ FinalHorizonStateValidation, }, }; -use std::{fmt, sync::Arc}; #[derive(Clone)] pub struct SyncValidators { @@ -51,9 +52,13 @@ impl SyncValidators { } } - pub fn full_consensus(rules: ConsensusManager, factories: CryptoFactories) -> Self { + pub fn full_consensus( + rules: ConsensusManager, + factories: CryptoFactories, + bypass_range_proof_verification: bool, + ) -> Self { Self::new( - BlockValidator::new(rules.clone(), factories.clone()), + BlockValidator::new(rules.clone(), bypass_range_proof_verification, factories.clone()), ChainBalanceValidator::::new(rules, factories), ) } diff --git a/base_layer/core/src/blocks/block.rs b/base_layer/core/src/blocks/block.rs index 87124dbece..04cded3ada 100644 --- a/base_layer/core/src/blocks/block.rs +++ b/base_layer/core/src/blocks/block.rs @@ -23,6 +23,18 @@ // Portions of this file were originally copyrighted (c) 2018 The Grin Developers, issued under the Apache License, // Version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0. +use std::{ + fmt, + fmt::{Display, Formatter}, +}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::tari_utilities::Hashable; +use thiserror::Error; + +use tari_common_types::types::BlockHash; + use crate::{ blocks::BlockHeader, chain_storage::MmrTree, @@ -33,18 +45,9 @@ use crate::{ aggregated_body::AggregateBody, tari_amount::MicroTari, transaction::{Transaction, TransactionError, TransactionInput, TransactionKernel, TransactionOutput}, - types::CryptoFactories, + CryptoFactories, }, }; -use log::*; -use serde::{Deserialize, Serialize}; -use std::{ - fmt, - fmt::{Display, Formatter}, -}; -use tari_common_types::types::BlockHash; -use tari_crypto::tari_utilities::Hashable; -use thiserror::Error; #[derive(Clone, Debug, PartialEq, Error)] pub enum BlockValidationError { diff --git a/base_layer/core/src/blocks/block_header.rs b/base_layer/core/src/blocks/block_header.rs index 3346d0945d..fb7aad681e 100644 --- a/base_layer/core/src/blocks/block_header.rs +++ b/base_layer/core/src/blocks/block_header.rs @@ -39,11 +39,7 @@ #[cfg(feature = "base_node")] use crate::blocks::{BlockBuilder, NewBlockHeaderTemplate}; - -use crate::{ - proof_of_work::{PowAlgorithm, PowError, ProofOfWork}, - transactions::types::{BlindingFactor, HashDigest}, -}; +use crate::proof_of_work::{PowAlgorithm, PowError, ProofOfWork}; use chrono::{DateTime, Utc}; use digest::Digest; use serde::{ @@ -57,7 +53,7 @@ use std::{ fmt, fmt::{Display, Error, Formatter}, }; -use tari_common_types::types::{BlockHash, BLOCK_HASH_LENGTH}; +use tari_common_types::types::{BlindingFactor, BlockHash, HashDigest, BLOCK_HASH_LENGTH}; use tari_crypto::tari_utilities::{epoch_time::EpochTime, hex::Hex, ByteArray, Hashable}; use thiserror::Error; diff --git a/base_layer/core/src/blocks/genesis_block.rs b/base_layer/core/src/blocks/genesis_block.rs index a346bea095..861b3e8650 100644 --- a/base_layer/core/src/blocks/genesis_block.rs +++ b/base_layer/core/src/blocks/genesis_block.rs @@ -30,15 +30,13 @@ use crate::{ chain_storage::{BlockHeaderAccumulatedData, ChainBlock}, transactions::{ aggregated_body::AggregateBody, - bullet_rangeproofs::BulletRangeProof, tari_amount::MicroTari, transaction::{KernelFeatures, OutputFeatures, OutputFlags, TransactionKernel, TransactionOutput}, - types::{Commitment, PrivateKey, PublicKey, Signature}, }, }; use chrono::DateTime; use std::sync::Arc; -use tari_common_types::types::BLOCK_HASH_LENGTH; +use tari_common_types::types::{BulletRangeProof, Commitment, PrivateKey, PublicKey, Signature, BLOCK_HASH_LENGTH}; use tari_crypto::{ script::TariScript, tari_utilities::{hash::Hashable, hex::*}, @@ -369,10 +367,96 @@ pub fn get_ridcully_genesis_block_raw() -> Block { } } +pub fn get_igor_genesis_block() -> ChainBlock { + // lets get the block + let block = get_igor_genesis_block_raw(); + + let accumulated_data = BlockHeaderAccumulatedData { + hash: block.hash(), + total_kernel_offset: block.header.total_kernel_offset.clone(), + achieved_difficulty: 1.into(), + total_accumulated_difficulty: 1, + accumulated_monero_difficulty: 1.into(), + accumulated_sha_difficulty: 1.into(), + target_difficulty: 1.into(), + }; + ChainBlock::try_construct(Arc::new(block), accumulated_data).unwrap() +} + +#[allow(deprecated)] +pub fn get_igor_genesis_block_raw() -> Block { + let sig = Signature::new( + PublicKey::from_hex("f2139d1cdbcfa670bbb60d4d03d9d50b0a522e674b11280e8064f6dc30e84133").unwrap(), + PrivateKey::from_hex("3ff7522d9a744ebf99c7b6664c0e2c8c64d2a7b902a98b78964766f9f7f2b107").unwrap(), + ); + let mut body = AggregateBody::new( + vec![], + vec![TransactionOutput { + features: OutputFeatures { + flags: OutputFlags::COINBASE_OUTPUT, + maturity: 60, + }, + commitment: Commitment::from_hex( + "fadafb12de96d90042dcbf839985aadb7ae88baa3446d5c6a17937ef2b36783e", + ) + .unwrap(), + proof: BulletRangeProof::from_hex("845c947cbf23683f6ff6a56d0aa55fca14a618f7476d4e29348c5cbadf2bb062b8da701a0f058eb69c88492895c3f034db194f6d1b2d29ea83c1a68cbdd19a3f90ae080cfd0315bb20cd05a462c4e06e708b015da1d70c0f87e8c7413b579008e43a6c8dc1edb72b0b67612e897d251ec55798184ff35c80d18262e98034677b73f2dcc7ae25c9119900aadaf04a16068bf57b9e8b9bb694331750dc8acc6102b8961be183419dce2f96c48ced9892e4cdb091dcda0d6a0bb4ed94fc0c63ca065f25ce1e560504d49970bcaac007f33368f15ffa0dd3f56bf799b66fa684fe0fbeb882aee4a6fe05a3ca7c488a6ba22779a42f0f5d875175b8ebc517dd49df20b4f04f027b7d22b7c62cb93727f35c18a0b776d95fac4ff5405d6ed3dbb7613152178cecea4b712aa6e6701804ded71d94cf67de2e86ae401499b39de81b7344185c9eb3bd570ac6121143a690f118d9413abb894729b6b3e057f4771b2c2204285151a56695257992f2b0331f27066270718b37ab472c339d2560c1f6559f3c4ce31ec7f7e2acdbebb1715951d8177283a1ccc2f393ce292956de5db4afde419c0264d5cc4758e6e2c07b730ad43819f3761658d63794cc8071b30f9d7cd622bece4f086b0ca6a04fee888856084543a99848f06334acf48cace58e5ef8c85412017c400b4ec92481ba6d745915aef40531db73d1d84d07d7fce25737629e0fc4ee71e7d505bfd382e362cd1ac03a67c93b8f20cb4285ce240cf1e000d48332ba32e713d6cdf6266449a0a156241f7b1b36753f46f1ecb8b1836625508c5f31bc7ebc1d7cd634272be02cc109bf86983a0591bf00bacea1287233fc12324846398be07d44e8e14bd78cd548415f6de60b5a0c43a84ac29f6a8ac0b1b748dd07a8a4124625e1055b5f5b19da79c319b6e465ca5df0eb70cb4e3dc399891ce90b").unwrap(), + // For genesis block: A default script can never be spent, intentionally + script: TariScript::default(), + // Script offset never checked for coinbase, thus can use default + sender_offset_public_key: Default::default(), + // For genesis block: Metadata signature will never be checked + metadata_signature: Default::default(), + }], + vec![TransactionKernel { + features: KernelFeatures::COINBASE_KERNEL, + fee: MicroTari(0), + lock_height: 0, + excess: Commitment::from_hex( + "f472cc347a1006b7390f9c93b3c62fba334fd99f6c9c1daf9302646cd4781f61", + ) + .unwrap(), + excess_sig: sig, + }], + ); + body.sort(); + // set genesis timestamp + let genesis = DateTime::parse_from_rfc2822("27 Aug 2021 06:00:00 +0200").unwrap(); + let timestamp = genesis.timestamp() as u64; + Block { + header: BlockHeader { + version: 0, + height: 0, + prev_hash: vec![0; BLOCK_HASH_LENGTH], + timestamp: timestamp.into(), + output_mr: from_hex("dcc44f39b65e5e1e526887e7d56f7b85e2ea44bd29bc5bc195e6e015d19e1c06").unwrap(), + witness_mr: from_hex("e4d7dab49a66358379a901b9a36c10f070aa9d7bdc8ae752947b6fc4e55d255f").unwrap(), + output_mmr_size: 1, + kernel_mr: from_hex("589bc62ac5d9139f921c68b8075c32d8d130024acaf3196d1d6a89df601e2bcf").unwrap(), + kernel_mmr_size: 1, + input_mr: vec![0; BLOCK_HASH_LENGTH], + total_kernel_offset: PrivateKey::from_hex( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(), + total_script_offset: PrivateKey::from_hex( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(), + nonce: 0, + pow: ProofOfWork { + pow_algo: PowAlgorithm::Sha3, + pow_data: vec![], + }, + }, + body, + } +} + #[cfg(test)] mod test { use super::*; - use crate::transactions::types::CryptoFactories; + use crate::transactions::CryptoFactories; #[test] fn weatherwax_genesis_sanity_check() { diff --git a/base_layer/core/src/blocks/new_blockheader_template.rs b/base_layer/core/src/blocks/new_blockheader_template.rs index 543a22a287..7fc902fdf0 100644 --- a/base_layer/core/src/blocks/new_blockheader_template.rs +++ b/base_layer/core/src/blocks/new_blockheader_template.rs @@ -23,11 +23,10 @@ use crate::{ blocks::block_header::{hash_serializer, BlockHeader}, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Error, Formatter}; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlindingFactor, BlockHash}; use tari_crypto::tari_utilities::hex::Hex; /// The NewBlockHeaderTemplate is used for the construction of a new mineable block. It contains all the metadata for diff --git a/base_layer/core/src/chain_storage/accumulated_data.rs b/base_layer/core/src/chain_storage/accumulated_data.rs index 40f0bac01b..eb6d50e499 100644 --- a/base_layer/core/src/chain_storage/accumulated_data.rs +++ b/base_layer/core/src/chain_storage/accumulated_data.rs @@ -25,10 +25,7 @@ use crate::{ chain_storage::ChainStorageError, proof_of_work::{AchievedTargetDifficulty, Difficulty, PowAlgorithm}, tari_utilities::Hashable, - transactions::{ - aggregated_body::AggregateBody, - types::{BlindingFactor, Commitment, HashOutput}, - }, + transactions::aggregated_body::AggregateBody, }; use croaring::Bitmap; use log::*; @@ -47,6 +44,7 @@ use std::{ fmt::{Display, Formatter}, sync::Arc, }; +use tari_common_types::types::{BlindingFactor, Commitment, HashOutput}; use tari_crypto::tari_utilities::hex::Hex; use tari_mmr::{pruned_hashset::PrunedHashSet, ArrayLike}; diff --git a/base_layer/core/src/chain_storage/async_db.rs b/base_layer/core/src/chain_storage/async_db.rs index 14e4f0e5cd..ca1aea97c5 100644 --- a/base_layer/core/src/chain_storage/async_db.rs +++ b/base_layer/core/src/chain_storage/async_db.rs @@ -42,16 +42,16 @@ use crate::{ common::rolling_vec::RollingVec, proof_of_work::{PowAlgorithm, TargetDifficultyWindow}, tari_utilities::epoch_time::EpochTime, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{mem, ops::RangeBounds, sync::Arc, time::Instant}; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, Commitment, HashOutput, Signature}, +}; use tari_mmr::pruned_hashset::PrunedHashSet; const LOG_TARGET: &str = "c::bn::async_db"; diff --git a/base_layer/core/src/chain_storage/blockchain_backend.rs b/base_layer/core/src/chain_storage/blockchain_backend.rs index f5d3b6ad36..505d25dda7 100644 --- a/base_layer/core/src/chain_storage/blockchain_backend.rs +++ b/base_layer/core/src/chain_storage/blockchain_backend.rs @@ -14,13 +14,13 @@ use crate::{ HorizonData, MmrTree, }, - transactions::{ - transaction::{TransactionInput, TransactionKernel}, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::{TransactionInput, TransactionKernel}, }; use croaring::Bitmap; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{Commitment, HashOutput, Signature}, +}; use tari_mmr::Hash; /// Identify behaviour for Blockchain database backends. Implementations must support `Send` and `Sync` so that diff --git a/base_layer/core/src/chain_storage/blockchain_database.rs b/base_layer/core/src/chain_storage/blockchain_database.rs index 867e7877e1..f2bf132cf4 100644 --- a/base_layer/core/src/chain_storage/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/blockchain_database.rs @@ -46,10 +46,7 @@ use crate::{ consensus::{chain_strength_comparer::ChainStrengthComparer, ConsensusConstants, ConsensusManager}, proof_of_work::{monero_rx::MoneroPowData, PowAlgorithm, TargetDifficultyWindow}, tari_utilities::epoch_time::EpochTime, - transactions::{ - transaction::TransactionKernel, - types::{Commitment, HashDigest, HashOutput, Signature}, - }, + transactions::transaction::TransactionKernel, validation::{DifficultyCalculator, HeaderValidation, OrphanValidation, PostOrphanBodyValidation, ValidationError}, }; use croaring::Bitmap; @@ -64,7 +61,10 @@ use std::{ sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, time::Instant, }; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, Commitment, HashDigest, HashOutput, Signature}, +}; use tari_crypto::tari_utilities::{hex::Hex, ByteArray, Hashable}; use tari_mmr::{MerkleMountainRange, MutableMmr}; use uint::static_assertions::_core::ops::RangeBounds; diff --git a/base_layer/core/src/chain_storage/db_transaction.rs b/base_layer/core/src/chain_storage/db_transaction.rs index dfe5947543..0d9736c39b 100644 --- a/base_layer/core/src/chain_storage/db_transaction.rs +++ b/base_layer/core/src/chain_storage/db_transaction.rs @@ -22,10 +22,7 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{error::ChainStorageError, ChainBlock, ChainHeader, MmrTree}, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{Commitment, HashOutput}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use std::{ @@ -33,7 +30,7 @@ use std::{ fmt::{Display, Error, Formatter}, sync::Arc, }; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, Commitment, HashOutput}; use tari_crypto::tari_utilities::{ hex::{to_hex, Hex}, Hashable, diff --git a/base_layer/core/src/chain_storage/historical_block.rs b/base_layer/core/src/chain_storage/historical_block.rs index 1188f7f27d..99fd45335f 100644 --- a/base_layer/core/src/chain_storage/historical_block.rs +++ b/base_layer/core/src/chain_storage/historical_block.rs @@ -23,10 +23,10 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockHeaderAccumulatedData, ChainBlock, ChainStorageError}, - transactions::types::HashOutput, }; use serde::{Deserialize, Serialize}; use std::{fmt, fmt::Display, sync::Arc}; +use tari_common_types::types::HashOutput; use tari_crypto::tari_utilities::hex::Hex; /// The representation of a historical block in the blockchain. It is essentially identical to a protocol-defined diff --git a/base_layer/core/src/chain_storage/horizon_data.rs b/base_layer/core/src/chain_storage/horizon_data.rs index 1e6f542142..6213d490f3 100644 --- a/base_layer/core/src/chain_storage/horizon_data.rs +++ b/base_layer/core/src/chain_storage/horizon_data.rs @@ -1,5 +1,3 @@ -use crate::transactions::types::Commitment; - // Copyright 2021. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -22,6 +20,7 @@ use crate::transactions::types::Commitment; // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use serde::{Deserialize, Serialize}; +use tari_common_types::types::Commitment; use tari_crypto::tari_utilities::ByteArray; #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs index 069765edbb..e49aee1fed 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs @@ -76,7 +76,6 @@ use crate::{ transactions::{ aggregated_body::AggregateBody, transaction::{TransactionInput, TransactionKernel, TransactionOutput}, - types::{Commitment, HashDigest, HashOutput, Signature}, }, }; use croaring::Bitmap; @@ -87,7 +86,7 @@ use serde::{Deserialize, Serialize}; use std::{convert::TryFrom, fmt, fs, fs::File, ops::Deref, path::Path, sync::Arc, time::Instant}; use tari_common_types::{ chain_metadata::ChainMetadata, - types::{BlockHash, BLOCK_HASH_LENGTH}, + types::{BlockHash, Commitment, HashDigest, HashOutput, Signature, BLOCK_HASH_LENGTH}, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex, ByteArray}; use tari_mmr::{pruned_hashset::PrunedHashSet, Hash, MerkleMountainRange, MutableMmr}; diff --git a/base_layer/core/src/chain_storage/lmdb_db/mod.rs b/base_layer/core/src/chain_storage/lmdb_db/mod.rs index 785b0363ee..f97c1c4878 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/mod.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/mod.rs @@ -24,12 +24,10 @@ mod lmdb; #[allow(clippy::module_inception)] mod lmdb_db; -use crate::transactions::{ - transaction::{TransactionInput, TransactionKernel, TransactionOutput}, - types::HashOutput, -}; +use crate::transactions::transaction::{TransactionInput, TransactionKernel, TransactionOutput}; pub use lmdb_db::{create_lmdb_database, create_recovery_lmdb_database, LMDBDatabase}; use serde::{Deserialize, Serialize}; +use tari_common_types::types::HashOutput; pub const LMDB_DB_METADATA: &str = "metadata"; pub const LMDB_DB_HEADERS: &str = "headers"; diff --git a/base_layer/core/src/chain_storage/pruned_output.rs b/base_layer/core/src/chain_storage/pruned_output.rs index 957c0e8c86..8c753f30a5 100644 --- a/base_layer/core/src/chain_storage/pruned_output.rs +++ b/base_layer/core/src/chain_storage/pruned_output.rs @@ -19,7 +19,8 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{transaction::TransactionOutput, types::HashOutput}; +use crate::transactions::transaction::TransactionOutput; +use tari_common_types::types::HashOutput; #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq)] diff --git a/base_layer/core/src/consensus/consensus_constants.rs b/base_layer/core/src/consensus/consensus_constants.rs index 52c0620e8c..85e5027479 100644 --- a/base_layer/core/src/consensus/consensus_constants.rs +++ b/base_layer/core/src/consensus/consensus_constants.rs @@ -356,6 +356,38 @@ impl ConsensusConstants { }] } + pub fn igor() -> Vec { + let mut algos = HashMap::new(); + // seting sha3/monero to 40/60 split + algos.insert(PowAlgorithm::Sha3, PowAlgorithmConstants { + max_target_time: 1800, + min_difficulty: 60_000_000.into(), + max_difficulty: u64::MAX.into(), + target_time: 300, + }); + algos.insert(PowAlgorithm::Monero, PowAlgorithmConstants { + max_target_time: 1200, + min_difficulty: 60_000.into(), + max_difficulty: u64::MAX.into(), + target_time: 200, + }); + vec![ConsensusConstants { + effective_from_height: 0, + coinbase_lock_height: 6, + blockchain_version: 1, + future_time_limit: 540, + difficulty_block_window: 90, + max_block_transaction_weight: 19500, + median_timestamp_count: 11, + emission_initial: 5_538_846_115 * uT, + emission_decay: &EMISSION_DECAY, + emission_tail: 100.into(), + max_randomx_seed_height: std::u64::MAX, + proof_of_work: algos, + faucet_value: (5000 * 4000) * T, + }] + } + pub fn mainnet() -> Vec { // Note these values are all placeholders for final values let difficulty_block_window = 90; diff --git a/base_layer/core/src/consensus/consensus_manager.rs b/base_layer/core/src/consensus/consensus_manager.rs index 0663d153ca..5ca1610e03 100644 --- a/base_layer/core/src/consensus/consensus_manager.rs +++ b/base_layer/core/src/consensus/consensus_manager.rs @@ -23,6 +23,7 @@ use crate::{ blocks::{ genesis_block::{ + get_igor_genesis_block, get_mainnet_genesis_block, get_ridcully_genesis_block, get_stibbons_genesis_block, @@ -82,6 +83,7 @@ impl ConsensusManager { .gen_block .clone() .unwrap_or_else(get_weatherwax_genesis_block), + Network::Igor => get_igor_genesis_block(), } } diff --git a/base_layer/core/src/consensus/network.rs b/base_layer/core/src/consensus/network.rs index 55e17e3c98..0e2e598a0a 100644 --- a/base_layer/core/src/consensus/network.rs +++ b/base_layer/core/src/consensus/network.rs @@ -36,6 +36,7 @@ impl NetworkConsensus { Stibbons => ConsensusConstants::stibbons(), Weatherwax => ConsensusConstants::weatherwax(), LocalNet => ConsensusConstants::localnet(), + Igor => ConsensusConstants::igor(), } } diff --git a/base_layer/core/src/lib.rs b/base_layer/core/src/lib.rs index 2e4bdc2f49..5a93b73576 100644 --- a/base_layer/core/src/lib.rs +++ b/base_layer/core/src/lib.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work +// Needed to make tokio::select! work #![recursion_limit = "512"] #![feature(shrink_to)] // #![cfg_attr(not(debug_assertions), deny(unused_variables))] diff --git a/base_layer/core/src/mempool/async_mempool.rs b/base_layer/core/src/mempool/async_mempool.rs index 04c4314c26..9a9b0a5f8a 100644 --- a/base_layer/core/src/mempool/async_mempool.rs +++ b/base_layer/core/src/mempool/async_mempool.rs @@ -23,9 +23,10 @@ use crate::{ blocks::Block, mempool::{error::MempoolError, Mempool, StateResponse, StatsResponse, TxStorageResponse}, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use std::sync::Arc; +use tari_common_types::types::Signature; macro_rules! make_async { ($fn:ident($($param1:ident:$ptype1:ty,$param2:ident:$ptype2:ty),+) -> $rtype:ty) => { diff --git a/base_layer/core/src/mempool/mempool.rs b/base_layer/core/src/mempool/mempool.rs index 97b3ceac15..865ca7b980 100644 --- a/base_layer/core/src/mempool/mempool.rs +++ b/base_layer/core/src/mempool/mempool.rs @@ -30,10 +30,11 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, validation::MempoolTransactionValidation, }; use std::sync::{Arc, RwLock}; +use tari_common_types::types::Signature; /// The Mempool consists of an Unconfirmed Transaction Pool, Pending Pool, Orphan Pool and Reorg Pool and is responsible /// for managing and maintaining all unconfirmed transactions have not yet been included in a block, and transactions diff --git a/base_layer/core/src/mempool/mempool_storage.rs b/base_layer/core/src/mempool/mempool_storage.rs index b5c0a800c7..d2ccb38dbd 100644 --- a/base_layer/core/src/mempool/mempool_storage.rs +++ b/base_layer/core/src/mempool/mempool_storage.rs @@ -31,11 +31,12 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, validation::{MempoolTransactionValidation, ValidationError}, }; use log::*; use std::sync::Arc; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; pub const LOG_TARGET: &str = "c::mp::mempool_storage"; diff --git a/base_layer/core/src/mempool/mod.rs b/base_layer/core/src/mempool/mod.rs index 1374d4f08e..afe8d5a69c 100644 --- a/base_layer/core/src/mempool/mod.rs +++ b/base_layer/core/src/mempool/mod.rs @@ -72,9 +72,10 @@ mod sync_protocol; #[cfg(feature = "base_node")] pub use sync_protocol::MempoolSyncInitializer; -use crate::transactions::{transaction::Transaction, types::Signature}; +use crate::transactions::transaction::Transaction; use core::fmt::{Display, Error, Formatter}; use serde::{Deserialize, Serialize}; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::hex::Hex; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] diff --git a/base_layer/core/src/mempool/priority/prioritized_transaction.rs b/base_layer/core/src/mempool/priority/prioritized_transaction.rs index 1080536d7b..cc82531461 100644 --- a/base_layer/core/src/mempool/priority/prioritized_transaction.rs +++ b/base_layer/core/src/mempool/priority/prioritized_transaction.rs @@ -20,11 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - mempool::priority::PriorityError, - transactions::{transaction::Transaction, types::HashOutput}, -}; +use crate::{mempool::priority::PriorityError, transactions::transaction::Transaction}; use std::sync::Arc; +use tari_common_types::types::HashOutput; use tari_crypto::tari_utilities::message_format::MessageFormat; /// Create a unique unspent transaction priority based on the transaction fee, maturity of the oldest input UTXO and the diff --git a/base_layer/core/src/mempool/proto/state_response.rs b/base_layer/core/src/mempool/proto/state_response.rs index 80ab03dd0a..8b3af21ac0 100644 --- a/base_layer/core/src/mempool/proto/state_response.rs +++ b/base_layer/core/src/mempool/proto/state_response.rs @@ -23,10 +23,8 @@ use crate::mempool::{proto::mempool::StateResponse as ProtoStateResponse, StateResponse}; use std::convert::{TryFrom, TryInto}; // use crate::transactions::proto::types::Signature as ProtoSignature; -use crate::{ - mempool::proto::mempool::Signature as ProtoSignature, - transactions::types::{PrivateKey, PublicKey, Signature}, -}; +use crate::mempool::proto::mempool::Signature as ProtoSignature; +use tari_common_types::types::{PrivateKey, PublicKey, Signature}; use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; //---------------------------------- Signature --------------------------------------------// diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs index 5626242474..5e0f12856e 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs @@ -26,7 +26,7 @@ use crate::{ consts::{MEMPOOL_REORG_POOL_CACHE_TTL, MEMPOOL_REORG_POOL_STORAGE_CAPACITY}, reorg_pool::{ReorgPoolError, ReorgPoolStorage}, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use serde::{Deserialize, Serialize}; use std::{ @@ -34,6 +34,7 @@ use std::{ time::Duration, }; use tari_common::configuration::seconds; +use tari_common_types::types::Signature; /// Configuration for the ReorgPool #[derive(Clone, Copy, Deserialize, Serialize)] diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs index 71fbb468ff..c178c33545 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs @@ -20,13 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - blocks::Block, - mempool::reorg_pool::reorg_pool::ReorgPoolConfig, - transactions::{transaction::Transaction, types::Signature}, -}; +use crate::{blocks::Block, mempool::reorg_pool::reorg_pool::ReorgPoolConfig, transactions::transaction::Transaction}; use log::*; use std::sync::Arc; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::hex::Hex; use ttl_cache::TtlCache; diff --git a/base_layer/core/src/mempool/rpc/test.rs b/base_layer/core/src/mempool/rpc/test.rs index 64ad84f122..a9cbb2ee49 100644 --- a/base_layer/core/src/mempool/rpc/test.rs +++ b/base_layer/core/src/mempool/rpc/test.rs @@ -43,7 +43,7 @@ mod get_stats { use super::*; use crate::mempool::{MempoolService, StatsResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_stats() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_stats = StatsResponse { @@ -66,7 +66,7 @@ mod get_state { use super::*; use crate::mempool::{MempoolService, StateResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_state() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_state = StateResponse { @@ -94,7 +94,7 @@ mod get_tx_state_by_excess_sig { use tari_crypto::ristretto::{RistrettoPublicKey, RistrettoSecretKey}; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_storage_status() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -116,7 +116,7 @@ mod get_tx_state_by_excess_sig { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_signature() { let (service, _, req_mock, _tmpdir) = setup(); let status = service @@ -139,7 +139,7 @@ mod submit_transaction { use tari_crypto::ristretto::RistrettoSecretKey; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_submits_transaction() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -166,7 +166,7 @@ mod submit_transaction { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_transaction() { let (service, _, req_mock, _tmpdir) = setup(); let status = service diff --git a/base_layer/core/src/mempool/service/handle.rs b/base_layer/core/src/mempool/service/handle.rs index 662e411f15..6eebf2b958 100644 --- a/base_layer/core/src/mempool/service/handle.rs +++ b/base_layer/core/src/mempool/service/handle.rs @@ -28,8 +28,9 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; +use tari_common_types::types::Signature; use tari_service_framework::{reply_channel::TrySenderService, Service}; #[derive(Clone)] diff --git a/base_layer/core/src/mempool/service/inbound_handlers.rs b/base_layer/core/src/mempool/service/inbound_handlers.rs index a4f9aa1aee..7f3f90c44a 100644 --- a/base_layer/core/src/mempool/service/inbound_handlers.rs +++ b/base_layer/core/src/mempool/service/inbound_handlers.rs @@ -122,7 +122,7 @@ impl MempoolInboundHandlers { if tx_storage.is_stored() { debug!( target: LOG_TARGET, - "Mempool already has transaction: {}", kernel_excess_sig + "Mempool already has transaction: {}.", kernel_excess_sig ); return Ok(tx_storage); } diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index cb57d58e94..a295daf96a 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -37,7 +37,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -54,7 +54,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::mempool_service::initializer"; const SUBSCRIPTION_LABEL: &str = "Mempool"; @@ -148,7 +148,7 @@ impl ServiceInitializer for MempoolServiceInitializer { let mempool_handle = MempoolHandle::new(request_sender); context.register_handle(mempool_handle); - let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded(); + let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded_channel(); let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (mempool_state_event_publisher, _) = broadcast::channel(100); @@ -167,7 +167,7 @@ impl ServiceInitializer for MempoolServiceInitializer { context.register_handle(outbound_mp_interface); context.register_handle(local_mp_interface); - context.spawn_when_ready(move |handles| async move { + context.spawn_until_shutdown(move |handles| { let outbound_message_service = handles.expect_handle::().outbound_requester(); let state_machine = handles.expect_handle::(); let base_node = handles.expect_handle::(); @@ -182,11 +182,7 @@ impl ServiceInitializer for MempoolServiceInitializer { block_event_stream: base_node.get_block_event_stream(), request_receiver, }; - let service = - MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams); - futures::pin_mut!(service); - future::select(service, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "Mempool Service shutdown"); + MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams) }); Ok(()) diff --git a/base_layer/core/src/mempool/service/local_service.rs b/base_layer/core/src/mempool/service/local_service.rs index c58af9912d..05f3ac6779 100644 --- a/base_layer/core/src/mempool/service/local_service.rs +++ b/base_layer/core/src/mempool/service/local_service.rs @@ -28,8 +28,9 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; +use tari_common_types::types::Signature; use tari_service_framework::{reply_channel::SenderService, Service}; use tokio::sync::broadcast; @@ -146,7 +147,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); @@ -157,7 +158,7 @@ mod test { assert_eq!(stats, request_stats()); } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats_from_multiple() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); diff --git a/base_layer/core/src/mempool/service/outbound_interface.rs b/base_layer/core/src/mempool/service/outbound_interface.rs index 87cda226f3..8606a8ce02 100644 --- a/base_layer/core/src/mempool/service/outbound_interface.rs +++ b/base_layer/core/src/mempool/service/outbound_interface.rs @@ -26,12 +26,13 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; -use futures::channel::mpsc::UnboundedSender; use log::*; +use tari_common_types::types::Signature; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::mp::service::outbound_interface"; @@ -71,15 +72,13 @@ impl OutboundMempoolServiceInterface { transaction: Transaction, exclude_peers: Vec, ) -> Result<(), MempoolServiceError> { - self.tx_sender - .unbounded_send((transaction, exclude_peers)) - .or_else(|e| { - { - error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); - Err(e) - } - .map_err(|_| MempoolServiceError::BroadcastFailed) - }) + self.tx_sender.send((transaction, exclude_peers)).or_else(|e| { + { + error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); + Err(e) + } + .map_err(|_| MempoolServiceError::BroadcastFailed) + }) } /// Check if the specified transaction is stored in the mempool of a remote base node. diff --git a/base_layer/core/src/mempool/service/request.rs b/base_layer/core/src/mempool/service/request.rs index 8437b84a18..a6d6910024 100644 --- a/base_layer/core/src/mempool/service/request.rs +++ b/base_layer/core/src/mempool/service/request.rs @@ -20,10 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{transaction::Transaction, types::Signature}; +use crate::transactions::transaction::Transaction; use core::fmt::{Display, Error, Formatter}; use serde::{Deserialize, Serialize}; -use tari_common_types::waiting_requests::RequestKey; +use tari_common_types::{types::Signature, waiting_requests::RequestKey}; use tari_crypto::tari_utilities::hex::Hex; /// API Request enum for Mempool requests. diff --git a/base_layer/core/src/mempool/service/service.rs b/base_layer/core/src/mempool/service/service.rs index b8ee487b9c..137dd9d0a0 100644 --- a/base_layer/core/src/mempool/service/service.rs +++ b/base_layer/core/src/mempool/service/service.rs @@ -38,13 +38,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{ - channel::{mpsc, oneshot::Sender as OneshotSender}, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -58,7 +52,10 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot::Sender as OneshotSender}, + task, +}; const LOG_TARGET: &str = "c::mempool::service::service"; @@ -118,7 +115,7 @@ impl MempoolService { { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let mut outbound_tx_stream = streams.outbound_tx_stream.fuse(); + let mut outbound_tx_stream = streams.outbound_tx_stream; let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); let inbound_response_stream = streams.inbound_response_stream.fuse(); @@ -127,70 +124,70 @@ impl MempoolService { pin_mut!(inbound_transaction_stream); let local_request_stream = streams.local_request_stream.fuse(); pin_mut!(local_request_stream); - let mut block_event_stream = streams.block_event_stream.fuse(); + let mut block_event_stream = streams.block_event_stream; let mut timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Mempool Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Mempool Service initialized without timeout_receiver_stream"); let mut request_receiver = streams.request_receiver; loop { - futures::select! { + tokio::select! { // Requests sent from the handle - request = request_receiver.select_next_some() => { + Some(request) = request_receiver.next() => { let (request, reply) = request.split(); let _ = reply.send(self.handle_request(request).await); }, // Outbound request messages from the OutboundMempoolServiceInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound tx messages from the OutboundMempoolServiceInterface - (txn, excluded_peers) = outbound_tx_stream.select_next_some() => { + Some((txn, excluded_peers)) = outbound_tx_stream.recv() => { self.spawn_handle_outbound_tx(txn, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Incoming transaction messages from the Comms layer - transaction_msg = inbound_transaction_stream.select_next_some() => { + Some(transaction_msg) = inbound_transaction_stream.next() => { self.spawn_handle_incoming_tx(transaction_msg).await; } // Incoming local request messages from the LocalMempoolServiceInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Block events from local Base Node. - block_event = block_event_stream.select_next_some() => { + block_event = block_event_stream.recv() => { if let Ok(block_event) = block_event { self.spawn_handle_block_event(block_event); } }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, - complete => { + else => { info!(target: LOG_TARGET, "Mempool service shutting down"); break; } } } + Ok(()) } @@ -490,7 +487,7 @@ async fn handle_outbound_tx( exclude_peers: Vec, ) -> Result<(), MempoolServiceError> { let result = outbound_message_service - .propagate( + .flood( NodeDestination::Unknown, OutboundEncryption::ClearText, exclude_peers, @@ -506,9 +503,9 @@ async fn handle_outbound_tx( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/mempool/sync_protocol/initializer.rs b/base_layer/core/src/mempool/sync_protocol/initializer.rs index df40de648a..9af871121d 100644 --- a/base_layer/core/src/mempool/sync_protocol/initializer.rs +++ b/base_layer/core/src/mempool/sync_protocol/initializer.rs @@ -28,13 +28,17 @@ use crate::{ MempoolServiceConfig, }, }; -use futures::channel::mpsc; +use log::*; +use std::time::Duration; use tari_comms::{ connectivity::ConnectivityRequester, protocol::{ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolNotification}, Substream, }; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::{sync::mpsc, time::sleep}; + +const LOG_TARGET: &str = "c::mempool::sync_protocol"; pub struct MempoolSyncInitializer { config: MempoolServiceConfig, @@ -70,17 +74,31 @@ impl ServiceInitializer for MempoolSyncInitializer { let mempool = self.mempool.clone(); let notif_rx = self.notif_rx.take().unwrap(); - context.spawn_when_ready(move |handles| { + context.spawn_until_shutdown(move |handles| async move { let state_machine = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - MempoolSyncProtocol::new( - config, - notif_rx, - connectivity.get_event_subscription(), - mempool, - Some(state_machine), - ) - .run() + // Ensure that we get an subscription ASAP so that we don't miss any connectivity events + let connectivity_event_subscription = connectivity.get_event_subscription(); + + let mut status_watch = state_machine.get_status_info_watch(); + if !status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Waiting for node to bootstrap..."); + while status_watch.changed().await.is_ok() { + if status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Node bootstrapped. Starting mempool sync protocol"); + break; + } + trace!( + target: LOG_TARGET, + "Mempool sync still on hold, waiting for bootstrap to finish", + ); + sleep(Duration::from_secs(1)).await; + } + } + + MempoolSyncProtocol::new(config, notif_rx, connectivity_event_subscription, mempool) + .run() + .await; }); Ok(()) diff --git a/base_layer/core/src/mempool/sync_protocol/mod.rs b/base_layer/core/src/mempool/sync_protocol/mod.rs index 08219f2a8d..dc133f744d 100644 --- a/base_layer/core/src/mempool/sync_protocol/mod.rs +++ b/base_layer/core/src/mempool/sync_protocol/mod.rs @@ -73,12 +73,11 @@ mod initializer; pub use initializer::MempoolSyncInitializer; use crate::{ - base_node::StateMachineHandle, mempool::{async_mempool, proto, Mempool, MempoolServiceConfig}, proto as shared_proto, transactions::transaction::Transaction, }; -use futures::{stream, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{stream, SinkExt, Stream, StreamExt}; use log::*; use prost::Message; use std::{ @@ -88,7 +87,6 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx}, @@ -101,7 +99,11 @@ use tari_comms::{ PeerConnection, }; use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tokio::{sync::Semaphore, task}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Semaphore, + task, +}; const MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB const LOG_TARGET: &str = "c::mempool::sync_protocol"; @@ -111,11 +113,10 @@ pub static MEMPOOL_SYNC_PROTOCOL: Bytes = Bytes::from_static(b"t/mempool-sync/1" pub struct MempoolSyncProtocol { config: MempoolServiceConfig, protocol_notifier: ProtocolNotificationRx, - connectivity_events: Fuse, + connectivity_events: ConnectivityEventRx, mempool: Mempool, num_synched: Arc, permits: Arc, - state_machine: Option, } impl MempoolSyncProtocol @@ -126,54 +127,34 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static protocol_notifier: ProtocolNotificationRx, connectivity_events: ConnectivityEventRx, mempool: Mempool, - state_machine: Option, ) -> Self { Self { config, protocol_notifier, - connectivity_events: connectivity_events.fuse(), + connectivity_events, mempool, num_synched: Arc::new(AtomicUsize::new(0)), permits: Arc::new(Semaphore::new(1)), - state_machine, } } pub async fn run(mut self) { info!(target: LOG_TARGET, "Mempool protocol handler has started"); - while let Some(ref v) = self.state_machine { - let status_watch = v.get_status_info_watch(); - if (*status_watch.borrow()).bootstrapped { - break; - } - trace!( - target: LOG_TARGET, - "Mempool sync still on hold, waiting for bootstrap to finish", - ); - tokio::time::delay_for(Duration::from_secs(1)).await; - } + loop { - futures::select! { - event = self.connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event).await; - } + tokio::select! { + Ok(event) = self.connectivity_events.recv() => { + self.handle_connectivity_event(event).await; }, - notif = self.protocol_notifier.select_next_some() => { + Some(notif) = self.protocol_notifier.recv() => { self.handle_protocol_notification(notif); } - - // protocol_notifier and connectivity_events are closed - complete => { - info!(target: LOG_TARGET, "Mempool protocol handler is shutting down"); - break; - } } } } - async fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + async fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { match event { // If this node is connecting to a peer ConnectivityEvent::PeerConnected(conn) if conn.direction().is_outbound() => { diff --git a/base_layer/core/src/mempool/sync_protocol/test.rs b/base_layer/core/src/mempool/sync_protocol/test.rs index dd77fe3c70..30f4b096d2 100644 --- a/base_layer/core/src/mempool/sync_protocol/test.rs +++ b/base_layer/core/src/mempool/sync_protocol/test.rs @@ -30,7 +30,7 @@ use crate::{ transactions::{helpers::create_tx, tari_amount::uT, transaction::Transaction}, validation::mocks::MockValidator, }; -use futures::{channel::mpsc, Sink, SinkExt, Stream, StreamExt}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use std::{fmt, io, iter::repeat_with, sync::Arc}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventTx}, @@ -44,7 +44,10 @@ use tari_comms::{ BytesMut, }; use tari_crypto::tari_utilities::ByteArray; -use tokio::{sync::broadcast, task}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; pub fn create_transactions(n: usize) -> Vec { repeat_with(|| { @@ -82,7 +85,6 @@ fn setup( protocol_notif_rx, connectivity_events_rx, mempool.clone(), - None, ); task::spawn(protocol.run()); @@ -90,7 +92,7 @@ fn setup( (protocol_notif_tx, connectivity_events_tx, mempool, transactions) } -#[tokio_macros::test_basic] +#[tokio::test] async fn empty_set() { let (_, connectivity_events_tx, mempool1, _) = setup(0); @@ -101,7 +103,7 @@ async fn empty_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -120,7 +122,7 @@ async fn empty_set() { assert_eq!(transactions.len(), 0); } -#[tokio_macros::test_basic] +#[tokio::test] async fn synchronise() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(5); @@ -131,7 +133,7 @@ async fn synchronise() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -154,7 +156,7 @@ async fn synchronise() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn duplicate_set() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(2); @@ -165,7 +167,7 @@ async fn duplicate_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -189,9 +191,9 @@ async fn duplicate_set() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -225,9 +227,9 @@ async fn responder() { // this. } -#[tokio_macros::test_basic] +#[tokio::test] async fn initiator_messages() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -260,7 +262,7 @@ async fn initiator_messages() { assert_eq!(indexes.indexes, [0, 1]); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder_messages() { let (_, connectivity_events_tx, _, transactions1) = setup(1); @@ -271,7 +273,7 @@ async fn responder_messages() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); diff --git a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs index db4c1e95ec..0a9d7a4e64 100644 --- a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs +++ b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs @@ -20,6 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::tari_utilities::{hex::Hex, Hashable}; + use crate::{ blocks::Block, mempool::{ @@ -27,18 +36,9 @@ use crate::{ priority::{FeePriority, PrioritizedTransaction}, unconfirmed_pool::UnconfirmedPoolError, }, - transactions::{ - transaction::Transaction, - types::{HashOutput, Signature}, - }, + transactions::transaction::Transaction, }; -use log::*; -use serde::{Deserialize, Serialize}; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; -use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tari_common_types::types::{HashOutput, Signature}; pub const LOG_TARGET: &str = "c::mp::unconfirmed_pool::unconfirmed_pool_storage"; @@ -474,7 +474,8 @@ impl UnconfirmedPool { #[cfg(test)] mod test { - use super::*; + use tari_common::configuration::Network; + use crate::{ consensus::ConsensusManagerBuilder, test_helpers::create_orphan_block, @@ -483,12 +484,14 @@ mod test { helpers::{TestParams, UtxoTestParams}, tari_amount::MicroTari, transaction::KernelFeatures, - types::{CryptoFactories, HashDigest}, + CryptoFactories, SenderTransactionProtocol, }, tx, }; - use tari_common::configuration::Network; + use tari_common_types::types::HashDigest; + + use super::*; #[test] fn test_find_duplicate_input() { diff --git a/base_layer/core/src/proto/block.rs b/base_layer/core/src/proto/block.rs index 94a2f7fd20..50778f5a92 100644 --- a/base_layer/core/src/proto/block.rs +++ b/base_layer/core/src/proto/block.rs @@ -25,10 +25,9 @@ use crate::{ blocks::{Block, NewBlock, NewBlockHeaderTemplate, NewBlockTemplate}, chain_storage::{BlockHeaderAccumulatedData, HistoricalBlock}, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; use std::convert::{TryFrom, TryInto}; -use tari_common_types::types::BLOCK_HASH_LENGTH; +use tari_common_types::types::{BlindingFactor, BLOCK_HASH_LENGTH}; use tari_crypto::tari_utilities::ByteArray; //---------------------------------- Block --------------------------------------------// diff --git a/base_layer/core/src/proto/block_header.rs b/base_layer/core/src/proto/block_header.rs index 4258836106..a2ac77689e 100644 --- a/base_layer/core/src/proto/block_header.rs +++ b/base_layer/core/src/proto/block_header.rs @@ -25,9 +25,9 @@ use crate::{ blocks::BlockHeader, proof_of_work::{PowAlgorithm, ProofOfWork}, proto::utils::{datetime_to_timestamp, timestamp_to_datetime}, - transactions::types::BlindingFactor, }; use std::convert::TryFrom; +use tari_common_types::types::BlindingFactor; use tari_crypto::tari_utilities::ByteArray; //---------------------------------- BlockHeader --------------------------------------------// diff --git a/base_layer/core/src/proto/transaction.rs b/base_layer/core/src/proto/transaction.rs index cb96c11c66..d500157360 100644 --- a/base_layer/core/src/proto/transaction.rs +++ b/base_layer/core/src/proto/transaction.rs @@ -27,7 +27,6 @@ use crate::{ tari_utilities::convert::try_convert_all, transactions::{ aggregated_body::AggregateBody, - bullet_rangeproofs::BulletRangeProof, tari_amount::MicroTari, transaction::{ KernelFeatures, @@ -38,10 +37,10 @@ use crate::{ TransactionKernel, TransactionOutput, }, - types::{BlindingFactor, Commitment, PublicKey}, }, }; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::{BlindingFactor, BulletRangeProof, Commitment, PublicKey}; use tari_crypto::{ script::{ExecutionStack, TariScript}, tari_utilities::{ByteArray, ByteArrayError}, diff --git a/base_layer/core/src/proto/types_impls.rs b/base_layer/core/src/proto/types_impls.rs index e978d86724..8e865345c1 100644 --- a/base_layer/core/src/proto/types_impls.rs +++ b/base_layer/core/src/proto/types_impls.rs @@ -21,7 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::types as proto; -use crate::transactions::types::{ +use std::convert::TryFrom; +use tari_common_types::types::{ BlindingFactor, ComSignature, Commitment, @@ -30,7 +31,6 @@ use crate::transactions::types::{ PublicKey, Signature, }; -use std::convert::TryFrom; use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; //---------------------------------- Commitment --------------------------------------------// diff --git a/base_layer/core/src/test_helpers/blockchain.rs b/base_layer/core/src/test_helpers/blockchain.rs index e871ed10c9..d52df6deb7 100644 --- a/base_layer/core/src/test_helpers/blockchain.rs +++ b/base_layer/core/src/test_helpers/blockchain.rs @@ -20,6 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + fs, + ops::Deref, + path::{Path, PathBuf}, +}; + +use croaring::Bitmap; + +use tari_common::configuration::Network; +use tari_common_types::chain_metadata::ChainMetadata; +use tari_storage::lmdb_store::LMDBConfig; +use tari_test_utils::paths::create_temporary_data_path; + use crate::{ blocks::{genesis_block::get_weatherwax_genesis_block, Block, BlockHeader}, chain_storage::{ @@ -45,7 +58,7 @@ use crate::{ consensus::{chain_strength_comparer::ChainStrengthComparerBuilder, ConsensusConstantsBuilder, ConsensusManager}, transactions::{ transaction::{TransactionInput, TransactionKernel}, - types::{Commitment, CryptoFactories, HashOutput, Signature}, + CryptoFactories, }, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, @@ -53,16 +66,7 @@ use crate::{ DifficultyCalculator, }, }; -use croaring::Bitmap; -use std::{ - fs, - ops::Deref, - path::{Path, PathBuf}, -}; -use tari_common::configuration::Network; -use tari_common_types::chain_metadata::ChainMetadata; -use tari_storage::lmdb_store::LMDBConfig; -use tari_test_utils::paths::create_temporary_data_path; +use tari_common_types::types::{Commitment, HashOutput, Signature}; /// Create a new blockchain database containing no blocks. pub fn create_new_blockchain() -> BlockchainDatabase { @@ -111,7 +115,7 @@ pub fn create_store_with_consensus(rules: ConsensusManager) -> BlockchainDatabas let validators = Validators::new( BodyOnlyValidator::default(), MockValidator::new(true), - OrphanBlockValidator::new(rules.clone(), factories), + OrphanBlockValidator::new(rules.clone(), false, factories), ); create_store_with_consensus_and_validators(rules, validators) } diff --git a/base_layer/core/src/test_helpers/mod.rs b/base_layer/core/src/test_helpers/mod.rs index a1055b75da..bbcb23da67 100644 --- a/base_layer/core/src/test_helpers/mod.rs +++ b/base_layer/core/src/test_helpers/mod.rs @@ -23,7 +23,13 @@ //! Common test helper functions that are small and useful enough to be included in the main crate, rather than the //! integration test folder. -pub mod blockchain; +use std::{iter, path::Path, sync::Arc}; + +use rand::{distributions::Alphanumeric, Rng}; + +use tari_common::configuration::Network; +use tari_comms::PeerManager; +use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use crate::{ blocks::{Block, BlockHeader}, @@ -34,15 +40,12 @@ use crate::{ transactions::{ tari_amount::T, transaction::{Transaction, UnblindedOutput}, - types::CryptoFactories, CoinbaseBuilder, + CryptoFactories, }, }; -use rand::{distributions::Alphanumeric, Rng}; -use std::{iter, path::Path, sync::Arc}; -use tari_common::configuration::Network; -use tari_comms::PeerManager; -use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; + +pub mod blockchain; /// Create a partially constructed block using the provided set of transactions /// is chain_block, or rename it to `create_orphan_block` and drop the prev_block argument diff --git a/base_layer/core/src/transactions/aggregated_body.rs b/base_layer/core/src/transactions/aggregated_body.rs index d1dfb4f198..ac44b04b4d 100644 --- a/base_layer/core/src/transactions/aggregated_body.rs +++ b/base_layer/core/src/transactions/aggregated_body.rs @@ -1,3 +1,14 @@ +use std::fmt::{Display, Error, Formatter}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::PublicKey as PublicKeyTrait, + ristretto::pedersen::PedersenCommitment, + tari_utilities::hex::Hex, +}; + // Copyright 2019, The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -19,20 +30,14 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{ - fee::Fee, - tari_amount::*, - transaction::*, - types::{BlindingFactor, Commitment, CommitmentFactory, CryptoFactories, PrivateKey, PublicKey, RangeProofService}, -}; -use log::*; -use serde::{Deserialize, Serialize}; -use std::fmt::{Display, Error, Formatter}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::PublicKey as PublicKeyTrait, - ristretto::pedersen::PedersenCommitment, - tari_utilities::hex::Hex, +use crate::transactions::{crypto_factories::CryptoFactories, fee::Fee, tari_amount::*, transaction::*}; +use tari_common_types::types::{ + BlindingFactor, + Commitment, + CommitmentFactory, + PrivateKey, + PublicKey, + RangeProofService, }; pub const LOG_TARGET: &str = "c::tx::aggregated_body"; @@ -307,6 +312,7 @@ impl AggregateBody { &self, tx_offset: &BlindingFactor, script_offset: &BlindingFactor, + bypass_range_proof_verification: bool, total_reward: MicroTari, factories: &CryptoFactories, ) -> Result<(), TransactionError> { @@ -316,7 +322,9 @@ impl AggregateBody { self.verify_kernel_signatures()?; self.validate_kernel_sum(total_offset, &factories.commitment)?; - self.validate_range_proofs(&factories.range_proof)?; + if !bypass_range_proof_verification { + self.validate_range_proofs(&factories.range_proof)?; + } self.verify_metadata_signatures()?; self.validate_script_offset(script_offset_g, &factories.commitment) } diff --git a/base_layer/core/src/transactions/bullet_rangeproofs.rs b/base_layer/core/src/transactions/bullet_rangeproofs.rs index 9d96e2bb03..5ba0a05923 100644 --- a/base_layer/core/src/transactions/bullet_rangeproofs.rs +++ b/base_layer/core/src/transactions/bullet_rangeproofs.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::types::HashDigest; use digest::Digest; use serde::{ de::{self, Visitor}, @@ -30,6 +29,7 @@ use serde::{ Serializer, }; use std::fmt; +use tari_common_types::types::HashDigest; use tari_crypto::tari_utilities::{byte_array::*, hash::*, hex::*}; #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] diff --git a/base_layer/core/src/transactions/coinbase_builder.rs b/base_layer/core/src/transactions/coinbase_builder.rs index 52cb4ac0ba..5091844e6a 100644 --- a/base_layer/core/src/transactions/coinbase_builder.rs +++ b/base_layer/core/src/transactions/coinbase_builder.rs @@ -21,12 +21,23 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + inputs, + keys::{PublicKey as PK, SecretKey}, + script, + script::TariScript, +}; +use thiserror::Error; + use crate::{ consensus::{ emission::{Emission, EmissionSchedule}, ConsensusConstants, }, transactions::{ + crypto_factories::CryptoFactories, tari_amount::{uT, MicroTari}, transaction::{ KernelBuilder, @@ -38,18 +49,9 @@ use crate::{ UnblindedOutput, }, transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, - types::{BlindingFactor, CryptoFactories, PrivateKey, PublicKey, Signature}, }, }; -use rand::rngs::OsRng; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - inputs, - keys::{PublicKey as PK, SecretKey}, - script, - script::TariScript, -}; -use thiserror::Error; +use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey, Signature}; #[derive(Debug, Clone, Error, PartialEq)] pub enum CoinbaseBuildError { @@ -241,21 +243,24 @@ impl CoinbaseBuilder { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey as SecretKeyTrait}; + + use tari_common::configuration::Network; + use crate::{ consensus::{emission::Emission, ConsensusManager, ConsensusManagerBuilder}, transactions::{ coinbase_builder::CoinbaseBuildError, + crypto_factories::CryptoFactories, helpers::TestParams, tari_amount::uT, transaction::{KernelFeatures, OutputFeatures, OutputFlags, TransactionError}, transaction_protocol::RewindData, - types::{BlindingFactor, CryptoFactories, PrivateKey}, CoinbaseBuilder, }, }; - use rand::rngs::OsRng; - use tari_common::configuration::Network; - use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey as SecretKeyTrait}; + use tari_common_types::types::{BlindingFactor, PrivateKey}; fn get_builder() -> (CoinbaseBuilder, ConsensusManager, CryptoFactories) { let network = Network::LocalNet; @@ -520,6 +525,7 @@ mod test { tx.body.validate_internal_consistency( &BlindingFactor::default(), &PrivateKey::default(), + false, block_reward, &factories ), diff --git a/base_layer/core/src/transactions/crypto_factories.rs b/base_layer/core/src/transactions/crypto_factories.rs new file mode 100644 index 0000000000..86270dc42d --- /dev/null +++ b/base_layer/core/src/transactions/crypto_factories.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use tari_common_types::types::{CommitmentFactory, RangeProofService, MAX_RANGE_PROOF_RANGE}; + +/// A convenience struct wrapping cryptographic factories that are used through-out the rest of the code base +/// Uses Arcs internally so calling clone on this is cheap, no need to wrap this in an Arc +pub struct CryptoFactories { + pub commitment: Arc, + pub range_proof: Arc, +} + +impl Default for CryptoFactories { + /// Return a default set of crypto factories based on Pedersen commitments with G and H defined in + /// [pedersen.rs](/infrastructure/crypto/src/ristretto/pedersen.rs), and an associated range proof factory with a + /// range of `[0; 2^64)`. + fn default() -> Self { + CryptoFactories::new(MAX_RANGE_PROOF_RANGE) + } +} + +impl CryptoFactories { + /// Create a new set of crypto factories. + /// + /// ## Parameters + /// + /// * `max_proof_range`: Sets the the maximum value in range proofs, where `max = 2^max_proof_range` + pub fn new(max_proof_range: usize) -> Self { + let commitment = Arc::new(CommitmentFactory::default()); + let range_proof = Arc::new(RangeProofService::new(max_proof_range, &commitment).unwrap()); + Self { + commitment, + range_proof, + } + } +} + +/// Uses Arc's internally so calling clone on this is cheap, no need to wrap this in an Arc +impl Clone for CryptoFactories { + fn clone(&self) -> Self { + Self { + commitment: self.commitment.clone(), + range_proof: self.range_proof.clone(), + } + } +} diff --git a/base_layer/core/src/transactions/helpers.rs b/base_layer/core/src/transactions/helpers.rs index 8e6c4d4c7b..54f90cebb8 100644 --- a/base_layer/core/src/transactions/helpers.rs +++ b/base_layer/core/src/transactions/helpers.rs @@ -20,7 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use num::pow; +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + inputs, + keys::{PublicKey as PK, SecretKey}, + range_proof::RangeProofService, + script, + script::{ExecutionStack, TariScript}, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, tari_amount::MicroTari, transaction::{ @@ -34,21 +49,9 @@ use crate::transactions::{ UnblindedOutput, }, transaction_protocol::{build_challenge, TransactionMetadata}, - types::{Commitment, CommitmentFactory, CryptoFactories, PrivateKey, PublicKey, Signature}, SenderTransactionProtocol, }; -use num::pow; -use rand::rngs::OsRng; -use std::sync::Arc; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - inputs, - keys::{PublicKey as PK, SecretKey}, - range_proof::RangeProofService, - script, - script::{ExecutionStack, TariScript}, -}; +use tari_common_types::types::{Commitment, CommitmentFactory, PrivateKey, PublicKey, Signature}; pub fn create_test_input( amount: MicroTari, diff --git a/base_layer/core/src/transactions/mod.rs b/base_layer/core/src/transactions/mod.rs index 9d98bd394b..5653939bd4 100644 --- a/base_layer/core/src/transactions/mod.rs +++ b/base_layer/core/src/transactions/mod.rs @@ -1,20 +1,19 @@ pub mod aggregated_body; -pub mod bullet_rangeproofs; +mod crypto_factories; pub mod fee; pub mod tari_amount; pub mod transaction; #[allow(clippy::op_ref)] pub mod transaction_protocol; + +pub use crypto_factories::*; + pub mod types; // Re-export commonly used structs pub use transaction_protocol::{recipient::ReceiverTransactionProtocol, sender::SenderTransactionProtocol}; #[macro_use] pub mod helpers; -#[cfg(any(feature = "base_node", feature = "transactions"))] -mod coinbase_builder; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuildError; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuilder; +mod coinbase_builder; +pub use crate::transactions::coinbase_builder::{CoinbaseBuildError, CoinbaseBuilder}; diff --git a/base_layer/core/src/transactions/transaction.rs b/base_layer/core/src/transactions/transaction.rs index 81d972d90b..ce3e2c8aec 100644 --- a/base_layer/core/src/transactions/transaction.rs +++ b/base_layer/core/src/transactions/transaction.rs @@ -23,29 +23,6 @@ // Portions of this file were originally copyrighted (c) 2018 The Grin Developers, issued under the Apache License, // Version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0. -use crate::transactions::{ - aggregated_body::AggregateBody, - tari_amount::{uT, MicroTari}, - transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, - types::{ - BlindingFactor, - Challenge, - ComSignature, - Commitment, - CommitmentFactory, - CryptoFactories, - HashDigest, - MessageHash, - PrivateKey, - PublicKey, - RangeProof, - RangeProofService, - Signature, - }, -}; -use blake2::Digest; -use rand::rngs::OsRng; -use serde::{Deserialize, Serialize}; use std::{ cmp::{max, min, Ordering}, fmt, @@ -53,6 +30,10 @@ use std::{ hash::{Hash, Hasher}, ops::Add, }; + +use blake2::Digest; +use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::{PublicKey as PublicKeyTrait, SecretKey}, @@ -70,6 +51,27 @@ use tari_crypto::{ }; use thiserror::Error; +use crate::transactions::{ + aggregated_body::AggregateBody, + crypto_factories::CryptoFactories, + tari_amount::{uT, MicroTari}, + transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, +}; +use tari_common_types::types::{ + BlindingFactor, + Challenge, + ComSignature, + Commitment, + CommitmentFactory, + HashDigest, + MessageHash, + PrivateKey, + PublicKey, + RangeProof, + RangeProofService, + Signature, +}; + // Tx_weight(inputs(12,500), outputs(500), kernels(1)) = 19,003, still well enough below block weight of 19,500 pub const MAX_TRANSACTION_INPUTS: usize = 12_500; pub const MAX_TRANSACTION_OUTPUTS: usize = 500; @@ -1109,12 +1111,18 @@ impl Transaction { #[allow(clippy::erasing_op)] // This is for 0 * uT pub fn validate_internal_consistency( &self, + bypass_range_proof_verification: bool, factories: &CryptoFactories, reward: Option, ) -> Result<(), TransactionError> { let reward = reward.unwrap_or_else(|| 0 * uT); - self.body - .validate_internal_consistency(&self.offset, &self.script_offset, reward, factories) + self.body.validate_internal_consistency( + &self.offset, + &self.script_offset, + bypass_range_proof_verification, + reward, + factories, + ) } pub fn get_body(&self) -> &AggregateBody { @@ -1264,7 +1272,7 @@ impl TransactionBuilder { if let (Some(script_offset), Some(offset)) = (self.script_offset, self.offset) { let (i, o, k) = self.body.dissolve(); let tx = Transaction::new(i, o, k, offset, script_offset); - tx.validate_internal_consistency(factories, self.reward)?; + tx.validate_internal_consistency(true, factories, self.reward)?; Ok(tx) } else { Err(TransactionError::ValidationError( @@ -1289,24 +1297,26 @@ impl Default for TransactionBuilder { #[cfg(test)] mod test { - use super::*; + use rand::{self, rngs::OsRng}; + use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + ristretto::pedersen::PedersenCommitmentFactory, + script, + script::ExecutionStack, + }; + use crate::{ transactions::{ helpers, helpers::{TestParams, UtxoTestParams}, tari_amount::T, transaction::OutputFeatures, - types::{BlindingFactor, PrivateKey, PublicKey, RangeProof}, }, txn_schema, }; - use rand::{self, rngs::OsRng}; - use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - ristretto::pedersen::PedersenCommitmentFactory, - script, - script::ExecutionStack, - }; + use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; + + use super::*; #[test] fn input_and_output_hash_match() { @@ -1514,7 +1524,7 @@ mod test { let (tx, _, _) = helpers::create_tx(5000.into(), 15.into(), 1, 2, 1, 4); let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(false, &factories, None).is_ok()); } #[test] @@ -1527,7 +1537,7 @@ mod test { assert_eq!(tx.body.kernels().len(), 1); let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(false, &factories, None).is_ok()); let schema = txn_schema!(from: vec![outputs[1].clone()], to: vec![1 * T, 2 * T]); let (tx2, _outputs, _) = helpers::spend_utxos(schema); @@ -1558,10 +1568,12 @@ mod test { } // Validate basis transaction where cut-through has not been applied. - assert!(tx3.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx3.validate_internal_consistency(false, &factories, None).is_ok()); // tx3_cut_through has manual cut-through, it should not be possible so this should fail - assert!(tx3_cut_through.validate_internal_consistency(&factories, None).is_err()); + assert!(tx3_cut_through + .validate_internal_consistency(false, &factories, None) + .is_err()); } #[test] @@ -1598,7 +1610,7 @@ mod test { tx.body.inputs_mut()[0].input_data = stack; let factories = CryptoFactories::default(); - let err = tx.validate_internal_consistency(&factories, None).unwrap_err(); + let err = tx.validate_internal_consistency(false, &factories, None).unwrap_err(); assert!(matches!(err, TransactionError::InvalidSignatureError(_))); } diff --git a/base_layer/core/src/transactions/transaction_protocol/mod.rs b/base_layer/core/src/transactions/transaction_protocol/mod.rs index 2f66ef3643..5e140c4388 100644 --- a/base_layer/core/src/transactions/transaction_protocol/mod.rs +++ b/base_layer/core/src/transactions/transaction_protocol/mod.rs @@ -86,13 +86,11 @@ pub mod sender; pub mod single_receiver; pub mod transaction_initializer; -use crate::transactions::{ - tari_amount::*, - transaction::TransactionError, - types::{Challenge, MessageHash, PrivateKey, PublicKey}, -}; +use crate::transactions::{tari_amount::*, transaction::TransactionError}; use digest::Digest; use serde::{Deserialize, Serialize}; +use tari_common_types::types::{MessageHash, PrivateKey, PublicKey}; +use tari_comms::types::Challenge; use tari_crypto::{ range_proof::{RangeProofError, REWIND_USER_MESSAGE_LENGTH}, signatures::SchnorrSignatureError, diff --git a/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs b/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs index 699cf3145a..c149874ef3 100644 --- a/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs +++ b/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs @@ -22,8 +22,9 @@ use super::protocol as proto; -use crate::transactions::{transaction_protocol::recipient::RecipientSignedMessage, types::PublicKey}; +use crate::transactions::transaction_protocol::recipient::RecipientSignedMessage; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::PublicKey; use tari_crypto::tari_utilities::ByteArray; impl TryFrom for RecipientSignedMessage { diff --git a/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs b/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs index 22f0e59306..14c0f7ee2c 100644 --- a/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs @@ -28,8 +28,8 @@ use std::convert::{TryFrom, TryInto}; use tari_crypto::tari_utilities::ByteArray; // The generated _oneof_ enum -use crate::transactions::types::PublicKey; use proto::transaction_sender_message::Message as ProtoTxnSenderMessage; +use tari_common_types::types::PublicKey; use tari_crypto::script::TariScript; impl proto::TransactionSenderMessage { diff --git a/base_layer/core/src/transactions/transaction_protocol/recipient.rs b/base_layer/core/src/transactions/transaction_protocol/recipient.rs index e9f21c3778..2518f8f1de 100644 --- a/base_layer/core/src/transactions/transaction_protocol/recipient.rs +++ b/base_layer/core/src/transactions/transaction_protocol/recipient.rs @@ -20,7 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{collections::HashMap, fmt}; + +use serde::{Deserialize, Serialize}; + use crate::transactions::{ + crypto_factories::CryptoFactories, transaction::{OutputFeatures, TransactionOutput}, transaction_protocol::{ sender::{SingleRoundSenderData as SD, TransactionSenderMessage}, @@ -28,10 +33,8 @@ use crate::transactions::{ RewindData, TransactionProtocolError, }, - types::{CryptoFactories, MessageHash, PrivateKey, PublicKey, Signature}, }; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt}; +use tari_common_types::types::{MessageHash, PrivateKey, PublicKey, Signature}; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[allow(clippy::large_enum_variant)] @@ -202,9 +205,16 @@ impl ReceiverTransactionProtocol { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PK, SecretKey as SecretKeyTrait}, + }; + use crate::{ crypto::script::TariScript, transactions::{ + crypto_factories::CryptoFactories, helpers::TestParams, tari_amount::*, transaction::OutputFeatures, @@ -214,15 +224,10 @@ mod test { RewindData, TransactionMetadata, }, - types::{CryptoFactories, PrivateKey, PublicKey, Signature}, ReceiverTransactionProtocol, }, }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PK, SecretKey as SecretKeyTrait}, - }; + use tari_common_types::types::{PrivateKey, PublicKey, Signature}; #[test] fn single_round_recipient() { diff --git a/base_layer/core/src/transactions/transaction_protocol/sender.rs b/base_layer/core/src/transactions/transaction_protocol/sender.rs index 0341dcbda1..c91097d32a 100644 --- a/base_layer/core/src/transactions/transaction_protocol/sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/sender.rs @@ -20,7 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::fmt; + +use digest::Digest; +use serde::{Deserialize, Serialize}; +use tari_crypto::{ + keys::PublicKey as PublicKeyTrait, + ristretto::pedersen::{PedersenCommitment, PedersenCommitmentFactory}, + script::TariScript, + tari_utilities::ByteArray, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, tari_amount::*, transaction::{ KernelBuilder, @@ -42,17 +54,8 @@ use crate::transactions::{ TransactionMetadata, TransactionProtocolError as TPE, }, - types::{BlindingFactor, ComSignature, CryptoFactories, PrivateKey, PublicKey, RangeProofService, Signature}, -}; -use digest::Digest; -use serde::{Deserialize, Serialize}; -use std::fmt; -use tari_crypto::{ - keys::PublicKey as PublicKeyTrait, - ristretto::pedersen::{PedersenCommitment, PedersenCommitmentFactory}, - script::TariScript, - tari_utilities::ByteArray, }; +use tari_common_types::types::{BlindingFactor, ComSignature, PrivateKey, PublicKey, RangeProofService, Signature}; //---------------------------------------- Local Data types ----------------------------------------------------// @@ -562,7 +565,7 @@ impl SenderTransactionProtocol { } let transaction = result.unwrap(); let result = transaction - .validate_internal_consistency(factories, None) + .validate_internal_consistency(true, factories, None) .map_err(TPE::TransactionBuildError); if let Err(e) = result { self.state = SenderState::Failed(e.clone()); @@ -705,7 +708,20 @@ impl fmt::Display for SenderState { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + range_proof::RangeProofService, + ristretto::pedersen::PedersenCommitmentFactory, + script, + script::{ExecutionStack, TariScript}, + tari_utilities::{hex::Hex, ByteArray}, + }; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, helpers::{create_test_input, create_unblinded_output, TestParams}, tari_amount::*, @@ -716,19 +732,8 @@ mod test { RewindData, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey, PublicKey, RangeProof}, - }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - range_proof::RangeProofService, - ristretto::pedersen::PedersenCommitmentFactory, - script, - script::{ExecutionStack, TariScript}, - tari_utilities::{hex::Hex, ByteArray}, }; + use tari_common_types::types::{PrivateKey, PublicKey, RangeProof}; #[test] fn test_metadata_signature_finalize() { @@ -965,7 +970,10 @@ mod test { assert_eq!(tx.body.inputs().len(), 1); assert_eq!(tx.body.inputs()[0], utxo); assert_eq!(tx.body.outputs().len(), 2); - assert!(tx.clone().validate_internal_consistency(&factories, None).is_ok()); + assert!(tx + .clone() + .validate_internal_consistency(false, &factories, None) + .is_ok()); } #[test] diff --git a/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs b/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs index 5d60e64acc..7f6060fed6 100644 --- a/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs +++ b/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs @@ -20,7 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::PublicKey as PK, + range_proof::{RangeProofError, RangeProofService as RPS}, + tari_utilities::byte_array::ByteArray, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, transaction::{OutputFeatures, TransactionOutput}, transaction_protocol::{ build_challenge, @@ -29,14 +37,8 @@ use crate::transactions::{ RewindData, TransactionProtocolError as TPE, }, - types::{CryptoFactories, PrivateKey as SK, PublicKey, RangeProof, Signature}, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::PublicKey as PK, - range_proof::{RangeProofError, RangeProofService as RPS}, - tari_utilities::byte_array::ByteArray, }; +use tari_common_types::types::{PrivateKey as SK, PublicKey, RangeProof, Signature}; /// SingleReceiverTransactionProtocol represents the actions taken by the single receiver in the one-round Tari /// transaction protocol. The procedure is straightforward. Upon receiving the sender's information, the receiver: @@ -133,7 +135,15 @@ impl SingleReceiverTransactionProtocol { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PK, SecretKey as SK}, + script::TariScript, + }; + use crate::transactions::{ + crypto_factories::CryptoFactories, tari_amount::*, transaction::OutputFeatures, transaction_protocol::{ @@ -143,14 +153,8 @@ mod test { TransactionMetadata, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey, PublicKey}, - }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PK, SecretKey as SK}, - script::TariScript, }; + use tari_common_types::types::{PrivateKey, PublicKey}; fn generate_output_parms() -> (PrivateKey, PrivateKey, OutputFeatures) { let r = PrivateKey::random(&mut OsRng); diff --git a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs index 0d5beb738d..b27a1e528f 100644 --- a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs +++ b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs @@ -20,7 +20,24 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + fmt::{Debug, Error, Formatter}, +}; + +use digest::Digest; +use log::*; +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + ristretto::pedersen::PedersenCommitmentFactory, + script::{ExecutionStack, TariScript}, + tari_utilities::fixed_set::FixedSet, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, tari_amount::*, transaction::{ @@ -38,22 +55,8 @@ use crate::transactions::{ RewindData, TransactionMetadata, }, - types::{BlindingFactor, CryptoFactories, PrivateKey, PublicKey}, -}; -use digest::Digest; -use log::*; -use rand::rngs::OsRng; -use std::{ - collections::HashMap, - fmt::{Debug, Error, Formatter}, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - ristretto::pedersen::PedersenCommitmentFactory, - script::{ExecutionStack, TariScript}, - tari_utilities::fixed_set::FixedSet, }; +use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; pub const LOG_TARGET: &str = "c::tx::tx_protocol::tx_initializer"; @@ -571,9 +574,18 @@ impl SenderTransactionInitializer { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + common::Blake256, + keys::SecretKey, + script, + script::{ExecutionStack, TariScript}, + }; + use crate::{ consensus::{KERNEL_WEIGHT, WEIGHT_PER_INPUT, WEIGHT_PER_OUTPUT}, transactions::{ + crypto_factories::CryptoFactories, fee::Fee, helpers::{create_test_input, create_unblinded_output, TestParams, UtxoTestParams}, tari_amount::*, @@ -583,16 +595,9 @@ mod test { transaction_initializer::SenderTransactionInitializer, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey}, }, }; - use rand::rngs::OsRng; - use tari_crypto::{ - common::Blake256, - keys::SecretKey, - script, - script::{ExecutionStack, TariScript}, - }; + use tari_common_types::types::PrivateKey; /// One input, 2 outputs #[test] @@ -763,6 +768,7 @@ mod test { .with_output(output, p.sender_offset_private_key) .unwrap() .with_fee_per_gram(MicroTari(2)); + for _ in 0..MAX_TRANSACTION_INPUTS + 1 { let (utxo, input) = create_test_input(MicroTari(50), 0, &factories.commitment); builder.with_input(utxo, input); diff --git a/base_layer/core/src/transactions/types.rs b/base_layer/core/src/transactions/types.rs index 40b0fa625d..d051788bca 100644 --- a/base_layer/core/src/transactions/types.rs +++ b/base_layer/core/src/transactions/types.rs @@ -19,99 +19,3 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::transactions::bullet_rangeproofs::BulletRangeProof; -use std::sync::Arc; -use tari_crypto::{ - common::Blake256, - ristretto::{ - dalek_range_proof::DalekRangeProofService, - pedersen::{PedersenCommitment, PedersenCommitmentFactory}, - RistrettoComSig, - RistrettoPublicKey, - RistrettoSchnorr, - RistrettoSecretKey, - }, -}; - -/// Define the explicit Signature implementation for the Tari base layer. A different signature scheme can be -/// employed by redefining this type. -pub type Signature = RistrettoSchnorr; -/// Define the explicit Commitment Signature implementation for the Tari base layer. -pub type ComSignature = RistrettoComSig; - -/// Define the explicit Commitment implementation for the Tari base layer. -pub type Commitment = PedersenCommitment; -pub type CommitmentFactory = PedersenCommitmentFactory; - -/// Define the explicit Secret key implementation for the Tari base layer. -pub type PrivateKey = RistrettoSecretKey; -pub type BlindingFactor = RistrettoSecretKey; - -/// Define the hash function that will be used to produce a signature challenge -pub type SignatureHasher = Blake256; - -/// Define the explicit Public key implementation for the Tari base layer -pub type PublicKey = RistrettoPublicKey; - -/// Specify the Hash function for general hashing -pub type HashDigest = Blake256; - -/// Specify the digest type for signature challenges -pub type Challenge = Blake256; - -/// The type of output that `Challenge` produces -pub type MessageHash = Vec; - -/// Specify the range proof type -pub type RangeProofService = DalekRangeProofService; - -/// Specify the range proof -pub type RangeProof = BulletRangeProof; - -/// Define the data type that is used to store results of `HashDigest` -pub type HashOutput = Vec; - -pub const MAX_RANGE_PROOF_RANGE: usize = 64; // 2^64 - -/// A convenience struct wrapping cryptographic factories that are used through-out the rest of the code base -/// Uses Arcs internally so calling clone on this is cheap, no need to wrap this in an Arc -pub struct CryptoFactories { - pub commitment: Arc, - pub range_proof: Arc, -} - -impl Default for CryptoFactories { - /// Return a default set of crypto factories based on Pedersen commitments with G and H defined in - /// [pedersen.rs](/infrastructure/crypto/src/ristretto/pedersen.rs), and an associated range proof factory with a - /// range of `[0; 2^64)`. - fn default() -> Self { - CryptoFactories::new(MAX_RANGE_PROOF_RANGE) - } -} - -impl CryptoFactories { - /// Create a new set of crypto factories. - /// - /// ## Parameters - /// - /// * `max_proof_range`: Sets the the maximum value in range proofs, where `max = 2^max_proof_range` - pub fn new(max_proof_range: usize) -> Self { - let commitment = Arc::new(CommitmentFactory::default()); - let range_proof = Arc::new(RangeProofService::new(max_proof_range, &commitment).unwrap()); - Self { - commitment, - range_proof, - } - } -} - -/// Uses Arc's internally so calling clone on this is cheap, no need to wrap this in an Arc -impl Clone for CryptoFactories { - fn clone(&self) -> Self { - Self { - commitment: self.commitment.clone(), - range_proof: self.range_proof.clone(), - } - } -} diff --git a/base_layer/core/src/validation/block_validators.rs b/base_layer/core/src/validation/block_validators.rs index 0c4ee76bfd..3908c28a5f 100644 --- a/base_layer/core/src/validation/block_validators.rs +++ b/base_layer/core/src/validation/block_validators.rs @@ -1,3 +1,13 @@ +use std::marker::PhantomData; + +use log::*; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + tari_utilities::{hash::Hashable, hex::Hex}, +}; + +use tari_common_types::chain_metadata::ChainMetadata; + // Copyright 2019. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -27,7 +37,7 @@ use crate::{ transactions::{ aggregated_body::AggregateBody, transaction::{KernelFeatures, OutputFlags, TransactionError}, - types::CryptoFactories, + CryptoFactories, }, validation::{ helpers::{check_accounting_balance, check_block_weight, check_coinbase_output, is_all_unique_and_sorted}, @@ -37,13 +47,6 @@ use crate::{ ValidationError, }, }; -use log::*; -use std::marker::PhantomData; -use tari_common_types::chain_metadata::ChainMetadata; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - tari_utilities::{hash::Hashable, hex::Hex}, -}; pub const LOG_TARGET: &str = "c::val::block_validators"; @@ -51,12 +54,17 @@ pub const LOG_TARGET: &str = "c::val::block_validators"; #[derive(Clone)] pub struct OrphanBlockValidator { rules: ConsensusManager, + bypass_range_proof_verification: bool, factories: CryptoFactories, } impl OrphanBlockValidator { - pub fn new(rules: ConsensusManager, factories: CryptoFactories) -> Self { - Self { rules, factories } + pub fn new(rules: ConsensusManager, bypass_range_proof_verification: bool, factories: CryptoFactories) -> Self { + Self { + rules, + bypass_range_proof_verification, + factories, + } } } @@ -101,7 +109,12 @@ impl OrphanValidation for OrphanBlockValidator { trace!(target: LOG_TARGET, "SV - Output constraints are ok for {} ", &block_id); check_coinbase_output(block, &self.rules, &self.factories)?; trace!(target: LOG_TARGET, "SV - Coinbase output is ok for {} ", &block_id); - check_accounting_balance(block, &self.rules, &self.factories)?; + check_accounting_balance( + block, + &self.rules, + self.bypass_range_proof_verification, + &self.factories, + )?; trace!(target: LOG_TARGET, "SV - accounting balance correct for {}", &block_id); debug!( target: LOG_TARGET, @@ -311,15 +324,17 @@ fn check_mmr_roots(block: &Block, db: &B) -> Result<(), Va /// the block body using the header. It is assumed that the `BlockHeader` has already been validated. pub struct BlockValidator { rules: ConsensusManager, + bypass_range_proof_verification: bool, factories: CryptoFactories, phantom_data: PhantomData, } impl BlockValidator { - pub fn new(rules: ConsensusManager, factories: CryptoFactories) -> Self { + pub fn new(rules: ConsensusManager, bypass_range_proof_verification: bool, factories: CryptoFactories) -> Self { Self { rules, factories, + bypass_range_proof_verification, phantom_data: Default::default(), } } @@ -428,7 +443,12 @@ impl CandidateBlockBodyValidation for BlockValidator self.check_inputs(block)?; self.check_outputs(block)?; - check_accounting_balance(block, &self.rules, &self.factories)?; + check_accounting_balance( + block, + &self.rules, + self.bypass_range_proof_verification, + &self.factories, + )?; trace!(target: LOG_TARGET, "SV - accounting balance correct for {}", &block_id); debug!( target: LOG_TARGET, diff --git a/base_layer/core/src/validation/chain_balance.rs b/base_layer/core/src/validation/chain_balance.rs index f620bcfdca..6efcc3cb6e 100644 --- a/base_layer/core/src/validation/chain_balance.rs +++ b/base_layer/core/src/validation/chain_balance.rs @@ -20,18 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::marker::PhantomData; + +use log::*; +use tari_crypto::commitment::HomomorphicCommitmentFactory; + use crate::{ chain_storage::BlockchainBackend, consensus::ConsensusManager, - transactions::{ - tari_amount::MicroTari, - types::{Commitment, CryptoFactories, PrivateKey}, - }, + transactions::{tari_amount::MicroTari, CryptoFactories}, validation::{FinalHorizonStateValidation, ValidationError}, }; -use log::*; -use std::marker::PhantomData; -use tari_crypto::commitment::HomomorphicCommitmentFactory; +use tari_common_types::types::{Commitment, PrivateKey}; const LOG_TARGET: &str = "c::bn::state_machine_service::states::horizon_state_sync::chain_balance"; diff --git a/base_layer/core/src/validation/error.rs b/base_layer/core/src/validation/error.rs index 1a9ee2fab6..e651079fb8 100644 --- a/base_layer/core/src/validation/error.rs +++ b/base_layer/core/src/validation/error.rs @@ -24,8 +24,9 @@ use crate::{ blocks::{block_header::BlockHeaderValidationError, BlockValidationError}, chain_storage::ChainStorageError, proof_of_work::{monero_rx::MergeMineError, PowError}, - transactions::{transaction::TransactionError, types::HashOutput}, + transactions::transaction::TransactionError, }; +use tari_common_types::types::HashOutput; use thiserror::Error; #[derive(Debug, Error)] diff --git a/base_layer/core/src/validation/helpers.rs b/base_layer/core/src/validation/helpers.rs index e6405a04d2..f0d1947b1e 100644 --- a/base_layer/core/src/validation/helpers.rs +++ b/base_layer/core/src/validation/helpers.rs @@ -20,6 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use log::*; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}; + use crate::{ blocks::{ block_header::{BlockHeader, BlockHeaderValidationError}, @@ -38,11 +41,9 @@ use crate::{ PowAlgorithm, PowError, }, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::ValidationError, }; -use log::*; -use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}; pub const LOG_TARGET: &str = "c::val::helpers"; @@ -199,6 +200,7 @@ pub fn check_block_weight(block: &Block, consensus_constants: &ConsensusConstant pub fn check_accounting_balance( block: &Block, rules: &ConsensusManager, + bypass_range_proof_verification: bool, factories: &CryptoFactories, ) -> Result<(), ValidationError> { if block.header.height == 0 { @@ -210,7 +212,13 @@ pub fn check_accounting_balance( let total_coinbase = rules.calculate_coinbase_and_fees(block); block .body - .validate_internal_consistency(&offset, &script_offset, total_coinbase, factories) + .validate_internal_consistency( + &offset, + &script_offset, + bypass_range_proof_verification, + total_coinbase, + factories, + ) .map_err(|err| { warn!( target: LOG_TARGET, diff --git a/base_layer/core/src/validation/mocks.rs b/base_layer/core/src/validation/mocks.rs index 03c8951d3f..c2b3ffe5cf 100644 --- a/base_layer/core/src/validation/mocks.rs +++ b/base_layer/core/src/validation/mocks.rs @@ -24,7 +24,7 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockchainBackend, ChainBlock}, proof_of_work::{sha3_difficulty, AchievedTargetDifficulty, Difficulty, PowAlgorithm}, - transactions::{transaction::Transaction, types::Commitment}, + transactions::transaction::Transaction, validation::{ error::ValidationError, CandidateBlockBodyValidation, @@ -40,7 +40,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::Commitment}; #[derive(Clone)] pub struct MockValidator { diff --git a/base_layer/core/src/validation/test.rs b/base_layer/core/src/validation/test.rs index a5998fa3e2..e3a50914b2 100644 --- a/base_layer/core/src/validation/test.rs +++ b/base_layer/core/src/validation/test.rs @@ -20,6 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use tari_crypto::{commitment::HomomorphicCommitment, script}; + +use tari_common::configuration::Network; + use crate::{ blocks::BlockHeader, chain_storage::{BlockHeaderAccumulatedData, ChainBlock, ChainHeader, DbTransaction}, @@ -31,13 +37,11 @@ use crate::{ helpers::{create_random_signature_from_s_key, create_utxo}, tari_amount::{uT, MicroTari}, transaction::{KernelBuilder, KernelFeatures, OutputFeatures, TransactionKernel}, - types::{Commitment, CryptoFactories}, + CryptoFactories, }, validation::{header_iter::HeaderIter, ChainBalanceValidator, FinalHorizonStateValidation}, }; -use std::sync::Arc; -use tari_common::configuration::Network; -use tari_crypto::{commitment::HomomorphicCommitment, script}; +use tari_common_types::types::Commitment; #[test] fn header_iter_empty_and_invalid_height() { diff --git a/base_layer/core/src/validation/traits.rs b/base_layer/core/src/validation/traits.rs index e7fabb449b..cc8c287b0a 100644 --- a/base_layer/core/src/validation/traits.rs +++ b/base_layer/core/src/validation/traits.rs @@ -24,10 +24,10 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockchainBackend, ChainBlock}, proof_of_work::AchievedTargetDifficulty, - transactions::{transaction::Transaction, types::Commitment}, + transactions::transaction::Transaction, validation::{error::ValidationError, DifficultyCalculator}, }; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::Commitment}; /// A validator that determines if a block body is valid, assuming that the header has already been /// validated diff --git a/base_layer/core/src/validation/transaction_validators.rs b/base_layer/core/src/validation/transaction_validators.rs index 59ff3bfc41..4f136aeea5 100644 --- a/base_layer/core/src/validation/transaction_validators.rs +++ b/base_layer/core/src/validation/transaction_validators.rs @@ -20,14 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use log::*; + use crate::{ blocks::BlockValidationError, chain_storage::{BlockchainBackend, BlockchainDatabase, MmrTree}, crypto::tari_utilities::Hashable, - transactions::{transaction::Transaction, types::CryptoFactories}, + transactions::{transaction::Transaction, CryptoFactories}, validation::{MempoolTransactionValidation, ValidationError}, }; -use log::*; pub const LOG_TARGET: &str = "c::val::transaction_validators"; @@ -40,17 +41,21 @@ pub const LOG_TARGET: &str = "c::val::transaction_validators"; /// This function does NOT check that inputs come from the UTXO set pub struct TxInternalConsistencyValidator { factories: CryptoFactories, + bypass_range_proof_verification: bool, } impl TxInternalConsistencyValidator { - pub fn new(factories: CryptoFactories) -> Self { - Self { factories } + pub fn new(factories: CryptoFactories, bypass_range_proof_verification: bool) -> Self { + Self { + factories, + bypass_range_proof_verification, + } } } impl MempoolTransactionValidation for TxInternalConsistencyValidator { fn validate(&self, tx: &Transaction) -> Result<(), ValidationError> { - tx.validate_internal_consistency(&self.factories, None) + tx.validate_internal_consistency(self.bypass_range_proof_verification, &self.factories, None) .map_err(ValidationError::TransactionError)?; Ok(()) } diff --git a/base_layer/core/tests/async_db.rs b/base_layer/core/tests/async_db.rs index afedd9c7d9..a8e7902ed5 100644 --- a/base_layer/core/tests/async_db.rs +++ b/base_layer/core/tests/async_db.rs @@ -21,16 +21,17 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -#[allow(dead_code)] -mod helpers; +use std::ops::Deref; + +use tari_crypto::{commitment::HomomorphicCommitmentFactory, tari_utilities::Hashable}; use helpers::{ block_builders::chain_block_with_new_coinbase, database::create_orphan_block, sample_blockchains::{create_blockchain_db_no_cut_through, create_new_blockchain}, }; -use std::ops::Deref; use tari_common::configuration::Network; +use tari_common_types::types::CommitmentFactory; use tari_core::{ blocks::Block, chain_storage::{async_db::AsyncBlockchainDb, BlockAddResult, PrunedOutput}, @@ -38,13 +39,15 @@ use tari_core::{ helpers::schema_to_transaction, tari_amount::T, transaction::{TransactionOutput, UnblindedOutput}, - types::{CommitmentFactory, CryptoFactories}, + CryptoFactories, }, txn_schema, }; -use tari_crypto::{commitment::HomomorphicCommitmentFactory, tari_utilities::Hashable}; use tari_test_utils::runtime::test_async; +#[allow(dead_code)] +mod helpers; + /// Finds the UTXO in a block corresponding to the unblinded output. We have to search for outputs because UTXOs get /// sorted in blocks, and so the order they were inserted in can change. fn find_utxo(output: &UnblindedOutput, block: &Block, factory: &CommitmentFactory) -> Option { diff --git a/base_layer/core/tests/base_node_rpc.rs b/base_layer/core/tests/base_node_rpc.rs index 9b96512d47..e8b627dd57 100644 --- a/base_layer/core/tests/base_node_rpc.rs +++ b/base_layer/core/tests/base_node_rpc.rs @@ -42,13 +42,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod helpers; - -use crate::helpers::{ - block_builders::{chain_block, create_genesis_block_with_coinbase_value}, - nodes::{BaseNodeBuilder, NodeInterfaces}, -}; use std::convert::TryFrom; + +use tempfile::{tempdir, TempDir}; + use tari_common::configuration::Network; use tari_comms::protocol::rpc::mock::RpcRequestMock; use tari_core::{ @@ -76,27 +73,30 @@ use tari_core::{ helpers::schema_to_transaction, tari_amount::{uT, T}, transaction::{TransactionOutput, UnblindedOutput}, - types::CryptoFactories, + CryptoFactories, }, txn_schema, }; -use tempfile::{tempdir, TempDir}; -use tokio::runtime::Runtime; -fn setup() -> ( +use crate::helpers::{ + block_builders::{chain_block, create_genesis_block_with_coinbase_value}, + nodes::{BaseNodeBuilder, NodeInterfaces}, +}; + +mod helpers; + +async fn setup() -> ( BaseNodeWalletRpcService, NodeInterfaces, RpcRequestMock, ConsensusManager, ChainBlock, UnblindedOutput, - Runtime, TempDir, ) { let network = NetworkConsensus::from(Network::LocalNet); let consensus_constants = network.create_consensus_constants(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxo0) = @@ -107,13 +107,14 @@ fn setup() -> ( let (mut base_node, _consensus_manager) = BaseNodeBuilder::new(network) .with_consensus_manager(consensus_manager.clone()) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; base_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), }); - let request_mock = runtime.enter(|| RpcRequestMock::new(base_node.comms.peer_manager())); + let request_mock = RpcRequestMock::new(base_node.comms.peer_manager()); let service = BaseNodeWalletRpcService::new( base_node.blockchain_db.clone().into(), base_node.mempool_handle.clone(), @@ -126,16 +127,15 @@ fn setup() -> ( consensus_manager, block0, utxo0, - runtime, temp_dir, ) } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_base_node_wallet_rpc() { +async fn test_base_node_wallet_rpc() { // Testing the submit_transaction() and transaction_query() rpc calls - let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, mut runtime, _temp_dir) = setup(); + let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, _temp_dir) = setup().await; let (txs1, utxos1) = schema_to_transaction(&[txn_schema!(from: vec![utxo0.clone()], to: vec![1 * T, 1 * T])]); let tx1 = (*txs1[0]).clone(); @@ -151,8 +151,8 @@ fn test_base_node_wallet_rpc() { // Query Tx1 let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = service.transaction_query(req).await.unwrap().into_message(); + let resp = TxQueryResponse::try_from(resp).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -162,13 +162,7 @@ fn test_base_node_wallet_rpc() { let msg = TransactionProto::from(tx2.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::Orphan); @@ -176,8 +170,7 @@ fn test_base_node_wallet_rpc() { // Query Tx2 to confirm it wasn't accepted let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -189,24 +182,22 @@ fn test_base_node_wallet_rpc() { .prepare_block_merkle_roots(chain_block(&block0.block(), vec![tx1.clone()], &consensus_manager)) .unwrap(); - assert!(runtime - .block_on(base_node.local_nci.submit_block(block1.clone(), Broadcast::from(true))) - .is_ok()); + base_node + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); // Check that subitting Tx2 will now be accepted let msg = TransactionProto::from(tx2); let req = request_mock.request_with_context(Default::default(), msg); - let resp = runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(); + let resp = service.submit_transaction(req).await.unwrap().into_message(); assert!(resp.accepted); // Query Tx2 which should now be in the mempool let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -215,13 +206,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined as Tx1's kernel is present let msg = TransactionProto::from(tx1); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::AlreadyMined); @@ -233,13 +218,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined let msg = TransactionProto::from(tx1b); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::DoubleSpend); @@ -253,15 +232,16 @@ fn test_base_node_wallet_rpc() { block2.header.output_mmr_size += 1; block2.header.kernel_mmr_size += 1; - runtime - .block_on(base_node.local_nci.submit_block(block2, Broadcast::from(true))) + base_node + .local_nci + .submit_block(block2, Broadcast::from(true)) + .await .unwrap(); // Query Tx1 which should be in block 1 with 1 confirmation let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 1); assert_eq!(resp.block_hash, Some(block1.hash())); @@ -271,10 +251,7 @@ fn test_base_node_wallet_rpc() { sigs: vec![SignatureProto::from(tx1_sig.clone()), SignatureProto::from(tx2_sig)], }; let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.transaction_batch_query(req)) - .unwrap() - .into_message(); + let response = service.transaction_batch_query(req).await.unwrap().into_message(); for r in response.responses { let response = TxQueryBatchResponse::try_from(r).unwrap(); @@ -299,10 +276,7 @@ fn test_base_node_wallet_rpc() { let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.fetch_matching_utxos(req)) - .unwrap() - .into_message(); + let response = service.fetch_matching_utxos(req).await.unwrap().into_message(); assert_eq!(response.outputs.len(), utxos1.len()); for output_proto in response.outputs.iter() { diff --git a/base_layer/core/tests/block_validation.rs b/base_layer/core/tests/block_validation.rs index 6961e48dfd..1a91df6012 100644 --- a/base_layer/core/tests/block_validation.rs +++ b/base_layer/core/tests/block_validation.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::helpers::{block_builders::chain_block_with_new_coinbase, test_blockchain::TestBlockchain}; -use monero::blockdata::block::Block as MoneroBlock; use std::sync::Arc; + +use monero::blockdata::block::Block as MoneroBlock; +use tari_crypto::inputs; + use tari_common::configuration::Network; use tari_core::{ blocks::{Block, BlockHeaderValidationError, BlockValidationError}, @@ -38,7 +40,7 @@ use tari_core::{ transactions::{ helpers::{schema_to_transaction, TestParams, UtxoTestParams}, tari_amount::T, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::{ @@ -50,7 +52,8 @@ use tari_core::{ ValidationError, }, }; -use tari_crypto::inputs; + +use crate::helpers::{block_builders::chain_block_with_new_coinbase, test_blockchain::TestBlockchain}; mod helpers; @@ -63,7 +66,7 @@ fn test_genesis_block() { let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories), + OrphanBlockValidator::new(rules.clone(), false, factories), ); let db = BlockchainDatabase::new( backend, @@ -216,7 +219,7 @@ fn inputs_are_not_malleable() { input_mut.input_data = malicious_input.input_data; input_mut.script_signature = malicious_input.script_signature; - let validator = BlockValidator::new(blockchain.consensus_manager().clone(), CryptoFactories::default()); + let validator = BlockValidator::new(blockchain.consensus_manager().clone(), true, CryptoFactories::default()); let err = validator .validate_body(&block, &*blockchain.store().db_read_access().unwrap()) .unwrap_err(); diff --git a/base_layer/core/tests/chain_storage_tests/chain_storage.rs b/base_layer/core/tests/chain_storage_tests/chain_storage.rs index b716858d4c..98a3aac848 100644 --- a/base_layer/core/tests/chain_storage_tests/chain_storage.rs +++ b/base_layer/core/tests/chain_storage_tests/chain_storage.rs @@ -20,24 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// use crate::helpers::database::create_test_db; -// use crate::helpers::database::create_store; -use crate::helpers::{ - block_builders::{ - append_block, - chain_block, - create_chain_header, - create_genesis_block, - find_header_with_achieved_difficulty, - generate_new_block, - generate_new_block_with_achieved_difficulty, - generate_new_block_with_coinbase, - }, - database::create_orphan_block, - sample_blockchains::{create_new_blockchain, create_new_blockchain_lmdb}, - test_blockchain::TestBlockchain, -}; use rand::{rngs::OsRng, RngCore}; +use tari_crypto::{script::StackItem, tari_utilities::Hashable}; + use tari_common::configuration::Network; use tari_common_types::types::BlockHash; use tari_core::{ @@ -63,16 +48,33 @@ use tari_core::{ transactions::{ helpers::{schema_to_transaction, spend_utxos}, tari_amount::{uT, MicroTari, T}, - types::CryptoFactories, + CryptoFactories, }, tx, txn_schema, validation::{mocks::MockValidator, DifficultyCalculator, ValidationError}, }; -use tari_crypto::{script::StackItem, tari_utilities::Hashable}; use tari_storage::lmdb_store::LMDBConfig; use tari_test_utils::{paths::create_temporary_data_path, unpack_enum}; +// use crate::helpers::database::create_test_db; +// use crate::helpers::database::create_store; +use crate::helpers::{ + block_builders::{ + append_block, + chain_block, + create_chain_header, + create_genesis_block, + find_header_with_achieved_difficulty, + generate_new_block, + generate_new_block_with_achieved_difficulty, + generate_new_block_with_coinbase, + }, + database::create_orphan_block, + sample_blockchains::{create_new_blockchain, create_new_blockchain_lmdb}, + test_blockchain::TestBlockchain, +}; + #[test] fn fetch_nonexistent_header() { let network = Network::LocalNet; diff --git a/base_layer/core/tests/helpers/block_builders.rs b/base_layer/core/tests/helpers/block_builders.rs index 6ff5c2102a..12d4659d62 100644 --- a/base_layer/core/tests/helpers/block_builders.rs +++ b/base_layer/core/tests/helpers/block_builders.rs @@ -20,10 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{iter::repeat_with, sync::Arc}; + use croaring::Bitmap; use rand::{rngs::OsRng, RngCore}; -use std::{iter::repeat_with, sync::Arc}; +use tari_crypto::{ + keys::PublicKey as PublicKeyTrait, + script, + tari_utilities::{hash::Hashable, hex::Hex}, +}; + use tari_common::configuration::Network; +use tari_common_types::types::{Commitment, HashDigest, HashOutput, PublicKey}; use tari_core::{ blocks::{Block, BlockHeader, NewBlockTemplate}, chain_storage::{ @@ -57,14 +65,9 @@ use tari_core::{ TransactionOutput, UnblindedOutput, }, - types::{Commitment, CryptoFactories, HashDigest, HashOutput, PublicKey}, + CryptoFactories, }, }; -use tari_crypto::{ - keys::PublicKey as PublicKeyTrait, - script, - tari_utilities::{hash::Hashable, hex::Hex}, -}; use tari_mmr::MutableMmr; pub fn create_coinbase( diff --git a/base_layer/core/tests/helpers/database.rs b/base_layer/core/tests/helpers/database.rs index 0bafeed84b..e1132445a3 100644 --- a/base_layer/core/tests/helpers/database.rs +++ b/base_layer/core/tests/helpers/database.rs @@ -20,13 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::helpers::block_builders::create_coinbase; use tari_core::{ blocks::{Block, BlockHeader, NewBlockTemplate}, consensus::{emission::Emission, ConsensusManager}, - transactions::{tari_amount::MicroTari, transaction::Transaction, types::CryptoFactories}, + transactions::{tari_amount::MicroTari, transaction::Transaction, CryptoFactories}, }; +use crate::helpers::block_builders::create_coinbase; + // use tari_test_utils::paths::create_temporary_data_path; /// Create a partially constructed block using the provided set of transactions diff --git a/base_layer/core/tests/helpers/event_stream.rs b/base_layer/core/tests/helpers/event_stream.rs index b79b494900..5485467f4c 100644 --- a/base_layer/core/tests/helpers/event_stream.rs +++ b/base_layer/core/tests/helpers/event_stream.rs @@ -20,16 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{future, future::Either, FutureExt, Stream, StreamExt}; use std::time::Duration; +use tokio::{sync::broadcast, time}; #[allow(dead_code)] -pub async fn event_stream_next(stream: &mut TStream, timeout: Duration) -> Option -where TStream: Stream + Unpin { - let either = future::select(stream.next(), tokio::time::delay_for(timeout).fuse()).await; - - match either { - Either::Left((v, _)) => v, - Either::Right(_) => None, +pub async fn event_stream_next(stream: &mut broadcast::Receiver, timeout: Duration) -> Option { + tokio::select! { + item = stream.recv() => match item { + Ok(item) => Some(item), + Err(broadcast::error::RecvError::Closed) => None, + Err(broadcast::error::RecvError::Lagged(n)) => panic!("Lagged events channel {}", n), + }, + _ = time::sleep(timeout) => None } } diff --git a/base_layer/core/tests/helpers/mock_state_machine.rs b/base_layer/core/tests/helpers/mock_state_machine.rs index 7d49f93e85..0d4b6ce512 100644 --- a/base_layer/core/tests/helpers/mock_state_machine.rs +++ b/base_layer/core/tests/helpers/mock_state_machine.rs @@ -40,7 +40,7 @@ impl MockBaseNodeStateMachine { } pub fn publish_status(&mut self, status: StatusInfo) { - let _ = self.status_sender.broadcast(status); + let _ = self.status_sender.send(status); } pub fn get_initializer(&self) -> MockBaseNodeStateMachineInitializer { diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index ffe69c8034..06f2d5f8e0 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -21,9 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::helpers::mock_state_machine::MockBaseNodeStateMachine; -use futures::Sink; use rand::rngs::OsRng; -use std::{error::Error, path::Path, sync::Arc, time::Duration}; +use std::{path::Path, sync::Arc, time::Duration}; use tari_common::configuration::Network; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, @@ -60,13 +59,12 @@ use tari_core::{ }, }; use tari_p2p::{ - comms_connector::{pubsub_connector, InboundDomainConnector, PeerMessage}, + comms_connector::{pubsub_connector, InboundDomainConnector}, initialization::initialize_local_test_comms, services::liveness::{LivenessConfig, LivenessHandle, LivenessInitializer}, }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tokio::runtime::Runtime; /// The NodeInterfaces is used as a container for providing access to all the services and interfaces of a single node. pub struct NodeInterfaces { @@ -91,7 +89,7 @@ pub struct NodeInterfaces { #[allow(dead_code)] impl NodeInterfaces { pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -182,7 +180,7 @@ impl BaseNodeBuilder { /// Build the test base node and start its services. #[allow(clippy::redundant_closure)] - pub fn start(self, runtime: &mut Runtime, data_path: &str) -> (NodeInterfaces, ConsensusManager) { + pub async fn start(self, data_path: &str) -> (NodeInterfaces, ConsensusManager) { let validators = self.validators.unwrap_or_else(|| { Validators::new( MockValidator::new(true), @@ -199,7 +197,6 @@ impl BaseNodeBuilder { let mempool = Mempool::new(self.mempool_config.unwrap_or_default(), Arc::new(mempool_validator)); let node_identity = self.node_identity.unwrap_or_else(|| random_node_identity()); let node_interfaces = setup_base_node_services( - runtime, node_identity, self.peers.unwrap_or_default(), blockchain_db, @@ -209,17 +206,19 @@ impl BaseNodeBuilder { self.mempool_service_config.unwrap_or_default(), self.liveness_service_config.unwrap_or_default(), data_path, - ); + ) + .await; (node_interfaces, consensus_manager) } } -#[allow(dead_code)] -pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { +pub async fn wait_until_online(nodes: &[&NodeInterfaces]) { for node in nodes { - runtime - .block_on(node.comms.connectivity().wait_for_connectivity(Duration::from_secs(10))) + node.comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await .map_err(|err| format!("Node '{}' failed to go online {:?}", node.node_identity.node_id(), err)) .unwrap(); } @@ -227,10 +226,7 @@ pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes( - runtime: &mut Runtime, - data_path: &str, -) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { +pub async fn create_network_with_2_base_nodes(data_path: &str) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { let alice_node_identity = random_node_identity(); let bob_node_identity = random_node_identity(); @@ -238,22 +234,23 @@ pub fn create_network_with_2_base_nodes( let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) - .start(runtime, data_path); + .start(data_path) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) .with_consensus_manager(consensus_manager) - .start(runtime, data_path); + .start(data_path) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_2_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -269,7 +266,8 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) @@ -277,35 +275,34 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes( data_path: &str, ) -> (NodeInterfaces, NodeInterfaces, NodeInterfaces, ConsensusManager) { let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); create_network_with_3_base_nodes_with_config( - runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, data_path, ) + .await } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -329,7 +326,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("carol").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("carol").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![carol_node_identity.clone()]) @@ -337,7 +335,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) @@ -345,9 +344,10 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; (alice_node, bob_node, carol_node, consensus_manager) } @@ -365,16 +365,12 @@ pub fn random_node_identity() -> Arc { // Helper function for starting the comms stack. #[allow(dead_code)] -async fn setup_comms_services( +async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, -) -> (CommsNode, Dht, MessagingEventSender, Shutdown) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender, Shutdown) { let peers = peers.into_iter().map(|p| p.to_peer()).collect(); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = initialize_local_test_comms( @@ -393,8 +389,7 @@ where // Helper function for starting the services of the Base node. #[allow(clippy::too_many_arguments)] -fn setup_base_node_services( - runtime: &mut Runtime, +async fn setup_base_node_services( node_identity: Arc, peers: Vec>, blockchain_db: BlockchainDatabase, @@ -405,14 +400,14 @@ fn setup_base_node_services( liveness_service_config: LivenessConfig, data_path: &str, ) -> NodeInterfaces { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht, messaging_events, shutdown) = - runtime.block_on(setup_comms_services(node_identity.clone(), peers, publisher, data_path)); + setup_comms_services(node_identity.clone(), peers, publisher, data_path).await; let mock_state_machine = MockBaseNodeStateMachine::new(); - let fut = StackBuilder::new(shutdown.to_signal()) + let handles = StackBuilder::new(shutdown.to_signal()) .add_initializer(RegisterHandle::new(dht)) .add_initializer(RegisterHandle::new(comms.connectivity())) .add_initializer(LivenessInitializer::new( @@ -433,9 +428,9 @@ fn setup_base_node_services( )) .add_initializer(mock_state_machine.get_initializer()) .add_initializer(ChainMetadataServiceInitializer) - .build(); - - let handles = runtime.block_on(fut).expect("Service initialization failed"); + .build() + .await + .unwrap(); let outbound_nci = handles.expect_handle::(); let local_nci = handles.expect_handle::(); diff --git a/base_layer/core/tests/helpers/sample_blockchains.rs b/base_layer/core/tests/helpers/sample_blockchains.rs index ddc6398fed..108d77c248 100644 --- a/base_layer/core/tests/helpers/sample_blockchains.rs +++ b/base_layer/core/tests/helpers/sample_blockchains.rs @@ -21,8 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use crate::helpers::block_builders::{create_genesis_block, generate_new_block}; - use tari_common::configuration::Network; use tari_core::{ chain_storage::{ @@ -38,12 +36,15 @@ use tari_core::{ transactions::{ tari_amount::{uT, T}, transaction::UnblindedOutput, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::DifficultyCalculator, }; use tari_storage::lmdb_store::LMDBConfig; + +use crate::helpers::block_builders::{create_genesis_block, generate_new_block}; + // use crate::helpers::database::{TempDatabase, create_store_with_consensus}; static EMISSION: [u64; 2] = [10, 10]; diff --git a/base_layer/core/tests/helpers/test_blockchain.rs b/base_layer/core/tests/helpers/test_blockchain.rs index e961cb14de..43291d5f4e 100644 --- a/base_layer/core/tests/helpers/test_blockchain.rs +++ b/base_layer/core/tests/helpers/test_blockchain.rs @@ -21,24 +21,27 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use crate::helpers::{ - block_builders::{chain_block_with_new_coinbase, find_header_with_achieved_difficulty}, - block_proxy::BlockProxy, - sample_blockchains::create_new_blockchain, - test_block_builder::{TestBlockBuilder, TestBlockBuilderInner}, -}; +use std::{collections::HashMap, sync::Arc}; + use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{collections::HashMap, sync::Arc}; +use tari_crypto::tari_utilities::Hashable; + use tari_common::configuration::Network; use tari_core::{ blocks::Block, chain_storage::{BlockAddResult, BlockchainDatabase, ChainStorageError}, consensus::ConsensusManager, test_helpers::blockchain::TempDatabase, - transactions::{transaction::UnblindedOutput, types::CryptoFactories}, + transactions::{transaction::UnblindedOutput, CryptoFactories}, +}; + +use crate::helpers::{ + block_builders::{chain_block_with_new_coinbase, find_header_with_achieved_difficulty}, + block_proxy::BlockProxy, + sample_blockchains::create_new_blockchain, + test_block_builder::{TestBlockBuilder, TestBlockBuilderInner}, }; -use tari_crypto::tari_utilities::Hashable; const LOG_TARGET: &str = "tari_core::tests::helpers::test_blockchain"; diff --git a/base_layer/core/tests/mempool.rs b/base_layer/core/tests/mempool.rs index 80e187a99a..ae79003806 100644 --- a/base_layer/core/tests/mempool.rs +++ b/base_layer/core/tests/mempool.rs @@ -20,8 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; +// use crate::helpers::database::create_store; +use std::{ops::Deref, sync::Arc, time::Duration}; + +use tari_crypto::{keys::PublicKey as PublicKeyTrait, script}; +use tempfile::tempdir; use helpers::{ block_builders::{ @@ -35,10 +38,8 @@ use helpers::{ nodes::{create_network_with_2_base_nodes_with_config, create_network_with_3_base_nodes_with_config}, sample_blockchains::{create_new_blockchain, create_new_blockchain_with_constants}, }; -use tari_crypto::keys::PublicKey as PublicKeyTrait; -// use crate::helpers::database::create_store; -use std::{ops::Deref, sync::Arc, time::Duration}; use tari_common::configuration::Network; +use tari_common_types::types::{Commitment, PrivateKey, PublicKey, Signature}; use tari_comms_dht::domain_message::OutboundDomainMessage; use tari_core::{ base_node::{ @@ -56,21 +57,20 @@ use tari_core::{ tari_amount::{uT, MicroTari, T}, transaction::{KernelBuilder, OutputFeatures, Transaction, TransactionOutput}, transaction_protocol::{build_challenge, TransactionMetadata}, - types::{Commitment, CryptoFactories, PrivateKey, PublicKey, Signature}, + CryptoFactories, }, tx, txn_schema, validation::transaction_validators::{TxConsensusValidator, TxInputAndMaturityValidator}, }; -use tari_crypto::script; use tari_p2p::{services::liveness::LivenessConfig, tari_message::TariMessageType}; use tari_test_utils::async_assert_eventually; -use tempfile::tempdir; -use tokio::runtime::Runtime; +#[allow(dead_code)] +mod helpers; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_insert_and_process_published_block() { +async fn test_insert_and_process_published_block() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -201,9 +201,9 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.total_weight, 30); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_time_locked() { +async fn test_time_locked() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -245,9 +245,9 @@ fn test_time_locked() { assert_eq!(mempool.insert(tx2).unwrap(), TxStorageResponse::UnconfirmedPool); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_retrieve() { +async fn test_retrieve() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -331,9 +331,9 @@ fn test_retrieve() { assert!(retrieved_txs.contains(&tx2[1])); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_zero_conf() { +async fn test_zero_conf() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -631,9 +631,9 @@ fn test_zero_conf() { assert!(retrieved_txs.contains(&Arc::new(tx34))); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_reorg() { +async fn test_reorg() { let network = Network::LocalNet; let (mut db, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(db.clone()); @@ -712,13 +712,12 @@ fn test_reorg() { mempool.process_reorg(vec![], vec![reorg_block4.into()]).unwrap(); } -#[test] // TODO: This test returns 0 in the unconfirmed pool, so might not catch errors. It should be updated to return better // data #[allow(clippy::identity_op)] -fn request_response_get_stats() { +#[tokio::test] +async fn request_response_get_stats() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -731,13 +730,13 @@ fn request_response_get_stats() { .with_block(block0) .build(); let (mut alice, bob, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path(), - ); + ) + .await; // Create a tx spending the genesis output. Then create 2 orphan txs let (tx1, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![2 * T, 2 * T, 2 * T])); @@ -759,21 +758,18 @@ fn request_response_get_stats() { assert_eq!(stats.reorg_txs, 0); assert_eq!(stats.total_weight, 0); - runtime.block_on(async { - // Alice will request mempool stats from Bob, and thus should be identical - let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); - assert_eq!(received_stats.total_txs, 0); - assert_eq!(received_stats.unconfirmed_txs, 0); - assert_eq!(received_stats.reorg_txs, 0); - assert_eq!(received_stats.total_weight, 0); - }); + // Alice will request mempool stats from Bob, and thus should be identical + let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); + assert_eq!(received_stats.total_txs, 0); + assert_eq!(received_stats.unconfirmed_txs, 0); + assert_eq!(received_stats.reorg_txs, 0); + assert_eq!(received_stats.total_weight, 0); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn request_response_get_tx_state_by_excess_sig() { +async fn request_response_get_tx_state_by_excess_sig() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -786,13 +782,13 @@ fn request_response_get_tx_state_by_excess_sig() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let (tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo.clone()], to: vec![2 * T, 2 * T, 2 * T])); let (unpublished_tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![3 * T])); @@ -807,43 +803,40 @@ fn request_response_get_tx_state_by_excess_sig() { // Check that the transactions are in the expected pools. // Spending the coinbase utxo will be in the pending pool, because cb utxos have a maturity. // The orphan tx will be in the orphan pool, while the unadded tx won't be found - runtime.block_on(async { - let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); - let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); - let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(orphan_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - }); + let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); + let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); + let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(orphan_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); } static EMISSION: [u64; 2] = [10, 10]; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn receive_and_propagate_transaction() { +async fn receive_and_propagate_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -857,13 +850,13 @@ fn receive_and_propagate_transaction() { .build(); let (mut alice_node, mut bob_node, mut carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -884,63 +877,61 @@ fn receive_and_propagate_transaction() { assert!(alice_node.mempool.insert(Arc::new(tx.clone())).is_ok()); assert!(alice_node.mempool.insert(Arc::new(orphan.clone())).is_ok()); - runtime.block_on(async { - alice_node - .outbound_message_service - .send_direct( - bob_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), - ) - .await - .unwrap(); - alice_node - .outbound_message_service - .send_direct( - carol_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), - ) - .await - .unwrap(); + alice_node + .outbound_message_service + .send_direct( + bob_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), + ) + .await + .unwrap(); + alice_node + .outbound_message_service + .send_direct( + carol_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), + ) + .await + .unwrap(); - async_assert_eventually!( - bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(tx_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // Carol got sent the orphan tx directly, so it will be in her mempool - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated - // by the time we check it - async_assert_eventually!( - bob_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - ); - }); + async_assert_eventually!( + bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(tx_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // Carol got sent the orphan tx directly, so it will be in her mempool + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated + // by the time we check it + async_assert_eventually!( + bob_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + ); } -#[test] -fn consensus_validation_large_tx() { +#[tokio::test] +async fn consensus_validation_large_tx() { let network = Network::LocalNet; // We dont want to compute the 19500 limit of local net, so we create smaller blocks let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -1029,7 +1020,7 @@ fn consensus_validation_large_tx() { // make sure the tx was correctly made and is valid let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(true, &factories, None).is_ok()); let weight = tx.calculate_weight(); let height = blocks.len() as u64; @@ -1042,9 +1033,8 @@ fn consensus_validation_large_tx() { assert!(matches!(response, TxStorageResponse::NotStored)); } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let mempool_service_config = MempoolServiceConfig { @@ -1053,27 +1043,25 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), mempool_service_config, LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - bob_node.shutdown().await; + bob_node.shutdown().await; - match alice_node.outbound_mp_interface.get_stats().await { - Err(MempoolServiceError::RequestTimedOut) => {}, - _ => panic!(), - } - }); + match alice_node.outbound_mp_interface.get_stats().await { + Err(MempoolServiceError::RequestTimedOut) => {}, + _ => panic!(), + } } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn block_event_and_reorg_event_handling() { +async fn block_event_and_reorg_event_handling() { // This test creates 2 nodes Alice and Bob // Then creates 2 chains B1 -> B2A (diff 1) and B1 -> B2B (diff 10) // There are 5 transactions created @@ -1086,7 +1074,6 @@ fn block_event_and_reorg_event_handling() { let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxos0) = create_genesis_block_with_coinbase_value(&factories, 100_000_000.into(), &consensus_constants[0]); @@ -1095,13 +1082,13 @@ fn block_event_and_reorg_event_handling() { .with_block(block0.clone()) .build(); let (mut alice, mut bob, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -1135,88 +1122,86 @@ fn block_event_and_reorg_event_handling() { .prepare_block_merkle_roots(chain_block(block0.block(), vec![], &consensus_manager)) .unwrap(); - runtime.block_on(async { - // Add one empty block, so the coinbase UTXO is no longer time-locked. - assert!(bob - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - assert!(alice - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); - let mut block1 = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); - // Add Block1 - tx1 will be moved to the ReorgPool. - assert!(bob - .local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .is_ok()); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - - let mut block2a = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); - // Block2b also builds on Block1 but has a stronger PoW - let mut block2b = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); - - // Add Block2a - tx2b and tx3b will be discarded as double spends. - assert!(bob - .local_nci - .submit_block(block2a.clone(), Broadcast::from(true)) - .await - .is_ok()); - - async_assert_eventually!( - bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - }); + // Add one empty block, so the coinbase UTXO is no longer time-locked. + assert!(bob + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + assert!(alice + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); + let mut block1 = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); + // Add Block1 - tx1 will be moved to the ReorgPool. + assert!(bob + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .is_ok()); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + + let mut block2a = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); + // Block2b also builds on Block1 but has a stronger PoW + let mut block2b = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); + + // Add Block2a - tx2b and tx3b will be discarded as double spends. + assert!(bob + .local_nci + .submit_block(block2a.clone(), Broadcast::from(true)) + .await + .is_ok()); + + async_assert_eventually!( + bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); } diff --git a/base_layer/core/tests/node_comms_interface.rs b/base_layer/core/tests/node_comms_interface.rs index 532d102bbe..a6096b8dc9 100644 --- a/base_layer/core/tests/node_comms_interface.rs +++ b/base_layer/core/tests/node_comms_interface.rs @@ -20,13 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; -use futures::{channel::mpsc, StreamExt}; -use helpers::block_builders::append_block; use std::sync::Arc; + +use futures::StreamExt; +use helpers::block_builders::append_block; use tari_common::configuration::Network; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::PublicKey}; use tari_comms::peer_manager::NodeId; use tari_core::{ base_node::{ @@ -42,7 +41,7 @@ use tari_core::{ helpers::{create_utxo, spend_utxos}, tari_amount::MicroTari, transaction::{OutputFeatures, TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, PublicKey}, + CryptoFactories, }, txn_schema, validation::{mocks::MockValidator, transaction_validators::TxInputAndMaturityValidator}, @@ -56,6 +55,10 @@ use tari_crypto::{ }; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tokio::sync::broadcast; + +use tokio::sync::mpsc; +#[allow(dead_code)] +mod helpers; // use crate::helpers::database::create_test_db; async fn test_request_responder( @@ -71,10 +74,10 @@ fn new_mempool() -> Mempool { Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)) } -#[tokio_macros::test] +#[tokio::test] async fn outbound_get_metadata() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let metadata = ChainMetadata::new(5, vec![0u8], 3, 0, 5); @@ -86,7 +89,7 @@ async fn outbound_get_metadata() { assert_eq!(received_metadata.unwrap(), metadata); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_get_metadata() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -95,7 +98,7 @@ async fn inbound_get_metadata() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -117,7 +120,7 @@ async fn inbound_get_metadata() { } } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_kernel_by_excess_sig() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -126,7 +129,7 @@ async fn inbound_fetch_kernel_by_excess_sig() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -149,10 +152,10 @@ async fn inbound_fetch_kernel_by_excess_sig() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_headers() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let mut header = BlockHeader::new(0); @@ -167,7 +170,7 @@ async fn outbound_fetch_headers() { assert_eq!(received_headers[0], header); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_headers() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -175,7 +178,7 @@ async fn inbound_fetch_headers() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -197,11 +200,11 @@ async fn inbound_fetch_headers() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_utxos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (utxo, _, _) = create_utxo( @@ -221,7 +224,7 @@ async fn outbound_fetch_utxos() { assert_eq!(received_utxos[0], utxo); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_utxos() { let factories = CryptoFactories::default(); @@ -231,7 +234,7 @@ async fn inbound_fetch_utxos() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -264,11 +267,11 @@ async fn inbound_fetch_utxos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_txos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (txo1, _, _) = create_utxo( @@ -296,7 +299,7 @@ async fn outbound_fetch_txos() { assert_eq!(received_txos[1], txo2); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_txos() { let factories = CryptoFactories::default(); let store = create_test_blockchain_db(); @@ -305,7 +308,7 @@ async fn inbound_fetch_txos() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -366,10 +369,10 @@ async fn inbound_fetch_txos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_blocks() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -385,7 +388,7 @@ async fn outbound_fetch_blocks() { assert_eq!(received_blocks[0], block); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_blocks() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -393,7 +396,7 @@ async fn inbound_fetch_blocks() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -415,7 +418,7 @@ async fn inbound_fetch_blocks() { } } -#[tokio_macros::test] +#[tokio::test] // Test needs to be updated to new pruned structure. async fn inbound_fetch_blocks_before_horizon_height() { let factories = CryptoFactories::default(); @@ -437,7 +440,7 @@ async fn inbound_fetch_blocks_before_horizon_height() { let mempool = Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index 14f277f19b..af128e5966 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -23,7 +23,6 @@ #[allow(dead_code)] mod helpers; use crate::helpers::block_builders::{construct_chained_blocks, create_coinbase}; -use futures::join; use helpers::{ block_builders::{ append_block, @@ -59,7 +58,7 @@ use tari_core::{ helpers::{schema_to_transaction, spend_utxos}, tari_amount::{uT, T}, transaction::OutputFeatures, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::{ @@ -68,15 +67,13 @@ use tari_core::{ mocks::MockValidator, }, }; -use tari_crypto::tari_utilities::hash::Hashable; +use tari_crypto::tari_utilities::Hashable; use tari_p2p::services::liveness::LivenessConfig; use tari_test_utils::unpack_enum; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] -fn request_response_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_response_get_metadata() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -89,27 +86,24 @@ fn request_response_get_metadata() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); - assert_eq!(received_metadata.height_of_longest_chain(), 0); + let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); + assert_eq!(received_metadata.height_of_longest_chain(), 0); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -122,13 +116,13 @@ fn request_and_response_fetch_blocks() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -147,26 +141,23 @@ fn request_and_response_fetch_blocks() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks_with_hashes() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks_with_hashes() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -179,13 +170,13 @@ fn request_and_response_fetch_blocks_with_hashes() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -206,34 +197,31 @@ fn request_and_response_fetch_blocks_with_hashes() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(received_blocks[0], received_blocks[1]); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(received_blocks[0], received_blocks[1]); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_many_valid_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_many_valid_blocks() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate a number of block hashes to bob, bob will receive it, request the full block, verify and @@ -261,24 +249,28 @@ fn propagate_and_forward_many_valid_blocks() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity) .with_peers(vec![carol_node_identity, bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -302,56 +294,49 @@ fn propagate_and_forward_many_valid_blocks() { let blocks = construct_chained_blocks(&alice_node.blockchain_db, block0, &rules, 5); - runtime.block_on(async { - for block in &blocks { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block.block()), vec![]) - .await - .unwrap(); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - let block_hash = block.hash(); - - if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Bob's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*carol_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Carol's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*dan_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Dan's node did not receive and validate the expected block"); - } + for block in &blocks { + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block.block()), vec![]) + .await + .unwrap(); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + let block_hash = block.hash(); + + if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Bob's node did not receive and validate the expected block"); } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Carol's node did not receive and validate the expected block"); + } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*dan_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Dan's node did not receive and validate the expected block"); + } + } - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn propagate_and_forward_invalid_block_hash() { +#[tokio::test] +async fn propagate_and_forward_invalid_block_hash() { // Alice will propagate a "made up" block hash to Bob, Bob will request the block from Alice. Alice will not be able // to provide the block and so Bob will not propagate the hash further to Carol. // alice -> bob -> carol - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); @@ -370,19 +355,22 @@ fn propagate_and_forward_invalid_block_hash() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity) .with_peers(vec![bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -409,42 +397,37 @@ fn propagate_and_forward_invalid_block_hash() { let mut bob_message_events = bob_node.messaging_events.subscribe(); let mut carol_message_events = carol_node.messaging_events.subscribe(); - runtime.block_on(async { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .unwrap(); + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .unwrap(); - // Alice propagated to Bob - // Bob received the invalid hash - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); - // Sent the request for the block to Alice - // Bob received a response from Alice - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); - assert_eq!(&*node_id, alice_node.node_identity.node_id()); - // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be - // flaky. - let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; - assert!(msg_event.is_none()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + // Alice propagated to Bob + // Bob received the invalid hash + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + // Sent the request for the block to Alice + // Bob received a response from Alice + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + assert_eq!(&*node_id, alice_node.node_identity.node_id()); + // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be + // flaky. + let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; + assert!(msg_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_invalid_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_invalid_block() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate an invalid block to Carol and Bob, they will check the received block and not propagate the @@ -467,13 +450,14 @@ fn propagate_and_forward_invalid_block() { .with_consensus_constants(consensus_constants) .with_block(block0.clone()) .build(); - let stateless_block_validator = OrphanBlockValidator::new(rules.clone(), factories); + let stateless_block_validator = OrphanBlockValidator::new(rules.clone(), true, factories); let mock_validator = MockValidator::new(false); let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![dan_node_identity.clone()]) @@ -483,20 +467,23 @@ fn propagate_and_forward_invalid_block() { mock_validator.clone(), stateless_block_validator.clone(), ) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![dan_node_identity]) .with_consensus_manager(rules) .with_validators(mock_validator.clone(), mock_validator, stateless_block_validator) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, @@ -520,45 +507,42 @@ fn propagate_and_forward_invalid_block() { let block1 = append_block(&alice_node.blockchain_db, &block0, vec![], &rules, 1.into()).unwrap(); let block1_hash = block1.hash(); - runtime.block_on(async { - let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); - let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); - let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - - assert!(alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .is_ok()); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - - if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Bob's node should have detected an invalid block"); - } - if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Carol's node should have detected an invalid block"); - } - assert!(dan_block_event.is_none()); + let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); + let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); + let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + assert!(alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .is_ok()); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + + if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Bob's node should have detected an invalid block"); + } + if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Carol's node should have detected an invalid block"); + } + assert!(dan_block_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let base_node_service_config = BaseNodeServiceConfig { @@ -569,47 +553,42 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, base_node_service_config, MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - // Bob should not be reachable - bob_node.shutdown().await; - unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); - alice_node.shutdown().await; - }); + // Bob should not be reachable + bob_node.shutdown().await; + unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); + alice_node.shutdown().await; } -#[test] -fn local_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_get_metadata() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let block0 = db.fetch_block(0).unwrap().try_into_chain_block().unwrap(); let block1 = append_block(db, &block0, vec![], &consensus_manager, 1.into()).unwrap(); let block2 = append_block(db, &block1, vec![], &consensus_manager, 1.into()).unwrap(); - runtime.block_on(async { - let metadata = node.local_nci.get_metadata().await.unwrap(); - assert_eq!(metadata.height_of_longest_chain(), 2); - assert_eq!(metadata.best_block(), block2.hash()); + let metadata = node.local_nci.get_metadata().await.unwrap(); + assert_eq!(metadata.height_of_longest_chain(), 2); + assert_eq!(metadata.best_block(), block2.hash()); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_template_and_get_new_block() { +#[tokio::test] +async fn local_get_new_block_template_and_get_new_block() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -620,7 +599,8 @@ fn local_get_new_block_template_and_get_new_block() { .build(); let (mut node, _rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let schema = [ txn_schema!(from: vec![outputs[1].clone()], to: vec![10_000 * uT, 20_000 * uT]), @@ -630,29 +610,26 @@ fn local_get_new_block_template_and_get_new_block() { assert!(node.mempool.insert(txs[0].clone()).is_ok()); assert!(node.mempool.insert(txs[1].clone()).is_ok()); - runtime.block_on(async { - let block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 2); + let block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 2); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); - node.blockchain_db.add_block(block.clone().into()).unwrap(); + node.blockchain_db.add_block(block.clone().into()).unwrap(); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_with_zero_conf() { +#[tokio::test] +async fn local_get_new_block_with_zero_conf() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -666,9 +643,10 @@ fn local_get_new_block_with_zero_conf() { .with_validators( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules, factories.clone()), + OrphanBlockValidator::new(rules, true, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -700,38 +678,35 @@ fn local_get_new_block_with_zero_conf() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_get_new_block_with_combined_transaction() { +#[tokio::test] +async fn local_get_new_block_with_combined_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -745,9 +720,10 @@ fn local_get_new_block_with_combined_transaction() { .with_validators( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules, factories.clone()), + OrphanBlockValidator::new(rules, true, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -774,41 +750,39 @@ fn local_get_new_block_with_combined_transaction() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_submit_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_submit_block() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let mut event_stream = node.local_nci.get_block_event_stream(); @@ -818,20 +792,18 @@ fn local_submit_block() { .unwrap(); block1.header.kernel_mmr_size += 1; block1.header.output_mmr_size += 1; - runtime.block_on(async { - node.local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .unwrap(); + node.local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); - let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; - if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap().unwrap() { - assert_eq!(received_block.hash(), block1.hash()); - result.assert_added(); - } else { - panic!("Block validation failed"); - } + let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; + if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap() { + assert_eq!(received_block.hash(), block1.hash()); + result.assert_added(); + } else { + panic!("Block validation failed"); + } - node.shutdown().await; - }); + node.shutdown().await; } diff --git a/base_layer/core/tests/node_state_machine.rs b/base_layer/core/tests/node_state_machine.rs index bcbeeea436..da9e5e3d95 100644 --- a/base_layer/core/tests/node_state_machine.rs +++ b/base_layer/core/tests/node_state_machine.rs @@ -20,16 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; - -use futures::StreamExt; use helpers::{ block_builders::{append_block, chain_block, create_genesis_block}, chain_metadata::{random_peer_metadata, MockChainMetadata}, nodes::{create_network_with_2_base_nodes_with_config, wait_until_online, BaseNodeBuilder}, }; -use std::{thread, time::Duration}; +use std::time::Duration; use tari_common::configuration::Network; use tari_core::{ base_node::{ @@ -47,22 +43,24 @@ use tari_core::{ mempool::MempoolServiceConfig, proof_of_work::randomx_factory::RandomXFactory, test_helpers::blockchain::create_test_blockchain_db, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::mocks::MockValidator, }; use tari_p2p::services::liveness::LivenessConfig; use tari_shutdown::Shutdown; use tempfile::tempdir; use tokio::{ - runtime::Runtime, sync::{broadcast, watch}, + task, time, }; +#[allow(dead_code)] +mod helpers; + static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn test_listening_lagging() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_listening_lagging() { let factories = CryptoFactories::default(); let network = Network::LocalNet; let temp_dir = tempdir().unwrap(); @@ -75,7 +73,6 @@ fn test_listening_lagging() { .with_block(prev_block.clone()) .build(); let (alice_node, bob_node, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig { @@ -84,7 +81,8 @@ fn test_listening_lagging() { }, consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let shutdown = Shutdown::new(); let (state_change_event_publisher, _) = broadcast::channel(10); let (status_event_sender, _status_event_receiver) = watch::channel(StatusInfo::new()); @@ -103,46 +101,44 @@ fn test_listening_lagging() { consensus_manager.clone(), shutdown.to_signal(), ); - wait_until_online(&mut runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; - let await_event_task = runtime.spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); + let await_event_task = task::spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); - runtime.block_on(async move { - let bob_db = bob_node.blockchain_db; - let mut bob_local_nci = bob_node.local_nci; + let bob_db = bob_node.blockchain_db; + let mut bob_local_nci = bob_node.local_nci; - // Bob Block 1 - no block event - let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); - // Bob Block 2 - with block event and liveness service metadata update - let mut prev_block = bob_db - .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) - .unwrap(); - prev_block.header.output_mmr_size += 1; - prev_block.header.kernel_mmr_size += 1; - bob_local_nci - .submit_block(prev_block, Broadcast::from(true)) - .await - .unwrap(); - assert_eq!(bob_db.get_height().unwrap(), 2); + // Bob Block 1 - no block event + let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); + // Bob Block 2 - with block event and liveness service metadata update + let mut prev_block = bob_db + .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) + .unwrap(); + prev_block.header.output_mmr_size += 1; + prev_block.header.kernel_mmr_size += 1; + bob_local_nci + .submit_block(prev_block, Broadcast::from(true)) + .await + .unwrap(); + assert_eq!(bob_db.get_height().unwrap(), 2); - let next_event = time::timeout(Duration::from_secs(10), await_event_task) - .await - .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") - .unwrap(); + let next_event = time::timeout(Duration::from_secs(10), await_event_task) + .await + .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") + .unwrap(); - match next_event { - StateEvent::InitialSync => {}, - _ => panic!(), - } - }); + match next_event { + StateEvent::InitialSync => {}, + _ => panic!(), + } } -#[test] -fn test_event_channel() { +#[tokio::test] +async fn test_event_channel() { let temp_dir = tempdir().unwrap(); - let mut runtime = Runtime::new().unwrap(); - let (node, consensus_manager) = - BaseNodeBuilder::new(Network::Weatherwax.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (node, consensus_manager) = BaseNodeBuilder::new(Network::Weatherwax.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; // let shutdown = Shutdown::new(); let db = create_test_blockchain_db(); let shutdown = Shutdown::new(); @@ -165,24 +161,21 @@ fn test_event_channel() { shutdown.to_signal(), ); - runtime.spawn(state_machine.run()); + task::spawn(state_machine.run()); let PeerChainMetadata { node_id, chain_metadata, } = random_peer_metadata(10, 5_000); - runtime - .block_on(mock.publish_chain_metadata(&node_id, &chain_metadata)) + mock.publish_chain_metadata(&node_id, &chain_metadata) + .await .expect("Could not publish metadata"); - thread::sleep(Duration::from_millis(50)); - runtime.block_on(async { - let event = state_change_event_subscriber.next().await; - assert_eq!(*event.unwrap().unwrap(), StateEvent::Initialized); - let event = state_change_event_subscriber.next().await; - let event = event.unwrap().unwrap(); - match event.as_ref() { - StateEvent::InitialSync => (), - _ => panic!("Unexpected state was found:{:?}", event), - } - }); + let event = state_change_event_subscriber.recv().await; + assert_eq!(*event.unwrap(), StateEvent::Initialized); + let event = state_change_event_subscriber.recv().await; + let event = event.unwrap(); + match event.as_ref() { + StateEvent::InitialSync => (), + _ => panic!("Unexpected state was found:{:?}", event), + } } diff --git a/base_layer/key_manager/Cargo.toml b/base_layer/key_manager/Cargo.toml index e57d0bf8b4..a40415ce17 100644 --- a/base_layer/key_manager/Cargo.toml +++ b/base_layer/key_manager/Cargo.toml @@ -15,7 +15,7 @@ sha2 = "0.9.5" serde = "1.0.89" serde_derive = "1.0.89" serde_json = "1.0.39" -thiserror = "1.0.20" +thiserror = "1.0.26" [features] avx2 = ["tari_crypto/avx2"] diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index 38b723a5f2..0d725874fb 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -14,7 +14,7 @@ benches = ["criterion"] [dependencies] tari_utilities = "^0.3" -thiserror = "1.0.20" +thiserror = "1.0.26" digest = "0.9.0" log = "0.4" serde = { version = "1.0.97", features = ["derive"] } diff --git a/base_layer/p2p/Cargo.toml b/base_layer/p2p/Cargo.toml index e9d3a0291c..0d7aae7390 100644 --- a/base_layer/p2p/Cargo.toml +++ b/base_layer/p2p/Cargo.toml @@ -10,37 +10,38 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../../comms"} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht"} -tari_common = { version= "^0.9", path = "../../common" } +tari_comms = { version = "^0.9", path = "../../comms" } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } +tari_common = { version = "^0.9", path = "../../common" } tari_crypto = "0.11.1" -tari_service_framework = { version = "^0.9", path = "../service_framework"} -tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } tari_utilities = "^0.3" anyhow = "1.0.32" bytes = "0.5" -chrono = {version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } fs2 = "0.3.0" -futures = {version = "^0.3.1"} +futures = { version = "^0.3.1" } lmdb-zero = "0.4.4" log = "0.4.6" -pgp = {version = "0.7.1", optional = true} -prost = "0.6.1" +pgp = { version = "0.7.1", optional = true } +prost = "=0.8.0" rand = "0.8" -reqwest = {version = "0.10", optional = true, default-features = false} +reqwest = { version = "0.10", optional = true, default-features = false } semver = "1.0.1" serde = "1.0.90" serde_derive = "1.0.90" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["blocking"]} +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tokio-stream = { version = "0.1.7", default-features = false, features = ["time"] } tower = "0.3.0-alpha.2" -tower-service = { version="0.3.0-alpha.2" } -trust-dns-client = {version="0.19.5", features=["dns-over-rustls"]} +tower-service = { version = "0.3.0-alpha.2" } +trust-dns-client = { version = "0.21.0-alpha.1", features = ["dns-over-rustls"] } [dev-dependencies] -tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } clap = "2.33.0" env_logger = "0.6.2" @@ -48,7 +49,6 @@ futures-timer = "0.3.0" lazy_static = "1.3.0" stream-cancel = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.4" [dev-dependencies.log4rs] version = "^0.8" @@ -56,7 +56,7 @@ features = ["console_appender", "file_appender", "file", "yaml_format"] default-features = false [build-dependencies] -tari_common = { version = "^0.9", path="../../common", features = ["build"] } +tari_common = { version = "^0.9", path = "../../common", features = ["build"] } [features] test-mocks = [] diff --git a/base_layer/p2p/examples/gen_tor_identity.rs b/base_layer/p2p/examples/gen_tor_identity.rs index 52a2e3c785..c1a7693de3 100644 --- a/base_layer/p2p/examples/gen_tor_identity.rs +++ b/base_layer/p2p/examples/gen_tor_identity.rs @@ -39,7 +39,7 @@ fn to_abs_path(path: &str) -> String { } } -#[tokio_macros::main] +#[tokio::main] async fn main() { let matches = App::new("Tor identity file generator") .version("1.0") diff --git a/base_layer/p2p/src/auto_update/dns.rs b/base_layer/p2p/src/auto_update/dns.rs index 7042b19919..63fb4ed05e 100644 --- a/base_layer/p2p/src/auto_update/dns.rs +++ b/base_layer/p2p/src/auto_update/dns.rs @@ -32,7 +32,7 @@ use std::{ use tari_common::configuration::bootstrap::ApplicationType; use tari_utilities::hex::{from_hex, Hex}; -const LOG_TARGET: &str = "p2p::auto-update:dns"; +const LOG_TARGET: &str = "p2p::auto_update::dns"; pub struct DnsSoftwareUpdate { client: DnsClient, @@ -189,19 +189,26 @@ impl Display for UpdateSpec { #[cfg(test)] mod test { use super::*; + use crate::dns::mock; use trust_dns_client::{ - proto::rr::{rdata, RData, RecordType}, + op::Query, + proto::{ + rr::{rdata, Name, RData, RecordType}, + xfer::DnsResponse, + }, rr::Record, }; - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let resp_query = Query::query(Name::from_str("test.local.").unwrap(), RecordType::A); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } mod update_spec { @@ -220,7 +227,6 @@ mod test { mod dns_software_update { use super::*; use crate::DEFAULT_DNS_NAME_SERVER; - use std::{collections::HashMap, iter::FromIterator}; impl AutoUpdateConfig { fn get_test_defaults() -> Self { @@ -238,15 +244,15 @@ mod test { } } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_ignores_non_conforming_txt_entries() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec![":::"]), - create_txt_record(vec!["base-node:::"]), - create_txt_record(vec!["base-node::1.0:"]), - create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec![":::"])), + Ok(create_txt_record(vec!["base-node:::"])), + Ok(create_txt_record(vec!["base-node::1.0:"])), + Ok(create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), @@ -258,12 +264,12 @@ mod test { assert!(spec.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_best_update() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), diff --git a/base_layer/p2p/src/auto_update/mod.rs b/base_layer/p2p/src/auto_update/mod.rs index 21de03b693..8ce17b5f2b 100644 --- a/base_layer/p2p/src/auto_update/mod.rs +++ b/base_layer/p2p/src/auto_update/mod.rs @@ -46,7 +46,7 @@ use std::{ use tari_common::configuration::bootstrap::ApplicationType; use tari_utilities::hex::Hex; -const LOG_TARGET: &str = "p2p::auto-update"; +const LOG_TARGET: &str = "p2p::auto_update"; #[derive(Debug, Clone)] pub struct AutoUpdateConfig { @@ -58,6 +58,12 @@ pub struct AutoUpdateConfig { pub hashes_sig_url: String, } +impl AutoUpdateConfig { + pub fn is_update_enabled(&self) -> bool { + !self.update_uris.is_empty() + } +} + pub async fn check_for_updates( app: ApplicationType, arch: &str, diff --git a/base_layer/p2p/src/auto_update/service.rs b/base_layer/p2p/src/auto_update/service.rs index a786d84ad3..a235ec6fe0 100644 --- a/base_layer/p2p/src/auto_update/service.rs +++ b/base_layer/p2p/src/auto_update/service.rs @@ -24,19 +24,19 @@ use crate::{ auto_update, auto_update::{AutoUpdateConfig, SoftwareUpdate, Version}, }; -use futures::{ - channel::{mpsc, oneshot}, - future::Either, - stream, - SinkExt, - StreamExt, -}; +use futures::{future::Either, stream, StreamExt}; +use log::*; use std::{env::consts, time::Duration}; use tari_common::configuration::bootstrap::ApplicationType; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::{sync::watch, time}; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; +use tokio_stream::wrappers; -const LOG_TARGET: &str = "app:auto-update"; +const LOG_TARGET: &str = "p2p::auto_update"; /// A watch notifier that contains the latest software update, if any pub type SoftwareUpdateNotifier = watch::Receiver>; @@ -94,20 +94,25 @@ impl SoftwareUpdaterService { new_update_notification: watch::Receiver>, ) { let mut interval_or_never = match self.check_interval { - Some(interval) => Either::Left(time::interval(interval)).fuse(), - None => Either::Right(stream::empty()).fuse(), + Some(interval) => { + let mut interval = time::interval(interval); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + Either::Left(wrappers::IntervalStream::new(interval)) + }, + None => Either::Right(stream::empty()), }; loop { let last_version = new_update_notification.borrow().clone(); - let maybe_update = futures::select! { - reply = request_rx.select_next_some() => { + let maybe_update = tokio::select! { + Some(reply) = request_rx.recv() => { let maybe_update = self.check_for_updates().await; let _ = reply.send(maybe_update.clone()); maybe_update }, - _ = interval_or_never.next() => { + + Some(_) = interval_or_never.next() => { // Periodically, check for updates if configured to do so. // If an update is found the new update notifier will be triggered and any listeners notified self.check_for_updates().await @@ -121,7 +126,7 @@ impl SoftwareUpdaterService { .map(|up| up.version() < update.version()) .unwrap_or(true) { - let _ = notifier.broadcast(Some(update.clone())); + let _ = notifier.send(Some(update.clone())); } } } @@ -133,6 +138,13 @@ impl SoftwareUpdaterService { "Checking for updates ({})...", self.config.update_uris.join(", ") ); + if !self.config.is_update_enabled() { + warn!( + target: LOG_TARGET, + "Check for updates has been called but auto update has been disabled in the config" + ); + return None; + } let arch = format!("{}-{}", consts::OS, consts::ARCH); diff --git a/base_layer/p2p/src/comms_connector/inbound_connector.rs b/base_layer/p2p/src/comms_connector/inbound_connector.rs index ed16cd578d..6feed82be7 100644 --- a/base_layer/p2p/src/comms_connector/inbound_connector.rs +++ b/base_layer/p2p/src/comms_connector/inbound_connector.rs @@ -22,45 +22,42 @@ use super::peer_message::PeerMessage; use anyhow::anyhow; -use futures::{task::Context, Future, Sink, SinkExt}; +use futures::{task::Context, Future}; use log::*; use std::{pin::Pin, sync::Arc, task::Poll}; use tari_comms::pipeline::PipelineError; use tari_comms_dht::{domain_message::MessageHeader, inbound::DecryptedDhtMessage}; +use tokio::sync::mpsc; use tower::Service; const LOG_TARGET: &str = "comms::middleware::inbound_connector"; /// This service receives DecryptedDhtMessage, deserializes the MessageHeader and /// sends a `PeerMessage` on the given sink. #[derive(Clone)] -pub struct InboundDomainConnector { - sink: TSink, +pub struct InboundDomainConnector { + sink: mpsc::Sender>, } -impl InboundDomainConnector { - pub fn new(sink: TSink) -> Self { +impl InboundDomainConnector { + pub fn new(sink: mpsc::Sender>) -> Self { Self { sink } } } -impl Service for InboundDomainConnector -where - TSink: Sink> + Unpin + Clone + 'static, - TSink::Error: std::error::Error + Send + Sync + 'static, -{ +impl Service for InboundDomainConnector { type Error = PipelineError; - type Future = Pin>>>; + type Future = Pin> + Send>>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - let mut sink = self.sink.clone(); + let sink = self.sink.clone(); let future = async move { let peer_message = Self::construct_peer_message(msg)?; - // If this fails there is something wrong with the sink and the pubsub middleware should not + // If this fails the channel has closed and the pubsub middleware should not // continue sink.send(Arc::new(peer_message)).await?; @@ -70,7 +67,7 @@ where } } -impl InboundDomainConnector { +impl InboundDomainConnector { fn construct_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result { let envelope_body = inbound_message .success_mut() @@ -107,41 +104,17 @@ impl InboundDomainConnector { } } -impl Sink for InboundDomainConnector -where - TSink: Sink> + Unpin, - TSink::Error: Into + Send + Sync + 'static, -{ - type Error = PipelineError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) - } - - fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { - let item = Self::construct_peer_message(item)?; - Pin::new(&mut self.sink).start_send(Arc::new(item)).map_err(Into::into) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_flush(cx).map_err(Into::into) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_close(cx).map_err(Into::into) - } -} - #[cfg(test)] mod test { use super::*; use crate::test_utils::{make_dht_inbound_message, make_node_identity}; - use futures::{channel::mpsc, executor::block_on, StreamExt}; + use futures::executor::block_on; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_comms_dht::domain_message::MessageHeader; + use tokio::sync::mpsc; use tower::ServiceExt; - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -151,12 +124,12 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -165,14 +138,14 @@ mod test { let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); - InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); + InboundDomainConnector::new(tx).call(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); @@ -182,10 +155,11 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); - assert!(rx.try_next().unwrap().is_none()); + rx.close(); + assert!(rx.recv().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_send() { // Drop the receiver of the channel, this is the only reason this middleware should return an error // from it's call function diff --git a/base_layer/p2p/src/comms_connector/pubsub.rs b/base_layer/p2p/src/comms_connector/pubsub.rs index ae01e8ced8..198eee63f2 100644 --- a/base_layer/p2p/src/comms_connector/pubsub.rs +++ b/base_layer/p2p/src/comms_connector/pubsub.rs @@ -22,11 +22,15 @@ use super::peer_message::PeerMessage; use crate::{comms_connector::InboundDomainConnector, tari_message::TariMessageType}; -use futures::{channel::mpsc, future, stream::Fuse, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{cmp, fmt::Debug, sync::Arc, time::Duration}; use tari_comms::rate_limit::RateLimit; -use tokio::{runtime::Handle, sync::broadcast}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "comms::middleware::pubsub"; @@ -35,16 +39,11 @@ const RATE_LIMIT_MIN_CAPACITY: usize = 5; const RATE_LIMIT_RESTOCK_INTERVAL: Duration = Duration::from_millis(1000); /// Alias for a pubsub-type domain connector -pub type PubsubDomainConnector = InboundDomainConnector>>; +pub type PubsubDomainConnector = InboundDomainConnector; pub type SubscriptionFactory = TopicSubscriptionFactory>; /// Connects `InboundDomainConnector` to a `tari_pubsub::TopicPublisher` through a buffered broadcast channel -pub fn pubsub_connector( - // TODO: Remove this arg in favor of task::spawn - executor: Handle, - buf_size: usize, - rate_limit: usize, -) -> (PubsubDomainConnector, SubscriptionFactory) { +pub fn pubsub_connector(buf_size: usize, rate_limit: usize) -> (PubsubDomainConnector, SubscriptionFactory) { let (publisher, subscription_factory) = pubsub_channel(buf_size); let (sender, receiver) = mpsc::channel(buf_size); trace!( @@ -55,8 +54,8 @@ pub fn pubsub_connector( ); // Spawn a task which forwards messages from the pubsub service to the TopicPublisher - executor.spawn(async move { - let forwarder = receiver + task::spawn(async move { + wrappers::ReceiverStream::new(receiver) // Rate limit the receiver; the sender will adhere to the limit .rate_limit(cmp::max(rate_limit, RATE_LIMIT_MIN_CAPACITY), RATE_LIMIT_RESTOCK_INTERVAL) // Map DomainMessage into a TopicPayload @@ -89,8 +88,7 @@ pub fn pubsub_connector( ); } future::ready(()) - }); - forwarder.await; + }).await; }); (InboundDomainConnector::new(sender), subscription_factory) } @@ -98,8 +96,8 @@ pub fn pubsub_connector( /// Create a topic-based pub-sub channel fn pubsub_channel(size: usize) -> (TopicPublisher, TopicSubscriptionFactory) where - T: Clone + Debug + Send + Eq, - M: Send + Clone, + T: Clone + Debug + Send + Eq + 'static, + M: Send + Clone + 'static, { let (publisher, _) = broadcast::channel(size); (publisher.clone(), TopicSubscriptionFactory::new(publisher)) @@ -138,8 +136,8 @@ pub struct TopicSubscriptionFactory { impl TopicSubscriptionFactory where - T: Clone + Eq + Debug + Send, - M: Clone + Send, + T: Clone + Eq + Debug + Send + 'static, + M: Clone + Send + 'static, { pub fn new(sender: broadcast::Sender>) -> Self { TopicSubscriptionFactory { sender } @@ -148,38 +146,22 @@ where /// Create a subscription stream to a particular topic. The provided label is used to identify which consumer is /// lagging. pub fn get_subscription(&self, topic: T, label: &'static str) -> impl Stream { - self.sender - .subscribe() - .filter_map({ - let topic = topic.clone(); - move |result| { - let opt = match result { - Ok(payload) => Some(payload), - Err(broadcast::RecvError::Closed) => None, - Err(broadcast::RecvError::Lagged(n)) => { - warn!( - target: LOG_TARGET, - "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n - ); - None - }, - }; - future::ready(opt) - } - }) - .filter_map(move |item| { - let opt = if item.topic() == &topic { - Some(item.message) - } else { - None + wrappers::BroadcastStream::new(self.sender.subscribe()).filter_map({ + move |result| { + let opt = match result { + Ok(payload) if *payload.topic() == topic => Some(payload.message), + Ok(_) => None, + Err(wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => { + warn!( + target: LOG_TARGET, + "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n + ); + None + }, }; future::ready(opt) - }) - } - - /// Convenience function that returns a fused (`stream::Fuse`) version of the subscription stream. - pub fn get_subscription_fused(&self, topic: T, label: &'static str) -> Fuse> { - self.get_subscription(topic, label).fuse() + } + }) } } @@ -190,7 +172,7 @@ mod test { use std::time::Duration; use tari_test_utils::collect_stream; - #[tokio_macros::test_basic] + #[tokio::test] async fn topic_pub_sub() { let (publisher, subscriber_factory) = pubsub_channel(10); diff --git a/base_layer/p2p/src/dns/client.rs b/base_layer/p2p/src/dns/client.rs index 85e03f71e4..78093186ef 100644 --- a/base_layer/p2p/src/dns/client.rs +++ b/base_layer/p2p/src/dns/client.rs @@ -20,35 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#[cfg(test)] +use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use super::DnsClientError; -use futures::future; +use futures::{future, FutureExt}; use std::{net::SocketAddr, sync::Arc}; use tari_shutdown::Shutdown; use tokio::{net::UdpSocket, task}; use trust_dns_client::{ - client::{AsyncClient, AsyncDnssecClient}, - op::{DnsResponse, Query}, - proto::{ - rr::dnssec::TrustAnchor, - udp::{UdpClientStream, UdpResponse}, - xfer::DnsRequestOptions, - DnsHandle, - }, + client::{AsyncClient, AsyncDnssecClient, ClientHandle}, + op::Query, + proto::{error::ProtoError, rr::dnssec::TrustAnchor, udp::UdpClientStream, xfer::DnsResponse, DnsHandle}, rr::{DNSClass, IntoName, RecordType}, serialize::binary::BinEncoder, }; -#[cfg(test)] -use std::collections::HashMap; -#[cfg(test)] -use trust_dns_client::{proto::xfer::DnsMultiplexerSerialResponse, rr::Record}; - #[derive(Clone)] pub enum DnsClient { - Secure(Client>), - Normal(Client>), + Secure(Client), + Normal(Client), #[cfg(test)] - Mock(Client>), + Mock(Client>), } impl DnsClient { @@ -63,18 +56,18 @@ impl DnsClient { } #[cfg(test)] - pub async fn connect_mock(records: HashMap<&'static str, Vec>) -> Result { - let client = Client::connect_mock(records).await?; + pub async fn connect_mock(messages: Vec>) -> Result { + let client = Client::connect_mock(messages).await?; Ok(DnsClient::Mock(client)) } - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { + pub async fn lookup(&mut self, query: Query) -> Result { use DnsClient::*; match self { - Secure(ref mut client) => client.lookup(query, options).await, - Normal(ref mut client) => client.lookup(query, options).await, + Secure(ref mut client) => client.lookup(query).await, + Normal(ref mut client) => client.lookup(query).await, #[cfg(test)] - Mock(ref mut client) => client.lookup(query, options).await, + Mock(ref mut client) => client.lookup(query).await, } } @@ -85,11 +78,11 @@ impl DnsClient { .set_query_class(DNSClass::IN) .set_query_type(RecordType::TXT); - let response = self.lookup(query, Default::default()).await?; + let responses = self.lookup(query).await?; - let records = response - .messages() - .flat_map(|msg| msg.answers()) + let records = responses + .answers() + .iter() .map(|answer| { let data = answer.rdata(); let mut buf = Vec::new(); @@ -116,7 +109,7 @@ pub struct Client { shutdown: Arc, } -impl Client> { +impl Client { pub async fn connect_secure(name_server: SocketAddr, trust_anchor: TrustAnchor) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); @@ -124,7 +117,7 @@ impl Client> { .trust_anchor(trust_anchor) .build() .await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -133,12 +126,12 @@ impl Client> { } } -impl Client> { +impl Client { pub async fn connect(name_server: SocketAddr) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); let (client, background) = AsyncClient::connect(stream).await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -148,87 +141,31 @@ impl Client> { } impl Client -where C: DnsHandle +where C: DnsHandle { - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { - let resp = self.inner.lookup(query, options).await?; - Ok(resp) + pub async fn lookup(&mut self, query: Query) -> Result { + let client_resp = self + .inner + .query(query.name().clone(), query.query_class(), query.query_type()) + .await?; + Ok(client_resp) } } #[cfg(test)] mod mock { use super::*; - use futures::{channel::mpsc, future, Stream, StreamExt}; - use std::{ - fmt, - fmt::Display, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; + use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use std::sync::Arc; use tari_shutdown::Shutdown; - use tokio::task; - use trust_dns_client::{ - client::AsyncClient, - op::Message, - proto::{ - error::ProtoError, - xfer::{DnsClientStream, DnsMultiplexerSerialResponse, SerialMessage}, - StreamHandle, - }, - rr::Record, - }; - - pub struct MockStream { - receiver: mpsc::UnboundedReceiver>, - answers: HashMap<&'static str, Vec>, - } - - impl DnsClientStream for MockStream { - fn name_server_addr(&self) -> SocketAddr { - ([0u8, 0, 0, 0], 53).into() - } - } - - impl Display for MockStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MockStream") - } - } - - impl Stream for MockStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let req = match futures::ready!(self.receiver.poll_next_unpin(cx)) { - Some(r) => r, - None => return Poll::Ready(None), - }; - let req = Message::from_vec(&req).unwrap(); - let name = req.queries()[0].name().to_string(); - let mut msg = Message::new(); - let answers = self.answers.get(name.as_str()).into_iter().flatten().cloned(); - msg.set_id(req.id()).add_answers(answers); - Poll::Ready(Some(Ok(SerialMessage::new( - msg.to_vec().unwrap(), - self.name_server_addr(), - )))) - } - } - - impl Client> { - pub async fn connect_mock(answers: HashMap<&'static str, Vec>) -> Result { - let (tx, rx) = mpsc::unbounded(); - let stream = future::ready(Ok(MockStream { receiver: rx, answers })); - let (client, background) = AsyncClient::new(stream, Box::new(StreamHandle::new(tx)), None).await?; + use trust_dns_client::proto::error::ProtoError; - let shutdown = Shutdown::new(); - task::spawn(future::select(shutdown.to_signal(), background)); + impl Client> { + pub async fn connect_mock(messages: Vec>) -> Result { + let client = MockClientHandle::mock(messages); Ok(Self { inner: client, - shutdown: Arc::new(shutdown), + shutdown: Arc::new(Shutdown::new()), }) } } diff --git a/base_layer/p2p/src/dns/mock.rs b/base_layer/p2p/src/dns/mock.rs new file mode 100644 index 0000000000..9dec5155cb --- /dev/null +++ b/base_layer/p2p/src/dns/mock.rs @@ -0,0 +1,105 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{future, stream, Future}; +use std::{error::Error, pin::Pin, sync::Arc}; +use trust_dns_client::{ + op::{Message, Query}, + proto::{ + error::ProtoError, + xfer::{DnsHandle, DnsRequest, DnsResponse}, + }, + rr::Record, +}; + +#[derive(Clone)] +pub struct MockClientHandle { + messages: Arc>>, + on_send: O, +} + +impl MockClientHandle { + /// constructs a new MockClient which returns each Message one after the other + pub fn mock(messages: Vec>) -> Self { + println!("MockClientHandle::mock message count: {}", messages.len()); + + MockClientHandle { + messages: Arc::new(messages), + on_send: DefaultOnSend, + } + } +} + +impl DnsHandle for MockClientHandle +where E: From + Error + Clone + Send + Sync + Unpin + 'static +{ + type Error = E; + type Response = stream::Once>>; + + fn send>(&mut self, _: R) -> Self::Response { + let responses = (*self.messages) + .clone() + .into_iter() + .fold(Result::<_, E>::Ok(Message::new()), |msg, resp| { + msg.and_then(|mut msg| { + resp.map(move |resp| { + msg.add_answers(resp.answers().iter().cloned()); + msg + }) + }) + }) + .map(DnsResponse::from); + + // let stream = stream::unfold(messages, |mut msgs| async move { + // let msg = msgs.pop()?; + // Some((msg, msgs)) + // }); + + stream::once(future::ready(responses)) + } +} + +pub fn message(query: Query, answers: Vec, name_servers: Vec, additionals: Vec) -> Message { + let mut message = Message::new(); + message.add_query(query); + message.insert_answers(answers); + message.insert_name_servers(name_servers); + message.insert_additionals(additionals); + message +} + +pub trait OnSend: Clone + Send + Sync + 'static { + fn on_send( + &mut self, + response: Result, + ) -> Pin> + Send>> + where + E: From + Send + 'static, + { + Box::pin(future::ready(response)) + } +} + +#[derive(Clone)] +pub struct DefaultOnSend; + +impl OnSend for DefaultOnSend {} diff --git a/base_layer/p2p/src/dns/mod.rs b/base_layer/p2p/src/dns/mod.rs index 197b788236..8c49de1993 100644 --- a/base_layer/p2p/src/dns/mod.rs +++ b/base_layer/p2p/src/dns/mod.rs @@ -4,6 +4,9 @@ pub use client::DnsClient; mod error; pub use error::DnsClientError; +#[cfg(test)] +pub(crate) mod mock; + use trust_dns_client::proto::rr::dnssec::{public_key::Rsa, TrustAnchor}; #[inline] diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9264a5033c..0915479612 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -19,20 +19,20 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] use crate::{ - comms_connector::{InboundDomainConnector, PeerMessage, PubsubDomainConnector}, + comms_connector::{InboundDomainConnector, PubsubDomainConnector}, peer_seeds::{DnsSeedResolver, SeedPeer}, transport::{TorConfig, TransportType}, MAJOR_NETWORK_VERSION, MINOR_NETWORK_VERSION, }; use fs2::FileExt; -use futures::{channel::mpsc, future, Sink}; +use futures::future; use log::*; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::{ - error::Error, fs::File, iter, net::SocketAddr, @@ -47,7 +47,6 @@ use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerManagerError}, pipeline, - pipeline::SinkService, protocol::{ messaging::{MessagingEventSender, MessagingProtocolExtension}, rpc::RpcServer, @@ -71,7 +70,7 @@ use tari_storage::{ LMDBWrapper, }; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::ServiceBuilder; const LOG_TARGET: &str = "p2p::initialization"; @@ -158,18 +157,14 @@ pub struct CommsConfig { } /// Initialize Tari Comms configured for tests -pub async fn initialize_local_test_comms( +pub async fn initialize_local_test_comms( node_identity: Arc, - connector: InboundDomainConnector, + connector: InboundDomainConnector, data_path: &str, discovery_request_timeout: Duration, seed_peers: Vec, shutdown_signal: ShutdownSignal, -) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> { let peer_database_name = { let mut rng = thread_rng(); iter::repeat(()) @@ -230,7 +225,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); @@ -319,15 +314,11 @@ async fn initialize_hidden_service( builder.build().await } -async fn configure_comms_and_dht( +async fn configure_comms_and_dht( builder: CommsBuilder, config: &CommsConfig, - connector: InboundDomainConnector, -) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ + connector: InboundDomainConnector, +) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> { let file_lock = acquire_exclusive_file_lock(&config.datastore_path)?; let datastore = LMDBBuilder::new() @@ -391,7 +382,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); diff --git a/base_layer/p2p/src/lib.rs b/base_layer/p2p/src/lib.rs index c21dace083..93cb45d1db 100644 --- a/base_layer/p2p/src/lib.rs +++ b/base_layer/p2p/src/lib.rs @@ -20,8 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work -#![recursion_limit = "256"] #![cfg_attr(not(debug_assertions), deny(unused_variables))] #![cfg_attr(not(debug_assertions), deny(unused_imports))] #![cfg_attr(not(debug_assertions), deny(dead_code))] diff --git a/base_layer/p2p/src/peer_seeds.rs b/base_layer/p2p/src/peer_seeds.rs index 24ce3c4dda..4698770fa1 100644 --- a/base_layer/p2p/src/peer_seeds.rs +++ b/base_layer/p2p/src/peer_seeds.rs @@ -120,6 +120,8 @@ mod test { use super::*; use tari_utilities::hex::Hex; + const TEST_NAME: &str = "test.local."; + mod peer_seed { use super::*; @@ -182,76 +184,88 @@ mod test { mod peer_seed_resolver { use super::*; - use std::{collections::HashMap, iter::FromIterator}; - use trust_dns_client::rr::{rdata, RData, Record, RecordType}; + use crate::dns::mock; + use trust_dns_client::{ + proto::{ + op::Query, + rr::{DNSClass, Name}, + xfer::DnsResponse, + }, + rr::{rdata, RData, Record, RecordType}, + }; #[ignore = "This test requires network IO and is mostly useful during development"] - #[tokio_macros::test] - async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { + #[tokio::test] + async fn it_returns_seeds_from_real_address() { let mut resolver = DnsSeedResolver { client: DnsClient::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(), }; - let seeds = resolver.resolve("tari.com").await.unwrap(); - assert!(seeds.is_empty()); + let seeds = resolver.resolve("seeds.weatherwax.tari.com").await.unwrap(); + assert!(!seeds.is_empty()); } - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let mut resp_query = Query::query(Name::from_str(TEST_NAME).unwrap(), RecordType::TXT); + resp_query.set_query_class(DNSClass::IN); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } - #[tokio_macros::test] + #[tokio::test] async fn it_returns_peer_seeds() { - let records = HashMap::from_iter([("test.local.", vec![ + let records = vec![ // Multiple addresses(works) - create_txt_record(vec![ - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::/\ + Ok(create_txt_record(vec![ + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c::/ip4/127.0.0.1/tcp/8000::/\ onion3/bsmuof2cn4y2ysz253gzsvg3s72fcgh4f3qcm3hdlxdtcwe6al2dicyd:1234", - ]), + ])), // Misc - create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"]), + Ok(create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"])), // Single address (works) - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // Single address trailing delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::", - ]), + ])), // Invalid public key - create_txt_record(vec![ + Ok(create_txt_record(vec![ "07e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // No Address with delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::", - ]), + ])), // No Address no delim - create_txt_record(vec!["06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a"]), + Ok(create_txt_record(vec![ + "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a", + ])), // Invalid address - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/onion3/invalid:1234", - ]), - ])]); + ])), + ]; let mut resolver = DnsSeedResolver { client: DnsClient::connect_mock(records).await.unwrap(), }; - let seeds = resolver.resolve("test.local.").await.unwrap(); + let seeds = resolver.resolve(TEST_NAME).await.unwrap(); assert_eq!(seeds.len(), 2); assert_eq!( seeds[0].public_key.to_hex(), - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c" ); + assert_eq!(seeds[0].addresses.len(), 2); assert_eq!( seeds[1].public_key.to_hex(), "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" ); - assert_eq!(seeds[0].addresses.len(), 2); assert_eq!(seeds[1].addresses.len(), 1); } } diff --git a/base_layer/p2p/src/services/liveness/mock.rs b/base_layer/p2p/src/services/liveness/mock.rs index fa5f2f72c1..470cca84c2 100644 --- a/base_layer/p2p/src/services/liveness/mock.rs +++ b/base_layer/p2p/src/services/liveness/mock.rs @@ -36,9 +36,9 @@ use std::sync::{ RwLock, }; -use tari_crypto::tari_utilities::acquire_write_lock; +use tari_crypto::tari_utilities::{acquire_read_lock, acquire_write_lock}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::sync::{broadcast, broadcast::SendError}; +use tokio::sync::{broadcast, broadcast::error::SendError}; const LOG_TARGET: &str = "p2p::liveness_mock"; @@ -69,7 +69,8 @@ impl LivenessMockState { } pub async fn publish_event(&self, event: LivenessEvent) -> Result<(), SendError>> { - acquire_write_lock!(self.event_publisher).send(Arc::new(event))?; + let lock = acquire_read_lock!(self.event_publisher); + lock.send(Arc::new(event))?; Ok(()) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 65e85e0ea8..0f35122ea9 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -49,6 +49,7 @@ use tari_comms_dht::{ use tari_service_framework::reply_channel::RequestContext; use tari_shutdown::ShutdownSignal; use tokio::time; +use tokio_stream::wrappers; /// Service responsible for testing Liveness of Peers. pub struct LivenessService { @@ -59,7 +60,7 @@ pub struct LivenessService { connectivity: ConnectivityRequester, outbound_messaging: OutboundMessageRequester, event_publisher: LivenessEventSender, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl LivenessService @@ -85,7 +86,7 @@ where connectivity, outbound_messaging, event_publisher, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, config, } } @@ -100,39 +101,36 @@ where pin_mut!(request_stream); let mut ping_tick = match self.config.auto_ping_interval { - Some(interval) => Either::Left(time::interval_at((Instant::now() + interval).into(), interval)), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval_at( + (Instant::now() + interval).into(), + interval, + ))), None => Either::Right(futures::stream::iter(iter::empty())), - } - .fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("Liveness service initialized without shutdown signal"); + }; loop { - futures::select! { + tokio::select! { // Requests from the handle - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let _ = reply_tx.send(self.handle_request(request).await); }, // Tick events - _ = ping_tick.select_next_some() => { + Some(_) = ping_tick.next() => { if let Err(err) = self.start_ping_round().await { warn!(target: LOG_TARGET, "Error when pinging peers: {}", err); } }, // Incoming messages from the Comms layer - msg = ping_stream.select_next_some() => { + Some(msg) = ping_stream.next() => { if let Err(err) = self.handle_incoming_message(msg).await { warn!(target: LOG_TARGET, "Failed to handle incoming PingPong message: {}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Liveness service shutting down because the shutdown signal was received"); break; } @@ -143,11 +141,13 @@ where async fn handle_incoming_message(&mut self, msg: DomainMessage) -> Result<(), LivenessError> { let DomainMessage::<_> { source_peer, + dht_header, inner: ping_pong_msg, .. } = msg; let node_id = source_peer.node_id; let public_key = source_peer.public_key; + let message_tag = dht_header.message_tag; match ping_pong_msg.kind().ok_or(LivenessError::InvalidPingPongType)? { PingPong::Ping => { @@ -157,9 +157,10 @@ where debug!( target: LOG_TARGET, - "Received ping from peer '{}' with useragent '{}'", + "Received ping from peer '{}' with useragent '{}' (Trace: {})", node_id.short_str(), source_peer.user_agent, + message_tag, ); let ping_event = PingPongEvent::new(node_id, None, ping_pong_msg.metadata.into()); @@ -169,9 +170,10 @@ where if !self.state.is_inflight(ping_pong_msg.nonce) { debug!( target: LOG_TARGET, - "Received Pong that was not requested from '{}' with useragent {}. Ignoring it.", + "Received Pong that was not requested from '{}' with useragent {}. Ignoring it. (Trace: {})", node_id.short_str(), source_peer.user_agent, + message_tag, ); return Ok(()); } @@ -179,10 +181,11 @@ where let maybe_latency = self.state.record_pong(ping_pong_msg.nonce); debug!( target: LOG_TARGET, - "Received pong from peer '{}' with useragent '{}'. {}", + "Received pong from peer '{}' with useragent '{}'. {} (Trace: {})", node_id.short_str(), source_peer.user_agent, maybe_latency.map(|ms| format!("Latency: {}ms", ms)).unwrap_or_default(), + message_tag, ); let pong_event = PingPongEvent::new(node_id, maybe_latency, ping_pong_msg.metadata.into()); @@ -306,10 +309,7 @@ mod test { proto::liveness::MetadataKey, services::liveness::{handle::LivenessHandle, state::Metadata}, }; - use futures::{ - channel::{mpsc, oneshot}, - stream, - }; + use futures::stream; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ @@ -325,9 +325,12 @@ mod test { use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; use tari_shutdown::Shutdown; - use tokio::{sync::broadcast, task}; + use tokio::{ + sync::{broadcast, mpsc, oneshot}, + task, + }; - #[tokio_macros::test_basic] + #[tokio::test] async fn get_ping_pong_count() { let mut state = LivenessState::new(); state.inc_pings_received(); @@ -369,7 +372,7 @@ mod test { assert_eq!(res, 2); } - #[tokio_macros::test] + #[tokio::test] async fn send_ping() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); @@ -401,8 +404,9 @@ mod test { let node_id = NodeId::from_key(&pk); // Receive outbound request task::spawn(async move { - match outbound_rx.select_next_some().await { - DhtOutboundRequest::SendMessage(_, _, reply_tx) => { + #[allow(clippy::single_match)] + match outbound_rx.recv().await { + Some(DhtOutboundRequest::SendMessage(_, _, reply_tx)) => { let (_, rx) = oneshot::channel(); reply_tx .send(SendMessageResponse::Queued( @@ -410,6 +414,7 @@ mod test { )) .unwrap(); }, + None => {}, } }); @@ -445,7 +450,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn handle_message_ping() { let state = LivenessState::new(); @@ -478,10 +483,10 @@ mod test { task::spawn(service.run()); // Test oms got request to send message - unwrap_oms_send_msg!(outbound_rx.select_next_some().await); + unwrap_oms_send_msg!(outbound_rx.recv().await.unwrap()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_pong() { let mut state = LivenessState::new(); @@ -516,9 +521,9 @@ mod test { task::spawn(service.run()); // Listen for the pong event - let subscriber = publisher.subscribe(); + let mut subscriber = publisher.subscribe(); - let event = time::timeout(Duration::from_secs(10), subscriber.fuse().select_next_some()) + let event = time::timeout(Duration::from_secs(10), subscriber.recv()) .await .unwrap() .unwrap(); @@ -530,12 +535,12 @@ mod test { _ => panic!("Unexpected event"), } - shutdown.trigger().unwrap(); + shutdown.trigger(); // No further events (malicious_msg was ignored) - let mut subscriber = publisher.subscribe().fuse(); + let mut subscriber = publisher.subscribe(); drop(publisher); - let msg = subscriber.next().await; - assert!(msg.is_none()); + let msg = subscriber.recv().await; + assert!(msg.is_err()); } } diff --git a/base_layer/p2p/tests/services/liveness.rs b/base_layer/p2p/tests/services/liveness.rs index ab9a66cf59..2505c15543 100644 --- a/base_layer/p2p/tests/services/liveness.rs +++ b/base_layer/p2p/tests/services/liveness.rs @@ -35,16 +35,15 @@ use tari_p2p::{ }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tari_test_utils::collect_stream; +use tari_test_utils::collect_try_recv; use tempfile::tempdir; -use tokio::runtime; pub async fn setup_liveness_service( node_identity: Arc, peers: Vec>, data_path: &str, ) -> (LivenessHandle, CommsNode, Dht, Shutdown) { - let (publisher, subscription_factory) = pubsub_connector(runtime::Handle::current(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let shutdown = Shutdown::new(); let (comms, dht, _) = @@ -75,7 +74,7 @@ fn make_node_identity() -> Arc { )) } -#[tokio_macros::test_basic] +#[tokio::test] async fn end_to_end() { let node_1_identity = make_node_identity(); let node_2_identity = make_node_identity(); @@ -114,34 +113,34 @@ async fn end_to_end() { liveness1.send_ping(node_2_identity.node_id().clone()).await.unwrap(); } - let events = collect_stream!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20),); + let events = collect_try_recv!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 10); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 8); - let events = collect_stream!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10),); + let events = collect_try_recv!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 8); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 10); diff --git a/base_layer/p2p/tests/support/comms_and_services.rs b/base_layer/p2p/tests/support/comms_and_services.rs index 40cc85a710..33bc8fdef7 100644 --- a/base_layer/p2p/tests/support/comms_and_services.rs +++ b/base_layer/p2p/tests/support/comms_and_services.rs @@ -20,27 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeIdentity, protocol::messaging::MessagingEventSender, CommsNode}; use tari_comms_dht::Dht; -use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, - initialization::initialize_local_test_comms, -}; +use tari_p2p::{comms_connector::InboundDomainConnector, initialization::initialize_local_test_comms}; use tari_shutdown::ShutdownSignal; -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht, MessagingEventSender) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, messaging_events) = initialize_local_test_comms( node_identity, diff --git a/base_layer/service_framework/Cargo.toml b/base_layer/service_framework/Cargo.toml index 49e829c949..ead5e311f2 100644 --- a/base_layer/service_framework/Cargo.toml +++ b/base_layer/service_framework/Cargo.toml @@ -14,15 +14,15 @@ tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } anyhow = "1.0.32" async-trait = "0.1.50" -futures = { version = "^0.3.1", features=["async-await"]} +futures = { version = "^0.3.16", features=["async-await"]} log = "0.4.8" -thiserror = "1.0.20" -tokio = { version = "0.2.10" } +thiserror = "1.0.26" +tokio = {version="1.10", features=["rt"]} tower-service = { version="0.3.0" } [dev-dependencies] tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tokio = {version="1.10", features=["rt-multi-thread", "macros", "time"]} futures-test = { version = "0.3.3" } -tokio-macros = "0.2.5" tower = "0.3.1" diff --git a/base_layer/service_framework/examples/services/service_a.rs b/base_layer/service_framework/examples/services/service_a.rs index c898696415..dfbd9ace93 100644 --- a/base_layer/service_framework/examples/services/service_a.rs +++ b/base_layer/service_framework/examples/services/service_a.rs @@ -69,7 +69,7 @@ impl ServiceA { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service A API Request"); @@ -82,7 +82,7 @@ impl ServiceA { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service A shutting down because the shutdown signal was received"); break; } diff --git a/base_layer/service_framework/examples/services/service_b.rs b/base_layer/service_framework/examples/services/service_b.rs index decf53ab14..8e74408077 100644 --- a/base_layer/service_framework/examples/services/service_b.rs +++ b/base_layer/service_framework/examples/services/service_b.rs @@ -31,7 +31,7 @@ use tari_service_framework::{ ServiceInitializerContext, }; use tari_shutdown::ShutdownSignal; -use tokio::time::delay_for; +use tokio::time::sleep; use tower::Service; pub struct ServiceB { @@ -67,7 +67,7 @@ impl ServiceB { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service B API Request"); @@ -76,7 +76,7 @@ impl ServiceB { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service B shutting down because the shutdown signal was received"); break; } @@ -134,7 +134,7 @@ impl ServiceInitializer for ServiceBInitializer { println!("Service B has shutdown and initializer spawned task is now ending"); }); - delay_for(Duration::from_secs(10)).await; + sleep(Duration::from_secs(10)).await; Ok(()) } } diff --git a/base_layer/service_framework/examples/stack_builder_example.rs b/base_layer/service_framework/examples/stack_builder_example.rs index 35fd785ed7..6f150796e7 100644 --- a/base_layer/service_framework/examples/stack_builder_example.rs +++ b/base_layer/service_framework/examples/stack_builder_example.rs @@ -25,9 +25,9 @@ use crate::services::{ServiceAHandle, ServiceAInitializer, ServiceBHandle, Servi use std::time::Duration; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; -use tokio::time::delay_for; +use tokio::time::sleep; -#[tokio_macros::main] +#[tokio::main] async fn main() { let mut shutdown = Shutdown::new(); let fut = StackBuilder::new(shutdown.to_signal()) @@ -40,7 +40,7 @@ async fn main() { let mut service_a_handle = handles.expect_handle::(); let mut service_b_handle = handles.expect_handle::(); - delay_for(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; println!("----------------------------------------------------"); let response_b = service_b_handle.send_msg("Hello B".to_string()).await; println!("Response from Service B: {}", response_b); @@ -51,5 +51,5 @@ async fn main() { let _ = shutdown.trigger(); - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } diff --git a/base_layer/service_framework/src/reply_channel.rs b/base_layer/service_framework/src/reply_channel.rs index 54cef8d90f..4921ed9601 100644 --- a/base_layer/service_framework/src/reply_channel.rs +++ b/base_layer/service_framework/src/reply_channel.rs @@ -20,26 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{ - mpsc::{self, SendError}, - oneshot, - }, - ready, - stream::FusedStream, - task::Context, - Future, - FutureExt, - Stream, - StreamExt, -}; +use futures::{ready, stream::FusedStream, task::Context, Future, FutureExt, Stream}; use std::{pin::Pin, task::Poll}; use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; use tower_service::Service; /// Create a new Requester/Responder pair which wraps and calls the given service pub fn unbounded() -> (SenderService, Receiver) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (SenderService::new(tx), Receiver::new(rx)) } @@ -81,20 +70,15 @@ impl Service for SenderService { type Future = TransportResponseFuture; type Response = TRes; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready(cx).map_err(|err| { - if err.is_disconnected() { - return TransportChannelError::ChannelClosed; - } - - unreachable!("unbounded channels can never be full"); - }) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + // An unbounded sender is always ready (i.e. will never wait to send) + Poll::Ready(Ok(())) } fn call(&mut self, request: TReq) -> Self::Future { let (tx, rx) = oneshot::channel(); - if self.tx.unbounded_send((request, tx)).is_ok() { + if self.tx.send((request, tx)).is_ok() { TransportResponseFuture::new(rx) } else { // We're not able to send (rx closed) so return a future which resolves to @@ -106,8 +90,6 @@ impl Service for SenderService { #[derive(Debug, Error, Eq, PartialEq, Clone)] pub enum TransportChannelError { - #[error("Error occurred when sending: `{0}`")] - SendError(#[from] SendError), #[error("Request was canceled")] Canceled, #[error("The response channel has closed")] @@ -188,23 +170,21 @@ impl RequestContext { } /// Receiver side of the reply channel. -/// This is functionally equivalent to `rx.map(|(req, reply_tx)| RequestContext::new(req, reply_tx))` -/// but is ergonomically better to use with the `futures::select` macro (implements FusedStream) -/// and has a short type signature. pub struct Receiver { rx: Rx, + is_closed: bool, } impl FusedStream for Receiver { fn is_terminated(&self) -> bool { - self.rx.is_terminated() + self.is_closed } } impl Receiver { // Create a new Responder pub fn new(rx: Rx) -> Self { - Self { rx } + Self { rx, is_closed: false } } pub fn close(&mut self) { @@ -216,10 +196,17 @@ impl Stream for Receiver { type Item = RequestContext; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.rx.poll_next_unpin(cx)) { + if self.is_terminated() { + return Poll::Ready(None); + } + + match ready!(self.rx.poll_recv(cx)) { Some((req, tx)) => Poll::Ready(Some(RequestContext::new(req, tx))), // Stream has closed, so we're done - None => Poll::Ready(None), + None => { + self.is_closed = true; + Poll::Ready(None) + }, } } } @@ -227,7 +214,7 @@ impl Stream for Receiver { #[cfg(test)] mod test { use super::*; - use futures::{executor::block_on, future}; + use futures::{executor::block_on, future, StreamExt}; use std::fmt::Debug; use tari_test_utils::unpack_enum; use tower::ServiceExt; @@ -247,7 +234,7 @@ mod test { async fn reply(mut rx: Rx, msg: TResp) where TResp: Debug { - match rx.next().await { + match rx.recv().await { Some((_, tx)) => { tx.send(msg).unwrap(); }, @@ -257,7 +244,7 @@ mod test { #[test] fn requestor_call() { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); let requestor = SenderService::<_, _>::new(tx); let fut = future::join(requestor.oneshot("PING"), reply(rx, "PONG")); diff --git a/base_layer/service_framework/src/stack.rs b/base_layer/service_framework/src/stack.rs index 2489006d56..ad38a94be3 100644 --- a/base_layer/service_framework/src/stack.rs +++ b/base_layer/service_framework/src/stack.rs @@ -103,7 +103,7 @@ mod test { use tari_shutdown::Shutdown; use tower::service_fn; - #[tokio_macros::test] + #[tokio::test] async fn service_defn_simple() { // This is less of a test and more of a demo of using the short-hand implementation of ServiceInitializer let simple_initializer = |_: ServiceInitializerContext| Ok(()); @@ -155,7 +155,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn service_stack_new() { let shared_state = Arc::new(AtomicUsize::new(0)); diff --git a/base_layer/tari_stratum_ffi/Cargo.toml b/base_layer/tari_stratum_ffi/Cargo.toml index 6598df9f06..9ed773d89f 100644 --- a/base_layer/tari_stratum_ffi/Cargo.toml +++ b/base_layer/tari_stratum_ffi/Cargo.toml @@ -14,7 +14,7 @@ tari_app_grpc = { path = "../../applications/tari_app_grpc" } tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} tari_utilities = "^0.3" libc = "0.2.65" -thiserror = "1.0.20" +thiserror = "1.0.26" hex = "0.4.2" serde = { version="1.0.106", features = ["derive"] } serde_json = "1.0.57" diff --git a/base_layer/wallet/Cargo.toml b/base_layer/wallet/Cargo.toml index c2cf12ca22..639bdbd9fd 100644 --- a/base_layer/wallet/Cargo.toml +++ b/base_layer/wallet/Cargo.toml @@ -7,55 +7,53 @@ version = "0.9.5" edition = "2018" [dependencies] -tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} -tari_comms = { version = "^0.9", path = "../../comms"} +tari_common_types = { version = "^0.9", path = "../../base_layer/common_types" } +tari_comms = { version = "^0.9", path = "../../comms" } tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } -tari_service_framework = { version = "^0.9", path = "../service_framework"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } aes-gcm = "^0.8" blake2 = "0.9.0" -chrono = { version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } crossbeam-channel = "0.3.8" digest = "0.9.0" -diesel = { version="1.4.7", features = ["sqlite", "serde_json", "chrono"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } fs2 = "0.3.0" -futures = { version = "^0.3.1", features =["compat", "std"]} -lazy_static = "1.4.0" +futures = { version = "^0.3.1", features = ["compat", "std"] } log = "0.4.6" -log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} +log4rs = { version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"] } lmdb-zero = "0.4.4" rand = "0.8" -serde = {version = "1.0.89", features = ["derive"] } +serde = { version = "1.0.89", features = ["derive"] } serde_json = "1.0.39" -tokio = { version = "0.2.10", features = ["blocking", "sync"]} +tokio = { version = "1.10", features = ["sync", "macros"] } tower = "0.3.0-alpha.2" tempfile = "3.1.0" -time = {version = "0.1.39"} -thiserror = "1.0.20" +time = { version = "0.1.39" } +thiserror = "1.0.26" bincode = "1.3.1" [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" default-features = false -features = ["transactions", "mempool_proto", "base_node_proto",] +features = ["transactions", "mempool_proto", "base_node_proto", ] [dev-dependencies] -tari_p2p = { version = "^0.9", path = "../p2p", features=["test-mocks"]} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features=["test-mocks"]} +tari_p2p = { version = "^0.9", path = "../p2p", features = ["test-mocks"] } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features = ["test-mocks"] } tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } -lazy_static = "1.3.0" env_logger = "0.7.1" -prost = "0.6.1" -tokio-macros = "0.2.4" +prost = "0.8.0" [features] c_integration = [] avx2 = ["tari_crypto/avx2", "tari_core/avx2"] +bundled_sqlite = ["libsqlite3-sys"] diff --git a/base_layer/wallet/src/base_node_service/handle.rs b/base_layer/wallet/src/base_node_service/handle.rs index 4957823c72..f495479778 100644 --- a/base_layer/wallet/src/base_node_service/handle.rs +++ b/base_layer/wallet/src/base_node_service/handle.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::BaseNodeServiceError, service::BaseNodeState}; -use futures::{stream::Fuse, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::peer_manager::Peer; @@ -72,8 +71,8 @@ impl BaseNodeServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> BaseNodeEventReceiver { + self.event_stream_sender.subscribe() } pub async fn get_chain_metadata(&mut self) -> Result, BaseNodeServiceError> { diff --git a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs index 1bc57ed9d2..9aa981150d 100644 --- a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs +++ b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs @@ -20,13 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::{ - error::BaseNodeServiceError, - handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, - service::BaseNodeState, - }, - connectivity_service::OnlineStatus, +use crate::base_node_service::{ + error::BaseNodeServiceError, + handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, + service::BaseNodeState, }; use futures::StreamExt; use tari_common_types::chain_metadata::ChainMetadata; @@ -81,30 +78,28 @@ impl MockBaseNodeService { /// Set the mock server state, either online and synced to a specific height, or offline with None pub fn set_base_node_state(&mut self, height: Option) { - let (chain_metadata, is_synced, online) = match height { + let (chain_metadata, is_synced) = match height { Some(height) => { let metadata = ChainMetadata::new(height, Vec::new(), 0, 0, 0); - (Some(metadata), Some(true), OnlineStatus::Online) + (Some(metadata), Some(true)) }, - None => (None, None, OnlineStatus::Offline), + None => (None, None), }; self.state = BaseNodeState { chain_metadata, is_synced, updated: None, latency: None, - online, } } pub fn set_default_base_node_state(&mut self) { - let metadata = ChainMetadata::new(std::u64::MAX, Vec::new(), 0, 0, 0); + let metadata = ChainMetadata::new(u64::MAX, Vec::new(), 0, 0, 0); self.state = BaseNodeState { chain_metadata: Some(metadata), is_synced: Some(true), updated: None, latency: None, - online: OnlineStatus::Online, } } diff --git a/base_layer/wallet/src/base_node_service/monitor.rs b/base_layer/wallet/src/base_node_service/monitor.rs index 5a2c3a7e76..8e0298ca27 100644 --- a/base_layer/wallet/src/base_node_service/monitor.rs +++ b/base_layer/wallet/src/base_node_service/monitor.rs @@ -25,7 +25,7 @@ use crate::{ handle::{BaseNodeEvent, BaseNodeEventSender}, service::BaseNodeState, }, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, error::WalletStorageError, storage::database::{WalletBackend, WalletDatabase}, }; @@ -33,7 +33,7 @@ use chrono::Utc; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; -use tari_comms::{peer_manager::NodeId, protocol::rpc::RpcError}; +use tari_comms::protocol::rpc::RpcError; use tokio::{sync::RwLock, time}; const LOG_TARGET: &str = "wallet::base_node_service::chain_metadata_monitor"; @@ -78,9 +78,6 @@ impl BaseNodeMonitor { }, Err(e @ BaseNodeMonitorError::RpcFailed(_)) => { warn!(target: LOG_TARGET, "Connectivity failure to base node: {}", e); - debug!(target: LOG_TARGET, "Setting as OFFLINE and retrying...",); - - self.set_offline().await; continue; }, Err(e @ BaseNodeMonitorError::InvalidBaseNodeResponse(_)) | @@ -96,34 +93,19 @@ impl BaseNodeMonitor { ); } - async fn update_connectivity_status(&self) -> NodeId { - let mut watcher = self.wallet_connectivity.get_connectivity_status_watch(); - loop { - use OnlineStatus::*; - match watcher.recv().await.unwrap_or(Offline) { - Online => match self.wallet_connectivity.get_current_base_node_id() { - Some(node_id) => return node_id, - _ => continue, - }, - Connecting => { - self.set_connecting().await; - }, - Offline => { - self.set_offline().await; - }, - } - } - } - async fn monitor_node(&mut self) -> Result<(), BaseNodeMonitorError> { loop { - let peer_node_id = self.update_connectivity_status().await; let mut client = self .wallet_connectivity .obtain_base_node_wallet_rpc_client() .await .ok_or(BaseNodeMonitorError::NodeShuttingDown)?; + let base_node_id = match self.wallet_connectivity.get_current_base_node_id() { + Some(n) => n, + None => continue, + }; + let tip_info = client.get_tip_info().await?; let chain_metadata = tip_info @@ -138,7 +120,7 @@ impl BaseNodeMonitor { debug!( target: LOG_TARGET, "Base node {} Tip: {} ({}) Latency: {} ms", - peer_node_id, + base_node_id, chain_metadata.height_of_longest_chain(), if is_synced { "Synced" } else { "Syncing..." }, latency.as_millis() @@ -151,11 +133,10 @@ impl BaseNodeMonitor { is_synced: Some(is_synced), updated: Some(Utc::now().naive_utc()), latency: Some(latency), - online: OnlineStatus::Online, }) .await; - time::delay_for(self.interval).await + time::sleep(self.interval).await } // loop only exits on shutdown/error @@ -163,28 +144,6 @@ impl BaseNodeMonitor { Ok(()) } - async fn set_connecting(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Connecting, - }) - .await; - } - - async fn set_offline(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Offline, - }) - .await; - } - async fn map_state(&self, transform: F) where F: FnOnce(&BaseNodeState) -> BaseNodeState { let new_state = { diff --git a/base_layer/wallet/src/base_node_service/service.rs b/base_layer/wallet/src/base_node_service/service.rs index 3da987c8b1..eb2b91ebda 100644 --- a/base_layer/wallet/src/base_node_service/service.rs +++ b/base_layer/wallet/src/base_node_service/service.rs @@ -27,7 +27,7 @@ use super::{ }; use crate::{ base_node_service::monitor::BaseNodeMonitor, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, storage::database::{WalletBackend, WalletDatabase}, }; use chrono::NaiveDateTime; @@ -49,8 +49,6 @@ pub struct BaseNodeState { pub is_synced: Option, pub updated: Option, pub latency: Option, - pub online: OnlineStatus, - // pub base_node_peer: Option, } impl Default for BaseNodeState { @@ -60,7 +58,6 @@ impl Default for BaseNodeState { is_synced: None, updated: None, latency: None, - online: OnlineStatus::Connecting, } } } diff --git a/base_layer/wallet/src/config.rs b/base_layer/wallet/src/config.rs index cd17024068..3844fc13e7 100644 --- a/base_layer/wallet/src/config.rs +++ b/base_layer/wallet/src/config.rs @@ -20,14 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::time::Duration; + +use tari_core::{consensus::NetworkConsensus, transactions::CryptoFactories}; +use tari_p2p::initialization::CommsConfig; + use crate::{ base_node_service::config::BaseNodeServiceConfig, output_manager_service::config::OutputManagerServiceConfig, transaction_service::config::TransactionServiceConfig, }; -use std::time::Duration; -use tari_core::{consensus::NetworkConsensus, transactions::types::CryptoFactories}; -use tari_p2p::initialization::CommsConfig; pub const KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY: &str = "comms"; diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index ac218edc5e..5a35696e14 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -22,16 +22,12 @@ use super::service::OnlineStatus; use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use tari_comms::{ peer_manager::{NodeId, Peer}, protocol::rpc::RpcClientLease, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::sync::watch; +use tokio::sync::{mpsc, oneshot, watch}; pub enum WalletConnectivityRequest { ObtainBaseNodeWalletRpcClient(oneshot::Sender>), @@ -102,8 +98,8 @@ impl WalletConnectivityHandle { reply_rx.await.ok() } - pub async fn get_connectivity_status(&mut self) -> OnlineStatus { - self.online_status_rx.recv().await.unwrap_or(OnlineStatus::Offline) + pub fn get_connectivity_status(&mut self) -> OnlineStatus { + *self.online_status_rx.borrow() } pub fn get_connectivity_status_watch(&self) -> watch::Receiver { diff --git a/base_layer/wallet/src/connectivity_service/initializer.rs b/base_layer/wallet/src/connectivity_service/initializer.rs index d0c2b94126..1610a834e3 100644 --- a/base_layer/wallet/src/connectivity_service/initializer.rs +++ b/base_layer/wallet/src/connectivity_service/initializer.rs @@ -30,8 +30,8 @@ use super::{handle::WalletConnectivityHandle, service::WalletConnectivityService, watch::Watch}; use crate::{base_node_service::config::BaseNodeServiceConfig, connectivity_service::service::OnlineStatus}; -use futures::channel::mpsc; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::sync::mpsc; pub struct WalletConnectivityInitializer { config: BaseNodeServiceConfig, @@ -59,8 +59,13 @@ impl ServiceInitializer for WalletConnectivityInitializer { context.spawn_until_shutdown(move |handles| { let connectivity = handles.expect_handle(); - let service = - WalletConnectivityService::new(config, receiver, base_node_watch, online_status_watch, connectivity); + let service = WalletConnectivityService::new( + config, + receiver, + base_node_watch.get_receiver(), + online_status_watch, + connectivity, + ); service.start() }); diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index c0cf474b96..950b9a9a72 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -24,15 +24,8 @@ use crate::{ base_node_service::config::BaseNodeServiceConfig, connectivity_service::{error::WalletConnectivityError, handle::WalletConnectivityRequest, watch::Watch}, }; -use core::mem; -use futures::{ - channel::{mpsc, oneshot}, - future, - future::Either, - stream::Fuse, - StreamExt, -}; use log::*; +use std::{mem, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer}, @@ -40,7 +33,11 @@ use tari_comms::{ PeerConnection, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; const LOG_TARGET: &str = "wallet::connectivity"; @@ -54,9 +51,9 @@ pub enum OnlineStatus { pub struct WalletConnectivityService { config: BaseNodeServiceConfig, - request_stream: Fuse>, + request_stream: mpsc::Receiver, connectivity: ConnectivityRequester, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, pools: Option, online_status_watch: Watch, pending_requests: Vec, @@ -71,13 +68,13 @@ impl WalletConnectivityService { pub(super) fn new( config: BaseNodeServiceConfig, request_stream: mpsc::Receiver, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, online_status_watch: Watch, connectivity: ConnectivityRequester, ) -> Self { Self { config, - request_stream: request_stream.fuse(), + request_stream, connectivity, base_node_watch, pools: None, @@ -88,22 +85,41 @@ impl WalletConnectivityService { pub async fn start(mut self) { debug!(target: LOG_TARGET, "Wallet connectivity service has started."); - let mut base_node_watch_rx = self.base_node_watch.get_receiver().fuse(); + let mut check_connection = + time::interval_at(time::Instant::now() + Duration::from_secs(5), Duration::from_secs(5)); + check_connection.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - req = self.request_stream.select_next_some() => { - self.handle_request(req).await; - }, - maybe_peer = base_node_watch_rx.select_next_some() => { - if maybe_peer.is_some() { + tokio::select! { + // BIASED: select branches are in order of priority + biased; + + Ok(_) = self.base_node_watch.changed() => { + if self.base_node_watch.borrow().is_some() { // This will block the rest until the connection is established. This is what we want. self.setup_base_node_connection().await; } + }, + + Some(req) = self.request_stream.recv() => { + self.handle_request(req).await; + }, + + _ = check_connection.tick() => { + self.check_connection().await; } } } } + async fn check_connection(&mut self) { + if let Some(pool) = self.pools.as_ref() { + if !pool.base_node_wallet_rpc_client.is_connected().await { + debug!(target: LOG_TARGET, "Peer connection lost. Attempting to reconnect..."); + self.setup_base_node_connection().await; + } + } + } + async fn handle_request(&mut self, request: WalletConnectivityRequest) { use WalletConnectivityRequest::*; match request { @@ -138,7 +154,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -169,7 +184,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -186,21 +200,6 @@ impl WalletConnectivityService { } } - fn trigger_reconnect(&mut self) { - let peer = self - .base_node_watch - .borrow() - .clone() - .expect("trigger_reconnect called before base node is set"); - // Trigger the watch so that a peer connection is reinitiated - self.set_base_node_peer(peer); - } - - fn set_base_node_peer(&mut self, peer: Peer) { - self.pools = None; - self.base_node_watch.broadcast(Some(peer)); - } - fn current_base_node(&self) -> Option { self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone()) } @@ -236,8 +235,8 @@ impl WalletConnectivityService { } else { self.set_online_status(OnlineStatus::Offline); } - error!(target: LOG_TARGET, "{}", e); - time::delay_for(self.config.base_node_monitor_refresh_interval).await; + warn!(target: LOG_TARGET, "{}", e); + time::sleep(self.config.base_node_monitor_refresh_interval).await; continue; }, } @@ -275,13 +274,15 @@ impl WalletConnectivityService { } async fn try_dial_peer(&mut self, peer: NodeId) -> Result, WalletConnectivityError> { - let recv_fut = self.base_node_watch.recv(); - futures::pin_mut!(recv_fut); - let dial_fut = self.connectivity.dial_peer(peer); - futures::pin_mut!(dial_fut); - match future::select(recv_fut, dial_fut).await { - Either::Left(_) => Ok(None), - Either::Right((conn, _)) => Ok(Some(conn?)), + tokio::select! { + biased; + + _ = self.base_node_watch.changed() => { + Ok(None) + } + result = self.connectivity.dial_peer(peer) => { + Ok(Some(result?)) + } } } @@ -307,8 +308,8 @@ impl ReplyOneshot { pub fn is_canceled(&self) -> bool { use ReplyOneshot::*; match self { - WalletRpc(tx) => tx.is_canceled(), - SyncRpc(tx) => tx.is_canceled(), + WalletRpc(tx) => tx.is_closed(), + SyncRpc(tx) => tx.is_closed(), } } } diff --git a/base_layer/wallet/src/connectivity_service/test.rs b/base_layer/wallet/src/connectivity_service/test.rs index 7c24ef5b46..9a8f5a2da9 100644 --- a/base_layer/wallet/src/connectivity_service/test.rs +++ b/base_layer/wallet/src/connectivity_service/test.rs @@ -23,7 +23,7 @@ use super::service::WalletConnectivityService; use crate::connectivity_service::{watch::Watch, OnlineStatus, WalletConnectivityHandle}; use core::convert; -use futures::{channel::mpsc, future}; +use futures::future; use std::{iter, sync::Arc}; use tari_comms::{ peer_manager::PeerFeatures, @@ -39,7 +39,10 @@ use tari_comms::{ }; use tari_shutdown::Shutdown; use tari_test_utils::runtime::spawn_until_shutdown; -use tokio::{sync::Barrier, task}; +use tokio::{ + sync::{mpsc, Barrier}, + task, +}; async fn setup() -> ( WalletConnectivityHandle, @@ -57,7 +60,7 @@ async fn setup() -> ( let service = WalletConnectivityService::new( Default::default(), rx, - base_node_watch, + base_node_watch.get_receiver(), online_status_watch, connectivity, ); @@ -70,7 +73,7 @@ async fn setup() -> ( (handle, mock_server, mock_state, shutdown) } -#[tokio_macros::test] +#[tokio::test] async fn it_dials_peer_when_base_node_is_set() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -92,7 +95,7 @@ async fn it_dials_peer_when_base_node_is_set() { assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_resolves_many_pending_rpc_session_requests() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -122,7 +125,7 @@ async fn it_resolves_many_pending_rpc_session_requests() { assert!(results.into_iter().map(Result::unwrap).all(convert::identity)); } -#[tokio_macros::test] +#[tokio::test] async fn it_changes_to_a_new_base_node() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -138,7 +141,7 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer1.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); + assert!(mock_state.count_calls_containing("AddManagedPeer").await >= 1); let _ = mock_state.take_calls().await; let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); @@ -149,13 +152,12 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer2.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_connect_fail_reconnect() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -198,7 +200,7 @@ async fn it_gracefully_handles_connect_fail_reconnect() { pending_request.await.unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_multiple_connection_failures() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/src/connectivity_service/watch.rs b/base_layer/wallet/src/connectivity_service/watch.rs index 1f1e868d47..4669b355f6 100644 --- a/base_layer/wallet/src/connectivity_service/watch.rs +++ b/base_layer/wallet/src/connectivity_service/watch.rs @@ -26,24 +26,19 @@ use tokio::sync::watch; #[derive(Clone)] pub struct Watch(Arc>, watch::Receiver); -impl Watch { +impl Watch { pub fn new(initial: T) -> Self { let (tx, rx) = watch::channel(initial); Self(Arc::new(tx), rx) } - #[allow(dead_code)] - pub async fn recv(&mut self) -> Option { - self.receiver_mut().recv().await - } - pub fn borrow(&self) -> watch::Ref<'_, T> { self.receiver().borrow() } pub fn broadcast(&self, item: T) { - // SAFETY: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime - if self.sender().broadcast(item).is_err() { + // PANIC: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime + if self.sender().send(item).is_err() { // Result::expect requires E: fmt::Debug and `watch::SendError` is not, this is equivalent panic!("watch internal receiver is dropped"); } @@ -53,10 +48,6 @@ impl Watch { &self.0 } - fn receiver_mut(&mut self) -> &mut watch::Receiver { - &mut self.1 - } - pub fn receiver(&self) -> &watch::Receiver { &self.1 } diff --git a/base_layer/wallet/src/contacts_service/service.rs b/base_layer/wallet/src/contacts_service/service.rs index f86f6b49cc..cfc8473202 100644 --- a/base_layer/wallet/src/contacts_service/service.rs +++ b/base_layer/wallet/src/contacts_service/service.rs @@ -76,8 +76,8 @@ where T: ContactsBackend + 'static info!(target: LOG_TARGET, "Contacts Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { + tokio::select! { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { error!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -88,14 +88,10 @@ where T: ContactsBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Contacts service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Contacts service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Contacts Service ended"); diff --git a/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs b/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs index 8fc8234f56..b32f798777 100644 --- a/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs @@ -30,7 +30,7 @@ use crate::{ }; use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; use std::convert::TryFrom; -use tari_core::transactions::types::PublicKey; +use tari_common_types::types::PublicKey; use tari_crypto::tari_utilities::ByteArray; /// A Sqlite backend for the Output Manager Service. The Backend is accessed via a connection pool to the Sqlite file. @@ -192,7 +192,7 @@ mod test { use diesel::{Connection, SqliteConnection}; use rand::rngs::OsRng; use std::convert::TryFrom; - use tari_core::transactions::types::{PrivateKey, PublicKey}; + use tari_common_types::types::{PrivateKey, PublicKey}; use tari_crypto::{ keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, tari_utilities::ByteArray, diff --git a/base_layer/wallet/src/lib.rs b/base_layer/wallet/src/lib.rs index bc0b4c1a04..22bce8bfdb 100644 --- a/base_layer/wallet/src/lib.rs +++ b/base_layer/wallet/src/lib.rs @@ -25,8 +25,6 @@ pub mod wallet; extern crate diesel; #[macro_use] extern crate diesel_migrations; -#[macro_use] -extern crate lazy_static; mod config; pub mod schema; diff --git a/base_layer/wallet/src/output_manager_service/handle.rs b/base_layer/wallet/src/output_manager_service/handle.rs index 659fab4a42..54b082c900 100644 --- a/base_layer/wallet/src/output_manager_service/handle.rs +++ b/base_layer/wallet/src/output_manager_service/handle.rs @@ -31,14 +31,13 @@ use crate::{ types::ValidationRetryStrategy, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc, time::Duration}; +use tari_common_types::types::PublicKey; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{Transaction, TransactionInput, TransactionOutput, UnblindedOutput}, transaction_protocol::sender::TransactionSenderMessage, - types::PublicKey, ReceiverTransactionProtocol, SenderTransactionProtocol, }; @@ -191,8 +190,8 @@ impl OutputManagerHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> OutputManagerEventReceiver { + self.event_stream_sender.subscribe() } pub async fn add_output(&mut self, output: UnblindedOutput) -> Result<(), OutputManagerError> { diff --git a/base_layer/wallet/src/output_manager_service/master_key_manager.rs b/base_layer/wallet/src/output_manager_service/master_key_manager.rs index 7b315569dc..4f33a909cf 100644 --- a/base_layer/wallet/src/output_manager_service/master_key_manager.rs +++ b/base_layer/wallet/src/output_manager_service/master_key_manager.rs @@ -30,10 +30,8 @@ use crate::{ }; use futures::lock::Mutex; use log::*; -use tari_core::transactions::{ - transaction_protocol::RewindData, - types::{PrivateKey, PublicKey}, -}; +use tari_common_types::types::{PrivateKey, PublicKey}; +use tari_core::transactions::transaction_protocol::RewindData; use tari_crypto::{keys::PublicKey as PublicKeyTrait, range_proof::REWIND_USER_MESSAGE_LENGTH}; use tari_key_manager::{ key_manager::KeyManager, diff --git a/base_layer/wallet/src/output_manager_service/mod.rs b/base_layer/wallet/src/output_manager_service/mod.rs index ce6fd70699..80f02f2445 100644 --- a/base_layer/wallet/src/output_manager_service/mod.rs +++ b/base_layer/wallet/src/output_manager_service/mod.rs @@ -20,22 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::handle::BaseNodeServiceHandle, - output_manager_service::{ - config::OutputManagerServiceConfig, - handle::OutputManagerHandle, - service::OutputManagerService, - storage::database::{OutputManagerBackend, OutputManagerDatabase}, - }, - transaction_service::handle::TransactionServiceHandle, -}; use futures::future; use log::*; +use tokio::sync::broadcast; + +pub(crate) use master_key_manager::MasterKeyManager; use tari_comms::{connectivity::ConnectivityRequester, types::CommsSecretKey}; use tari_core::{ consensus::{ConsensusConstantsBuilder, NetworkConsensus}, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; use tari_service_framework::{ async_trait, @@ -44,7 +37,18 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +pub use tasks::TxoValidationType; + +use crate::{ + base_node_service::handle::BaseNodeServiceHandle, + output_manager_service::{ + config::OutputManagerServiceConfig, + handle::OutputManagerHandle, + service::OutputManagerService, + storage::database::{OutputManagerBackend, OutputManagerDatabase}, + }, + transaction_service::handle::TransactionServiceHandle, +}; pub mod config; pub mod error; @@ -57,9 +61,6 @@ pub mod service; pub mod storage; mod tasks; -pub(crate) use master_key_manager::MasterKeyManager; -pub use tasks::TxoValidationType; - const LOG_TARGET: &str = "wallet::output_manager_service::initializer"; pub type TxId = u64; diff --git a/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs b/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs index 64e4d510d2..1da885d7f1 100644 --- a/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs +++ b/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs @@ -20,6 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use log::*; +use tari_crypto::{inputs, keys::PublicKey as PublicKeyTrait, tari_utilities::hex::Hex}; + +use tari_common_types::types::PublicKey; +use tari_core::transactions::{ + transaction::{TransactionOutput, UnblindedOutput}, + CryptoFactories, +}; + use crate::output_manager_service::{ error::OutputManagerError, storage::{ @@ -28,13 +39,6 @@ use crate::output_manager_service::{ }, MasterKeyManager, }; -use log::*; -use std::sync::Arc; -use tari_core::transactions::{ - transaction::{TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, PublicKey}, -}; -use tari_crypto::{inputs, keys::PublicKey as PublicKeyTrait, tari_utilities::hex::Hex}; const LOG_TARGET: &str = "wallet::output_manager_service::recovery"; diff --git a/base_layer/wallet/src/output_manager_service/resources.rs b/base_layer/wallet/src/output_manager_service/resources.rs index d6e17b570b..f094b0b79c 100644 --- a/base_layer/wallet/src/output_manager_service/resources.rs +++ b/base_layer/wallet/src/output_manager_service/resources.rs @@ -20,6 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use tari_comms::{connectivity::ConnectivityRequester, types::CommsPublicKey}; +use tari_core::{consensus::ConsensusConstants, transactions::CryptoFactories}; +use tari_shutdown::ShutdownSignal; + use crate::{ output_manager_service::{ config::OutputManagerServiceConfig, @@ -29,10 +35,6 @@ use crate::{ }, transaction_service::handle::TransactionServiceHandle, }; -use std::sync::Arc; -use tari_comms::{connectivity::ConnectivityRequester, types::CommsPublicKey}; -use tari_core::{consensus::ConsensusConstants, transactions::types::CryptoFactories}; -use tari_shutdown::ShutdownSignal; /// This struct is a collection of the common resources that a async task in the service requires. #[derive(Clone)] diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index c8946b9ff9..bb1a98f505 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -20,38 +20,30 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::handle::BaseNodeServiceHandle, - output_manager_service::{ - config::OutputManagerServiceConfig, - error::{OutputManagerError, OutputManagerProtocolError, OutputManagerStorageError}, - handle::{OutputManagerEventSender, OutputManagerRequest, OutputManagerResponse}, - recovery::StandardUtxoRecoverer, - resources::OutputManagerResources, - storage::{ - database::{OutputManagerBackend, OutputManagerDatabase, PendingTransactionOutputs}, - models::{DbUnblindedOutput, KnownOneSidedPaymentScript}, - }, - tasks::{TxoValidationTask, TxoValidationType}, - MasterKeyManager, - TxId, - }, - transaction_service::handle::TransactionServiceHandle, - types::{HashDigest, ValidationRetryStrategy}, +use std::{ + cmp::Ordering, + collections::HashMap, + fmt::{self, Display}, + sync::Arc, + time::Duration, }; + use blake2::Digest; use chrono::Utc; use diesel::result::{DatabaseErrorKind, Error as DieselError}; use futures::{pin_mut, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{ - cmp::Ordering, - collections::HashMap, - fmt::{self, Display}, - sync::Arc, - time::Duration, +use tari_crypto::{ + inputs, + keys::{DiffieHellmanSharedSecret, PublicKey as PublicKeyTrait, SecretKey}, + script, + script::TariScript, + tari_utilities::{hex::Hex, ByteArray}, }; +use tokio::sync::broadcast; + +use tari_common_types::types::{PrivateKey, PublicKey}; use tari_comms::{ connectivity::ConnectivityRequester, types::{CommsPublicKey, CommsSecretKey}, @@ -70,22 +62,34 @@ use tari_core::{ UnblindedOutput, }, transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, PrivateKey, PublicKey}, CoinbaseBuilder, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }, }; -use tari_crypto::{ - inputs, - keys::{DiffieHellmanSharedSecret, PublicKey as PublicKeyTrait, SecretKey}, - script, - script::TariScript, - tari_utilities::{hex::Hex, ByteArray}, -}; use tari_service_framework::reply_channel; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; + +use crate::{ + base_node_service::handle::BaseNodeServiceHandle, + output_manager_service::{ + config::OutputManagerServiceConfig, + error::{OutputManagerError, OutputManagerProtocolError, OutputManagerStorageError}, + handle::{OutputManagerEventSender, OutputManagerRequest, OutputManagerResponse}, + recovery::StandardUtxoRecoverer, + resources::OutputManagerResources, + storage::{ + database::{OutputManagerBackend, OutputManagerDatabase, PendingTransactionOutputs}, + models::{DbUnblindedOutput, KnownOneSidedPaymentScript}, + }, + tasks::{TxoValidationTask, TxoValidationType}, + MasterKeyManager, + TxId, + }, + transaction_service::handle::TransactionServiceHandle, + types::{HashDigest, ValidationRetryStrategy}, +}; const LOG_TARGET: &str = "wallet::output_manager_service"; const LOG_TARGET_STRESS: &str = "stress_test::output_manager_service"; @@ -166,9 +170,9 @@ where TBackend: OutputManagerBackend + 'static info!(target: LOG_TARGET, "Output Manager Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Service API Request"); + tokio::select! { + Some(request_context) = request_stream.next() => { + trace!(target: LOG_TARGET, "Handling Service API Request"); let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { warn!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -179,14 +183,10 @@ where TBackend: OutputManagerBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Output manager service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Output manager service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Output Manager Service ended"); diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 52d552e016..7344550c63 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -35,11 +35,8 @@ use std::{ sync::Arc, time::Duration, }; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::TransactionOutput, - types::{BlindingFactor, Commitment, PrivateKey}, -}; +use tari_common_types::types::{BlindingFactor, Commitment, PrivateKey}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::TransactionOutput}; const LOG_TARGET: &str = "wallet::output_manager_service::database"; diff --git a/base_layer/wallet/src/output_manager_service/storage/models.rs b/base_layer/wallet/src/output_manager_service/storage/models.rs index dd36eb6934..e0f00a0569 100644 --- a/base_layer/wallet/src/output_manager_service/storage/models.rs +++ b/base_layer/wallet/src/output_manager_service/storage/models.rs @@ -20,17 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::output_manager_service::error::OutputManagerStorageError; use std::cmp::Ordering; + +use tari_crypto::script::{ExecutionStack, TariScript}; + +use tari_common_types::types::{Commitment, HashOutput, PrivateKey}; use tari_core::{ tari_utilities::hash::Hashable, - transactions::{ - transaction::UnblindedOutput, - transaction_protocol::RewindData, - types::{Commitment, CryptoFactories, HashOutput, PrivateKey}, - }, + transactions::{transaction::UnblindedOutput, transaction_protocol::RewindData, CryptoFactories}, }; -use tari_crypto::script::{ExecutionStack, TariScript}; + +use crate::output_manager_service::error::OutputManagerStorageError; #[derive(Debug, Clone)] pub struct DbUnblindedOutput { diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index 665c408bdf..052bad580b 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -20,6 +20,37 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + convert::TryFrom, + str::from_utf8, + sync::{Arc, RwLock}, + time::Duration, +}; + +use aes_gcm::{aead::Error as AeadError, Aes256Gcm, Error}; +use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; +use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; +use log::*; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + script::{ExecutionStack, TariScript}, + tari_utilities::{ + hex::{from_hex, Hex}, + ByteArray, + }, +}; + +use tari_common_types::types::{ComSignature, Commitment, PrivateKey, PublicKey}; +use tari_core::{ + tari_utilities::hash::Hashable, + transactions::{ + tari_amount::MicroTari, + transaction::{OutputFeatures, OutputFlags, TransactionOutput, UnblindedOutput}, + CryptoFactories, + }, +}; + use crate::{ output_manager_service::{ error::OutputManagerStorageError, @@ -41,33 +72,6 @@ use crate::{ storage::sqlite_utilities::WalletDbConnection, util::encryption::{decrypt_bytes_integral_nonce, encrypt_bytes_integral_nonce, Encryptable}, }; -use aes_gcm::{aead::Error as AeadError, Aes256Gcm, Error}; -use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; -use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; -use log::*; -use std::{ - collections::HashMap, - convert::TryFrom, - str::from_utf8, - sync::{Arc, RwLock}, - time::Duration, -}; -use tari_core::{ - tari_utilities::hash::Hashable, - transactions::{ - tari_amount::MicroTari, - transaction::{OutputFeatures, OutputFlags, TransactionOutput, UnblindedOutput}, - types::{ComSignature, Commitment, CryptoFactories, PrivateKey, PublicKey}, - }, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - script::{ExecutionStack, TariScript}, - tari_utilities::{ - hex::{from_hex, Hex}, - ByteArray, - }, -}; const LOG_TARGET: &str = "wallet::output_manager_service::database::sqlite_db"; @@ -1714,6 +1718,27 @@ impl Encryptable for KnownOneSidedPaymentScriptSql { #[cfg(test)] mod test { + use std::{convert::TryFrom, time::Duration}; + + use aes_gcm::{ + aead::{generic_array::GenericArray, NewAead}, + Aes256Gcm, + }; + use chrono::{Duration as ChronoDuration, Utc}; + use diesel::{Connection, SqliteConnection}; + use rand::{rngs::OsRng, RngCore}; + use tari_crypto::{keys::SecretKey, script}; + use tempfile::tempdir; + + use tari_common_types::types::{CommitmentFactory, PrivateKey}; + use tari_core::transactions::{ + helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, + tari_amount::MicroTari, + transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, + CryptoFactories, + }; + use tari_test_utils::random; + use crate::{ output_manager_service::storage::{ database::{DbKey, KeyManagerState, OutputManagerBackend}, @@ -1731,23 +1756,6 @@ mod test { storage::sqlite_utilities::WalletDbConnection, util::encryption::Encryptable, }; - use aes_gcm::{ - aead::{generic_array::GenericArray, NewAead}, - Aes256Gcm, - }; - use chrono::{Duration as ChronoDuration, Utc}; - use diesel::{Connection, SqliteConnection}; - use rand::{rngs::OsRng, RngCore}; - use std::{convert::TryFrom, time::Duration}; - use tari_core::transactions::{ - helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, - tari_amount::MicroTari, - transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, - types::{CommitmentFactory, CryptoFactories, PrivateKey}, - }; - use tari_crypto::{keys::SecretKey, script}; - use tari_test_utils::random; - use tempfile::tempdir; pub fn make_input(val: MicroTari) -> (TransactionInput, UnblindedOutput) { let test_params = TestParamsHelpers::new(); diff --git a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs index e3e022e4fb..e08059e16b 100644 --- a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs +++ b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs @@ -30,17 +30,18 @@ use crate::{ transaction_service::storage::models::TransactionStatus, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, collections::HashMap, convert::TryFrom, fmt, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; use tari_core::{ base_node::rpc::BaseNodeWalletRpcClient, proto::base_node::FetchMatchingUtxos, - transactions::{transaction::TransactionOutput, types::Signature}, + transactions::transaction::TransactionOutput, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::output_manager_service::utxo_validation_task"; @@ -87,21 +88,17 @@ where TBackend: OutputManagerBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - OutputManagerProtocolError::new( - self.id, - OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), - ) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + OutputManagerProtocolError::new( + self.id, + OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), + ) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); let total_retries_str = match self.retry_strategy { - ValidationRetryStrategy::Limited(n) => format!("{}", n), + ValidationRetryStrategy::Limited(n) => n.to_string(), ValidationRetryStrategy::UntilSuccess => "∞".to_string(), }; @@ -180,14 +177,14 @@ where TBackend: OutputManagerBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key.clone()); let mut connection: Option = None; - let delay = delay_for(self.resources.config.peer_dial_retry_timeout); + let delay = sleep(self.resources.config.peer_dial_retry_timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -197,7 +194,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!( @@ -228,7 +225,7 @@ where TBackend: OutputManagerBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -236,7 +233,7 @@ where TBackend: OutputManagerBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -253,7 +250,7 @@ where TBackend: OutputManagerBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -294,9 +291,9 @@ where TBackend: OutputManagerBackend + 'static batch_num, batch_total ); - let delay = delay_for(self.retry_delay); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.retry_delay); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_bn) => { info!(target: LOG_TARGET, "TXO Validation protocol aborted due to Base Node Public key change" ); @@ -323,7 +320,7 @@ where TBackend: OutputManagerBackend + 'static } } }, - result = self.send_query_batch(batch.clone(), &mut client).fuse() => { + result = self.send_query_batch(batch.clone(), &mut client) => { match result { Ok(synced) => { self.base_node_synced = synced; @@ -374,7 +371,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, diff --git a/base_layer/wallet/src/storage/database.rs b/base_layer/wallet/src/storage/database.rs index 3d5cbf8838..d23d806298 100644 --- a/base_layer/wallet/src/storage/database.rs +++ b/base_layer/wallet/src/storage/database.rs @@ -373,7 +373,7 @@ mod test { #[test] fn test_database_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db_name = format!("{}.sqlite3", string(8).as_str()); let db_folder = tempdir().unwrap().path().to_str().unwrap().to_string(); diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index c197dd2024..05e9e4af2b 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -34,7 +34,7 @@ use tari_p2p::services::liveness::error::LivenessError; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; use time::OutOfRangeError; -use tokio::sync::broadcast::RecvError; +use tokio::sync::broadcast::error::RecvError; #[derive(Debug, Error)] pub enum TransactionServiceError { diff --git a/base_layer/wallet/src/transaction_service/handle.rs b/base_layer/wallet/src/transaction_service/handle.rs index f34a5f667f..7a6ab48649 100644 --- a/base_layer/wallet/src/transaction_service/handle.rs +++ b/base_layer/wallet/src/transaction_service/handle.rs @@ -28,7 +28,6 @@ use crate::{ }, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; @@ -187,8 +186,8 @@ impl TransactionServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> TransactionEventReceiver { + self.event_stream_sender.subscribe() } pub async fn send_transaction( diff --git a/base_layer/wallet/src/transaction_service/mod.rs b/base_layer/wallet/src/transaction_service/mod.rs index 3efbade3c3..541d898770 100644 --- a/base_layer/wallet/src/transaction_service/mod.rs +++ b/base_layer/wallet/src/transaction_service/mod.rs @@ -20,31 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod config; -pub mod error; -pub mod handle; -pub mod protocols; -pub mod service; -pub mod storage; -pub mod tasks; +use std::sync::Arc; -use crate::{ - output_manager_service::handle::OutputManagerHandle, - transaction_service::{ - config::TransactionServiceConfig, - handle::TransactionServiceHandle, - service::TransactionService, - storage::database::{TransactionBackend, TransactionDatabase}, - }, -}; use futures::{Stream, StreamExt}; use log::*; -use std::sync::Arc; +use tokio::sync::broadcast; + use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeIdentity}; use tari_comms_dht::Dht; use tari_core::{ proto::base_node as base_node_proto, - transactions::{transaction_protocol::proto, types::CryptoFactories}, + transactions::{transaction_protocol::proto, CryptoFactories}, }; use tari_p2p::{ comms_connector::SubscriptionFactory, @@ -59,7 +45,24 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; + +use crate::{ + output_manager_service::handle::OutputManagerHandle, + transaction_service::{ + config::TransactionServiceConfig, + handle::TransactionServiceHandle, + service::TransactionService, + storage::database::{TransactionBackend, TransactionDatabase}, + }, +}; + +pub mod config; +pub mod error; +pub mod handle; +pub mod protocols; +pub mod service; +pub mod storage; +pub mod tasks; const LOG_TARGET: &str = "wallet::transaction_service"; const SUBSCRIPTION_LABEL: &str = "Transaction Service"; diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs index 05191c8f8d..4a28383226 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs @@ -32,19 +32,20 @@ use crate::{ }, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; use tari_core::{ base_node::{ proto::wallet_rpc::{TxLocation, TxQueryResponse, TxSubmissionRejectionReason, TxSubmissionResponse}, rpc::BaseNodeWalletRpcClient, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::broadcast_protocol"; @@ -86,21 +87,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); // Main protocol loop @@ -108,14 +101,14 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -139,7 +132,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -158,7 +151,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -179,7 +172,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -187,11 +180,11 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -243,10 +236,10 @@ where TBackend: TransactionBackend + 'static }, }; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -315,7 +308,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -332,7 +325,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs index 65fb58f601..f368d33d89 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs @@ -29,19 +29,17 @@ use crate::{ storage::{database::TransactionBackend, models::CompletedTransaction}, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; -use tari_core::{ - base_node::{ - proto::wallet_rpc::{TxLocation, TxQueryResponse}, - rpc::BaseNodeWalletRpcClient, - }, - transactions::types::Signature, +use tari_core::base_node::{ + proto::wallet_rpc::{TxLocation, TxQueryResponse}, + rpc::BaseNodeWalletRpcClient, }; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::coinbase_monitoring"; @@ -86,21 +84,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; trace!( target: LOG_TARGET, @@ -173,7 +163,7 @@ where TBackend: TransactionBackend + 'static self.base_node_public_key, self.tx_id, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -203,7 +193,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -225,7 +215,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -248,7 +238,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring protocol (TxId: {}) shutting down because it received the shutdown \ @@ -259,14 +249,14 @@ where TBackend: TransactionBackend + 'static }, } - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the \ @@ -314,10 +304,10 @@ where TBackend: TransactionBackend + 'static TransactionServiceError::InvalidCompletedTransaction, )); } - let delay = delay_for(self.timeout).fuse(); + let delay = sleep(self.timeout).fuse(); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -392,7 +382,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -411,7 +401,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the shutdown \ diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs index 744bb6f1fc..0a6bd89fb2 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs @@ -34,21 +34,18 @@ use crate::{ }, }; use chrono::Utc; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; +use futures::future::FutureExt; use log::*; use std::sync::Arc; use tari_comms::types::CommsPublicKey; +use tokio::sync::{mpsc, oneshot}; use tari_core::transactions::{ transaction::Transaction, transaction_protocol::{recipient::RecipientState, sender::TransactionSenderMessage}, }; use tari_crypto::tari_utilities::Hashable; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "wallet::transaction_service::protocols::receive_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::receive_protocol"; @@ -263,7 +260,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match inbound_tx.last_send_timestamp { @@ -310,9 +308,9 @@ where TBackend: TransactionBackend + 'static let mut incoming_finalized_transaction = None; loop { loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, tx_id, tx) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, tx_id, tx)) = receiver.recv() => { incoming_finalized_transaction = Some(tx); if inbound_tx.source_public_key != spk { warn!( @@ -325,16 +323,14 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { - if result.is_ok() { - info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); - return Err(TransactionServiceProtocolError::new( - self.id, - TransactionServiceError::TransactionCancelled, - )); - } + Ok(_) = &mut cancellation_receiver => { + info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + )); }, - () = resend_timeout => { + _ = resend_timeout => { match send_transaction_reply( inbound_tx.clone(), self.resources.outbound_message_service.clone(), @@ -353,10 +349,10 @@ where TBackend: TransactionBackend + 'static ), } }, - () = timeout_delay => { + _ = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Receive Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } @@ -381,7 +377,7 @@ where TBackend: TransactionBackend + 'static ); finalized_transaction - .validate_internal_consistency(&self.resources.factories, None) + .validate_internal_consistency(true, &self.resources.factories, None) .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; // Find your own output in the transaction diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index cf30a7f928..e88c9013e6 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -20,12 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; - -use chrono::Utc; -use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; -use log::*; - use crate::transaction_service::{ config::TransactionRoutingMechanism, error::{TransactionServiceError, TransactionServiceProtocolError}, @@ -41,7 +35,10 @@ use crate::transaction_service::{ wait_on_dial::wait_on_dial, }, }; -use futures::channel::oneshot; +use chrono::Utc; +use futures::FutureExt; +use log::*; +use std::sync::Arc; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_comms_dht::{ domain_message::OutboundDomainMessage, @@ -55,7 +52,10 @@ use tari_core::transactions::{ }; use tari_crypto::script; use tari_p2p::tari_message::TariMessageType; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc::Receiver, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::send_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::send_protocol"; @@ -344,7 +344,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match outbound_tx.last_send_timestamp { @@ -390,9 +391,9 @@ where TBackend: TransactionBackend + 'static #[allow(unused_assignments)] let mut reply = None; loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, rr) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, rr)) = receiver.recv() => { let rr_tx_id = rr.tx_id; reply = Some(rr); @@ -407,7 +408,7 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { + result = &mut cancellation_receiver => { if result.is_ok() { info!(target: LOG_TARGET, "Cancelling Transaction Send Protocol (TxId: {})", self.id); let _ = send_transaction_cancelled_message(self.id,self.dest_pubkey.clone(), self.resources.outbound_message_service.clone(), ).await.map_err(|e| { @@ -441,10 +442,10 @@ where TBackend: TransactionBackend + 'static .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; } }, - () = timeout_delay => { + () = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Send Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs index d0b2f7f6ac..dcf072c272 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -43,7 +43,7 @@ use tari_core::{ }, proto::{base_node::Signatures as SignaturesProto, types::Signature as SignatureProto}, }; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::validation_protocol"; @@ -94,14 +94,12 @@ where TBackend: TransactionBackend + 'static let mut timeout_update_receiver = self .timeout_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut base_node_update_receiver = self .base_node_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut shutdown = self.resources.shutdown_signal.clone(); @@ -158,13 +156,13 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -175,7 +173,7 @@ where TBackend: TransactionBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { @@ -204,7 +202,7 @@ where TBackend: TransactionBackend + 'static } } } - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -223,7 +221,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -231,7 +229,7 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -248,7 +246,7 @@ where TBackend: TransactionBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -282,9 +280,9 @@ where TBackend: TransactionBackend + 'static } else { break 'main; }; - let delay = delay_for(self.timeout); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.timeout); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!(target: LOG_TARGET, "Aborting Transaction Validation Protocol as new Base node is set"); @@ -372,7 +370,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -391,7 +389,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index c30fb7f412..12aa052382 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -20,49 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - output_manager_service::{handle::OutputManagerHandle, TxId}, - transaction_service::{ - config::TransactionServiceConfig, - error::{TransactionServiceError, TransactionServiceProtocolError}, - handle::{TransactionEvent, TransactionEventSender, TransactionServiceRequest, TransactionServiceResponse}, - protocols::{ - transaction_broadcast_protocol::TransactionBroadcastProtocol, - transaction_coinbase_monitoring_protocol::TransactionCoinbaseMonitoringProtocol, - transaction_receive_protocol::{TransactionReceiveProtocol, TransactionReceiveProtocolStage}, - transaction_send_protocol::{TransactionSendProtocol, TransactionSendProtocolStage}, - transaction_validation_protocol::TransactionValidationProtocol, - }, - storage::{ - database::{TransactionBackend, TransactionDatabase}, - models::{CompletedTransaction, TransactionDirection, TransactionStatus}, - }, - tasks::{ - send_finalized_transaction::send_finalized_transaction_message, - send_transaction_cancelled::send_transaction_cancelled_message, - send_transaction_reply::send_transaction_reply, - }, - }, - types::{HashDigest, ValidationRetryStrategy}, -}; -use chrono::{NaiveDateTime, Utc}; -use digest::Digest; -use futures::{ - channel::{mpsc, mpsc::Sender, oneshot}, - pin_mut, - stream::FuturesUnordered, - SinkExt, - Stream, - StreamExt, -}; -use log::*; -use rand::{rngs::OsRng, RngCore}; use std::{ collections::{HashMap, HashSet}, convert::TryInto, sync::Arc, time::{Duration, Instant}, }; + +use chrono::{NaiveDateTime, Utc}; +use digest::Digest; +use futures::{pin_mut, stream::FuturesUnordered, Stream, StreamExt}; +use log::*; +use rand::{rngs::OsRng, RngCore}; +use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteArray}; +use tokio::{sync::broadcast, task::JoinHandle}; + +use tari_common_types::types::PrivateKey; use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeIdentity, types::CommsPublicKey}; use tari_comms_dht::outbound::OutboundMessageRequester; use tari_core::{ @@ -77,15 +50,40 @@ use tari_core::{ sender::TransactionSenderMessage, RewindData, }, - types::{CryptoFactories, PrivateKey}, + CryptoFactories, ReceiverTransactionProtocol, }, }; -use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteArray}; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle}; +use tokio::sync::{mpsc, mpsc::Sender, oneshot}; + +use crate::{ + output_manager_service::{handle::OutputManagerHandle, TxId}, + transaction_service::{ + config::TransactionServiceConfig, + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::{TransactionEvent, TransactionEventSender, TransactionServiceRequest, TransactionServiceResponse}, + protocols::{ + transaction_broadcast_protocol::TransactionBroadcastProtocol, + transaction_coinbase_monitoring_protocol::TransactionCoinbaseMonitoringProtocol, + transaction_receive_protocol::{TransactionReceiveProtocol, TransactionReceiveProtocolStage}, + transaction_send_protocol::{TransactionSendProtocol, TransactionSendProtocolStage}, + transaction_validation_protocol::TransactionValidationProtocol, + }, + storage::{ + database::{TransactionBackend, TransactionDatabase}, + models::{CompletedTransaction, TransactionDirection, TransactionStatus}, + }, + tasks::{ + send_finalized_transaction::send_finalized_transaction_message, + send_transaction_cancelled::send_transaction_cancelled_message, + send_transaction_reply::send_transaction_reply, + }, + }, + types::{HashDigest, ValidationRetryStrategy}, +}; const LOG_TARGET: &str = "wallet::transaction_service::service"; @@ -276,9 +274,9 @@ where info!(target: LOG_TARGET, "Transaction Service started"); loop { - futures::select! { + tokio::select! { //Incoming request - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (request, reply_tx) = request_context.split(); @@ -303,7 +301,7 @@ where ); }, // Incoming Transaction messages from the Comms layer - msg = transaction_stream.select_next_some() => { + Some(msg) = transaction_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -333,7 +331,7 @@ where ); }, // Incoming Transaction Reply messages from the Comms layer - msg = transaction_reply_stream.select_next_some() => { + Some(msg) = transaction_reply_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -364,7 +362,7 @@ where ); }, // Incoming Finalized Transaction messages from the Comms layer - msg = transaction_finalized_stream.select_next_some() => { + Some(msg) = transaction_finalized_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -402,7 +400,7 @@ where ); }, // Incoming messages from the Comms layer - msg = base_node_response_stream.select_next_some() => { + Some(msg) = base_node_response_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -421,7 +419,7 @@ where ); } // Incoming messages from the Comms layer - msg = transaction_cancelled_stream.select_next_some() => { + Some(msg) = transaction_cancelled_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -436,7 +434,7 @@ where finish.duration_since(start).as_millis(), ); } - join_result = send_transaction_protocol_handles.select_next_some() => { + Some(join_result) = send_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Send Protocol for Transaction has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_send_transaction_protocol( @@ -446,7 +444,7 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = receive_transaction_protocol_handles.select_next_some() => { + Some(join_result) = receive_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Receive Transaction Protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_receive_transaction_protocol( @@ -456,14 +454,14 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = transaction_broadcast_protocol_handles.select_next_some() => { + Some(join_result) = transaction_broadcast_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Broadcast protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_broadcast_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Broadcast Protocol: {:?}", e), }; } - join_result = coinbase_transaction_monitoring_protocol_handles.select_next_some() => { + Some(join_result) = coinbase_transaction_monitoring_protocol_handles.next() => { trace!(target: LOG_TARGET, "Coinbase transaction monitoring protocol has ended with result {:?}", join_result); match join_result { @@ -471,21 +469,17 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Coinbase Monitoring protocol: {:?}", e), }; } - join_result = transaction_validation_protocol_handles.select_next_some() => { + Some(join_result) = transaction_validation_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Validation protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_validation_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Transaction Validation protocol: {:?}", e), }; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Transaction service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Transaction service shut down"); diff --git a/base_layer/wallet/src/transaction_service/storage/database.rs b/base_layer/wallet/src/transaction_service/storage/database.rs index aaad0b0618..7cbaa52c85 100644 --- a/base_layer/wallet/src/transaction_service/storage/database.rs +++ b/base_layer/wallet/src/transaction_service/storage/database.rs @@ -43,8 +43,9 @@ use std::{ fmt::{Display, Error, Formatter}, sync::Arc, }; +use tari_common_types::types::BlindingFactor; use tari_comms::types::CommsPublicKey; -use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction, types::BlindingFactor}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; const LOG_TARGET: &str = "wallet::transaction_service::database"; diff --git a/base_layer/wallet/src/transaction_service/storage/models.rs b/base_layer/wallet/src/transaction_service/storage/models.rs index 37f84cc3fb..4d1f57f238 100644 --- a/base_layer/wallet/src/transaction_service/storage/models.rs +++ b/base_layer/wallet/src/transaction_service/storage/models.rs @@ -27,11 +27,11 @@ use std::{ convert::TryFrom, fmt::{Display, Error, Formatter}, }; +use tari_common_types::types::PrivateKey; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ tari_amount::MicroTari, transaction::Transaction, - types::PrivateKey, ReceiverTransactionProtocol, SenderTransactionProtocol, }; diff --git a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs index 0cd55fc7f2..700590562e 100644 --- a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs @@ -20,6 +20,26 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + convert::TryFrom, + str::from_utf8, + sync::{Arc, MutexGuard, RwLock}, +}; + +use aes_gcm::{self, aead::Error as AeadError, Aes256Gcm}; +use chrono::{NaiveDateTime, Utc}; +use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; +use log::*; +use tari_crypto::tari_utilities::{ + hex::{from_hex, Hex}, + ByteArray, +}; + +use tari_common_types::types::PublicKey; +use tari_comms::types::CommsPublicKey; +use tari_core::transactions::tari_amount::MicroTari; + use crate::{ output_manager_service::TxId, schema::{completed_transactions, inbound_transactions, outbound_transactions}, @@ -40,22 +60,6 @@ use crate::{ }, util::encryption::{decrypt_bytes_integral_nonce, encrypt_bytes_integral_nonce, Encryptable}, }; -use aes_gcm::{self, aead::Error as AeadError, Aes256Gcm}; -use chrono::{NaiveDateTime, Utc}; -use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; -use log::*; -use std::{ - collections::HashMap, - convert::TryFrom, - str::from_utf8, - sync::{Arc, MutexGuard, RwLock}, -}; -use tari_comms::types::CommsPublicKey; -use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; -use tari_crypto::tari_utilities::{ - hex::{from_hex, Hex}, - ByteArray, -}; const LOG_TARGET: &str = "wallet::transaction_service::database::sqlite_db"; @@ -1650,6 +1654,34 @@ impl From for UpdateCompletedTransactionSql { #[cfg(test)] mod test { + use std::convert::TryFrom; + + use aes_gcm::{ + aead::{generic_array::GenericArray, NewAead}, + Aes256Gcm, + }; + use chrono::Utc; + use diesel::{Connection, SqliteConnection}; + use rand::rngs::OsRng; + use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + script, + script::{ExecutionStack, TariScript}, + }; + use tempfile::tempdir; + + use tari_common_types::types::{HashDigest, PrivateKey, PublicKey}; + use tari_core::transactions::{ + helpers::{create_unblinded_output, TestParams}, + tari_amount::MicroTari, + transaction::{OutputFeatures, Transaction}, + transaction_protocol::sender::TransactionSenderMessage, + CryptoFactories, + ReceiverTransactionProtocol, + SenderTransactionProtocol, + }; + use tari_test_utils::random::string; + use crate::{ storage::sqlite_utilities::WalletDbConnection, transaction_service::storage::{ @@ -1670,30 +1702,6 @@ mod test { }, util::encryption::Encryptable, }; - use aes_gcm::{ - aead::{generic_array::GenericArray, NewAead}, - Aes256Gcm, - }; - use chrono::Utc; - use diesel::{Connection, SqliteConnection}; - use rand::rngs::OsRng; - use std::convert::TryFrom; - use tari_core::transactions::{ - helpers::{create_unblinded_output, TestParams}, - tari_amount::MicroTari, - transaction::{OutputFeatures, Transaction}, - transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, HashDigest, PrivateKey, PublicKey}, - ReceiverTransactionProtocol, - SenderTransactionProtocol, - }; - use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - script, - script::{ExecutionStack, TariScript}, - }; - use tari_test_utils::random::string; - use tempfile::tempdir; #[test] fn test_crud() { diff --git a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs index 61b3ee7d75..522bacdbb9 100644 --- a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs +++ b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs @@ -27,8 +27,8 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::StreamExt; use log::*; +use tokio::sync::broadcast; const LOG_TARGET: &str = "wallet::transaction_service::tasks::start_tx_validation_and_broadcast"; @@ -36,16 +36,16 @@ pub async fn start_transaction_validation_and_broadcast_protocols( mut handle: TransactionServiceHandle, retry_strategy: ValidationRetryStrategy, ) -> Result<(), TransactionServiceError> { - let mut event_stream = handle.get_event_stream_fused(); + let mut event_stream = handle.get_event_stream(); let our_id = handle.validate_transactions(retry_strategy).await?; // Now that its started we will spawn an task to monitor the event bus and when its successful we will start the // Broadcast protocols tokio::spawn(async move { - while let Some(event_item) = event_stream.next().await { - if let Ok(event) = event_item { - match (*event).clone() { + loop { + match event_stream.recv().await { + Ok(event) => match &*event { TransactionEvent::TransactionValidationSuccess(_id) => { info!( target: LOG_TARGET, @@ -59,19 +59,28 @@ pub async fn start_transaction_validation_and_broadcast_protocols( } }, TransactionEvent::TransactionValidationFailure(id) => { - if our_id == id { + if our_id == *id { error!(target: LOG_TARGET, "Transaction Validation failed!"); break; } }, _ => (), - } - } else { - warn!( - target: LOG_TARGET, - "Error reading from Transaction Service Event Stream" - ); - break; + }, + Err(e @ broadcast::error::RecvError::Lagged(_)) => { + warn!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols: {}", e + ); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { + debug!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols is exiting because the event stream \ + closed", + ); + break; + }, } } }); diff --git a/base_layer/wallet/src/util/mod.rs b/base_layer/wallet/src/util/mod.rs index 9664a0e376..7217ac5056 100644 --- a/base_layer/wallet/src/util/mod.rs +++ b/base_layer/wallet/src/util/mod.rs @@ -20,6 +20,4 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod emoji; pub mod encryption; -pub mod luhn; diff --git a/base_layer/wallet/src/utxo_scanner_service/mod.rs b/base_layer/wallet/src/utxo_scanner_service/mod.rs index b0b475b96b..956a32848b 100644 --- a/base_layer/wallet/src/utxo_scanner_service/mod.rs +++ b/base_layer/wallet/src/utxo_scanner_service/mod.rs @@ -33,7 +33,7 @@ use futures::future; use log::*; use std::{sync::Arc, time::Duration}; use tari_comms::{connectivity::ConnectivityRequester, NodeIdentity}; -use tari_core::transactions::types::CryptoFactories; +use tari_core::transactions::CryptoFactories; use tari_service_framework::{ async_trait, reply_channel, diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs index 6aaa38363f..31955c981c 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs @@ -20,24 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - error::WalletError, - output_manager_service::{handle::OutputManagerHandle, TxId}, - storage::{ - database::{WalletBackend, WalletDatabase}, - sqlite_db::WalletSqliteDatabase, - }, - transaction_service::handle::TransactionServiceHandle, - utxo_scanner_service::{ - error::UtxoScannerError, - handle::{UtxoScannerEvent, UtxoScannerRequest, UtxoScannerResponse}, - }, - WalletSqlite, -}; -use chrono::Utc; -use futures::{pin_mut, StreamExt}; -use log::*; -use serde::{Deserialize, Serialize}; use std::{ convert::TryFrom, sync::{ @@ -46,6 +28,14 @@ use std::{ }, time::{Duration, Instant}, }; + +use chrono::Utc; +use futures::{pin_mut, StreamExt}; +use log::*; +use serde::{Deserialize, Serialize}; +use tokio::{sync::broadcast, task, time}; + +use tari_common_types::types::HashOutput; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::NodeId, @@ -64,12 +54,27 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, HashOutput}, + CryptoFactories, }, }; use tari_service_framework::{reply_channel, reply_channel::SenderService}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task, time}; + +use crate::{ + error::WalletError, + output_manager_service::{handle::OutputManagerHandle, TxId}, + storage::{ + database::{WalletBackend, WalletDatabase}, + sqlite_db::WalletSqliteDatabase, + }, + transaction_service::handle::TransactionServiceHandle, + utxo_scanner_service::{ + error::UtxoScannerError, + handle::{UtxoScannerEvent, UtxoScannerRequest, UtxoScannerResponse}, + }, + WalletSqlite, +}; +use tokio::time::MissedTickBehavior; pub const LOG_TARGET: &str = "wallet::utxo_scanning"; @@ -715,35 +720,23 @@ where TBackend: WalletBackend + 'static let mut shutdown = self.shutdown_signal.clone(); let start_at = Instant::now() + Duration::from_secs(1); - let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval).fuse(); - let mut previous = Instant::now(); + let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval); + work_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - _ = work_interval.select_next_some() => { - // This bit of code prevents bottled up tokio interval events to be fired successively for the edge - // case where a computer wakes up from sleep. - if start_at.elapsed() > self.scan_for_utxo_interval && - previous.elapsed() < self.scan_for_utxo_interval.mul_f32(0.9) - { - debug!( - target: LOG_TARGET, - "UTXO scanning work interval event fired too quickly, not running the task" - ); - } else { - let running_flag = self.is_running.clone(); - if !running_flag.load(Ordering::SeqCst) { - let task = self.create_task(); - debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); - task::spawn(async move { - if let Err(err) = task.run().await { - error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); - } - //we make sure the flag is set to false here - running_flag.store(false, Ordering::Relaxed); - }); - } + tokio::select! { + _ = work_interval.tick() => { + let running_flag = self.is_running.clone(); + if !running_flag.load(Ordering::SeqCst) { + let task = self.create_task(); + debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); + task::spawn(async move { + if let Err(err) = task.run().await { + error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); + } + //we make sure the flag is set to false here + running_flag.store(false, Ordering::Relaxed); + }); } - previous = Instant::now(); }, request_context = request_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Service API Request"); @@ -757,7 +750,7 @@ where TBackend: WalletBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { // this will stop the task if its running, and let that thread exit gracefully self.is_running.store(false, Ordering::Relaxed); info!(target: LOG_TARGET, "UTXO scanning service shutting down because it received the shutdown signal"); diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index 1f91f3d625..24e4573181 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -20,28 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::{handle::BaseNodeServiceHandle, BaseNodeServiceInitializer}, - config::{WalletConfig, KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY}, - connectivity_service::{WalletConnectivityHandle, WalletConnectivityInitializer}, - contacts_service::{handle::ContactsServiceHandle, storage::database::ContactsBackend, ContactsServiceInitializer}, - error::WalletError, - output_manager_service::{ - error::OutputManagerError, - handle::OutputManagerHandle, - storage::{database::OutputManagerBackend, models::KnownOneSidedPaymentScript}, - OutputManagerServiceInitializer, - TxId, - }, - storage::database::{WalletBackend, WalletDatabase}, - transaction_service::{ - handle::TransactionServiceHandle, - storage::database::TransactionBackend, - TransactionServiceInitializer, - }, - types::KeyDigest, - utxo_scanner_service::{handle::UtxoScannerHandle, UtxoScannerServiceInitializer}, -}; +use std::{marker::PhantomData, sync::Arc}; + use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, @@ -49,7 +29,17 @@ use aes_gcm::{ use digest::Digest; use log::*; use rand::rngs::OsRng; -use std::{marker::PhantomData, sync::Arc}; +use tari_crypto::{ + common::Blake256, + keys::SecretKey, + ristretto::{RistrettoPublicKey, RistrettoSchnorr, RistrettoSecretKey}, + script, + script::{ExecutionStack, TariScript}, + signatures::{SchnorrSignature, SchnorrSignatureError}, + tari_utilities::hex::Hex, +}; + +use tari_common_types::types::{ComSignature, PrivateKey, PublicKey}; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, @@ -62,22 +52,35 @@ use tari_comms_dht::{store_forward::StoreAndForwardRequester, Dht}; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, UnblindedOutput}, - types::{ComSignature, CryptoFactories, PrivateKey, PublicKey}, -}; -use tari_crypto::{ - common::Blake256, - keys::SecretKey, - ristretto::{RistrettoPublicKey, RistrettoSchnorr, RistrettoSecretKey}, - script, - script::{ExecutionStack, TariScript}, - signatures::{SchnorrSignature, SchnorrSignatureError}, - tari_utilities::hex::Hex, + CryptoFactories, }; use tari_key_manager::key_manager::KeyManager; use tari_p2p::{comms_connector::pubsub_connector, initialization, initialization::P2pInitializer}; use tari_service_framework::StackBuilder; use tari_shutdown::ShutdownSignal; -use tokio::runtime; + +use crate::{ + base_node_service::{handle::BaseNodeServiceHandle, BaseNodeServiceInitializer}, + config::{WalletConfig, KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY}, + connectivity_service::{WalletConnectivityHandle, WalletConnectivityInitializer}, + contacts_service::{handle::ContactsServiceHandle, storage::database::ContactsBackend, ContactsServiceInitializer}, + error::WalletError, + output_manager_service::{ + error::OutputManagerError, + handle::OutputManagerHandle, + storage::{database::OutputManagerBackend, models::KnownOneSidedPaymentScript}, + OutputManagerServiceInitializer, + TxId, + }, + storage::database::{WalletBackend, WalletDatabase}, + transaction_service::{ + handle::TransactionServiceHandle, + storage::database::TransactionBackend, + TransactionServiceInitializer, + }, + types::KeyDigest, + utxo_scanner_service::{handle::UtxoScannerHandle, UtxoScannerServiceInitializer}, +}; const LOG_TARGET: &str = "wallet"; @@ -139,8 +142,7 @@ where let bn_service_db = wallet_database.clone(); let factories = config.clone().factories; - let (publisher, subscription_factory) = - pubsub_connector(runtime::Handle::current(), config.buffer_size, config.rate_limit); + let (publisher, subscription_factory) = pubsub_connector(config.buffer_size, config.rate_limit); let peer_message_subscription_factory = Arc::new(subscription_factory); let transport_type = config.comms_config.transport_type.clone(); diff --git a/base_layer/wallet/tests/contacts_service/mod.rs b/base_layer/wallet/tests/contacts_service/mod.rs index 80970c6a17..ed5ad5033c 100644 --- a/base_layer/wallet/tests/contacts_service/mod.rs +++ b/base_layer/wallet/tests/contacts_service/mod.rs @@ -22,7 +22,7 @@ use crate::support::data::get_temp_sqlite_database_connection; use rand::rngs::OsRng; -use tari_core::transactions::types::PublicKey; +use tari_common_types::types::PublicKey; use tari_crypto::keys::PublicKey as PublicKeyTrait; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index c6c23da53e..ceb4ed4c8b 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -19,18 +19,18 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - use crate::support::{ data::get_temp_sqlite_database_connection, rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, utils::{make_input, make_input_with_features, TestParams}, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use rand::{rngs::OsRng, RngCore}; -use std::{sync::Arc, thread, time::Duration}; +use std::{sync::Arc, time::Duration}; +use tari_common_types::types::{PrivateKey, PublicKey}; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, - protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, + protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcClientConfig, RpcStatus}, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -51,12 +51,12 @@ use tari_core::{ sender::TransactionSenderMessage, single_receiver::SingleReceiverTransactionProtocol, }, - types::{CryptoFactories, PrivateKey, PublicKey}, + CryptoFactories, SenderTransactionProtocol, }, }; use tari_crypto::{ - hash::blake2::Blake256, + common::Blake256, inputs, keys::{PublicKey as PublicKeyTrait, SecretKey}, script, @@ -74,7 +74,7 @@ use tari_wallet::{ service::OutputManagerService, storage::{ database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, - models::DbUnblindedOutput, + models::{DbUnblindedOutput, OutputStatus}, sqlite_db::OutputManagerSqliteDatabase, }, TxId, @@ -83,18 +83,14 @@ use tari_wallet::{ transaction_service::handle::TransactionServiceHandle, types::ValidationRetryStrategy, }; - -use tari_comms::protocol::rpc::RpcClientConfig; -use tari_wallet::output_manager_service::storage::models::OutputStatus; use tokio::{ - runtime::Runtime, sync::{broadcast, broadcast::channel}, - time::delay_for, + task, + time, }; #[allow(clippy::type_complexity)] -pub fn setup_output_manager_service( - runtime: &mut Runtime, +async fn setup_output_manager_service( backend: T, with_connection: bool, ) -> ( @@ -124,11 +120,11 @@ pub fn setup_output_manager_service( let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); + task::spawn(connectivity_mock.run()); let service = BaseNodeWalletRpcMockService::new(); let rpc_service_state = service.get_state(); @@ -137,43 +133,39 @@ pub fn setup_output_manager_service( let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = MockRpcServer::new(server, server_node_identity.clone()); - runtime.handle().enter(|| mock_server.serve()); + mock_server.serve(); if with_connection { - let connection = runtime.block_on(async { - mock_server - .create_connection(server_node_identity.to_peer(), protocol_name.into()) - .await - }); - runtime.block_on(connectivity_mock_state.add_active_connection(connection)); + let connection = mock_server + .create_connection(server_node_identity.to_peer(), protocol_name.into()) + .await; + connectivity_mock_state.add_active_connection(connection).await; } - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - basenode_service_handle, - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + basenode_service_handle, + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -218,8 +210,7 @@ async fn complete_transaction(mut stp: SenderTransactionProtocol, mut oms: Outpu stp.get_transaction().unwrap().clone() } -pub fn setup_oms_with_bn_state( - runtime: &mut Runtime, +pub async fn setup_oms_with_bn_state( backend: T, height: Option, ) -> ( @@ -246,35 +237,35 @@ pub fn setup_oms_with_bn_state( let base_node_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_base_node_state(height); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let _connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); - - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - base_node_service_handle.clone(), - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + task::spawn(connectivity_mock.run()); + + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + base_node_service_handle.clone(), + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -321,63 +312,65 @@ fn generate_sender_transaction_message(amount: MicroTari) -> (TxId, TransactionS ) } -#[test] -fn fee_estimate() { +#[tokio::test] +async fn fee_estimate() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let (_, uo) = make_input(&mut OsRng.clone(), MicroTari::from(3000), &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // minimum fee let fee_per_gram = MicroTari::from(1); - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1) + .await .unwrap(); assert_eq!(fee, MicroTari::from(100)); let fee_per_gram = MicroTari::from(25); for outputs in 1..5 { - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs) + .await .unwrap(); assert_eq!(fee, Fee::calculate(fee_per_gram, 1, 1, outputs as usize)); } // not enough funds - let err = runtime - .block_on(oms.fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1)) + let err = oms + .fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_no_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_no_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // no chain metadata let (mut oms, _shutdown, _, _) = - setup_oms_with_bn_state(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), None); + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), None).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -389,24 +382,25 @@ fn test_utxo_selection_no_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } // but we have no chain state so the lowest maturity should be used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 8); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -415,34 +409,31 @@ fn test_utxo_selection_no_chain_metadata() { } // test that we can get a fee estimate with no chain metadata - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(300)); // test if a fee estimate would be possible with pending funds included // at this point 52000 uT is still spendable, with pending change incoming of 1690 uT // so instead of returning "not enough funds", return "funds pending" let spendable_amount = (3..=10).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::FundsPending)); // test not enough funds let broke_amount = spendable_amount + MicroTari::from(2000); - let err = runtime - .block_on(oms.fee_estimate(broke_amount, fee_per_gram, 1, 2)) - .unwrap_err(); + let err = oms.fee_estimate(broke_amount, fee_per_gram, 1, 2).await.unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // coin split uses the "Largest" selection strategy - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(fee, MicroTari::from(820)); assert_eq!(utxos_total_value, MicroTari::from(10_000)); // test that largest utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -452,31 +443,28 @@ fn test_utxo_selection_no_chain_metadata() { } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_with_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_with_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // setup with chain metadata at a height of 6 - let (mut oms, _shutdown, _, _) = setup_oms_with_bn_state( - &mut runtime, - OutputManagerSqliteDatabase::new(connection, None), - Some(6), - ); + let (mut oms, _shutdown, _, _) = + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), Some(6)).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -488,52 +476,52 @@ fn test_utxo_selection_with_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 10); // test fee estimates - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(310)); // test fee estimates are maturity aware // even though we have utxos for the fee, they can't be spent because they are not mature yet let spendable_amount = (1..=6).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // test coin split is maturity aware - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(utxos_total_value, MicroTari::from(6_000)); assert_eq!(fee, MicroTari::from(820)); // test that largest spendable utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 9); let found = utxos.iter().any(|u| u.value == 6 * amount); assert!(!found, "An unspendable utxo was selected"); // test transactions - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 1); @@ -543,20 +531,21 @@ fn test_utxo_selection_with_chain_metadata() { } // when the amount is greater than the largest utxo, then "Largest" selection strategy is used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), 6 * amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the highest spendable 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 5); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 4); @@ -566,22 +555,21 @@ fn test_utxo_selection_with_chain_metadata() { } } -#[test] -fn sending_transaction_and_confirmation() { +#[tokio::test] +async fn sending_transaction_and_confirmation() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let (_ti, uo) = make_input( &mut OsRng.clone(), MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); - match runtime.block_on(oms.add_output(uo)) { + oms.add_output(uo.clone()).await.unwrap(); + match oms.add_output(uo).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::DuplicateOutput)) => {}, _ => panic!("Incorrect error message"), }; @@ -592,25 +580,26 @@ fn sending_transaction_and_confirmation() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); // 1 of the 2 outputs should be rewindable, there should be 2 outputs due to change but if we get unlucky enough // that there is no change we will skip this aspect of the test @@ -643,23 +632,23 @@ fn sending_transaction_and_confirmation() { assert_eq!(num_rewound, 1, "Should only be 1 rewindable output"); } - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); assert_eq!( - runtime.block_on(oms.get_pending_transactions()).unwrap().len(), + oms.get_pending_transactions().await.unwrap().len(), 0, "Should have no pending tx" ); assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), + oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len(), "# Outputs should equal number of sent inputs" ); assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - num_outputs + 1 - runtime.block_on(oms.get_spent_outputs()).unwrap().len() + tx.body.outputs().len() - 1, + oms.get_unspent_outputs().await.unwrap().len(), + num_outputs + 1 - oms.get_spent_outputs().await.unwrap().len() + tx.body.outputs().len() - 1, "Unspent outputs" ); @@ -675,16 +664,14 @@ fn sending_transaction_and_confirmation() { } } -#[test] -fn send_not_enough_funds() { +#[tokio::test] +async fn send_not_enough_funds() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { let (_ti, uo) = make_input( @@ -692,68 +679,70 @@ fn send_not_enough_funds() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(num_outputs * 2000), - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(num_outputs * 2000), + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn send_no_change() { +#[tokio::test] +async fn send_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - let mut stp = runtime - .block_on(oms.prepare_transaction_to_send( + let mut stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(value1 + value2) - fee_without_change, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); assert_eq!(stp.get_amount_to_self().unwrap(), MicroTari::from(0)); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let msg = stp.build_single_round_message().unwrap(); @@ -776,99 +765,91 @@ fn send_no_change() { let tx = stp.get_transaction().unwrap(); - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), - tx.body.inputs().len() - ); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len()); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); } -#[test] -fn send_not_enough_for_change() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn send_not_enough_for_change() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(value1 + value2 + 1) - fee_without_change, - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(value1 + value2 + 1) - fee_without_change, + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn receiving_and_confirmation() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn receiving_and_confirmation() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(tx_id, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); } -#[test] -fn cancel_transaction() { +#[tokio::test] +async fn cancel_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -877,46 +858,43 @@ fn cancel_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - match runtime.block_on(oms.cancel_transaction(1)) { + match oms.cancel_transaction(1).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::ValueNotFound)) => {}, _ => panic!("Value should not exist"), } - runtime - .block_on(oms.cancel_transaction(stp.get_tx_id().unwrap())) - .unwrap(); + oms.cancel_transaction(stp.get_tx_id().unwrap()).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn cancel_transaction_and_reinstate_inbound_tx() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn cancel_transaction_and_reinstate_inbound_tx() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); - let pending_txs = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_txs = oms.get_pending_transactions().await.unwrap(); assert_eq!(pending_txs.len(), 1); @@ -928,7 +906,7 @@ fn cancel_transaction_and_reinstate_inbound_tx() { .unwrap() .clone(); - runtime.block_on(oms.cancel_transaction(tx_id)).unwrap(); + oms.cancel_transaction(tx_id).await.unwrap(); let cancelled_output = backend .fetch(&DbKey::OutputsByTxIdAndStatus(tx_id, OutputStatus::CancelledInbound)) @@ -942,28 +920,25 @@ fn cancel_transaction_and_reinstate_inbound_tx() { panic!("Should have found cancelled output"); } - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); - runtime - .block_on(oms.reinstate_cancelled_inbound_transaction(tx_id)) - .unwrap(); + oms.reinstate_cancelled_inbound_transaction(tx_id).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_incoming_balance, value); } -#[test] -fn timeout_transaction() { +#[tokio::test] +async fn timeout_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -972,50 +947,43 @@ fn timeout_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let remaining_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap().len(); + let remaining_outputs = oms.get_unspent_outputs().await.unwrap().len(); - thread::sleep(Duration::from_millis(2)); + time::sleep(Duration::from_millis(2)).await; - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1000))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1000)).await.unwrap(); - assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - remaining_outputs - ); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), remaining_outputs); - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1)).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn test_get_balance() { +#[tokio::test] +async fn test_get_balance() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(MicroTari::from(0), balance.available_balance); @@ -1023,63 +991,62 @@ fn test_get_balance() { let output_val = MicroTari::from(2000); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let send_value = MicroTari::from(1000); - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), send_value, MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let change_val = stp.get_change_amount().unwrap(); let recv_value = MicroTari::from(1500); let (_tx_id, sender_message) = generate_sender_transaction_message(recv_value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(output_val, balance.available_balance); assert_eq!(recv_value + change_val, balance.pending_incoming_balance); assert_eq!(output_val, balance.pending_outgoing_balance); } -#[test] -fn test_confirming_received_output() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn test_confirming_received_output() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output.clone()])) + oms.confirm_transaction(tx_id, vec![], vec![output.clone()]) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value); let factories = CryptoFactories::default(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1090,99 +1057,100 @@ fn test_confirming_received_output() { assert_eq!(rewind_result.committed_value, value); } -#[test] -fn sending_transaction_with_short_term_clear() { +#[tokio::test] +async fn sending_transaction_with_short_term_clear() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let available_balance = 10_000 * uT; let (_ti, uo) = make_input(&mut OsRng.clone(), available_balance, &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // Check that funds are encumbered and then unencumbered if the pending tx is not confirmed before restart - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); let expected_change = balance.pending_incoming_balance; assert_eq!(balance.pending_outgoing_balance, available_balance); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that a unconfirm Pending Transaction can be cancelled - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - runtime.block_on(oms.cancel_transaction(sender_tx_id)).unwrap(); + oms.cancel_transaction(sender_tx_id).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that is the pending tx is confirmed that the encumberance persists after restart - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - runtime.block_on(oms.confirm_pending_transaction(sender_tx_id)).unwrap(); + oms.confirm_pending_transaction(sender_tx_id).await.unwrap(); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, expected_change); } -#[test] -fn coin_split_with_change() { +#[tokio::test] +async fn coin_split_with_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let val1 = 6_000 * uT; let val2 = 7_000 * uT; @@ -1190,14 +1158,15 @@ fn coin_split_with_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); let fee_per_gram = MicroTari::from(25); let split_count = 8; - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 2); assert_eq!(coin_split_tx.body.outputs().len(), split_count + 1); @@ -1205,13 +1174,12 @@ fn coin_split_with_change() { assert_eq!(amount, val2 + val3); } -#[test] -fn coin_split_no_change() { +#[tokio::test] +async fn coin_split_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(25); let split_count = 15; @@ -1222,12 +1190,13 @@ fn coin_split_no_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 3); assert_eq!(coin_split_tx.body.outputs().len(), split_count); @@ -1235,13 +1204,12 @@ fn coin_split_no_change() { assert_eq!(amount, val1 + val2 + val3); } -#[test] -fn handle_coinbase() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn handle_coinbase() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let reward1 = MicroTari::from(1000); let fees1 = MicroTari::from(500); @@ -1253,37 +1221,25 @@ fn handle_coinbase() { let fees3 = MicroTari::from(500); let value3 = reward3 + fees3; - let _ = runtime - .block_on(oms.get_coinbase_transaction(1, reward1, fees1, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value1 - ); - let _tx2 = runtime - .block_on(oms.get_coinbase_transaction(2, reward2, fees2, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - let tx3 = runtime - .block_on(oms.get_coinbase_transaction(3, reward3, fees3, 2)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 2); + let _ = oms.get_coinbase_transaction(1, reward1, fees1, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value1); + let _tx2 = oms.get_coinbase_transaction(2, reward2, fees2, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); + let tx3 = oms.get_coinbase_transaction(3, reward3, fees3, 2).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, + oms.get_balance().await.unwrap().pending_incoming_balance, value2 + value3 ); let output = tx3.body.outputs()[0].clone(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1293,28 +1249,22 @@ fn handle_coinbase() { .unwrap(); assert_eq!(rewind_result.committed_value, value3); - runtime - .block_on(oms.confirm_transaction(3, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(3, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value3); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value3); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_outgoing_balance, + oms.get_balance().await.unwrap().pending_outgoing_balance, MicroTari::from(0) ); } -#[test] -fn test_utxo_stxo_invalid_txo_validation() { +#[tokio::test] +async fn test_utxo_stxo_invalid_txo_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1374,8 +1324,8 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1386,7 +1336,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1396,7 +1346,7 @@ fn test_utxo_stxo_invalid_txo_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1407,7 +1357,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3.clone())).unwrap(); + oms.add_output(unspent_output3.clone()).await.unwrap(); let unspent_value4 = 901; let unspent_output4 = create_unblinded_output( @@ -1418,44 +1368,42 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output4 = unspent_output4.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output4.clone())).unwrap(); + oms.add_output(unspent_output4.clone()).await.unwrap(); rpc_service_state.set_utxos(vec![invalid_output.as_transaction_output(&factories).unwrap()]); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); @@ -1466,36 +1414,34 @@ fn test_utxo_stxo_invalid_txo_validation() { unspent_tx_output3, ]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(3, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(3, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = (*msg).clone() { - success = true; - break; - }; - }; - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 4); assert!(outputs.iter().any(|o| o == &unspent_output1)); @@ -1505,46 +1451,45 @@ fn test_utxo_stxo_invalid_txo_validation() { rpc_service_state.set_utxos(vec![spent_tx_output1]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { + success = true; + break; + }; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); assert!(outputs.iter().any(|o| o == &spent_output1)); } -#[test] -fn test_base_node_switch_during_validation() { +#[tokio::test] +async fn test_base_node_switch_during_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1556,8 +1501,8 @@ fn test_base_node_switch_during_validation() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1568,7 +1513,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1578,7 +1523,7 @@ fn test_base_node_switch_during_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1589,7 +1534,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3)).unwrap(); + oms.add_output(unspent_output3).await.unwrap(); // First RPC server state rpc_service_state.set_utxos(vec![unspent_tx_output1, unspent_tx_output3]); @@ -1598,53 +1543,52 @@ fn test_base_node_switch_during_validation() { // New base node we will switch to let new_server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime - .block_on(oms.set_base_node_public_key(new_server_node_identity.public_key().clone())) + oms.set_base_node_public_key(new_server_node_identity.public_key().clone()) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut abort = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { - abort = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut abort = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { + abort = true; + break; + } + } + }, + () = &mut delay => { + break; + }, } - assert!(abort, "Did not receive validation abort"); - }); + } + assert!(abort, "Did not receive validation abort"); } -#[test] -fn test_txo_validation_connection_timeout_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_connection_timeout_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, _rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, false); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, false).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1654,7 +1598,7 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1664,57 +1608,54 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut timeout = 0; - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - match (*msg).clone() { - OutputManagerEvent::TxoValidationTimedOut(_,_) => { - timeout+=1; - }, - OutputManagerEvent::TxoValidationFailure(_,_) => { - failed+=1; - }, - _ => (), - } - }; - if timeout+failed >= 3 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)); + tokio::pin!(delay); + let mut timeout = 0; + let mut failed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + match &*event { + OutputManagerEvent::TxoValidationTimedOut(_,_) => { + timeout+=1; + }, + OutputManagerEvent::TxoValidationFailure(_,_) => { + failed+=1; + }, + _ => (), + } + + if timeout+failed >= 3 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - assert_eq!(timeout, 2); - }); + } + assert_eq!(failed, 1); + assert_eq!(timeout, 2); } -#[test] -fn test_txo_validation_rpc_error_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_error_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_rpc_status_error(Some(RpcStatus::bad_request("blah".to_string()))); let unspent_value1 = 500; @@ -1725,7 +1666,7 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1735,44 +1676,42 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } - } - - if failed >= 1 { - break; + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { + failed+=1; } - }, - () = delay => { + } + + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_rpc_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_timeout() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1784,8 +1723,8 @@ fn test_txo_validation_rpc_timeout() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(120))); let unspent_value1 = 500; @@ -1796,7 +1735,7 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1806,57 +1745,51 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for( - RpcClientConfig::default().deadline.unwrap() + - RpcClientConfig::default().deadline_grace_period + - Duration::from_secs(30), - ) - .fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } + let delay = + time::sleep(RpcClientConfig::default().timeout_with_grace_period().unwrap() + Duration::from_secs(30)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = &*msg { + failed+=1; } + } - if failed >= 1 { - break; - } - }, - () = delay => { + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_base_node_not_synced() { +#[tokio::test] +async fn test_txo_validation_base_node_not_synced() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_is_synced(false); let unspent_value1 = 500; @@ -1868,7 +1801,7 @@ fn test_txo_validation_base_node_not_synced() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1878,74 +1811,67 @@ fn test_txo_validation_base_node_not_synced() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut delayed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationDelayed(_,_) = (*msg).clone() { - delayed+=1; - } - } - if delayed >= 2 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut delayed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationDelayed(_,_) = &*event { + delayed += 1; + } + if delayed >= 2 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(delayed, 2); - }); + } + assert_eq!(delayed, 2); rpc_service_state.set_is_synced(true); rpc_service_state.set_utxos(vec![unspent_tx_output1]); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,_) = (*msg).clone() { - success = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,_) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 1); assert!(outputs.iter().any(|o| o == &unspent_output1)); } -#[test] -fn test_oms_key_manager_discrepancy() { +#[tokio::test] +async fn test_oms_key_manager_discrepancy() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (_oms_request_sender, oms_request_receiver) = reply_channel::unbounded(); let (oms_event_publisher, _) = broadcast::channel(200); @@ -1959,7 +1885,7 @@ fn test_oms_key_manager_discrepancy() { let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, _connectivity_mock) = create_connectivity_mock(); @@ -1968,45 +1894,45 @@ fn test_oms_key_manager_discrepancy() { let master_key1 = CommsSecretKey::random(&mut OsRng); - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1.clone(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1.clone(), + ) + .await + .unwrap(); drop(output_manager_service); let (_oms_request_sender2, oms_request_receiver2) = reply_channel::unbounded(); - let output_manager_service2 = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver2, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1, - )) - .expect("Should be able to make a new OMS with same master key"); + let output_manager_service2 = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver2, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1, + ) + .await + .expect("Should be able to make a new OMS with same master key"); drop(output_manager_service2); let (_oms_request_sender3, oms_request_receiver3) = reply_channel::unbounded(); let master_key2 = CommsSecretKey::random(&mut OsRng); - let output_manager_service3 = runtime.block_on(OutputManagerService::new( + let output_manager_service3 = OutputManagerService::new( OutputManagerServiceConfig::default(), ts_handle, oms_request_receiver3, @@ -2018,7 +1944,8 @@ fn test_oms_key_manager_discrepancy() { basenode_service_handle, connectivity_manager, master_key2, - )); + ) + .await; assert!(matches!( output_manager_service3, @@ -2026,26 +1953,25 @@ fn test_oms_key_manager_discrepancy() { )); } -#[test] -fn get_coinbase_tx_for_same_height() { +#[tokio::test] +async fn get_coinbase_tx_for_same_height() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); - let mut runtime = Runtime::new().unwrap(); let (mut oms, _shutdown, _, _, _, _, _) = - setup_output_manager_service(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), true); + setup_output_manager_service(OutputManagerSqliteDatabase::new(connection, None), true).await; - runtime - .block_on(oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(pending_transactions.values().any(|p| p.tx_id == 1)); - runtime - .block_on(oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(!pending_transactions.values().any(|p| p.tx_id == 1)); assert!(pending_transactions.values().any(|p| p.tx_id == 2)); } diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index c0609da64c..746d3cd9c5 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -20,7 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{data::get_temp_sqlite_database_connection, utils::make_input}; +use std::time::Duration; + use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, @@ -28,14 +29,16 @@ use aes_gcm::{ use chrono::{Duration as ChronoDuration, Utc}; use diesel::result::{DatabaseErrorKind, Error::DatabaseError}; use rand::{rngs::OsRng, RngCore}; -use std::time::Duration; +use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey, script::TariScript}; +use tokio::runtime::Runtime; + +use tari_common_types::types::PrivateKey; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::MicroTari, transaction::OutputFeatures, - types::{CryptoFactories, PrivateKey}, + CryptoFactories, }; -use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey, script::TariScript}; use tari_wallet::output_manager_service::{ error::OutputManagerStorageError, service::Balance, @@ -46,11 +49,11 @@ use tari_wallet::output_manager_service::{ }, }; -use tokio::runtime::Runtime; +use crate::support::{data::get_temp_sqlite_database_connection, utils::make_input}; #[allow(clippy::same_item_push)] pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db = OutputManagerDatabase::new(backend); let factories = CryptoFactories::default(); @@ -392,7 +395,7 @@ pub fn test_output_manager_sqlite_db_encrypted() { #[test] pub fn test_key_manager_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let db = OutputManagerDatabase::new(backend); @@ -429,7 +432,7 @@ pub fn test_key_manager_crud() { assert_eq!(read_state3.primary_key_index, 2); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_short_term_encumberance() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); @@ -510,7 +513,7 @@ pub async fn test_short_term_encumberance() { ); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_no_duplicate_outputs() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index f9d8010ac7..1b1243d72b 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -20,8 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{ message::MessageTag, multiaddr::Multiaddr, @@ -32,7 +31,7 @@ use tari_comms::{ }; use tari_comms_dht::{envelope::DhtMessageHeader, Dht}; use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, + comms_connector::InboundDomainConnector, domain_message::DomainMessage, initialization::initialize_local_test_comms, }; @@ -43,18 +42,14 @@ pub fn get_next_memory_address() -> Multiaddr { format!("/memory/{}", port).parse().unwrap() } -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, database_path: String, discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, _) = initialize_local_test_comms( node_identity, diff --git a/base_layer/wallet/tests/support/rpc.rs b/base_layer/wallet/tests/support/rpc.rs index 0f7009bdd0..29e99e3372 100644 --- a/base_layer/wallet/tests/support/rpc.rs +++ b/base_layer/wallet/tests/support/rpc.rs @@ -25,6 +25,7 @@ use std::{ sync::{Arc, Mutex}, time::{Duration, Instant}, }; +use tari_common_types::types::Signature; use tari_comms::protocol::rpc::{Request, Response, RpcStatus}; use tari_core::{ base_node::{ @@ -52,12 +53,9 @@ use tari_core::{ }, }, tari_utilities::Hashable, - transactions::{ - transaction::{Transaction, TransactionOutput}, - types::Signature, - }, + transactions::transaction::{Transaction, TransactionOutput}, }; -use tokio::time::delay_for; +use tokio::time::sleep; /// This macro unlocks a Mutex or RwLock. If the lock is /// poisoned (i.e. panic while unlocked) the last value @@ -212,7 +210,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -234,7 +232,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -256,7 +254,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -276,7 +274,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err("Did not receive enough calls within the timeout period".to_string()) } @@ -318,7 +316,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -345,7 +343,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -371,7 +369,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -415,7 +413,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -448,7 +446,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { async fn get_tip_info(&self, _request: Request<()>) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } log::info!("Get tip info call received"); @@ -483,17 +481,18 @@ mod test { }; use std::convert::TryFrom; + use tari_common_types::types::BlindingFactor; use tari_core::{ base_node::{ proto::wallet_rpc::{TxSubmissionRejectionReason, TxSubmissionResponse}, rpc::{BaseNodeWalletRpcClient, BaseNodeWalletRpcServer}, }, proto::base_node::{ChainMetadata, TipInfoResponse}, - transactions::{transaction::Transaction, types::BlindingFactor}, + transactions::transaction::Transaction, }; use tokio::time::Duration; - #[tokio_macros::test] + #[tokio::test] async fn test_wallet_rpc_mock() { let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let client_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/tests/support/utils.rs b/base_layer/wallet/tests/support/utils.rs index 630b8cd6fe..034c116c08 100644 --- a/base_layer/wallet/tests/support/utils.rs +++ b/base_layer/wallet/tests/support/utils.rs @@ -22,11 +22,11 @@ use rand::{CryptoRng, Rng}; use std::{fmt::Debug, thread, time::Duration}; +use tari_common_types::types::{CommitmentFactory, PrivateKey, PublicKey}; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, tari_amount::MicroTari, transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, - types::{CommitmentFactory, PrivateKey, PublicKey}, }; use tari_crypto::{ keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index dc508f7321..6e1265460d 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -20,30 +20,48 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - support::{ - comms_and_services::{create_dummy_message, get_next_memory_address, setup_comms_services}, - rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, - utils::{make_input, TestParams}, - }, - transaction_service::transaction_protocols::add_transaction_to_database, +use std::{ + convert::{TryFrom, TryInto}, + path::Path, + sync::Arc, + time::Duration, }; + use chrono::{Duration as ChronoDuration, Utc}; use futures::{ channel::{mpsc, mpsc::Sender}, FutureExt, SinkExt, - StreamExt, }; use prost::Message; use rand::{rngs::OsRng, RngCore}; -use std::{ - convert::{TryFrom, TryInto}, - path::Path, - sync::Arc, - time::Duration, +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + inputs, + keys::{PublicKey as PK, SecretKey as SK}, + script, + script::{ExecutionStack, TariScript}, +}; +use tempfile::tempdir; +use tokio::{ + runtime, + runtime::{Builder, Runtime}, + sync::{broadcast, broadcast::channel}, +}; + +use crate::{ + support::{ + comms_and_services::{create_dummy_message, get_next_memory_address, setup_comms_services}, + rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, + utils::{make_input, TestParams}, + }, + transaction_service::transaction_protocols::add_transaction_to_database, +}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{PrivateKey, PublicKey, Signature}, }; -use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::{ message::EnvelopeBody, peer_manager::{NodeIdentity, PeerFeatures}, @@ -75,19 +93,11 @@ use tari_core::{ tari_amount::*, transaction::{KernelBuilder, KernelFeatures, OutputFeatures, Transaction}, transaction_protocol::{proto, recipient::RecipientSignedMessage, sender::TransactionSenderMessage}, - types::{CryptoFactories, PrivateKey, PublicKey, Signature}, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }, }; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - inputs, - keys::{PublicKey as PK, SecretKey as SK}, - script, - script::{ExecutionStack, TariScript}, -}; use tari_p2p::{comms_connector::pubsub_connector, domain_message::DomainMessage, Network}; use tari_service_framework::{reply_channel, RegisterHandle, StackBuilder}; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -137,19 +147,12 @@ use tari_wallet::{ }, types::{HashDigest, ValidationRetryStrategy}, }; -use tempfile::tempdir; -use tokio::{ - runtime, - runtime::{Builder, Runtime}, - sync::{broadcast, broadcast::channel}, - time::delay_for, -}; +use tokio::time::sleep; fn create_runtime() -> Runtime { - Builder::new() - .threaded_scheduler() + Builder::new_multi_thread() .enable_all() - .core_threads(8) + .worker_threads(8) .build() .unwrap() } @@ -172,7 +175,8 @@ pub fn setup_transaction_service< discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, ) -> (TransactionServiceHandle, OutputManagerHandle, CommsNode) { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let _enter = runtime.enter(); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht) = runtime.block_on(setup_comms_services( node_identity, @@ -303,11 +307,15 @@ pub fn setup_transaction_service_no_comms_and_oms_backend< let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(server, server_node_identity.clone()) + }; - runtime.handle().enter(|| mock_server.serve()); + { + let _enter = runtime.handle().enter(); + mock_server.serve(); + } let connection = runtime.block_on(async { mock_server @@ -504,9 +512,9 @@ fn manage_single_transaction() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( &mut runtime, @@ -524,7 +532,7 @@ fn manage_single_transaction() { .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); let _ = runtime.block_on( bob_comms @@ -556,18 +564,19 @@ fn manage_single_transaction() { .expect("Alice sending tx"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut count = 0; loop { - futures::select! { - _event = alice_event_stream.select_next_some() => { + tokio::select! { + _event = alice_event_stream.recv() => { println!("alice: {:?}", &*_event.as_ref().unwrap()); count+=1; if count>=2 { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -576,18 +585,19 @@ fn manage_single_transaction() { let mut tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { println!("bob: {:?}", &*event.as_ref().unwrap()); if let TransactionEvent::ReceivedFinalizedTransaction(id) = &*event.unwrap() { tx_id = *id; finalized+=1; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -747,7 +757,7 @@ fn send_one_sided_transaction_to_other() { shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -792,11 +802,12 @@ fn send_one_sided_transaction_to_other() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut found = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCompletedImmediately(id) = &*event.unwrap() { if id == &tx_id { found = true; @@ -804,7 +815,7 @@ fn send_one_sided_transaction_to_other() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1071,9 +1082,9 @@ fn manage_multiple_transactions() { Duration::from_secs(60), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); // Spin up Bob and Carol let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( @@ -1088,8 +1099,8 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + let mut bob_event_stream = bob_ts.get_event_stream(); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( &mut runtime, @@ -1103,18 +1114,18 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); // Establish some connections beforehand, to reduce the amount of work done concurrently in tests // Connect Bob and Alice - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); let _ = runtime.block_on( bob_comms .connectivity() .dial_peer(alice_node_identity.node_id().clone()), ); - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); // Connect alice to carol let _ = runtime.block_on( @@ -1182,12 +1193,13 @@ fn manage_multiple_transactions() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1198,7 +1210,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1210,12 +1222,14 @@ fn manage_multiple_transactions() { log::trace!("Alice received all Tx messages"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1225,7 +1239,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1235,14 +1249,17 @@ fn manage_multiple_transactions() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); + + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event.unwrap() { finalized+=1 } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1264,7 +1281,7 @@ fn manage_multiple_transactions() { assert_eq!(carol_pending_inbound.len(), 0); assert_eq!(carol_completed_tx.len(), 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; bob_comms.wait_until_shutdown().await; @@ -1303,7 +1320,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); @@ -1355,11 +1372,14 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); +tokio::pin!(delay); + + tokio::pin!(delay); let mut errors = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { log::error!("ERROR: {:?}", event); if let TransactionEvent::Error(s) = &*event.unwrap() { if s == &"TransactionProtocolError(TransactionBuildError(InvalidSignatureError(\"Verifying kernel signature\")))".to_string() @@ -1371,7 +1391,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1415,7 +1435,7 @@ fn finalize_tx_with_incorrect_pubkey() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1488,15 +1508,18 @@ fn finalize_tx_with_incorrect_pubkey() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event!"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1542,7 +1565,7 @@ fn finalize_tx_with_missing_output() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1623,15 +1646,18 @@ fn finalize_tx_with_missing_output() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1648,8 +1674,7 @@ fn discovery_async_return_test() { let db_tempdir = tempdir().unwrap(); let db_folder = db_tempdir.path(); - let mut runtime = runtime::Builder::new() - .basic_scheduler() + let mut runtime = runtime::Builder::new_current_thread() .enable_time() .thread_name("discovery_async_return_test") .build() @@ -1714,7 +1739,7 @@ fn discovery_async_return_test() { Duration::from_secs(20), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo1a) = make_input(&mut OsRng, MicroTari(5500), &factories.commitment); runtime.block_on(alice_oms.add_output(uo1a)).unwrap(); @@ -1741,17 +1766,20 @@ fn discovery_async_return_test() { let mut txid = 0; let mut is_success = true; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, result) = (*event.unwrap()).clone() { txid = tx_id; is_success = result; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to fail sending"); }, } @@ -1772,18 +1800,21 @@ fn discovery_async_return_test() { let mut success_result = false; let mut success_tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, success) = &*event.unwrap() { success_result = *success; success_tx_id = *tx_id; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to successfully be sent"); }, } @@ -1794,24 +1825,26 @@ fn discovery_async_return_test() { assert!(success_result); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap() { if tx_id == &tx_id2 { break; } } }, - () = delay => { + () = &mut delay => { panic!("Timeout while Alice was waiting for a transaction reply"); }, } } }); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; carol_comms.wait_until_shutdown().await; @@ -2012,7 +2045,7 @@ fn test_transaction_cancellation() { ..Default::default() }), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let alice_total_available = 250000 * uT; let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); @@ -2030,15 +2063,17 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionStoreForwardSendResult(_,_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2054,7 +2089,7 @@ fn test_transaction_cancellation() { None => (), Some(_) => break, } - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); if i >= 12 { panic!("Pending outbound transaction should have been added by now"); } @@ -2066,17 +2101,18 @@ fn test_transaction_cancellation() { // Wait for cancellation event, in an effort to nail down where the issue is for the flakey CI test runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2143,15 +2179,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2213,15 +2250,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2243,7 +2281,7 @@ fn test_transaction_cancellation() { ))) .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime .block_on(alice_ts.get_pending_inbound_transactions()) @@ -2257,17 +2295,18 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)).fuse(); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2386,7 +2425,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob_outbound_service.call_count(), 0, "Should be no more calls"); let (_wallet_backend, backend, oms_backend, _, _temp_dir) = make_wallet_databases(None); @@ -2428,7 +2467,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob2_outbound_service.call_count(), 0, "Should be no more calls"); // Test finalize is sent Direct Only. @@ -2449,7 +2488,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { let _ = alice_outbound_service.pop_call().unwrap(); let _ = alice_outbound_service.pop_call().unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls"); // Now to repeat sending so we can test the SAF send of the finalize message @@ -2520,7 +2559,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { assert_eq!(alice_outbound_service.call_count(), 1); let _ = alice_outbound_service.pop_call(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls2"); } @@ -2548,7 +2587,7 @@ fn test_tx_direct_send_behaviour() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, 1000000 * uT, &factories.commitment); runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); @@ -2576,12 +2615,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if !result { saf_count+=1}, _ => (), @@ -2591,7 +2631,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2619,12 +2659,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 @@ -2635,7 +2676,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2663,11 +2704,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if *result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, _) => panic!("Should be no SAF messages"), @@ -2678,7 +2720,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2705,11 +2747,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 }, TransactionEvent::TransactionDirectSendResult(_, result) => if *result { panic!( @@ -2720,7 +2763,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2852,7 +2895,7 @@ fn test_restarting_transaction_protocols() { // Test that Bob's node restarts the send protocol let (mut bob_ts, _bob_oms, _bob_outbound_service, _, _, mut bob_tx_reply, _, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, bob_oms_backend, None); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); runtime .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2864,18 +2907,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_reply = false; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_reply = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2886,7 +2930,7 @@ fn test_restarting_transaction_protocols() { // Test Alice's node restarts the receive protocol let (mut alice_ts, _alice_oms, _alice_outbound_service, _, _, _, mut alice_tx_finalized, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories, alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2906,18 +2950,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_finalized = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_finalized = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3046,7 +3091,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3131,11 +3176,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3145,7 +3191,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3170,11 +3216,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3184,7 +3231,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3215,7 +3262,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3301,11 +3348,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMinedUnconfirmed(tx_id, _) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3316,7 +3364,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3331,10 +3379,14 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); - runtime.handle().enter(|| new_mock_server.serve()); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); runtime.block_on(connectivity_mock_state.add_active_connection(connection)); @@ -3368,11 +3420,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3382,7 +3435,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3413,7 +3466,7 @@ fn test_coinbase_monitoring_mined_not_synced() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3499,11 +3552,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3513,7 +3567,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3538,11 +3592,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3552,7 +3607,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3760,7 +3815,7 @@ fn test_transaction_resending() { assert_eq!(bob_reply_message.tx_id, tx_id); } - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); // See if sending a second message too soon is ignored runtime .block_on(bob_tx_sender.send(create_dummy_message( @@ -3772,7 +3827,7 @@ fn test_transaction_resending() { assert!(bob_outbound_service.wait_call_count(1, Duration::from_secs(2)).is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(bob_tx_sender.send(create_dummy_message( alice_sender_message.into(), @@ -3819,7 +3874,7 @@ fn test_transaction_resending() { .is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(alice_tx_reply_sender.send(create_dummy_message( @@ -4143,7 +4198,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(data.tx_id, tx_id); } // Need a moment for Alice's wallet to finish writing to its database before cancelling - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime.block_on(alice_ts.cancel_transaction(tx_id)).unwrap(); @@ -4193,7 +4248,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(bob_reply_message.tx_id, tx_id); // Wait for cooldown to expire - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let _ = alice_outbound_service.take_calls(); @@ -4406,7 +4461,7 @@ fn test_transaction_timeout_cancellation() { ..Default::default() }), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); runtime .block_on(carol_tx_sender.send(create_dummy_message( @@ -4431,11 +4486,12 @@ fn test_transaction_timeout_cancellation() { assert_eq!(carol_reply_message.tx_id, tx_id); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut transaction_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(t) = &*event.unwrap() { if t == &tx_id { transaction_cancelled = true; @@ -4443,7 +4499,7 @@ fn test_transaction_timeout_cancellation() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4481,7 +4537,7 @@ fn transaction_service_tx_broadcast() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4609,11 +4665,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -4621,7 +4678,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4655,11 +4712,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_mined = true; @@ -4667,7 +4725,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4683,11 +4741,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_received = true; @@ -4695,7 +4754,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4733,11 +4792,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_cancelled = true; @@ -4745,7 +4805,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4851,15 +4911,16 @@ fn broadcast_all_completed_transactions_on_startup() { assert!(runtime.block_on(alice_ts.restart_broadcast_protocols()).is_ok()); - let mut event_stream = alice_ts.get_event_stream_fused(); + let mut event_stream = alice_ts.get_event_stream(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut found1 = false; let mut found2 = false; let mut found3 = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBroadcast(tx_id) = (*event.unwrap()).clone() { if tx_id == 1u64 { found1 = true @@ -4876,7 +4937,7 @@ fn broadcast_all_completed_transactions_on_startup() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4916,7 +4977,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4995,11 +5056,12 @@ fn transaction_service_tx_broadcast_with_base_node_change() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -5007,7 +5069,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5038,11 +5100,15 @@ fn transaction_service_tx_broadcast_with_base_node_change() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); @@ -5075,17 +5141,18 @@ fn transaction_service_tx_broadcast_with_base_node_change() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(_) = &*event.unwrap(){ tx_mined = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5350,11 +5417,15 @@ fn start_validation_protocol_then_broadcast_protocol_change_base_node() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); diff --git a/base_layer/wallet/tests/transaction_service/storage.rs b/base_layer/wallet/tests/transaction_service/storage.rs index 6ea420e6a4..5573bd63e5 100644 --- a/base_layer/wallet/tests/transaction_service/storage.rs +++ b/base_layer/wallet/tests/transaction_service/storage.rs @@ -26,20 +26,24 @@ use aes_gcm::{ }; use chrono::Utc; use rand::rngs::OsRng; +use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + script, + script::{ExecutionStack, TariScript}, +}; +use tempfile::tempdir; +use tokio::runtime::Runtime; + +use tari_common_types::types::{HashDigest, PrivateKey, PublicKey}; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::{uT, MicroTari}, transaction::{OutputFeatures, Transaction}, transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, HashDigest, PrivateKey, PublicKey}, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }; -use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - script, - script::{ExecutionStack, TariScript}, -}; use tari_test_utils::random; use tari_wallet::{ storage::sqlite_utilities::run_migration_and_create_sqlite_connection, @@ -56,11 +60,8 @@ use tari_wallet::{ sqlite_db::TransactionServiceSqliteDatabase, }, }; -use tempfile::tempdir; -use tokio::runtime::Runtime; - pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let mut db = TransactionDatabase::new(backend); let factories = CryptoFactories::default(); let input = create_unblinded_output( diff --git a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs index 66079613ca..d3d8c16ca2 100644 --- a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs +++ b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs @@ -20,14 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{ - rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, - utils::make_input, -}; use chrono::Utc; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use rand::rngs::OsRng; -use std::{sync::Arc, thread::sleep, time::Duration}; use tari_comms::{ peer_manager::PeerFeatures, protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, @@ -48,7 +43,7 @@ use tari_core::{ transactions::{ helpers::schema_to_transaction, tari_amount::{uT, MicroTari, T}, - types::CryptoFactories, + CryptoFactories, }, txn_schema, }; @@ -80,7 +75,13 @@ use tari_wallet::{ types::ValidationRetryStrategy, }; use tempfile::{tempdir, TempDir}; -use tokio::{sync::broadcast, task, time::delay_for}; +use tokio::{sync::broadcast, task, time::sleep}; + +use crate::support::{ + rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, + utils::make_input, +}; +use std::{sync::Arc, time::Duration}; // Just in case other options become apparent in later testing #[derive(PartialEq)] @@ -230,7 +231,7 @@ pub async fn oms_reply_channel_task( } /// A happy path test by submitting a transaction into the mempool, have it mined but unconfirmed and then confirmed. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_i() { let ( @@ -245,7 +246,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _temp_dir, mut transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); let protocol = TransactionBroadcastProtocol::new( @@ -353,7 +354,8 @@ async fn tx_broadcast_protocol_submit_success_i() { .unwrap(); // lets wait for the transaction service event to notify us of a confirmed tx // We need to do this to ensure that the wallet db has been updated to "Mined" - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMined(_) => { @@ -392,13 +394,14 @@ async fn tx_broadcast_protocol_submit_success_i() { ); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(5)).fuse(); + let delay = sleep(Duration::from_secs(5)); + tokio::pin!(delay); let mut broadcast = false; let mut unconfirmed = false; let mut confirmed = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionMinedUnconfirmed(_, confirmations) => if *confirmations == 1 { unconfirmed = true; @@ -412,7 +415,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _ => (), } }, - () = delay => { + () = &mut delay => { break; }, } @@ -426,7 +429,7 @@ async fn tx_broadcast_protocol_submit_success_i() { } /// Test submitting a transaction that is immediately rejected -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_rejection() { let ( @@ -441,7 +444,7 @@ async fn tx_broadcast_protocol_submit_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -478,16 +481,17 @@ async fn tx_broadcast_protocol_submit_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -498,7 +502,7 @@ async fn tx_broadcast_protocol_submit_rejection() { /// Test restarting a protocol which means the first step is a query not a submission, detecting the Tx is not in the /// mempool, resubmit the tx and then have it mined -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_restart_protocol_as_query() { let ( @@ -585,7 +589,7 @@ async fn tx_broadcast_protocol_restart_protocol_as_query() { /// This test will submit a Tx which will be accepted and then dropped from the mempool, resulting in a resubmit which /// will be rejected and result in a cancelled transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { let ( @@ -600,7 +604,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -666,16 +670,17 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -686,7 +691,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { /// This test will submit a tx which is accepted and mined but unconfirmed, then the next query it will not exist /// resulting in a resubmission which we will let run to being mined with success -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { let ( @@ -751,14 +756,14 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { // Wait for the "TransactionMinedUnconfirmed" tx event to ensure that the wallet db state is "MinedUnconfirmed" let mut count = 0u16; - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMinedUnconfirmed(_, _) => { break; }, _ => { - sleep(Duration::from_millis(1000)); count += 1; if count >= 10 { break; @@ -806,7 +811,7 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { } /// Test being unable to connect and then connection becoming available. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_connection_problem() { let ( @@ -823,7 +828,7 @@ async fn tx_broadcast_protocol_connection_problem() { ) = setup(TxProtocolTestConfig::WithoutConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -839,11 +844,12 @@ async fn tx_broadcast_protocol_connection_problem() { let join_handle = task::spawn(protocol.execute()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut connection_issues = 0; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBaseNodeConnectionProblem(_) = &*event.unwrap() { connection_issues+=1; } @@ -851,7 +857,7 @@ async fn tx_broadcast_protocol_connection_problem() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -878,7 +884,7 @@ async fn tx_broadcast_protocol_connection_problem() { } /// Submit a transaction that is Already Mined for the submission, the subsequent query should confirm the transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_already_mined() { let ( @@ -948,7 +954,7 @@ async fn tx_broadcast_protocol_submit_already_mined() { } /// A test to see that the broadcast protocol can handle a change to the base node address while it runs. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { let ( @@ -1050,7 +1056,7 @@ async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_valid() { let ( @@ -1148,7 +1154,7 @@ async fn tx_validation_protocol_tx_becomes_valid() { } /// Validate completed transaction, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_invalid() { let ( @@ -1213,7 +1219,7 @@ async fn tx_validation_protocol_tx_becomes_invalid() { } /// Validate completed transactions, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_unconfirmed() { let ( @@ -1285,7 +1291,7 @@ async fn tx_validation_protocol_tx_becomes_unconfirmed() { /// Test the validation protocol reacts correctly to a change in base node and redoes the full validation based on the /// new base node -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_ends_on_base_node_end() { let ( @@ -1302,7 +1308,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1398,16 +1404,17 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { let result = join_handle.await.unwrap(); assert!(result.is_ok()); - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut aborted = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionValidationAborted(_) = &*event.unwrap() { aborted = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1416,7 +1423,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { } /// Test the validation protocol reacts correctly when the RPC client returns an error between calls. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_between_calls() { let ( @@ -1540,7 +1547,7 @@ async fn tx_validation_protocol_rpc_client_broken_between_calls() { /// Test the validation protocol reacts correctly when the RPC client returns an error between calls and only retry /// finite amount of times -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_finite_retries() { let ( @@ -1557,7 +1564,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, 1 * T, @@ -1610,12 +1617,13 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { assert!(result.is_err()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut timeouts = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { log::error!("EVENT: {:?}", event); match &*event.unwrap() { TransactionEvent::TransactionValidationTimedOut(_) => { @@ -1630,7 +1638,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1641,7 +1649,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_base_node_not_synced() { let ( @@ -1658,7 +1666,7 @@ async fn tx_validation_protocol_base_node_not_synced() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1711,12 +1719,13 @@ async fn tx_validation_protocol_base_node_not_synced() { let result = join_handle.await.unwrap(); assert!(result.is_err()); - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut delayed = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionValidationDelayed(_) => { delayed +=1 ; @@ -1728,7 +1737,7 @@ async fn tx_validation_protocol_base_node_not_synced() { } }, - () = delay => { + () = &mut delay => { break; }, } diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index 6aa231b09c..c1320e5409 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -20,18 +20,27 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{comms_and_services::get_next_memory_address, utils::make_input}; -use tari_core::transactions::transaction::OutputFeatures; +use std::{panic, path::Path, sync::Arc, time::Duration}; use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, }; use digest::Digest; -use futures::{FutureExt, StreamExt}; use rand::rngs::OsRng; -use std::{panic, path::Path, sync::Arc, time::Duration}; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_crypto::{ + common::Blake256, + inputs, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + script, +}; +use tempfile::tempdir; +use tokio::runtime::Runtime; + +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{PrivateKey, PublicKey}, +}; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, @@ -41,13 +50,8 @@ use tari_comms_dht::DhtConfig; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::{uT, MicroTari}, - types::{CryptoFactories, PrivateKey, PublicKey}, -}; -use tari_crypto::{ - common::Blake256, - inputs, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - script, + transaction::OutputFeatures, + CryptoFactories, }; use tari_p2p::{initialization::CommsConfig, transport::TransportType, Network, DEFAULT_DNS_NAME_SERVER}; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -70,8 +74,9 @@ use tari_wallet::{ WalletConfig, WalletSqlite, }; -use tempfile::tempdir; -use tokio::{runtime::Runtime, time::delay_for}; +use tokio::time::sleep; + +use crate::support::{comms_and_services::get_next_memory_address, utils::make_input}; fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { Peer::new( @@ -163,7 +168,7 @@ async fn create_wallet( .await } -#[tokio_macros::test] +#[tokio::test] async fn test_wallet() { let mut shutdown_a = Shutdown::new(); let mut shutdown_b = Shutdown::new(); @@ -227,7 +232,7 @@ async fn test_wallet() { .await .unwrap(); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream(); let value = MicroTari::from(1000); let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); @@ -245,15 +250,16 @@ async fn test_wallet() { .await .unwrap(); - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut reply_count = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { - reply_count = true; - break; - }, - () = delay => { + tokio::select! { + event = alice_event_stream.recv() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { + reply_count = true; + break; + }, + () = &mut delay => { break; }, } @@ -298,7 +304,7 @@ async fn test_wallet() { } drop(alice_event_stream); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -343,7 +349,7 @@ async fn test_wallet() { alice_wallet.remove_encryption().await.unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -379,7 +385,7 @@ async fn test_wallet() { .await .unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; partial_wallet_backup(current_wallet_path.clone(), backup_wallet_path.clone()) @@ -400,12 +406,12 @@ async fn test_wallet() { let master_secret_key = backup_wallet_db.get_master_secret_key().await.unwrap(); assert!(master_secret_key.is_none()); - shutdown_b.trigger().unwrap(); + shutdown_b.trigger(); bob_wallet.wait_until_shutdown().await; } -#[tokio_macros::test] +#[tokio::test] async fn test_do_not_overwrite_master_key() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -423,7 +429,7 @@ async fn test_do_not_overwrite_master_key() { ) .await .unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); wallet.wait_until_shutdown().await; // try to use a new master key to create a wallet using the existing wallet database @@ -457,7 +463,7 @@ async fn test_do_not_overwrite_master_key() { .unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn test_sign_message() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -516,9 +522,9 @@ fn test_store_and_forward_send_tx() { let bob_db_tempdir = tempdir().unwrap(); let carol_db_tempdir = tempdir().unwrap(); - let mut alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); let mut alice_wallet = alice_runtime .block_on(create_wallet( @@ -554,7 +560,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); let carol_identity = (*carol_wallet.comms.node_identity()).clone(); - shutdown_c.trigger().unwrap(); + shutdown_c.trigger(); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); alice_runtime @@ -591,13 +597,13 @@ fn test_store_and_forward_send_tx() { .unwrap(); // Waiting here for a while to make sure the discovery retry is over - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); alice_runtime .block_on(alice_wallet.transaction_service.cancel_transaction(tx_id)) .unwrap(); - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); let carol_wallet = carol_runtime .block_on(create_wallet( @@ -610,7 +616,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); - let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream_fused(); + let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream(); carol_runtime .block_on(carol_wallet.comms.peer_manager().add_peer(create_peer( @@ -623,13 +629,14 @@ fn test_store_and_forward_send_tx() { .unwrap(); carol_runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_recv = false; let mut tx_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransaction(_) => tx_recv = true, TransactionEvent::TransactionCancelled(_) => tx_cancelled = true, @@ -639,7 +646,7 @@ fn test_store_and_forward_send_tx() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -647,15 +654,15 @@ fn test_store_and_forward_send_tx() { assert!(tx_recv, "Must have received a tx from alice"); assert!(tx_cancelled, "Must have received a cancel tx from alice"); }); - shutdown_a.trigger().unwrap(); - shutdown_b.trigger().unwrap(); - shutdown_c2.trigger().unwrap(); + shutdown_a.trigger(); + shutdown_b.trigger(); + shutdown_c2.trigger(); alice_runtime.block_on(alice_wallet.wait_until_shutdown()); bob_runtime.block_on(bob_wallet.wait_until_shutdown()); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); } -#[tokio_macros::test] +#[tokio::test] async fn test_import_utxo() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); diff --git a/base_layer/wallet_ffi/Cargo.toml b/base_layer/wallet_ffi/Cargo.toml index d7381c0d8e..b5d719981a 100644 --- a/base_layer/wallet_ffi/Cargo.toml +++ b/base_layer/wallet_ffi/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" [dependencies] tari_comms = { version = "^0.9", path = "../../comms", default-features = false} tari_comms_dht = { version = "^0.9", path = "../../comms/dht", default-features = false } +tari_common_types = {path="../common_types"} tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } @@ -17,11 +18,11 @@ tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } tari_utilities = "^0.3" futures = { version = "^0.3.1", features =["compat", "std"]} -tokio = "0.2.10" +tokio = "1.10.1" libc = "0.2.65" rand = "0.8" chrono = { version = "0.4.6", features = ["serde"]} -thiserror = "1.0.20" +thiserror = "1.0.26" log = "0.4.6" log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} @@ -41,4 +42,3 @@ env_logger = "0.7.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} -tokio = { version="0.2.10" } diff --git a/base_layer/wallet_ffi/src/callback_handler.rs b/base_layer/wallet_ffi/src/callback_handler.rs index 6d00799f23..c5af01c7b6 100644 --- a/base_layer/wallet_ffi/src/callback_handler.rs +++ b/base_layer/wallet_ffi/src/callback_handler.rs @@ -48,7 +48,6 @@ //! request_key is used to identify which request this callback references and a result of true means it was successful //! and false that the process timed out and new one will be started -use futures::{stream::Fuse, StreamExt}; use log::*; use tari_comms::types::CommsPublicKey; use tari_comms_dht::event::{DhtEvent, DhtEventReceiver}; @@ -96,9 +95,9 @@ where TBackend: TransactionBackend + 'static callback_transaction_validation_complete: unsafe extern "C" fn(u64, u8), callback_saf_messages_received: unsafe extern "C" fn(), db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: Option, comms_public_key: CommsPublicKey, } @@ -109,9 +108,9 @@ where TBackend: TransactionBackend + 'static { pub fn new( db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: ShutdownSignal, comms_public_key: CommsPublicKey, callback_received_transaction: unsafe extern "C" fn(*mut InboundTransaction), @@ -219,8 +218,8 @@ where TBackend: TransactionBackend + 'static info!(target: LOG_TARGET, "Transaction Service Callback Handler starting"); loop { - futures::select! { - result = self.transaction_service_event_stream.select_next_some() => { + tokio::select! { + result = self.transaction_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Transaction Service Callback Handler event {:?}", msg); @@ -271,7 +270,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Transaction Service event broadcast channel"), } }, - result = self.output_manager_service_event_stream.select_next_some() => { + result = self.output_manager_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -295,7 +294,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), } }, - result = self.dht_event_stream.select_next_some() => { + result = self.dht_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "DHT Callback Handler event {:?}", msg); @@ -306,11 +305,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from DHT event broadcast channel"), } } - complete => { - info!(target: LOG_TARGET, "Callback Handler is exiting because all tasks have completed"); - break; - }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Transaction Callback Handler shutting down because the shutdown signal was received"); break; }, @@ -585,18 +580,17 @@ where TBackend: TransactionBackend + 'static mod test { use crate::callback_handler::CallbackHandler; use chrono::Utc; - use futures::StreamExt; use rand::rngs::OsRng; use std::{ sync::{Arc, Mutex}, thread, time::Duration, }; + use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; use tari_comms_dht::event::DhtEvent; use tari_core::transactions::{ tari_amount::{uT, MicroTari}, transaction::Transaction, - types::{BlindingFactor, PrivateKey, PublicKey}, ReceiverTransactionProtocol, SenderTransactionProtocol, }; @@ -774,7 +768,7 @@ mod test { #[test] fn test_callback_handler() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (_wallet_backend, backend, _oms_backend, _, _tempdir) = make_wallet_databases(None); let db = TransactionDatabase::new(backend); @@ -854,9 +848,9 @@ mod test { let shutdown_signal = Shutdown::new(); let callback_handler = CallbackHandler::new( db, - tx_receiver.fuse(), - oms_receiver.fuse(), - dht_receiver.fuse(), + tx_receiver, + oms_receiver, + dht_receiver, shutdown_signal.to_signal(), PublicKey::from_secret_key(&PrivateKey::random(&mut OsRng)), received_tx_callback, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index d8b01d9b69..1f85e33d06 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -107,20 +107,18 @@ #[cfg(test)] #[macro_use] extern crate lazy_static; -mod callback_handler; -mod enums; -mod error; -mod tasks; -use crate::{ - callback_handler::CallbackHandler, - enums::SeedWordPushResult, - error::{InterfaceError, TransactionError}, - tasks::recovery_event_monitoring, -}; use core::ptr; -use error::LibWalletError; -use futures::StreamExt; +use std::{ + boxed::Box, + ffi::{CStr, CString}, + path::PathBuf, + slice, + str::FromStr, + sync::Arc, + time::Duration, +}; + use libc::{c_char, c_int, c_longlong, c_uchar, c_uint, c_ulonglong, c_ushort}; use log::{LevelFilter, *}; use log4rs::{ @@ -136,14 +134,19 @@ use log4rs::{ encode::pattern::PatternEncoder, }; use rand::rngs::OsRng; -use std::{ - boxed::Box, - ffi::{CStr, CString}, - path::PathBuf, - slice, - str::FromStr, - sync::Arc, - time::Duration, +use tari_crypto::{ + inputs, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + script, + tari_utilities::ByteArray, +}; +use tari_utilities::{hex, hex::Hex}; +use tokio::runtime::Runtime; + +use error::LibWalletError; +use tari_common_types::{ + emoji::{emoji_set, EmojiId, EmojiIdError}, + types::{ComSignature, PublicKey}, }; use tari_comms::{ multiaddr::Multiaddr, @@ -154,23 +157,12 @@ use tari_comms::{ types::CommsSecretKey, }; use tari_comms_dht::{DbConnectionUrl, DhtConfig}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::OutputFeatures, - types::{ComSignature, CryptoFactories, PublicKey}, -}; -use tari_crypto::{ - inputs, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - script, - tari_utilities::ByteArray, -}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::OutputFeatures, CryptoFactories}; use tari_p2p::{ transport::{TorConfig, TransportType, TransportType::Tor}, Network, }; use tari_shutdown::Shutdown; -use tari_utilities::{hex, hex::Hex}; use tari_wallet::{ contacts_service::storage::database::Contact, error::{WalletError, WalletStorageError}, @@ -195,13 +187,23 @@ use tari_wallet::{ }, }, types::ValidationRetryStrategy, - util::emoji::{emoji_set, EmojiId, EmojiIdError}, utxo_scanner_service::utxo_scanning::{UtxoScannerService, RECOVERY_KEY}, Wallet, WalletConfig, WalletSqlite, }; -use tokio::runtime::Runtime; + +use crate::{ + callback_handler::CallbackHandler, + enums::SeedWordPushResult, + error::{InterfaceError, TransactionError}, + tasks::recovery_event_monitoring, +}; + +mod callback_handler; +mod enums; +mod error; +mod tasks; const LOG_TARGET: &str = "wallet_ffi"; @@ -209,7 +211,7 @@ pub type TariTransportType = tari_p2p::transport::TransportType; pub type TariPublicKey = tari_comms::types::CommsPublicKey; pub type TariPrivateKey = tari_comms::types::CommsSecretKey; pub type TariCommsConfig = tari_p2p::initialization::CommsConfig; -pub type TariExcess = tari_core::transactions::types::Commitment; +pub type TariExcess = tari_common_types::types::Commitment; pub type TariExcessPublicNonce = tari_crypto::ristretto::RistrettoPublicKey; pub type TariExcessSignature = tari_crypto::ristretto::RistrettoSecretKey; @@ -917,14 +919,14 @@ pub unsafe extern "C" fn seed_words_push_word( (*seed_words).0.push(word_string); if (*seed_words).0.len() >= 24 { - if let Err(e) = TariPrivateKey::from_mnemonic(&(*seed_words).0) { + return if let Err(e) = TariPrivateKey::from_mnemonic(&(*seed_words).0) { log::error!(target: LOG_TARGET, "Problem building private key from seed phrase"); error = LibWalletError::from(e).code; ptr::swap(error_out, &mut error as *mut c_int); - return SeedWordPushResult::InvalidSeedPhrase as u8; + SeedWordPushResult::InvalidSeedPhrase as u8 } else { - return SeedWordPushResult::SeedPhraseComplete as u8; - } + SeedWordPushResult::SeedPhraseComplete as u8 + }; } SeedWordPushResult::SuccessfulPush as u8 @@ -2858,7 +2860,7 @@ pub unsafe extern "C" fn wallet_create( match TariPrivateKey::from_mnemonic(&(*seed_words).0) { Ok(private_key) => Some(private_key), Err(e) => { - error!(target: LOG_TARGET, "Mnemonic Error for given seed words: {}", e); + error!(target: LOG_TARGET, "Mnemonic Error for given seed words: {:?}", e); error = LibWalletError::from(e).code; ptr::swap(error_out, &mut error as *mut c_int); return ptr::null_mut(); @@ -2866,7 +2868,7 @@ pub unsafe extern "C" fn wallet_create( } }; - let mut runtime = match Runtime::new() { + let runtime = match Runtime::new() { Ok(r) => r, Err(e) => { error = LibWalletError::from(InterfaceError::TokioError(e.to_string())).code; @@ -2947,15 +2949,15 @@ pub unsafe extern "C" fn wallet_create( // lets ensure the wallet tor_id is saved, this could have been changed during wallet startup if let Some(hs) = w.comms.hidden_service() { if let Err(e) = runtime.block_on(w.db.set_tor_identity(hs.tor_identity().clone())) { - warn!(target: LOG_TARGET, "Could not save tor identity to db: {}", e); + warn!(target: LOG_TARGET, "Could not save tor identity to db: {:?}", e); } } // Start Callback Handler let callback_handler = CallbackHandler::new( TransactionDatabase::new(transaction_backend), - w.transaction_service.get_event_stream_fused(), - w.output_manager_service.get_event_stream_fused(), - w.dht_service.subscribe_dht_events().fuse(), + w.transaction_service.get_event_stream(), + w.output_manager_service.get_event_stream(), + w.dht_service.subscribe_dht_events(), w.comms.shutdown_signal(), w.comms.node_identity().public_key().clone(), callback_received_transaction, @@ -5154,7 +5156,7 @@ pub unsafe extern "C" fn file_partial_backup( let runtime = Runtime::new(); match runtime { - Ok(mut runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { + Ok(runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { Ok(_) => (), Err(e) => { error = LibWalletError::from(WalletError::WalletStorageError(e)).code; @@ -5281,10 +5283,8 @@ pub unsafe extern "C" fn emoji_set_destroy(emoji_set: *mut EmojiSet) { pub unsafe extern "C" fn wallet_destroy(wallet: *mut TariWallet) { if !wallet.is_null() { let mut w = Box::from_raw(wallet); - match w.shutdown.trigger() { - Err(_) => error!(target: LOG_TARGET, "No listeners for the shutdown signal!"), - Ok(()) => w.runtime.block_on(w.wallet.wait_until_shutdown()), - } + w.shutdown.trigger(); + w.runtime.block_on(w.wallet.wait_until_shutdown()); } } @@ -5306,21 +5306,24 @@ pub unsafe extern "C" fn log_debug_message(msg: *const c_char) { #[cfg(test)] mod test { - use crate::*; - use libc::{c_char, c_uchar, c_uint}; use std::{ ffi::CString, path::Path, str::{from_utf8, FromStr}, sync::Mutex, }; + + use libc::{c_char, c_uchar, c_uint}; + use tempfile::tempdir; + + use tari_common_types::emoji; use tari_test_utils::random; use tari_wallet::{ storage::sqlite_utilities::run_migration_and_create_sqlite_connection, transaction_service::storage::models::TransactionStatus, - util::emoji, }; - use tempfile::tempdir; + + use crate::*; fn type_of(_: T) -> String { std::any::type_name::().to_string() @@ -5781,7 +5784,7 @@ mod test { error_ptr, ); - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let connection = run_migration_and_create_sqlite_connection(&sql_database_path).expect("Could not open Sqlite db"); diff --git a/base_layer/wallet_ffi/src/tasks.rs b/base_layer/wallet_ffi/src/tasks.rs index 9c44c94106..9e67eaa091 100644 --- a/base_layer/wallet_ffi/src/tasks.rs +++ b/base_layer/wallet_ffi/src/tasks.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::StreamExt; use log::*; use tari_crypto::tari_utilities::hex::Hex; use tari_wallet::{error::WalletError, utxo_scanner_service::handle::UtxoScannerEvent}; @@ -44,8 +43,8 @@ pub async fn recovery_event_monitoring( recovery_join_handle: JoinHandle>, recovery_progress_callback: unsafe extern "C" fn(u8, u64, u64), ) { - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { unsafe { (recovery_progress_callback)(RecoveryEvent::ConnectingToBaseNode as u8, 0u64, 0u64); @@ -139,6 +138,9 @@ pub async fn recovery_event_monitoring( } warn!(target: LOG_TARGET, "UTXO Scanner failed and exited",); }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Err(e) => { // Event lagging warn!(target: LOG_TARGET, "{}", e); diff --git a/common/Cargo.toml b/common/Cargo.toml index aa9d646654..5998d0bdfd 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,7 +22,7 @@ dirs-next = "1.0.2" get_if_addrs = "0.5.3" log = "0.4.8" log4rs = { version = "1.0.0", default_features= false, features = ["config_parsing", "threshold_filter"]} -multiaddr={package="parity-multiaddr", version = "0.11.0"} +multiaddr={version = "0.13.0"} sha2 = "0.9.5" path-clean = "0.1.0" tari_storage = { version = "^0.9", path = "../infrastructure/storage"} @@ -36,7 +36,7 @@ opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} anyhow = { version = "1.0", optional = true } git2 = { version = "0.8", optional = true } -prost-build = { version = "0.6.1", optional = true } +prost-build = { version = "0.8.0", optional = true } toml = { version = "0.5", optional = true } [dev-dependencies] diff --git a/common/config/presets/tari_igor_config.toml b/common/config/presets/tari_igor_config.toml new file mode 100644 index 0000000000..efb6eedf47 --- /dev/null +++ b/common/config/presets/tari_igor_config.toml @@ -0,0 +1,535 @@ +######################################################################################################################## +# # +# The Tari Network Configuration File # +# # +######################################################################################################################## + +# This file carries all the configuration options for running Tari-related nodes and infrastructure in one single +# file. As you'll notice, almost all configuraton options are commented out. This is because they are either not +# needed, are for advanced users that know what they want to tweak, or are already set at their default values. If +# things are working fine, then there's no need to change anything here. +# +# Each major section is clearly marked so that you can quickly find the section you're looking for. This first +# section holds configuration options that are common to all sections. + +# A note about Logging - The logger is initialised before the configuration file is loaded. For this reason, logging +# is not configured here, but in `~/.tari/log4rs.yml` (*nix / OsX) or `%USERPROFILE%\.tari\log4rs.yml` (Windows) by +# default, or the location specified in the TARI_LOGFILE environment variable. + +[common] +# Select the network to connect to. Valid options are: +# mainnet - the "real" Tari network (default) +# igor - the Second Tari test net +network = "igor" + +# Tari is a 100% peer-to-peer network, so there are no servers to hold messages for you while you're offline. +# Instead, we rely on our peers to hold messages for us while we're offline. This settings sets maximum size of the +# message cache that for holding our peers' messages, in MB. +#message_cache_size = 10 + +# When storing messages for peers, hold onto them for at most this long before discarding them. The default is 1440 +# minutes = or 24 hrs. +#message_cache_ttl = 1440 + +# If peer nodes spam you with messages, or are otherwise badly behaved, they will be added to your denylist and banned +# You can set a time limit to release that ban (in minutes), or otherwise ban them for life (-1). The default is to +# ban them for 10 days. +#denylist_ban_period = 1440 + +# The number of liveness sessions to allow. Liveness sessions can be established by liveness monitors over TCP by +# sending a 0x50 (P) as the first byte. Any messages sent must be followed by newline message no longer than +# 50 characters. That message will be echoed back. +#liveness_max_sessions = 0 +#liveness_allowlist_cidrs = ["127.0.0.1/32"] + +# The buffer size constants for the publish/subscribe connector channel, connecting comms messages to the domain layer: +# - Buffer size for the base node (min value = 30, default value = 1500). +#buffer_size_base_node = 1500 +# - Buffer size for the console wallet (min value = 300, default value = 50000). +#buffer_size_console_wallet = 50000 +# The rate limit constants for the publish/subscribe connector channel, i.e. maximum amount of inbound messages to +# accept - any rate attemting to exceed this limit will be throttled. +# - Rate limit for the base node (min value = 5, default value = 1000). +#buffer_rate_limit_base_node = 1000 +# - Rate limit for the console wallet (min value = 5, default value = 1000). +buffer_rate_limit_console_wallet = 1000 +# The message deduplication persistent cache size - messages with these hashes in the cache will only be processed once. +# The cache will also be trimmed down to size periodically (min value = 0, default value = 2500). +dedup_cache_capacity = 25000 + +# The timeout (s) for requesting blocks from a peer during blockchain sync (min value = 10 s, default value = 150 s). +#fetch_blocks_timeout = 150 + +# The timeout (s) for requesting UTXOs from a base node (min value = 10 s, default value = 600 s). +#fetch_utxos_timeout = 600 + +# The timeout (s) for requesting other base node services (min value = 10 s, default value = 180 s). +#service_request_timeout = 180 + +# The maximum simultaneous comms RPC sessions allowed (default value = 1000). Setting this to -1 will allow unlimited +# sessions. +rpc_max_simultaneous_sessions = 10000 + +# Auto Update +# +# This interval in seconds to check for software updates. Setting this to 0 disables checking. +# auto_update.check_interval = 300 +# Customize the hosts that are used to check for updates. These hosts must contain update information in DNS TXT records. +# auto_update.dns_hosts = ["updates.tari.com"] +# Customize the location of the update SHA hashes and maintainer-signed signature. +# auto_update.hashes_url = "https://.../hashes.txt" +# auto_update.hashes_sig_url = "https://.../hashes.txt.sig" + +######################################################################################################################## +# # +# Wallet Configuration Options # +# # +######################################################################################################################## + +# If you are not running a wallet from this configuration, you can simply leave everything in this section commented out + +[wallet] +# Override common.network for wallet +# network = "igor" + +# The relative folder to store your local key data and transaction history. DO NOT EVER DELETE THIS FILE unless you +# a) have backed up your seed phrase and +# b) know what you are doing! +wallet_db_file = "wallet/wallet.dat" +console_wallet_db_file = "wallet/console-wallet.dat" + +# Console wallet password +# Should you wish to start your console wallet without typing in your password, the following options are available: +# 1. Start the console wallet with the --password=secret argument, or +# 2. Set the environment variable TARI_WALLET_PASSWORD=secret before starting the console wallet, or +# 3. Set the "password" key in this [wallet] section of the config +# password = "secret" + +# WalletNotify +# Allows you to execute a script or program when these transaction events are received by the console wallet: +# - transaction received +# - transaction sent +# - transaction cancelled +# - transaction mined but unconfirmed +# - transaction mined and confirmed +# An example script is available here: applications/tari_console_wallet/src/notifier/notify_example.sh +# notify = "/path/to/script" + +# This is the timeout period that will be used to monitor TXO queries to the base node (default = 60). Larger values +# are needed for wallets with many (>1000) TXOs to be validated. +base_node_query_timeout = 180 +# The amount of seconds added to the current time (Utc) which will then be used to check if the message has +# expired or not when processing the message (default = 10800). +#saf_expiry_duration = 10800 +# This is the number of block confirmations required for a transaction to be considered completely mined and +# confirmed. (default = 3) +#transaction_num_confirmations_required = 3 +# This is the timeout period that will be used for base node broadcast monitoring tasks (default = 60) +transaction_broadcast_monitoring_timeout = 180 +# This is the timeout period that will be used for chain monitoring tasks (default = 60) +#transaction_chain_monitoring_timeout = 60 +# This is the timeout period that will be used for sending transactions directly (default = 20) +transaction_direct_send_timeout = 180 +# This is the timeout period that will be used for sending transactions via broadcast mode (default = 60) +transaction_broadcast_send_timeout = 180 +# This is the size of the event channel used to communicate transaction status events to the wallet's UI. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>10000) (default = 1000). +transaction_event_channel_size = 25000 +# This is the size of the event channel used to communicate base node events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>3000) (default = 250). +base_node_event_channel_size = 3500 +# This is the size of the event channel used to communicate output manager events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>3000) (default = 250). +output_manager_event_channel_size = 3500 +# This is the size of the event channel used to communicate base node update events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>300) (default = 50). +base_node_update_publisher_channel_size = 500 +# If a large amount of tiny valued uT UTXOs are used as inputs to a transaction, the fee may be larger than +# the transaction amount. Set this value to `false` to allow spending of "dust" UTXOs for small valued +# transactions (default = true). +#prevent_fee_gt_amount = false +# This option specifies the transaction routing mechanism as being directly between wallets, making +# use of store and forward or using any combination of these. +# (options: "DirectOnly", "StoreAndForwardOnly", DirectAndStoreAndForward". default: "DirectAndStoreAndForward"). +#transaction_routing_mechanism = "DirectAndStoreAndForward" + +# UTXO scanning service interval (default = 12 hours, i.e. 60 * 60 * 12 seconds) +scan_for_utxo_interval = 180 + +# When running the console wallet in command mode, use these values to determine what "stage" and timeout to wait +# for sent transactions. +# The stages are: +# - "DirectSendOrSaf" - The transaction was initiated and was accepted via Direct Send or Store And Forward. +# - "Negotiated" - The recipient replied and the transaction was negotiated. +# - "Broadcast" - The transaction was broadcast to the base node mempool. +# - "MinedUnconfirmed" - The transaction was successfully detected as mined but unconfirmed on the blockchain. +# - "Mined" - The transaction was successfully detected as mined and confirmed on the blockchain. + +# The default values are: "Broadcast", 300 +#command_send_wait_stage = "Broadcast" +#command_send_wait_timeout = 300 + +# The base nodes that the wallet should use for service requests and tracking chain state. +# base_node_service_peers = ["public_key::net_address", ...] +# base_node_service_peers = ["e856839057aac496b9e25f10821116d02b58f20129e9b9ba681b830568e47c4d::/onion3/exe2zgehnw3tvrbef3ep6taiacr6sdyeb54be2s25fpru357r4skhtad:18141"] + +# Configuration for the wallet's base node service +# The refresh interval, defaults to 10 seconds +base_node_service_refresh_interval = 30 +# The maximum age of service requests in seconds, requests older than this are discarded +base_node_service_request_max_age = 180 + +#[base_node.transport.tor] +#control_address = "/ip4/127.0.0.1/tcp/9051" +#control_auth_type = "none" # or "password" +# Required for control_auth_type = "password" +#control_auth_password = "super-secure-password" + +# Wallet configuration options for testnet +[wallet.igor] +# -------------- Transport configuration -------------- +# Use TCP to connect to the Tari network. This transport can only communicate with TCP/IP addresses, so peers with +# e.g. tor onion addresses will not be contactable. +#transport = "tcp" +# The address and port to listen for peer connections over TCP. +#tcp_listener_address = "/ip4/0.0.0.0/tcp/18188" +# Configures a tor proxy used to connect to onion addresses. All other traffic uses direct TCP connections. +# This setting is optional however, if it is not specified, this node will not be able to connect to nodes that +# only advertise an onion address. +#tcp_tor_socks_address = "/ip4/127.0.0.1/tcp/36050" +#tcp_tor_socks_auth = "none" + +# Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, +# onion v2, onion v3 and dns addresses. +transport = "tor" +# Address of the tor control server +tor_control_address = "/ip4/127.0.0.1/tcp/9051" +# Authentication to use for the tor control server +tor_control_auth = "none" # or "password=xxxxxx" +# The onion port to use. +#tor_onion_port = 18141 +# The address to which traffic on the node's onion address will be forwarded +# tor_forward_address = "/ip4/127.0.0.1/tcp/0" +# Instead of attemping to get the SOCKS5 address from the tor control port, use this one. The default is to +# use the first address returned by the tor control port (GETINFO /net/listeners/socks). +#tor_socks_address_override= + +# Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. +#transport = "socks5" +# The address of the SOCKS5 proxy +#socks5_proxy_address = "/ip4/127.0.0.1/tcp/9050" +# The address to which traffic will be forwarded +#socks5_listener_address = "/ip4/127.0.0.1/tcp/18188" +#socks5_auth = "none" # or "username_password=username:xxxxxxx" + +# Optionally bind an additional TCP socket for inbound Tari P2P protocol commms. +# Use cases include: +# - allowing wallets to locally connect to their base node, rather than through tor, when used in conjunction with `tor_proxy_bypass_addresses` +# - multiple P2P addresses, one public over DNS and one private over TOR +# - a "bridge" between TOR and TCP-only nodes +# auxilary_tcp_listener_address = "/ip4/127.0.0.1/tcp/9998" + +# When these addresses are encountered when dialing another peer, the tor proxy is bypassed and the connection is made +# direcly over TCP. /ip4, /ip6, /dns, /dns4 and /dns6 are supported. +# tor_proxy_bypass_addresses = ["/dns4/my-foo-base-node/tcp/9998"] + +######################################################################################################################## +# # +# Base Node Configuration Options # +# # +######################################################################################################################## + +# If you are not running a Tari Base node, you can simply leave everything in this section commented out. Base nodes +# help maintain the security of the Tari token and are the surest way to preserve your privacy and be 100% sure that +# no-one is cheating you out of your money. + +[base_node] +# Override common.network for base node +# network = "igor" + +# Configuration options for testnet +[base_node.igor] +# The type of database backend to use. Currently supported options are "memory" and "lmdb". LMDB is recommnded for +# almost all use cases. +db_type = "lmdb" + +# db config defaults +# db_init_size_mb = 1000 +# db_grow_size_mb = 500 +# db_resize_threshold_mb = 100 + +# The maximum number of orphans that can be stored in the Orphan block pool. Default value is "720". +# orphan_storage_capacity = 720 +# The size that the orphan pool will be allowed to grow before it is cleaned out, with threshold being tested every +# time before fetch and add blocks. Default value is "0", which indicates the orphan pool will not be cleaned out. +#orphan_db_clean_out_threshold = 0 +# The pruning horizon that indicates how many full blocks without pruning must be kept by the base node. Default value +# is "0", which indicates an archival node without any pruning. +#pruning_horizon = 0 + +# The amount of messages that will be permitted in the flood ban timespan of 100s (Default igor = 1000, +# default mainnet = 10000) +flood_ban_max_msg_count = 10000 + +# The relative path to store persistent data +data_dir = "igor" + +# When first logging onto the Tari network, you need to find a few peers to bootstrap the process. In the absence of +# any servers, this is a little more challenging than usual. Our best strategy is just to try and connect to the peers +# you knew about last time you ran the software. But what about when you run the software for the first time? That's +# where this allowlist comes in. It's a list of known Tari nodes that are likely to be around for a long time and that +# new nodes can use to introduce themselves to the network. +# peer_seeds = ["public_key1::address1", "public_key2::address2",... ] +peer_seeds = [ + "8e7eb81e512f3d6347bf9b1ca9cd67d2c8e29f2836fc5bd608206505cc72af34::/onion3/l4wouomx42nezhzexjdzfh7pcou5l7df24ggmwgekuih7tkv2rsaokqd:18141", + "00b35047a341401bcd336b2a3d564280a72f6dc72ec4c739d30c502acce4e803::/onion3/ojhxd7z6ga7qrvjlr3px66u7eiwasmffnuklscbh5o7g6wrbysj45vid:18141", + "40a9d8573745072534bce7d0ecafe882b1c79570375a69841c08a98dee9ecb5f::/onion3/io37fylc2pupg4cte4siqlsmuszkeythgjsxs2i3prm6jyz2dtophaad:18141", + "126c7ee64f71aca36398b977dd31fbbe9f9dad615df96473fb655bef5709c540::/onion3/6ilmgndocop7ybgmcvivbdsetzr5ggj4hhsivievoa2dx2b43wqlrlid:18141", +] + +# This allowlist provides a method to force syncing from any known nodes you may choose, for example if you have a +# couple of nodes that you always want to have in sync. +# force_sync_peers = ["public_key1::address1", "public_key2::address2",... ] +force_sync_peers = [ + #my known peer 1 + #"public_key1::address1", + #my known peer 2 + #"public_key1::address1", +] + +# DNS seeds +# The DNS records in these hostnames should provide TXT records as per https://github.com/tari-project/tari/pull/2319 +# Enter a domain name for the TXT records: seeds.tari.com +dns_seeds =["seeds.igor.tari.com"] +# The name server used to resolve DNS seeds (Default: "1.1.1.1:53") +# dns_seeds_name_server = "1.1.1.1:53" +# Set to true to only accept DNS records that pass DNSSEC validation (Default: true) +dns_seeds_use_dnssec = false + +# Determines the method of syncing blocks when the node is lagging. If you are not struggling with syncing, then +# it is recommended to leave this setting as it. Available values are ViaBestChainMetadata and ViaRandomPeer. +#block_sync_strategy="ViaBestChainMetadata" + +# Configure the maximum number of threads available for base node operation. These threads are spawned lazily, so a higher +# number is recommended. +# max_threads = 512 + +# The number of threads to spawn and keep active at all times. The default is the number of cores available on this node. +# core_threads = + +# The node's publicly-accessible hostname. This is the host name that is advertised on the network so that +# peers can find you. +# _NOTE_: If using the `tor` transport type, public_address will be ignored and an onion address will be +# automatically configured +#public_address = "/ip4/172.2.3.4/tcp/18189" + +# do we allow test addresses to be accpted like 127.0.0.1 +allow_test_addresses = false + +# Enable the gRPC server for the base node. Set this to true if you want to enable third-party wallet software +grpc_enabled = true +# The socket to expose for the gRPC base node server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +grpc_base_node_address = "127.0.0.1:18142" +# The socket to expose for the gRPC wallet server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +grpc_console_wallet_address = "127.0.0.1:18143" + +# A path to the file that stores your node identity and secret key +base_node_identity_file = "config/base_node_id.json" + +# A path to the file that stores your console wallet's node identity and secret key +console_wallet_identity_file = "config/console_wallet_id.json" + +# -------------- Transport configuration -------------- +# Use TCP to connect to the Tari network. This transport can only communicate with TCP/IP addresses, so peers with +# e.g. tor onion addresses will not be contactable. +#transport = "tcp" +# The address and port to listen for peer connections over TCP. +#tcp_listener_address = "/ip4/0.0.0.0/tcp/18189" +# Configures a tor proxy used to connect to onion addresses. All other traffic uses direct TCP connections. +# This setting is optional however, if it is not specified, this node will not be able to connect to nodes that +# only advertise an onion address. +#tcp_tor_socks_address = "/ip4/127.0.0.1/tcp/36050" +#tcp_tor_socks_auth = "none" + +# Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, +# onion v2, onion v3 and dns addresses. +transport = "tor" +# Address of the tor control server +tor_control_address = "/ip4/127.0.0.1/tcp/9051" +# Authentication to use for the tor control server +tor_control_auth = "none" # or "password=xxxxxx" +# The onion port to use. +#tor_onion_port = 18141 +# The address to which traffic on the node's onion address will be forwarded +# tor_forward_address = "/ip4/127.0.0.1/tcp/0" +# Instead of attemping to get the SOCKS5 address from the tor control port, use this one. The default is to +# use the first address returned by the tor control port (GETINFO /net/listeners/socks). +#tor_socks_address_override= + +# Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. +#transport = "socks5" +# The address of the SOCKS5 proxy +#socks5_proxy_address = "/ip4/127.0.0.1/tcp/9050" +# The address to which traffic will be forwarded +#socks5_listener_address = "/ip4/127.0.0.1/tcp/18189" +#socks5_auth = "none" # or "username_password=username:xxxxxxx" + +# A path to the file that stores the tor hidden service private key, if using the tor transport. +base_node_tor_identity_file = "config/base_node_tor.json" + +# A path to the file that stores the console wallet's tor hidden service private key, if using the tor transport. +console_wallet_tor_identity_file = "config/console_wallet_tor.json" + +# Optionally bind an additional TCP socket for inbound Tari P2P protocol commms. +# Use cases include: +# - allowing wallets to locally connect to their base node, rather than through tor, when used in conjunction with `tor_proxy_bypass_addresses` +# - multiple P2P addresses, one public over DNS and one private over TOR +# - a "bridge" between TOR and TCP-only nodes +# auxilary_tcp_listener_address = "/ip4/127.0.0.1/tcp/9998" + +# When these addresses are encountered when dialing another peer, the tor proxy is bypassed and the connection is made +# direcly over TCP. /ip4, /ip6, /dns, /dns4 and /dns6 are supported. +# tor_proxy_bypass_addresses = ["/dns4/my-foo-base-node/tcp/9998"] + +######################################################################################################################## +# # +# Mempool Configuration Options # +# # +######################################################################################################################## +[mempool.igor] + +# The maximum number of transactions that can be stored in the Unconfirmed Transaction pool. This is the main waiting +# area in the mempool and almost all transactions will end up in this pool before being mined. It's for this reason +# that this parameter will have the greatest impact on actual memory usage by your mempool. If you are not mining, +# you can reduce this parameter to reduce memory consumption by your node, at the expense of network bandwith. For +# reference, a single block can hold about 4,000 transactions +# Default = 40,000 transactions +# unconfirmed_pool_storage_capacity = 40000 + +# The maximum number of transactions that can be stored in the Orphan Transaction pool. This pool keep transactions +# that are 'orphans', i.e. transactions with inputs that don't exist in the UTXO set. If you're not mining, and +# memory usage is a concern, this can safely be set to zero. Even so, orphan transactions do not appear that often +# (it's usually a short chain of spends that are broadcast in quick succession). The other potential source of orphan +# transactions are from DOS attacks and setting the `tx_ttl` parameter to a low value is an effective countermeasure +# in this case. Default: 250 transactions +# orphan_pool_storage_capacity = 250 + +# The maximum amount of time an orphan transaction will be permitted to stay in the mempool before being rejected. +# This should be set to a fairly long enough to allow the parent transaction to arrive; but low enough also to thwart +# DOS attacks. Default: 300 seconds +#orphan_tx_ttl = 300 + +# The maximum number of transactions that can be stored in the Pending Transaction pool. This pool holds transactions +# that are valid, but cannot be included in a block yet becuase there is a consensus rule holding it back, usually a +# time lock. Once the conditions holding the transaction in the pending pool are resolved, the transaction will move +# into the unconfirmed pool. Default: 5,000 transactions +# pending_pool_storage_capacity = 5000 + +# The ReorgPool consists of all transactions that have recently been added to blocks. +# When a potential blockchain reorganization occurs the transactions can be recovered from the ReorgPool and can be +# added back into the UnconfirmedPool. Transactions in the ReOrg pool have a limited Time-to-live and will be removed +# from the pool when the Time-to-live thresholds is reached. Also, when the capacity of the pool has been reached, the +# oldest transactions will be removed to make space for incoming transactions. The pool capacity and TTL parameters +# have the same meaning as those for the pending pool, but applied to the reorg pool; obviously. +# Defaults: 10,000 transactions and 300 seconds +#reorg_pool_storage_capacity = 10_000 +#reorg_tx_ttl = 300 + +# The maximum number of transactions that can be skipped when compiling a set of highest priority transactions, +# skipping over large transactions are performed in an attempt to fit more transactions into the remaining space. +# This parameter only affects mining nodes. You can ignore it if you are only running a base node. Even so, changing +# this parameter should not affect profitabilty in any meaningful way, since the transaction weights are selected to +# closely mirror how much block space they take up +#weight_tx_skip_count = 20 + +######################################################################################################################## +# # +# Validator Node Configuration Options # +# # +######################################################################################################################## + +# If you are not , you can simply leave everything in this section commented out. Base nodes +# help maintain the security of the Tari token and are the surest way to preserve your privacy and be 100% sure that +# no-one is cheating you out of your money. + +[validator_node] + +# Enable the gRPC server for the base node. Set this to true if you want to enable third-party wallet software +#grpc_enabled = false + +# The socket to expose for the gRPC base node server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +#grpc_address = "127.0.0.1:18042" + +######################################################################################################################## +# # +# Merge Mining Configuration Options # +# # +######################################################################################################################## + +[merge_mining_proxy.igor] + +# URL to monerod +monerod_url = "http://monero-stagenet.exan.tech:38081" # stagenet +#monerod_url = "http://18.133.59.45:28081" # testnet +#monerod_url = "http://18.132.124.81:18081" # mainnet +#monerod_url = "http://monero.exan.tech:18081" # mainnet alternative + +# Address of the tari_merge_mining_proxy application +proxy_host_address = "127.0.0.1:7878" + +# In sole merged mining, the block solution is usually submitted to the Monero blockchain +# (monerod) as well as to the Tari blockchain, then this setting should be "true". With pool +# merged mining, there is no sense in submitting the solution to the Monero blockchain as the +# pool does that, then this setting should be "false". (default = true). +proxy_submit_to_origin = true + +# If authentication is being used for curl +monerod_use_auth = false + +# Username for curl +monerod_username = "" + +# Password for curl +monerod_password = "" + +# The merge mining proxy can either wait for the base node to achieve initial sync at startup before it enables mining, +# or not. If merge mining starts before the base node has achieved initial sync, those Tari mined blocks will not be +# accepted. (Default value = true; will wait for base node initial sync). +#wait_for_initial_sync_at_startup = true + +[stratum_transcoder] + +# Address of the tari_stratum_transcoder application +transcoder_host_address = "127.0.0.1:7879" + +[mining_node] +# Number of mining threads +# Default: number of logical CPU cores +#num_mining_threads=8 + +# GRPC address of base node +# Default: value from `base_node.grpc_base_node_address` +#base_node_grpc_address = "127.0.0.1:18142" + +# GRPC address of console wallet +# Default: value from `base_node.grpc_console_wallet_address` +#wallet_grpc_address = "127.0.0.1:18143" + +# Start mining only when base node is bootstrapped +# and current block height is on the tip of network +# Default: true +#mine_on_tip_only=true + +# Will check tip with node every N seconds and restart mining +# if height already taken and option `mine_on_tip_only` is set +# to true +# Default: 30 seconds +#validate_tip_timeout_sec=30 + +# Stratum Mode configuration +# mining_pool_address = "miningcore.tarilabs.com:3052" +# mining_wallet_address = "YOUR_WALLET_PUBLIC_KEY" +# mining_worker_name = "worker1" diff --git a/common/logging/log4rs_sample_mining_node.yml b/common/logging/log4rs_sample_mining_node.yml index f0c8a965b8..16c4c43739 100644 --- a/common/logging/log4rs_sample_mining_node.yml +++ b/common/logging/log4rs_sample_mining_node.yml @@ -14,10 +14,6 @@ appenders: kind: console encoder: pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] {h({l}):5} {m}{n}" - filters: - - - kind: threshold - level: warn # An appender named "base_layer" that writes to a file with a custom pattern encoder mining_node: kind: rolling_file @@ -37,9 +33,22 @@ appenders: # Set the default logging level to "warn" and attach the "stdout" appender to the root root: - level: info + level: warn appenders: - stdout - - mining_node + +loggers: + # mining_node + tari::application: + level: debug + appenders: + - mining_node + additive: false + tari_mining_node: + level: debug + appenders: + - mining_node + - stdout + additive: false diff --git a/common/src/configuration/bootstrap.rs b/common/src/configuration/bootstrap.rs index 6e0a668541..79c98f8d97 100644 --- a/common/src/configuration/bootstrap.rs +++ b/common/src/configuration/bootstrap.rs @@ -148,6 +148,9 @@ pub struct ConfigBootstrap { pub miner_max_diff: Option, #[structopt(long, alias = "tracing")] pub tracing_enabled: bool, + /// Supply a network (overrides existing configuration) + #[structopt(long, alias = "network")] + pub network: Option, } fn normalize_path(path: PathBuf) -> PathBuf { @@ -183,6 +186,7 @@ impl Default for ConfigBootstrap { miner_min_diff: None, miner_max_diff: None, tracing_enabled: false, + network: None, } } } diff --git a/common/src/configuration/global.rs b/common/src/configuration/global.rs index 880e111cd1..be8cb38038 100644 --- a/common/src/configuration/global.rs +++ b/common/src/configuration/global.rs @@ -71,7 +71,6 @@ pub struct GlobalConfig { pub pruning_horizon: u64, pub pruned_mode_cleanup_interval: u64, pub core_threads: Option, - pub max_threads: Option, pub base_node_identity_file: PathBuf, pub public_address: Multiaddr, pub grpc_enabled: bool, @@ -137,6 +136,7 @@ pub struct GlobalConfig { pub mining_pool_address: String, pub mining_wallet_address: String, pub mining_worker_name: String, + pub base_node_bypass_range_proof_verification: bool, } impl GlobalConfig { @@ -270,10 +270,6 @@ fn convert_node_config( let core_threads = optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - let key = config_string("base_node", &net_str, "max_threads"); - let max_threads = - optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - // Max RandomX VMs let key = config_string("base_node", &net_str, "max_randomx_vms"); let max_randomx_vms = optional(cfg.get_int(&key).map(|n| n as usize)) @@ -376,6 +372,8 @@ fn convert_node_config( s.parse::() .map_err(|e| ConfigurationError::new(&key, &e.to_string())) })?; + let key = config_string("base_node", &net_str, "bypass_range_proof_verification"); + let base_node_bypass_range_proof_verification = cfg.get_bool(&key).unwrap_or(false); let key = config_string("base_node", &net_str, "dns_seeds_use_dnssec"); let dns_seeds_use_dnssec = cfg @@ -712,7 +710,6 @@ fn convert_node_config( pruning_horizon, pruned_mode_cleanup_interval, core_threads, - max_threads, base_node_identity_file, public_address, grpc_enabled, @@ -778,6 +775,7 @@ fn convert_node_config( mining_pool_address, mining_wallet_address, mining_worker_name, + base_node_bypass_range_proof_verification, }) } diff --git a/common/src/configuration/network.rs b/common/src/configuration/network.rs index c8a0d3fe4a..1498b1e623 100644 --- a/common/src/configuration/network.rs +++ b/common/src/configuration/network.rs @@ -37,6 +37,7 @@ pub enum Network { Ridcully = 0x21, Stibbons = 0x22, Weatherwax = 0x23, + Igor = 0x24, } impl Network { @@ -51,6 +52,7 @@ impl Network { Ridcully => "ridcully", Stibbons => "stibbons", Weatherwax => "weatherwax", + Igor => "igor", LocalNet => "localnet", } } @@ -73,6 +75,7 @@ impl FromStr for Network { "weatherwax" => Ok(Weatherwax), "mainnet" => Ok(MainNet), "localnet" => Ok(LocalNet), + "igor" => Ok(Igor), invalid => Err(ConfigurationError::new( "network", &format!("Invalid network option: {}", invalid), diff --git a/common/src/configuration/utils.rs b/common/src/configuration/utils.rs index 1f291cf0ce..3814deb8d9 100644 --- a/common/src/configuration/utils.rs +++ b/common/src/configuration/utils.rs @@ -192,7 +192,7 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { .unwrap(); cfg.set_default( "base_node.weatherwax.data_dir", - default_subdir("stibbons/", Some(&bootstrap.base_path)), + default_subdir("weatherwax/", Some(&bootstrap.base_path)), ) .unwrap(); cfg.set_default( @@ -228,7 +228,6 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { .unwrap(); cfg.set_default("base_node.weatherwax.grpc_console_wallet_address", "127.0.0.1:18143") .unwrap(); - cfg.set_default("base_node.weatherwax.dns_seeds_name_server", "1.1.1.1:53") .unwrap(); cfg.set_default("base_node.weatherwax.dns_seeds_use_dnssec", true) @@ -238,6 +237,28 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { cfg.set_default("wallet.base_node_service_peers", Vec::::new()) .unwrap(); + //---------------------------------- Igor Defaults --------------------------------------------// + + cfg.set_default("base_node.igor.db_type", "lmdb").unwrap(); + cfg.set_default("base_node.igor.orphan_storage_capacity", 720).unwrap(); + cfg.set_default("base_node.igor.orphan_db_clean_out_threshold", 0) + .unwrap(); + cfg.set_default("base_node.igor.pruning_horizon", 0).unwrap(); + cfg.set_default("base_node.igor.pruned_mode_cleanup_interval", 50) + .unwrap(); + cfg.set_default("base_node.igor.flood_ban_max_msg_count", 1000).unwrap(); + cfg.set_default("base_node.igor.public_address", format!("{}/tcp/18141", local_ip_addr)) + .unwrap(); + cfg.set_default("base_node.igor.grpc_enabled", false).unwrap(); + cfg.set_default("base_node.igor.grpc_base_node_address", "127.0.0.1:18142") + .unwrap(); + cfg.set_default("base_node.igor.grpc_console_wallet_address", "127.0.0.1:18143") + .unwrap(); + cfg.set_default("base_node.igor.dns_seeds_name_server", "1.1.1.1:53") + .unwrap(); + cfg.set_default("base_node.igor.dns_seeds_use_dnssec", true).unwrap(); + cfg.set_default("base_node.igor.auto_ping_interval", 30).unwrap(); + set_transport_defaults(&mut cfg).unwrap(); set_merge_mining_defaults(&mut cfg); set_mining_node_defaults(&mut cfg); @@ -254,6 +275,8 @@ fn set_stratum_transcoder_defaults(cfg: &mut Config) { "127.0.0.1:7879", ) .unwrap(); + cfg.set_default("stratum_transcoder.igor.transcoder_host_address", "127.0.0.1:7879") + .unwrap(); } fn set_merge_mining_defaults(cfg: &mut Config) { @@ -289,6 +312,16 @@ fn set_merge_mining_defaults(cfg: &mut Config) { .unwrap(); cfg.set_default("merge_mining_proxy.weatherwax.wait_for_initial_sync_at_startup", true) .unwrap(); + cfg.set_default("merge_mining_proxy.igor.proxy_host_address", "127.0.0.1:7878") + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.proxy_submit_to_origin", true) + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_use_auth", "false") + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_username", "").unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_password", "").unwrap(); + cfg.set_default("merge_mining_proxy.igor.wait_for_initial_sync_at_startup", true) + .unwrap(); } fn set_mining_node_defaults(cfg: &mut Config) { @@ -372,6 +405,18 @@ fn set_transport_defaults(cfg: &mut Config) -> Result<(), config::ConfigError> { )?; cfg.set_default(&format!("{}.weatherwax.socks5_auth", app), "none")?; + + // igor + cfg.set_default(&format!("{}.igor.transport", app), "tor")?; + + cfg.set_default(&format!("{}.igor.tor_control_address", app), "/ip4/127.0.0.1/tcp/9051")?; + cfg.set_default(&format!("{}.igor.tor_control_auth", app), "none")?; + cfg.set_default(&format!("{}.igor.tor_forward_address", app), "/ip4/127.0.0.1/tcp/0")?; + cfg.set_default(&format!("{}.igor.tor_onion_port", app), "18141")?; + + cfg.set_default(&format!("{}.igor.socks5_proxy_address", app), "/ip4/0.0.0.0/tcp/9150")?; + + cfg.set_default(&format!("{}.igor.socks5_auth", app), "none")?; } Ok(()) } diff --git a/common/src/dns/tests.rs b/common/src/dns/tests.rs index 955f22cf97..b7dc087517 100644 --- a/common/src/dns/tests.rs +++ b/common/src/dns/tests.rs @@ -48,7 +48,7 @@ use trust_dns_client::rr::{rdata, RData, Record, RecordType}; // Ignore as this test requires network IO #[ignore] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { let mut resolver = PeerSeedResolver::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(); let seeds = resolver.resolve("tari.com").await.unwrap(); @@ -64,7 +64,7 @@ fn create_txt_record(contents: Vec) -> Record { } #[allow(clippy::vec_init_then_push)] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_peer_seeds() { let mut records = Vec::new(); // Multiple addresses(works) diff --git a/common/src/lib.rs b/common/src/lib.rs index 6f5c98a2e4..cb2c99d0c1 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -72,7 +72,7 @@ //! let config = args.load_configuration().unwrap(); //! let global = GlobalConfig::convert_from(ApplicationType::BaseNode, config).unwrap(); //! assert_eq!(global.network, Network::Weatherwax); -//! assert!(global.max_threads.is_none()); +//! assert!(global.core_threads.is_none()); //! # std::fs::remove_dir_all(temp_dir).unwrap(); //! ``` diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 71b8eec12f..b781fa12bf 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -12,34 +12,36 @@ edition = "2018" [dependencies] tari_crypto = "0.11.1" tari_storage = { version = "^0.9", path = "../infrastructure/storage" } -tari_shutdown = { version="^0.9", path = "../infrastructure/shutdown" } +tari_shutdown = { version = "^0.9", path = "../infrastructure/shutdown" } +anyhow = "1.0.32" async-trait = "0.1.36" bitflags = "1.0.4" blake2 = "0.9.0" -bytes = { version = "0.5.x", features=["serde"] } +bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4.6", features = ["serde"] } cidr = "0.1.0" clear_on_drop = "=0.2.4" data-encoding = "2.2.0" digest = "0.9.0" -futures = { version = "^0.3", features = ["async-await"]} +futures = { version = "^0.3", features = ["async-await"] } lazy_static = "1.3.0" lmdb-zero = "0.4.4" log = { version = "0.4.0", features = ["std"] } -multiaddr = {version = "=0.11.0", package = "parity-multiaddr"} -nom = {version = "5.1.0", features=["std"], default-features=false} +multiaddr = { version = "0.13.0" } +nom = { version = "5.1.0", features = ["std"], default-features = false } openssl = { version = "0.10", features = ["vendored"] } -pin-project = "0.4.17" -prost = "=0.6.1" +pin-project = "1.0.8" +prost = "=0.8.0" rand = "0.8" serde = "1.0.119" serde_derive = "1.0.119" -snow = {version="=0.8.0", features=["default-resolver"]} -thiserror = "1.0.20" -tokio = {version="~0.2.19", features=["blocking", "time", "tcp", "dns", "sync", "stream", "signal"]} -tokio-util = {version="0.3.1", features=["codec"]} -tower= "0.3.1" +snow = { version = "=0.8.0", features = ["default-resolver"] } +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } +tokio-stream = { version = "0.1.7", features = ["sync"] } +tokio-util = { version = "0.6.7", features = ["codec", "compat"] } +tower = "0.3.1" tracing = "0.1.26" tracing-futures = "0.2.5" yamux = "=0.9.0" @@ -49,20 +51,19 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} # RPC dependencies -tower-make = {version="0.3.0", optional=true} -anyhow = "1.0.32" +tower-make = { version = "0.3.0", optional = true } [dev-dependencies] -tari_test_utils = {version="^0.9", path="../infrastructure/test_utils"} -tari_comms_rpc_macros = {version="*", path="./rpc_macros"} +tari_test_utils = { version = "^0.9", path = "../infrastructure/test_utils" } +tari_comms_rpc_macros = { version = "*", path = "./rpc_macros" } env_logger = "0.7.0" serde_json = "1.0.39" -tokio-macros = "0.2.3" +#tokio = {version="1.8", features=["macros"]} tempfile = "3.1.0" [build-dependencies] -tari_common = { version = "^0.9", path="../common", features = ["build"]} +tari_common = { version = "^0.9", path = "../common", features = ["build"] } [features] avx2 = ["tari_crypto/avx2"] diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index c75e423543..77cf0d9e9d 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -10,58 +10,57 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} -tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros"} +tari_comms = { version = "^0.9", path = "../", features = ["rpc"] } +tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros" } tari_crypto = "0.11.1" -tari_utilities = { version = "^0.3" } -tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown"} -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_utilities = { version = "^0.3" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } anyhow = "1.0.32" bitflags = "1.2.0" -bytes = "0.4.12" +bytes = "0.5" chacha20 = "0.7.1" chrono = "0.4.9" -diesel = {version="1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } digest = "0.9.0" -futures= {version= "^0.3.1"} +futures = { version = "^0.3.1" } log = "0.4.8" -prost = "=0.6.1" -prost-types = "=0.6.1" +prost = "=0.8.0" +prost-types = "=0.8.0" rand = "0.8" serde = "1.0.90" serde_derive = "1.0.90" serde_repr = "0.1.5" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} -tower= "0.3.1" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt", "macros"] } +tower = "0.3.1" ttl_cache = "0.5.1" # tower-filter dependencies pin-project = "0.4" [dev-dependencies] -tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } env_logger = "0.7.0" -futures-test = { version = "0.3.0-alpha.19", package = "futures-test-preview" } +futures-test = { version = "0.3.5" } lmdb-zero = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.3" +tokio-stream = { version = "0.1.7", features = ["sync"] } petgraph = "0.5.1" clap = "2.33.0" # tower-filter dependencies tower-test = { version = "^0.3" } -tokio-test = "^0.2" -tokio = "^0.2" +tokio-test = "^0.4.2" futures-util = "^0.3.1" lazy_static = "1.4.0" [build-dependencies] -tari_common = { version = "^0.9", path="../../common"} +tari_common = { version = "^0.9", path = "../../common" } [features] test-mocks = [] diff --git a/comms/dht/examples/graphing_utilities/utilities.rs b/comms/dht/examples/graphing_utilities/utilities.rs index bfb081eb21..3a4fea7ae9 100644 --- a/comms/dht/examples/graphing_utilities/utilities.rs +++ b/comms/dht/examples/graphing_utilities/utilities.rs @@ -32,6 +32,7 @@ use petgraph::{ }; use std::{collections::HashMap, convert::TryFrom, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; use tari_comms::{connectivity::ConnectivitySelection, peer_manager::NodeId}; +use tari_test_utils::streams::convert_unbounded_mpsc_to_stream; const TEMP_GRAPH_OUTPUT_DIR: &str = "/tmp/memorynet_temp"; @@ -277,7 +278,9 @@ pub enum PythonRenderType { /// This function will drain the message event queue and then build a message propagation tree assuming the first sender /// is the starting node pub async fn track_join_message_drain_messaging_events(messaging_rx: &mut NodeEventRx) -> StableGraph { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); let messages = drain_fut.await; let num_messages = messages.len(); diff --git a/comms/dht/examples/memory_net/drain_burst.rs b/comms/dht/examples/memory_net/drain_burst.rs index d2f5bce2be..59ed8723f4 100644 --- a/comms/dht/examples/memory_net/drain_burst.rs +++ b/comms/dht/examples/memory_net/drain_burst.rs @@ -42,7 +42,7 @@ where St: ?Sized + Stream + Unpin let (lower_bound, upper_bound) = stream.size_hint(); Self { inner: stream, - collection: Vec::with_capacity(upper_bound.or(Some(lower_bound)).unwrap()), + collection: Vec::with_capacity(upper_bound.unwrap_or(lower_bound)), } } } @@ -70,15 +70,16 @@ where St: ?Sized + Stream + Unpin mod test { use super::*; use futures::stream; + use tari_comms::runtime; - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_terminating_stream() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; assert_eq!(burst, (1..10u8).into_iter().collect::>()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_stream_with_pending() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bf1bf03ae4..0d4a675e8e 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -22,7 +22,7 @@ #![allow(clippy::mutex_atomic)] use crate::memory_net::DrainBurst; -use futures::{channel::mpsc, future, StreamExt}; +use futures::future; use lazy_static::lazy_static; use rand::{rngs::OsRng, Rng}; use std::{ @@ -62,8 +62,13 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tari_test_utils::{paths::create_temporary_data_path, random}; -use tokio::{runtime, sync::broadcast, task, time}; +use tari_test_utils::{paths::create_temporary_data_path, random, streams::convert_unbounded_mpsc_to_stream}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, + task, + time, +}; use tower::ServiceBuilder; pub type NodeEventRx = mpsc::UnboundedReceiver<(NodeId, NodeId)>; @@ -154,7 +159,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent start.elapsed() ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, Err(err) => { @@ -166,7 +171,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent err ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, } @@ -298,7 +303,7 @@ pub async fn do_network_wide_propagation(nodes: &mut [TestNode], origin_node_ind let node_name = node.name.clone(); task::spawn(async move { - let result = time::timeout(Duration::from_secs(30), ims_rx.next()).await; + let result = time::timeout(Duration::from_secs(30), ims_rx.recv()).await; let mut is_success = false; match result { Ok(Some(msg)) => { @@ -450,21 +455,23 @@ pub async fn do_store_and_forward_message_propagation( for (idx, mut s) in neighbour_subs.into_iter().enumerate() { let neighbour = neighbours[idx].name.clone(); task::spawn(async move { - let msg = time::timeout(Duration::from_secs(2), s.next()).await; + let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { - Ok(Some(Ok(evt))) => { + Ok(Ok(evt)) => { if let MessagingEvent::MessageReceived(_, tag) = &*evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, - Ok(_) => {}, + Ok(Err(err)) => { + println!("{}", err); + }, Err(_) => println!("{} did not receive the SAF message", neighbour), } }); } banner!("⏰ Waiting a few seconds for messages to propagate around the network..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; let mut total_messages = drain_messaging_events(messaging_rx, false).await; @@ -515,7 +522,7 @@ pub async fn do_store_and_forward_message_propagation( let mut num_msgs = 0; let mut succeeded = 0; loop { - let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().next()).await; + let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().recv()).await; num_msgs += 1; match result { Ok(msg) => { @@ -554,7 +561,9 @@ pub async fn do_store_and_forward_message_propagation( } pub async fn drain_messaging_events(messaging_rx: &mut NodeEventRx, show_logs: bool) -> usize { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); if show_logs { let messages = drain_fut.await; let num_messages = messages.len(); @@ -694,42 +703,46 @@ impl TestNode { fn spawn_event_monitor( comms: &CommsNode, - messaging_events: MessagingEventReceiver, + mut messaging_events: MessagingEventReceiver, events_tx: mpsc::Sender>, messaging_events_tx: NodeEventTx, quiet_mode: bool, ) { - let conn_man_event_sub = comms.subscribe_connection_manager_events(); + let mut conn_man_event_sub = comms.subscribe_connection_manager_events(); let executor = runtime::Handle::current(); - executor.spawn( - conn_man_event_sub - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .map(connection_manager_logger( - comms.node_identity().node_id().clone(), - quiet_mode, - )) - .map(Ok) - .forward(events_tx), - ); - let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + let mut logger = connection_manager_logger(node_id, quiet_mode); + loop { + match conn_man_event_sub.recv().await { + Ok(event) => { + events_tx.send(logger(event)).await.unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => break, + Err(err) => log::error!("{}", err), + } + } + }); - executor.spawn( - messaging_events - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .filter_map(move |event| { - use MessagingEvent::*; - future::ready(match &*event { - MessageReceived(peer_node_id, _) => Some((Clone::clone(&*peer_node_id), node_id.clone())), - _ => None, - }) - }) - .map(Ok) - .forward(messaging_events_tx), - ); + let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + loop { + let event = messaging_events.recv().await; + use MessagingEvent::*; + match event.as_deref() { + Ok(MessageReceived(peer_node_id, _)) => { + messaging_events_tx + .send((Clone::clone(&*peer_node_id), node_id.clone())) + .unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, + _ => {}, + } + } + }); } #[inline] @@ -749,7 +762,7 @@ impl TestNode { } use ConnectionManagerEvent::*; loop { - let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.next()) + let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.recv()) .await .ok()??; @@ -763,7 +776,7 @@ impl TestNode { } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -946,5 +959,5 @@ async fn setup_comms_dht( pub async fn take_a_break(num_nodes: usize) { banner!("Taking a break for a few seconds to let things settle..."); - time::delay_for(Duration::from_millis(num_nodes as u64 * 100)).await; + time::sleep(Duration::from_millis(num_nodes as u64 * 100)).await; } diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 9cc28551bc..3ed35c05b7 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -49,15 +49,16 @@ use crate::memory_net::utilities::{ shutdown_all, take_a_break, }; -use futures::{channel::mpsc, future}; +use futures::future; use rand::{rngs::OsRng, Rng}; use std::{iter::repeat_with, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; // Size of network -const NUM_NODES: usize = 6; +const NUM_NODES: usize = 40; // Must be at least 2 -const NUM_WALLETS: usize = 50; +const NUM_WALLETS: usize = 5; const QUIET_MODE: bool = true; /// Number of neighbouring nodes each node should include in the connection pool const NUM_NEIGHBOURING_NODES: usize = 8; @@ -66,7 +67,7 @@ const NUM_RANDOM_NODES: usize = 4; /// The number of messages that should be propagated out const PROPAGATION_FACTOR: usize = 4; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -77,7 +78,7 @@ async fn main() { NUM_WALLETS ); - let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let seed_node = vec![ make_node( diff --git a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs index a0a1bddc1e..5d2759fb54 100644 --- a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs +++ b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs @@ -65,11 +65,11 @@ use crate::{ }, }; use clap::{App, Arg}; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -96,7 +96,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_join.rs b/comms/dht/examples/memorynet_graph_network_track_join.rs index 259358e1cb..0a0324b683 100644 --- a/comms/dht/examples/memorynet_graph_network_track_join.rs +++ b/comms/dht/examples/memorynet_graph_network_track_join.rs @@ -73,11 +73,11 @@ use crate::{ }; use clap::{App, Arg}; use env_logger::Env; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -106,7 +106,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_propagation.rs b/comms/dht/examples/memorynet_graph_network_track_propagation.rs index fcc5debff3..d560b9f537 100644 --- a/comms/dht/examples/memorynet_graph_network_track_propagation.rs +++ b/comms/dht/examples/memorynet_graph_network_track_propagation.rs @@ -73,10 +73,10 @@ use crate::{ }, }; use env_logger::Env; -use futures::channel::mpsc; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -105,7 +105,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index c2c2d4e52a..2e453291ac 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -37,13 +37,7 @@ use crate::{ DhtConfig, }; use chrono::{DateTime, Utc}; -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot}, - future::BoxFuture, - stream::{Fuse, FuturesUnordered}, - SinkExt, - StreamExt, -}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use log::*; use std::{cmp, fmt, fmt::Display, sync::Arc}; use tari_comms::{ @@ -51,10 +45,15 @@ use tari_comms::{ peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, }; +use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::ShutdownSignal; use tari_utilities::message_format::{MessageFormat, MessageFormatError}; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::actor"; @@ -62,8 +61,6 @@ const LOG_TARGET: &str = "comms::dht::actor"; pub enum DhtActorError { #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("Reply sender canceled the request")] ReplyCanceled, #[error("PeerManagerError: {0}")] @@ -84,15 +81,9 @@ pub enum DhtActorError { ConnectivityEventStreamClosed, } -impl From for DhtActorError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtActorError::ChannelDisconnected - } else if err.is_full() { - DhtActorError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtActorError { + fn from(_: mpsc::error::SendError) -> Self { + DhtActorError::ChannelDisconnected } } @@ -101,9 +92,14 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Inserts a message signature to the msg hash cache. This operation replies with a boolean - /// which is true if the signature already exists in the cache, otherwise false - MsgHashCacheInsert(Vec, CommsPublicKey, oneshot::Sender), + /// Inserts a message signature to the msg hash cache. This operation replies with the number of times this message + /// has previously been seen (hit count) + MsgHashCacheInsert { + message_hash: Vec, + received_from: CommsPublicKey, + reply_tx: oneshot::Sender, + }, + GetMsgHashHitCount(Vec, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), @@ -114,12 +110,22 @@ impl Display for DhtRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use DhtRequest::*; match self { - SendJoin => f.write_str("SendJoin"), - MsgHashCacheInsert(_, _, _) => f.write_str("MsgHashCacheInsert"), - SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), - GetMetadata(key, _) => f.write_str(&format!("GetMetadata (key={})", key)), + SendJoin => write!(f, "SendJoin"), + MsgHashCacheInsert { + message_hash, + received_from, + .. + } => write!( + f, + "MsgHashCacheInsert(message hash: {}, received from: {})", + message_hash.to_hex(), + received_from.to_hex(), + ), + GetMsgHashHitCount(hash, _) => write!(f, "GetMsgHashHitCount({})", hash.to_hex()), + SelectPeers(s, _) => write!(f, "SelectPeers (Strategy={})", s), + GetMetadata(key, _) => write!(f, "GetMetadata (key={})", key), SetMetadata(key, value, _) => { - f.write_str(&format!("SetMetadata (key={}, value={} bytes)", key, value.len())) + write!(f, "SetMetadata (key={}, value={} bytes)", key, value.len()) }, } } @@ -147,14 +153,27 @@ impl DhtRequester { reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) } - pub async fn insert_message_hash( + pub async fn add_message_to_dedup_cache( &mut self, message_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { + received_from: CommsPublicKey, + ) -> Result { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(DhtRequest::MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + }) + .await?; + + reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) + } + + pub async fn get_message_cache_hit_count(&mut self, message_hash: Vec) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtRequest::MsgHashCacheInsert(message_hash, public_key, reply_tx)) + .send(DhtRequest::GetMsgHashHitCount(message_hash, reply_tx)) .await?; reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) @@ -186,8 +205,8 @@ pub struct DhtActor { outbound_requester: OutboundMessageRequester, connectivity: ConnectivityRequester, config: DhtConfig, - shutdown_signal: Option, - request_rx: Fuse>, + shutdown_signal: ShutdownSignal, + request_rx: mpsc::Receiver, msg_hash_dedup_cache: DedupCacheDatabase, } @@ -217,8 +236,8 @@ impl DhtActor { peer_manager, connectivity, node_identity, - shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + shutdown_signal, + request_rx, } } @@ -247,33 +266,28 @@ impl DhtActor { let mut pending_jobs = FuturesUnordered::new(); - let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval).fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtActor initialized without shutdown_signal"); + let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { trace!(target: LOG_TARGET, "DhtActor received request: {}", request); pending_jobs.push(self.request_handler(request)); }, - result = pending_jobs.select_next_some() => { + Some(result) = pending_jobs.next() => { if let Err(err) = result { debug!(target: LOG_TARGET, "Error when handling DHT request message. {}", err); } }, - _ = dedup_cache_trim_ticker.select_next_some() => { - if let Err(err) = self.msg_hash_dedup_cache.truncate().await { + _ = dedup_cache_trim_ticker.tick() => { + if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); self.mark_shutdown_time().await; break Ok(()); @@ -300,24 +314,36 @@ impl DhtActor { let outbound_requester = self.outbound_requester.clone(); Box::pin(Self::broadcast_join(node_identity, outbound_requester)) }, - MsgHashCacheInsert(hash, public_key, reply_tx) => { + MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + } => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - match msg_hash_cache.insert_body_hash_if_unique(hash, public_key).await { - Ok(already_exists) => { - let _ = reply_tx.send(already_exists).map_err(|_| DhtActorError::ReplyCanceled); + match msg_hash_cache.add_body_hash(message_hash, received_from).await { + Ok(hit_count) => { + let _ = reply_tx.send(hit_count); }, Err(err) => { warn!( target: LOG_TARGET, "Unable to update message dedup cache because {:?}", err ); - let _ = reply_tx.send(false).map_err(|_| DhtActorError::ReplyCanceled); + let _ = reply_tx.send(0); }, } Ok(()) }) }, + GetMsgHashHitCount(hash, reply_tx) => { + let msg_hash_cache = self.msg_hash_dedup_cache.clone(); + Box::pin(async move { + let hit_count = msg_hash_cache.get_hit_count(hash).await?; + let _ = reply_tx.send(hit_count); + Ok(()) + }) + }, SelectPeers(broadcast_strategy, reply_tx) => { let peer_manager = Arc::clone(&self.peer_manager); let node_identity = Arc::clone(&self.node_identity); @@ -690,11 +716,12 @@ mod test { test_utils::{build_peer_manager, make_client_identity, make_node_identity}, }; use chrono::{DateTime, Utc}; - use std::time::Duration; - use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; + use tari_comms::{ + runtime, + test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}, + }; use tari_shutdown::Shutdown; use tari_test_utils::random; - use tokio::time::delay_for; async fn db_connection() -> DbConnection { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); @@ -702,7 +729,7 @@ mod test { conn } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_join_request() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -727,11 +754,11 @@ mod test { actor.spawn(); requester.send_join().await.unwrap(); - let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap()); + let (params, _) = unwrap_oms_send_msg!(out_rx.recv().await.unwrap()); assert_eq!(params.dht_message_type, DhtMessageType::Join); } - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_message_signature() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -756,24 +783,24 @@ mod test { actor.spawn(); let signature = vec![1u8, 2, 3]; - let is_dup = requester - .insert_message_hash(signature.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(signature.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); - let is_dup = requester - .insert_message_hash(signature, CommsPublicKey::default()) + assert_eq!(num_hits, 1); + let num_hits = requester + .add_message_to_dedup_cache(signature, CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); - let is_dup = requester - .insert_message_hash(Vec::new(), CommsPublicKey::default()) + assert_eq!(num_hits, 2); + let num_hits = requester + .add_message_to_dedup_cache(Vec::new(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn dedup_cache_cleanup() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -783,14 +810,12 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); - let mut shutdown = Shutdown::new(); - let trim_interval_ms = 500; + let shutdown = Shutdown::new(); // Note: This must be equal or larger than the minimum dedup cache capacity for DedupCacheDatabase - let capacity = 120; + let capacity = 10; let actor = DhtActor::new( DhtConfig { dedup_cache_capacity: capacity, - dedup_cache_trim_interval: Duration::from_millis(trim_interval_ms), ..Default::default() }, db_connection().await, @@ -803,66 +828,64 @@ mod test { ); // Create signatures for double the dedup cache capacity - let mut signatures: Vec> = Vec::new(); - for i in 0..(capacity * 2) { - signatures.push(vec![1u8, 2, i as u8]) - } + let signatures = (0..(capacity * 2)).map(|i| vec![1u8, 2, i as u8]).collect::>(); - // Pre-populate the dedup cache; everything should be accepted due to cleanup ticker not active yet + // Pre-populate the dedup cache; everything should be accepted because the cleanup ticker has not run yet for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Try to re-insert all; everything should be marked as duplicates due to cleanup ticker not active yet + // Try to re-insert all; all hashes should have incremented their hit count for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 2); } - // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire immediately + let dedup_cache_db = actor.msg_hash_dedup_cache.clone(); + // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire fairly soon after the + // task is running on a thread. To remove this race condition, we trim the cache in the test. + dedup_cache_db.trim_entries().await.unwrap(); actor.spawn(); // Verify that the last half of the signatures are still present in the cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 3); } // Verify that the first half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Let the trim period expire; this will trim the dedup cache to capacity - delay_for(Duration::from_millis(trim_interval_ms * 2)).await; + // Trim the database of excess entries + dedup_cache_db.trim_entries().await.unwrap(); // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - - shutdown.trigger().unwrap(); } - #[tokio_macros::test_basic] + #[runtime::test] async fn select_peers() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -973,7 +996,7 @@ mod test { assert_eq!(peers.len(), 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn get_and_set_metadata() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -1029,6 +1052,6 @@ mod test { .unwrap(); assert_eq!(got_ts, ts); - shutdown.trigger().unwrap(); + shutdown.trigger(); } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index 249ed3d369..bd7fb2521c 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; -use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeIdentity, PeerManager}, }; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; pub struct DhtBuilder { node_identity: Arc, @@ -99,6 +99,11 @@ impl DhtBuilder { self } + pub fn with_dedup_discard_hit_count(mut self, max_hit_count: usize) -> Self { + self.config.dedup_allowed_message_occurrences = max_hit_count; + self + } + pub fn with_num_random_nodes(mut self, n: usize) -> Self { self.config.num_random_nodes = n; self diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0612445dca..a1b553ebb6 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -72,6 +72,10 @@ pub struct DhtConfig { /// The periodic trim interval for items in the message hash cache /// Default: 300s (5 mins) pub dedup_cache_trim_interval: Duration, + /// The number of occurrences of a message is allowed to pass through the DHT pipeline before being + /// deduped/discarded + /// Default: 1 + pub dedup_allowed_message_occurrences: usize, /// The duration to wait for a peer discovery to complete before giving up. /// Default: 2 minutes pub discovery_request_timeout: Duration, @@ -136,6 +140,7 @@ impl DhtConfig { impl Default for DhtConfig { fn default() -> Self { + // NB: please remember to update field comments to reflect these defaults Self { num_neighbouring_nodes: 8, num_random_nodes: 4, @@ -151,6 +156,7 @@ impl Default for DhtConfig { saf_max_message_size: 512 * 1024, dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), + dedup_allowed_message_occurrences: 1, database_url: DbConnectionUrl::Memory, discovery_request_timeout: Duration::from_secs(2 * 60), connectivity_update_interval: Duration::from_secs(2 * 60), diff --git a/comms/dht/src/connectivity/metrics.rs b/comms/dht/src/connectivity/metrics.rs index b7ee546d4e..ba456d9aa9 100644 --- a/comms/dht/src/connectivity/metrics.rs +++ b/comms/dht/src/connectivity/metrics.rs @@ -20,20 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot, oneshot::Canceled}, - future, - SinkExt, - StreamExt, -}; use log::*; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, - future::Future, time::{Duration, Instant}, }; use tari_comms::peer_manager::NodeId; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot}, + task, +}; const LOG_TARGET: &str = "comms::dht::metrics"; @@ -124,7 +120,7 @@ impl MetricsState { } pub struct MetricsCollector { - stream: Option>, + stream: mpsc::Receiver, state: MetricsState, } @@ -133,18 +129,17 @@ impl MetricsCollector { let (metrics_tx, metrics_rx) = mpsc::channel(500); let metrics_collector = MetricsCollectorHandle::new(metrics_tx); let collector = Self { - stream: Some(metrics_rx), + stream: metrics_rx, state: Default::default(), }; task::spawn(collector.run()); metrics_collector } - fn run(mut self) -> impl Future { - self.stream.take().unwrap().for_each(move |op| { + async fn run(mut self) { + while let Some(op) = self.stream.recv().await { self.handle(op); - future::ready(()) - }) + } } fn handle(&mut self, op: MetricOp) { @@ -286,7 +281,7 @@ impl MetricsCollectorHandle { match self.inner.try_send(MetricOp::Write(write)) { Ok(_) => true, Err(err) => { - warn!(target: LOG_TARGET, "Failed to write metric: {}", err.into_send_error()); + warn!(target: LOG_TARGET, "Failed to write metric: {:?}", err); false }, } @@ -338,14 +333,14 @@ pub enum MetricsError { ReplyCancelled, } -impl From for MetricsError { - fn from(_: SendError) -> Self { +impl From> for MetricsError { + fn from(_: mpsc::error::SendError) -> Self { MetricsError::ChannelClosedUnexpectedly } } -impl From for MetricsError { - fn from(_: Canceled) -> Self { +impl From for MetricsError { + fn from(_: oneshot::error::RecvError) -> Self { MetricsError::ReplyCancelled } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index c9c855b6c1..a3c4471451 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -27,7 +27,6 @@ mod metrics; pub use metrics::{MetricsCollector, MetricsCollectorHandle}; use crate::{connectivity::metrics::MetricsError, event::DhtEvent, DhtActorError, DhtConfig, DhtRequester}; -use futures::{stream::Fuse, StreamExt}; use log::*; use std::{sync::Arc, time::Instant}; use tari_comms::{ @@ -78,11 +77,11 @@ pub struct DhtConnectivity { /// Used to track when the random peer pool was last refreshed random_pool_last_refresh: Option, stats: Stats, - dht_events: Fuse>>, + dht_events: broadcast::Receiver>, metrics_collector: MetricsCollectorHandle, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl DhtConnectivity { @@ -108,8 +107,8 @@ impl DhtConnectivity { metrics_collector, random_pool_last_refresh: None, stats: Stats::new(), - dht_events: dht_events.fuse(), - shutdown_signal: Some(shutdown_signal), + dht_events, + shutdown_signal, } } @@ -131,21 +130,15 @@ impl DhtConnectivity { }) } - pub async fn run(mut self, connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { - let mut connectivity_events = connectivity_events.fuse(); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtConnectivity initialized without a shutdown_signal"); - + pub async fn run(mut self, mut connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { debug!(target: LOG_TARGET, "DHT connectivity starting"); self.refresh_neighbour_pool().await?; - let mut ticker = time::interval(self.config.connectivity_update_interval).fuse(); + let mut ticker = time::interval(self.config.connectivity_update_interval); loop { - futures::select! { - event = connectivity_events.select_next_some() => { + tokio::select! { + event = connectivity_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { debug!(target: LOG_TARGET, "Error handling connectivity event: {:?}", err); @@ -153,15 +146,13 @@ impl DhtConnectivity { } }, - event = self.dht_events.select_next_some() => { - if let Ok(event) = event { - if let Err(err) = self.handle_dht_event(&event).await { - debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); - } - } + Ok(event) = self.dht_events.recv() => { + if let Err(err) = self.handle_dht_event(&event).await { + debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); + } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_random_pool_if_required().await { debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); } @@ -170,7 +161,7 @@ impl DhtConnectivity { } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtConnectivity shutting down because the shutdown signal was received"); break; } diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 65dba8c235..d0e83c0aa5 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -30,6 +30,7 @@ use std::{iter::repeat_with, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{Peer, PeerFeatures}, + runtime, test_utils::{ count_string_occurrences, mocks::{create_connectivity_mock, create_dummy_peer_connection, ConnectivityManagerMockState}, @@ -89,7 +90,7 @@ async fn setup( ) } -#[tokio_macros::test_basic] +#[runtime::test] async fn initialize() { let config = DhtConfig { num_neighbouring_nodes: 4, @@ -127,7 +128,7 @@ async fn initialize() { assert!(managed.iter().all(|n| !neighbours.contains(n))); } -#[tokio_macros::test_basic] +#[runtime::test] async fn added_neighbours() { let node_identity = make_node_identity(); let mut node_identities = @@ -173,7 +174,7 @@ async fn added_neighbours() { assert!(managed.contains(closer_peer.node_id())); } -#[tokio_macros::test_basic] +#[runtime::test] #[allow(clippy::redundant_closure)] async fn reinitialize_pools_when_offline() { let node_identity = make_node_identity(); @@ -215,7 +216,7 @@ async fn reinitialize_pools_when_offline() { assert_eq!(managed.len(), 5); } -#[tokio_macros::test_basic] +#[runtime::test] async fn insert_neighbour() { let node_identity = make_node_identity(); let node_identities = @@ -254,11 +255,13 @@ async fn insert_neighbour() { } mod metrics { + use super::*; mod collector { + use super::*; use crate::connectivity::MetricsCollector; use tari_comms::peer_manager::NodeId; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_adds_message_received() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); @@ -273,7 +276,7 @@ mod metrics { assert_eq!(ts.count(), 100); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_clears_the_metrics() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); diff --git a/comms/dht/src/dedup/dedup_cache.rs b/comms/dht/src/dedup/dedup_cache.rs index f8f5f6fcbf..8364f020a0 100644 --- a/comms/dht/src/dedup/dedup_cache.rs +++ b/comms/dht/src/dedup/dedup_cache.rs @@ -24,15 +24,23 @@ use crate::{ schema::dedup_cache, storage::{DbConnection, StorageError}, }; -use chrono::Utc; -use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, QueryDsl, RunQueryDsl}; +use chrono::{NaiveDateTime, Utc}; +use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl}; use log::*; use tari_comms::types::CommsPublicKey; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tari_utilities::hex; +use tari_crypto::tari_utilities::hex::Hex; const LOG_TARGET: &str = "comms::dht::dedup_cache"; +#[derive(Queryable, PartialEq, Debug)] +struct DedupCacheEntry { + body_hash: String, + sender_public_ke: String, + number_of_hit: i32, + stored_at: NaiveDateTime, + last_hit_at: NaiveDateTime, +} + #[derive(Clone)] pub struct DedupCacheDatabase { connection: DbConnection, @@ -48,36 +56,40 @@ impl DedupCacheDatabase { Self { connection, capacity } } - /// Inserts and returns Ok(true) if the item already existed and Ok(false) if it didn't, also updating hit stats - pub async fn insert_body_hash_if_unique( - &self, - body_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { - let body_hash = hex::to_hex(&body_hash.as_bytes()); - let public_key = public_key.to_hex(); - match self - .insert_body_hash_or_update_stats(body_hash.clone(), public_key.clone()) - .await - { - Ok(val) => { - if val == 0 { - warn!( - target: LOG_TARGET, - "Unable to insert new entry into message dedup cache" - ); - } - Ok(false) - }, - Err(e) => match e { - StorageError::UniqueViolation(_) => Ok(true), - _ => Err(e), - }, + /// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body + /// hash + pub async fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { + let hit_count = self + .insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex()) + .await?; + + if hit_count == 0 { + warn!( + target: LOG_TARGET, + "Unable to insert new entry into message dedup cache" + ); } + Ok(hit_count) + } + + pub async fn get_hit_count(&self, body_hash: Vec) -> Result { + let hit_count = self + .connection + .with_connection_async(move |conn| { + dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash.to_hex())) + .get_result::(conn) + .optional() + .map_err(Into::into) + }) + .await?; + + Ok(hit_count.unwrap_or(0) as u32) } /// Trims the dedup cache to the configured limit by removing the oldest entries - pub async fn truncate(&self) -> Result { + pub async fn trim_entries(&self) -> Result { let capacity = self.capacity; self.connection .with_connection_async(move |conn| { @@ -109,40 +121,46 @@ impl DedupCacheDatabase { .await } - // Insert new row into the table or update existing row in an atomic fashion; more than one thread can access this - // table at the same time. + /// Insert new row into the table or updates an existing row. Returns the number of hits for this body hash. async fn insert_body_hash_or_update_stats( &self, body_hash: String, public_key: String, - ) -> Result { + ) -> Result { self.connection .with_connection_async(move |conn| { let insert_result = diesel::insert_into(dedup_cache::table) .values(( - dedup_cache::body_hash.eq(body_hash.clone()), - dedup_cache::sender_public_key.eq(public_key.clone()), + dedup_cache::body_hash.eq(&body_hash), + dedup_cache::sender_public_key.eq(&public_key), dedup_cache::number_of_hits.eq(1), dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), )) .execute(conn); match insert_result { - Ok(val) => Ok(val), + Ok(1) => Ok(1), + Ok(n) => Err(StorageError::UnexpectedResult(format!( + "Expected exactly one row to be inserted. Got {}", + n + ))), Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { DatabaseErrorKind::UniqueViolation => { // Update hit stats for the message - let result = - diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) - .set(( - dedup_cache::sender_public_key.eq(public_key), - dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn); - match result { - Ok(_) => Err(StorageError::UniqueViolation(body_hash)), - Err(e) => Err(e.into()), - } + diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) + .set(( + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(conn)?; + // TODO: Diesel support for RETURNING statements would remove this query, but is not + // available for Diesel + SQLite yet + let hits = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash)) + .get_result::(conn)?; + + Ok(hits as u32) }, _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), }, diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 5428277af0..8bea19f39b 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -47,13 +47,15 @@ fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { pub struct DedupMiddleware { next_service: S, dht_requester: DhtRequester, + allowed_message_occurrences: usize, } impl DedupMiddleware { - pub fn new(service: S, dht_requester: DhtRequester) -> Self { + pub fn new(service: S, dht_requester: DhtRequester, allowed_message_occurrences: usize) -> Self { Self { next_service: service, dht_requester, + allowed_message_occurrences, } } } @@ -71,9 +73,10 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); + let allowed_message_occurrences = self.allowed_message_occurrences; Box::pin(async move { let hash = hash_inbound_message(&message); trace!( @@ -83,14 +86,17 @@ where message.tag, message.dht_header.message_tag ); - if dht_requester - .insert_message_hash(hash, message.source_peer.public_key.clone()) - .await? - { + + message.dedup_hit_count = dht_requester + .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .await?; + + if message.dedup_hit_count as usize > allowed_message_occurrences { trace!( target: LOG_TARGET, - "Received duplicate message {} from peer '{}' (Trace: {}). Message discarded.", + "Received duplicate message {} (hit_count = {}) from peer '{}' (Trace: {}). Message discarded.", message.tag, + message.dedup_hit_count, message.source_peer.node_id.short_str(), message.dht_header.message_tag, ); @@ -99,8 +105,9 @@ where trace!( target: LOG_TARGET, - "Passing message {} onto next service (Trace: {})", + "Passing message {} (hit_count = {}) onto next service (Trace: {})", message.tag, + message.dedup_hit_count, message.dht_header.message_tag ); next_service.oneshot(message).await @@ -110,11 +117,15 @@ where pub struct DedupLayer { dht_requester: DhtRequester, + allowed_message_occurrences: usize, } impl DedupLayer { - pub fn new(dht_requester: DhtRequester) -> Self { - Self { dht_requester } + pub fn new(dht_requester: DhtRequester, allowed_message_occurrences: usize) -> Self { + Self { + dht_requester, + allowed_message_occurrences, + } } } @@ -122,7 +133,7 @@ impl Layer for DedupLayer { type Service = DedupMiddleware; fn layer(&self, service: S) -> Self::Service { - DedupMiddleware::new(service, self.dht_requester.clone()) + DedupMiddleware::new(service, self.dht_requester.clone(), self.allowed_message_occurrences) } } @@ -138,15 +149,15 @@ mod test { #[test] fn process_message() { - let mut rt = Runtime::new().unwrap(); + let rt = Runtime::new().unwrap(); let spy = service_spy(); let (dht_requester, mock) = create_dht_actor_mock(1); let mock_state = mock.get_shared_state(); - mock_state.set_signature_cache_insert(false); + mock_state.set_number_of_message_hits(1); rt.spawn(mock.run()); - let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); + let mut dedup = DedupLayer::new(dht_requester, 3).layer(spy.to_service::()); panic_context!(cx); @@ -157,7 +168,7 @@ mod test { rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - mock_state.set_signature_cache_insert(true); + mock_state.set_number_of_message_hits(4); rt.block_on(dedup.call(msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index dcdeea5730..9d29a70d79 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -26,6 +26,7 @@ use crate::{ connectivity::{DhtConnectivity, MetricsCollector, MetricsCollectorHandle}, discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester, DhtDiscoveryService}, event::{DhtEventReceiver, DhtEventSender}, + filter, inbound, inbound::{DecryptedDhtMessage, DhtInboundMessage, MetricsLayer}, logging_middleware::MessageLoggingLayer, @@ -37,12 +38,11 @@ use crate::{ storage::{DbConnection, StorageError}, store_forward, store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, - tower_filter, DedupLayer, DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, future, Future}; +use futures::Future; use log::*; use std::sync::Arc; use tari_comms::{ @@ -53,7 +53,7 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::{layer::Layer, Service, ServiceBuilder}; const LOG_TARGET: &str = "comms::dht"; @@ -285,13 +285,14 @@ impl Dht { S: Service + Clone + Send + Sync + 'static, S::Future: Send, { - // FIXME: There is an unresolved stack overflow issue on windows in debug mode during runtime, but not in - // release mode, related to the amount of layers. (issue #1416) ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) - .layer(DedupLayer::new(self.dht_requester())) - .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(DedupLayer::new( + self.dht_requester(), + self.config.dedup_allowed_message_occurrences, + )) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() @@ -301,6 +302,7 @@ impl Dht { self.node_identity.clone(), self.connectivity.clone(), )) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), @@ -363,34 +365,60 @@ impl Dht { /// Produces a filter predicate which disallows store and forward messages if that feature is not /// supported by the node. - fn unsupported_saf_messages_filter( - &self, - ) -> impl tower_filter::Predicate>> + Clone + Send - { + fn unsupported_saf_messages_filter(&self) -> impl filter::Predicate + Clone + Send { let node_identity = Arc::clone(&self.node_identity); move |msg: &DhtInboundMessage| { if node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) { - return future::ready(Ok(())); + return true; } match msg.dht_header.message_type { DhtMessageType::SafRequestMessages => { // TODO: #banheuristic This is an indication of node misbehaviour - debug!( + warn!( "Received store and forward message from PublicKey={}. Store and forward feature is not \ supported by this node. Discarding message.", msg.source_peer.public_key ); - future::ready(Err(anyhow::anyhow!( - "Message filtered out because store and forward is not supported by this node", - ))) + false }, - _ => future::ready(Ok(())), + _ => true, } } } } +/// Provides the gossip filtering rules for an inbound message +fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool { + // Let the message through if: + // it isn't a duplicate (normal message), or + let should_continue = !msg.is_duplicate() || + ( + // it is a duplicate domain message (i.e. not DHT or SAF protocol message), and + msg.dht_header.message_type.is_domain_message() && + // it has an unknown destination (e.g complete transactions, blocks, misc. encrypted + // messages) we allow it to proceed, which in turn, re-propagates it for another round. + msg.dht_header.destination.is_unknown() + ); + + if should_continue { + // The message has been forwarded, but downstream middleware may be interested + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Passing message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + true + } else { + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Discarding duplicate message {}", msg + ); + false + } +} + #[cfg(test)] mod test { use crate::{ @@ -404,22 +432,23 @@ mod test { make_comms_inbound_message, make_dht_envelope, make_node_identity, + service_spy, }, DhtBuilder, }; - use futures::{channel::mpsc, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_comms::{ message::{MessageExt, MessageTag}, pipeline::SinkService, + runtime, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body, }; use tari_shutdown::Shutdown; - use tokio::{task, time}; + use tokio::{sync::mpsc, task, time}; use tower::{layer::Layer, Service}; - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_unencrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -459,7 +488,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -469,7 +498,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_encrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -509,7 +538,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -519,7 +548,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_forward() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -528,7 +557,6 @@ mod test { peer_manager.add_peer(node_identity.to_peer()).await.unwrap(); let (connectivity, _) = create_connectivity_mock(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); let (oms_requester, oms_mock) = create_outbound_service_mock(1); // Send all outbound requests to the mock @@ -545,7 +573,8 @@ mod test { let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); @@ -574,10 +603,10 @@ mod test { assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called - assert!(next_service_rx.try_next().is_err()); + assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_filter_saf_message() { let node_identity = make_client_identity(); let peer_manager = build_peer_manager(); @@ -600,9 +629,8 @@ mod test { .await .unwrap(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); - - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"secret".to_vec()); let mut dht_envelope = make_dht_envelope( @@ -619,10 +647,6 @@ mod test { let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); - // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely - assert_eq!( - format!("{}", next_service_rx.try_next().unwrap_err()), - "receiver channel is empty" - ); + assert_eq!(spy.call_count(), 0); } } diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index cf98e42d9d..c2a77f0c9a 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::outbound::{message::SendFailure, DhtOutboundError}; -use futures::channel::mpsc::SendError; use tari_comms::peer_manager::PeerManagerError; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtDiscoveryError { @@ -37,8 +37,6 @@ pub enum DhtDiscoveryError { InvalidNodeId, #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("The discovery request timed out")] DiscoveryTimeout, #[error("Failed to send discovery message: {0}")] @@ -56,14 +54,8 @@ impl DhtDiscoveryError { } } -impl From for DhtDiscoveryError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtDiscoveryError::ChannelDisconnected - } else if err.is_full() { - DhtDiscoveryError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtDiscoveryError { + fn from(_: SendError) -> Self { + DhtDiscoveryError::ChannelDisconnected } } diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index a286bcc6a5..a7317f79a2 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -21,16 +21,15 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; #[derive(Debug)] pub enum DhtDiscoveryRequest { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 258eb67cea..2cebf571b4 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -27,11 +27,6 @@ use crate::{ proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, }; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -47,7 +42,11 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use tari_utilities::{hex::Hex, ByteArray}; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -72,8 +71,8 @@ pub struct DhtDiscoveryService { node_identity: Arc, outbound_requester: OutboundMessageRequester, peer_manager: Arc, - request_rx: Option>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, inflight_discoveries: HashMap, } @@ -91,8 +90,8 @@ impl DhtDiscoveryService { outbound_requester, node_identity, peer_manager, - shutdown_signal: Some(shutdown_signal), - request_rx: Some(request_rx), + shutdown_signal, + request_rx, inflight_discoveries: HashMap::new(), } } @@ -106,29 +105,19 @@ impl DhtDiscoveryService { pub async fn run(mut self) { info!(target: LOG_TARGET, "Dht discovery service started"); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DiscoveryService initialized without shutdown_signal") - .fuse(); - - let mut request_rx = self - .request_rx - .take() - .expect("DiscoveryService initialized without request_rx") - .fuse(); - loop { - futures::select! { - request = request_rx.select_next_some() => { - trace!(target: LOG_TARGET, "Received request '{}'", request); - self.handle_request(request).await; - }, + tokio::select! { + biased; - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Discovery service is shutting down because the shutdown signal was received"); break; } + + Some(request) = self.request_rx.recv() => { + trace!(target: LOG_TARGET, "Received request '{}'", request); + self.handle_request(request).await; + }, } } } @@ -153,7 +142,7 @@ impl DhtDiscoveryService { let mut remaining_requests = HashMap::new(); for (nonce, request) in self.inflight_discoveries.drain() { // Exclude canceled requests - if request.reply_tx.is_canceled() { + if request.reply_tx.is_closed() { continue; } @@ -199,7 +188,7 @@ impl DhtDiscoveryService { ); for request in self.collect_all_discovery_requests(&public_key) { - if !reply_tx.is_canceled() { + if !reply_tx.is_closed() { let _ = request.reply_tx.send(Ok(peer.clone())); } } @@ -299,7 +288,7 @@ impl DhtDiscoveryService { self.inflight_discoveries = self .inflight_discoveries .drain() - .filter(|(_, state)| !state.reply_tx.is_canceled()) + .filter(|(_, state)| !state.reply_tx.is_closed()) .collect(); trace!( @@ -393,9 +382,10 @@ mod test { test_utils::{build_peer_manager, make_node_identity}, }; use std::time::Duration; + use tari_comms::runtime; use tari_shutdown::Shutdown; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_discovery() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/domain_message.rs b/comms/dht/src/domain_message.rs index 2fe7af16fe..f565882725 100644 --- a/comms/dht/src/domain_message.rs +++ b/comms/dht/src/domain_message.rs @@ -33,7 +33,7 @@ impl ToProtoEnum for i32 { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OutboundDomainMessage { inner: T, message_type: i32, diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 0b93546dbb..c8b18d9e9b 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -22,6 +22,8 @@ use bitflags::bitflags; use bytes::Bytes; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost_types::Timestamp; use serde::{Deserialize, Serialize}; use std::{ cmp, @@ -30,14 +32,11 @@ use std::{ fmt::Display, }; use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey, NodeIdentity}; -use tari_utilities::{ByteArray, ByteArrayError}; +use tari_utilities::{epoch_time::EpochTime, ByteArray, ByteArrayError}; use thiserror::Error; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType}; -use chrono::{DateTime, NaiveDateTime, Utc}; -use prost_types::Timestamp; -use tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { @@ -106,8 +105,12 @@ impl DhtMessageFlags { } impl DhtMessageType { + pub fn is_domain_message(self) -> bool { + matches!(self, DhtMessageType::None) + } + pub fn is_dht_message(self) -> bool { - self.is_dht_discovery() || self.is_dht_join() + self.is_dht_discovery() || matches!(self, DhtMessageType::DiscoveryResponse) || self.is_dht_join() } pub fn is_dht_discovery(self) -> bool { diff --git a/comms/dht/src/tower_filter/future.rs b/comms/dht/src/filter/future.rs similarity index 66% rename from comms/dht/src/tower_filter/future.rs rename to comms/dht/src/filter/future.rs index 78b2c613e6..4559aeaadf 100644 --- a/comms/dht/src/tower_filter/future.rs +++ b/comms/dht/src/filter/future.rs @@ -13,16 +13,15 @@ use tower::Service; /// Filtered response future #[pin_project] #[derive(Debug)] -pub struct ResponseFuture +pub struct ResponseFuture where S: Service { #[pin] /// Response future state state: State, - #[pin] - /// Predicate future - check: T, + /// Predicate result + check: bool, /// Inner service service: S, @@ -35,12 +34,10 @@ enum State { WaitResponse(#[pin] U), } -impl ResponseFuture -where - F: Future>, - S: Service, +impl ResponseFuture +where S: Service { - pub(crate) fn new(request: Request, check: F, service: S) -> Self { + pub(crate) fn new(request: Request, check: bool, service: S) -> Self { ResponseFuture { state: State::Check(Some(request)), check, @@ -49,10 +46,8 @@ where } } -impl Future for ResponseFuture -where - F: Future>, - S: Service, +impl Future for ResponseFuture +where S: Service { type Output = Result; @@ -66,15 +61,13 @@ where .take() .expect("we either give it back or leave State::Check once we take"); - // Poll predicate - match this.check.as_mut().poll(cx)? { - Poll::Ready(_) => { + match this.check { + true => { let response = this.service.call(request); this.state.set(State::WaitResponse(response)); }, - Poll::Pending => { - this.state.set(State::Check(Some(request))); - return Poll::Pending; + false => { + return Poll::Ready(Ok(())); }, } }, diff --git a/comms/dht/src/tower_filter/layer.rs b/comms/dht/src/filter/layer.rs similarity index 100% rename from comms/dht/src/tower_filter/layer.rs rename to comms/dht/src/filter/layer.rs diff --git a/comms/dht/src/tower_filter/mod.rs b/comms/dht/src/filter/mod.rs similarity index 92% rename from comms/dht/src/tower_filter/mod.rs rename to comms/dht/src/filter/mod.rs index d1df2f27a7..e7f168161b 100644 --- a/comms/dht/src/tower_filter/mod.rs +++ b/comms/dht/src/filter/mod.rs @@ -33,11 +33,11 @@ impl Filter { impl Service for Filter where - T: Service + Clone, + T: Service + Clone, U: Predicate, { type Error = PipelineError; - type Future = ResponseFuture; + type Future = ResponseFuture; type Response = T::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/comms/dht/src/filter/predicate.rs b/comms/dht/src/filter/predicate.rs new file mode 100644 index 0000000000..024dee826d --- /dev/null +++ b/comms/dht/src/filter/predicate.rs @@ -0,0 +1,13 @@ +/// Checks a request +pub trait Predicate { + /// Check whether the given request should be forwarded. + fn check(&mut self, request: &Request) -> bool; +} + +impl Predicate for F +where F: Fn(&T) -> bool +{ + fn check(&mut self, request: &T) -> bool { + self(request) + } +} diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 9e3a6bbd3e..46fc3daa47 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -397,7 +397,12 @@ mod test { }; use futures::{executor::block_on, future}; use std::sync::Mutex; - use tari_comms::{message::MessageExt, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body}; + use tari_comms::{ + message::MessageExt, + runtime, + test_utils::mocks::create_connectivity_mock, + wrap_in_envelope_body, + }; use tari_test_utils::{counter_context, unpack_enum}; use tower::service_fn; @@ -469,7 +474,7 @@ mod test { assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index b28a057cb5..a73c3b3cfb 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -137,9 +137,12 @@ mod test { service_spy, }, }; - use tari_comms::message::{MessageExt, MessageTag}; + use tari_comms::{ + message::{MessageExt, MessageTag}, + runtime, + }; - #[tokio_macros::test_basic] + #[runtime::test] async fn deserialize() { let spy = service_spy(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index f45507a905..ec42bbd4fe 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -88,6 +88,20 @@ where S: Service return Ok(()); } + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + to next service (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + trace!( target: LOG_TARGET, "Received DHT message type `{}` (Source peer: {}, Tag: {}, Trace: {})", diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a49ae4b073..c9cdd103fd 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -43,6 +43,7 @@ pub struct DhtInboundMessage { pub dht_header: DhtMessageHeader, /// True if forwarded via store and forward, otherwise false pub is_saf_message: bool, + pub dedup_hit_count: u32, pub body: Vec, } impl DhtInboundMessage { @@ -53,6 +54,7 @@ impl DhtInboundMessage { dht_header, source_peer, is_saf_message: false, + dedup_hit_count: 0, body, } } @@ -62,11 +64,12 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHit Count: {}\nHeader: {}\n{}\n----", self.body.len(), self.dht_header.message_type, self.source_peer, self.dht_header, + self.dedup_hit_count, self.tag, ) } @@ -86,6 +89,14 @@ pub struct DecryptedDhtMessage { pub is_saf_stored: Option, pub is_already_forwarded: bool, pub decryption_result: Result>, + pub dedup_hit_count: u32, +} + +impl DecryptedDhtMessage { + /// Returns true if this message has been received before, otherwise false if this is the first time + pub fn is_duplicate(&self) -> bool { + self.dedup_hit_count > 1 + } } impl DecryptedDhtMessage { @@ -104,6 +115,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Ok(message_body), + dedup_hit_count: message.dedup_hit_count, } } @@ -118,6 +130,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Err(message.body), + dedup_hit_count: message.dedup_hit_count, } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index cab2f8ab6f..710b354f7b 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -71,7 +71,7 @@ //! #use std::sync::Arc; //! #use tari_comms::CommsBuilder; //! #use tokio::runtime::Runtime; -//! #use futures::channel::mpsc; +//! #use tokio::sync::mpsc; //! //! let runtime = Runtime::new().unwrap(); //! // Channel from comms to inbound dht @@ -153,11 +153,11 @@ pub use storage::DbConnectionUrl; mod dedup; pub use dedup::DedupLayer; +mod filter; mod logging_middleware; mod proto; mod rpc; mod schema; -mod tower_filter; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/network_discovery/on_connect.rs b/comms/dht/src/network_discovery/on_connect.rs index b93657f061..cd162b3903 100644 --- a/comms/dht/src/network_discovery/on_connect.rs +++ b/comms/dht/src/network_discovery/on_connect.rs @@ -33,7 +33,7 @@ use crate::{ }; use futures::StreamExt; use log::*; -use std::{convert::TryInto, ops::Deref}; +use std::convert::TryInto; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{NodeId, Peer}, @@ -62,8 +62,9 @@ impl OnConnect { pub async fn next_event(&mut self) -> StateEvent { let mut connectivity_events = self.context.connectivity.get_event_subscription(); - while let Some(event) = connectivity_events.next().await { - match event.as_ref().map(|e| e.deref()) { + loop { + let event = connectivity_events.recv().await; + match event { Ok(ConnectivityEvent::PeerConnected(conn)) => { if conn.peer_features().is_client() { continue; @@ -96,10 +97,10 @@ impl OnConnect { self.prev_synced.push(conn.peer_node_id().clone()); }, Ok(_) => { /* Nothing to do */ }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagged behind on {} connectivity event(s)", n) }, - Err(broadcast::RecvError::Closed) => { + Err(broadcast::error::RecvError::Closed) => { break; }, } diff --git a/comms/dht/src/network_discovery/test.rs b/comms/dht/src/network_discovery/test.rs index 54f596ee26..2f854627f1 100644 --- a/comms/dht/src/network_discovery/test.rs +++ b/comms/dht/src/network_discovery/test.rs @@ -28,12 +28,12 @@ use crate::{ test_utils::{build_peer_manager, make_node_identity}, DhtConfig, }; -use futures::StreamExt; use std::{iter, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityStatus, peer_manager::{Peer, PeerFeatures}, protocol::rpc::{mock::MockRpcServer, NamedProtocolService}, + runtime, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -97,7 +97,7 @@ mod state_machine { ) } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn it_fetches_peers() { const NUM_PEERS: usize = 3; @@ -139,7 +139,7 @@ mod state_machine { mock.get_peers.set_response(Ok(peers)).await; discovery_actor.spawn(); - let event = event_rx.next().await.unwrap().unwrap(); + let event = event_rx.recv().await.unwrap(); unpack_enum!(DhtEvent::NetworkDiscoveryPeersAdded(info) = &*event); assert!(info.has_new_neighbours()); assert_eq!(info.num_new_neighbours, NUM_PEERS); @@ -149,11 +149,11 @@ mod state_machine { assert_eq!(info.sync_peers, vec![peer_node_identity.node_id().clone()]); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_shuts_down() { let (discovery, _, _, _, _, mut shutdown) = setup(Default::default(), make_node_identity(), vec![]).await; - shutdown.trigger().unwrap(); + shutdown.trigger(); tokio::time::timeout(Duration::from_secs(5), discovery.run()) .await .unwrap(); @@ -200,7 +200,7 @@ mod discovery_ready { (node_identity, peer_manager, connectivity_mock, ready, context) } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_begins_aggressive_discovery() { let (_, pm, _, mut ready, _) = setup(Default::default()); let peers = build_many_node_identities(1, PeerFeatures::COMMUNICATION_NODE); @@ -212,14 +212,14 @@ mod discovery_ready { assert!(params.num_peers_to_request.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_no_sync_peers() { let (_, _, _, mut ready, _) = setup(Default::default()); let state_event = ready.next_event().await; unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_num_rounds_reached() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, @@ -240,7 +240,7 @@ mod discovery_ready { unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_transitions_to_on_connect() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, diff --git a/comms/dht/src/network_discovery/waiting.rs b/comms/dht/src/network_discovery/waiting.rs index 73e8929ac5..f61dfc6b24 100644 --- a/comms/dht/src/network_discovery/waiting.rs +++ b/comms/dht/src/network_discovery/waiting.rs @@ -46,7 +46,7 @@ impl Waiting { target: LOG_TARGET, "Network discovery is IDLING for {:.0?}", self.duration ); - time::delay_for(self.duration).await; + time::sleep(self.duration).await; debug!(target: LOG_TARGET, "Network discovery resuming"); StateEvent::Ready } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 0aa9fab611..57c423df55 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -39,7 +39,6 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use digest::Digest; use futures::{ - channel::oneshot, future, future::BoxFuture, stream::{self, StreamExt}, @@ -60,6 +59,7 @@ use tari_crypto::{ tari_utilities::{message_format::MessageFormat, ByteArray}, }; use tari_utilities::hex::Hex; +use tokio::sync::oneshot; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::outbound::broadcast_middleware"; @@ -251,11 +251,12 @@ where S: Service is_discovery_enabled, force_origin, dht_header, + tag, } = params; match self.select_peers(broadcast_strategy.clone()).await { Ok(mut peers) => { - if reply_tx.is_canceled() { + if reply_tx.is_closed() { return Err(DhtOutboundError::ReplyChannelCanceled); } @@ -320,6 +321,7 @@ where S: Service is_broadcast, body, Some(expires), + tag, ) .await { @@ -411,6 +413,7 @@ where S: Service is_broadcast: bool, body: Bytes, expires: Option>, + tag: Option, ) -> Result<(Vec, Vec), DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; @@ -424,7 +427,7 @@ where S: Service // Construct a DhtOutboundMessage for each recipient let messages = selected_peers.into_iter().map(|node_id| { let (reply_tx, reply_rx) = oneshot::channel(); - let tag = MessageTag::new(); + let tag = tag.unwrap_or_else(MessageTag::new); let send_state = MessageSendState::new(tag, reply_rx); ( DhtOutboundMessage { @@ -448,7 +451,7 @@ where S: Service Ok(messages.unzip()) } - async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result { + async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> { let hash = Challenge::new().chain(&body).finalize().to_vec(); trace!( target: LOG_TARGET, @@ -456,10 +459,19 @@ where S: Service hash.to_hex(), ); - self.dht_requester - .insert_message_hash(hash, public_key) + // Do not count messages we've broadcast towards the total hit count + let hit_count = self + .dht_requester + .get_message_cache_hit_count(hash.clone()) .await - .map_err(|_| DhtOutboundError::FailedToInsertMessageHash) + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + if hit_count == 0 { + self.dht_requester + .add_message_to_dedup_cache(hash, public_key) + .await + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + } + Ok(()) } fn process_encryption( @@ -525,19 +537,19 @@ mod test { DhtDiscoveryMockState, }, }; - use futures::channel::oneshot; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, + runtime, types::CommsPublicKey, }; use tari_crypto::keys::PublicKey; use tari_test_utils::unpack_enum; - use tokio::task; + use tokio::{sync::oneshot, task}; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_flood() { let pk = CommsPublicKey::default(); let example_peer = Peer::new( @@ -601,7 +613,7 @@ mod test { assert!(requests.iter().any(|msg| msg.destination_node_id == other_peer.node_id)); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_not_found() { // Test for issue https://github.com/tari-project/tari/issues/959 @@ -645,7 +657,7 @@ mod test { assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_dht_discovery() { let node_identity = NodeIdentity::random( &mut OsRng, diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 3f93dab043..b919e45134 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -20,16 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::outbound::message::SendFailure; -use futures::channel::mpsc::SendError; +use crate::outbound::{message::SendFailure, DhtOutboundRequest}; use tari_comms::message::MessageError; use tari_crypto::{signatures::SchnorrSignatureError, tari_utilities::message_format::MessageFormatError}; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtOutboundError { - #[error("SendError: {0}")] - SendError(#[from] SendError), + #[error("`Failed to send: {0}")] + SendError(#[from] SendError), #[error("MessageSerializationError: {0}")] MessageSerializationError(#[from] MessageError), #[error("MessageFormatError: {0}")] @@ -48,8 +48,8 @@ pub enum DhtOutboundError { SendToOurselves, #[error("Discovery process failed")] DiscoveryFailed, - #[error("Failed to insert message hash")] - FailedToInsertMessageHash, + #[error("Failed to insert message hash: {0}")] + FailedToInsertMessageHash(String), #[error("Failed to send message: {0}")] SendMessageFailed(SendFailure), #[error("No messages were queued for sending")] diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 52356ec364..bb782dc2e5 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -25,7 +25,6 @@ use crate::{ outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, }; use bytes::Bytes; -use futures::channel::oneshot; use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ message::{MessageTag, MessagingReplyTx}, @@ -34,6 +33,7 @@ use tari_comms::{ }; use tari_utilities::hex::Hex; use thiserror::Error; +use tokio::sync::oneshot; /// Determines if an outbound message should be Encrypted and, if so, for which public key #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 0ad00bbc4e..3b38272c38 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -27,7 +27,7 @@ use crate::{ proto::envelope::DhtMessageType, }; use std::{fmt, fmt::Display}; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey}; /// Configuration for outbound messages. /// @@ -66,6 +66,7 @@ pub struct FinalSendMessageParams { pub dht_message_type: DhtMessageType, pub dht_message_flags: DhtMessageFlags, pub dht_header: Option, + pub tag: Option, } impl Default for FinalSendMessageParams { @@ -79,6 +80,7 @@ impl Default for FinalSendMessageParams { force_origin: false, is_discovery_enabled: false, dht_header: None, + tag: None, } } } @@ -171,6 +173,12 @@ impl SendMessageParams { self } + /// Set the message trace tag + pub fn with_tag(&mut self, tag: MessageTag) -> &mut Self { + self.params_mut().tag = Some(tag); + self + } + /// Set destination field in message header. pub fn with_destination(&mut self, destination: NodeDestination) -> &mut Self { self.params_mut().destination = destination; diff --git a/comms/dht/src/outbound/message_send_state.rs b/comms/dht/src/outbound/message_send_state.rs index 1576e87c70..b3ca43fcb2 100644 --- a/comms/dht/src/outbound/message_send_state.rs +++ b/comms/dht/src/outbound/message_send_state.rs @@ -250,9 +250,9 @@ impl Index for MessageSendStates { #[cfg(test)] mod test { use super::*; - use futures::channel::oneshot; use std::iter::repeat_with; - use tari_comms::message::MessagingReplyTx; + use tari_comms::{message::MessagingReplyTx, runtime}; + use tokio::sync::oneshot; fn create_send_state() -> (MessageSendState, MessagingReplyTx) { let (reply_tx, reply_rx) = oneshot::channel(); @@ -269,7 +269,7 @@ mod test { assert!(!states.is_empty()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn wait_single() { let (state, mut reply_tx) = create_send_state(); let states = MessageSendStates::from(vec![state]); @@ -284,7 +284,7 @@ mod test { assert!(!states.wait_single().await); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_percentage_success() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -300,7 +300,7 @@ mod test { assert_eq!(failed.len(), 4); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_n_timeout() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -329,7 +329,7 @@ mod test { assert_eq!(failed.len(), 6); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_all() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index f5c3f30665..66e2b7258e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -31,11 +31,6 @@ use crate::{ }, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - StreamExt, -}; use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -45,7 +40,10 @@ use tari_comms::{ message::{MessageTag, MessagingReplyTx}, protocol::messaging::SendFailReason, }; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "mock::outbound_requester"; @@ -54,7 +52,7 @@ const LOG_TARGET: &str = "mock::outbound_requester"; /// Each time a request is expected, handle_next should be called. pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, OutboundServiceMock) { let (tx, rx) = mpsc::channel(size); - (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx.fuse())) + (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx)) } #[derive(Clone, Default)] @@ -149,12 +147,12 @@ impl OutboundServiceMockState { } pub struct OutboundServiceMock { - receiver: Fuse>, + receiver: mpsc::Receiver, mock_state: OutboundServiceMockState, } impl OutboundServiceMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, mock_state: OutboundServiceMockState::new(), @@ -166,7 +164,7 @@ impl OutboundServiceMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { let behaviour = self.mock_state.get_behaviour(); @@ -192,7 +190,7 @@ impl OutboundServiceMock { ResponseType::QueuedSuccessDelay(delay) => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); - delay_for(delay).await; + sleep(delay).await; inner_reply_tx.reply_success(); }, resp => { diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index c1957a98b0..f8536c9d9c 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -32,12 +32,9 @@ use crate::{ MessageSendStates, }, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use log::*; use tari_comms::{message::MessageExt, peer_manager::NodeId, types::CommsPublicKey, wrap_in_envelope_body}; +use tokio::sync::{mpsc, oneshot}; const LOG_TARGET: &str = "comms::dht::requests::outbound"; diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 97c7c0df58..195d2d3d39 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -137,9 +137,9 @@ mod test { use super::*; use crate::test_utils::{assert_send_static_service, create_outbound_message, service_spy}; use prost::Message; - use tari_comms::peer_manager::NodeId; + use tari_comms::{peer_manager::NodeId, runtime}; - #[tokio_macros::test_basic] + #[runtime::test] async fn serialize() { let spy = service_spy(); let mut serialize = SerializeLayer.layer(spy.to_service::()); diff --git a/comms/dht/src/rpc/service.rs b/comms/dht/src/rpc/service.rs index e762ed4d79..84aef6e5ff 100644 --- a/comms/dht/src/rpc/service.rs +++ b/comms/dht/src/rpc/service.rs @@ -24,16 +24,16 @@ use crate::{ proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}, rpc::DhtRpcService, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc}; use tari_comms::{ peer_manager::{NodeId, Peer, PeerFeatures, PeerQuery}, protocol::rpc::{Request, RpcError, RpcStatus, Streaming}, + utils, PeerManager, }; use tari_utilities::ByteArray; -use tokio::task; +use tokio::{sync::mpsc, task}; const LOG_TARGET: &str = "comms::dht::rpc"; @@ -56,17 +56,15 @@ impl DhtRpcServiceImpl { // A maximum buffer size of 10 is selected arbitrarily and is to allow the producer/consumer some room to // buffer. - let (mut tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); + let (tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); task::spawn(async move { let iter = peers .into_iter() .map(|peer| GetPeersResponse { peer: Some(peer.into()), }) - .map(Ok) .map(Ok); - let mut stream = stream::iter(iter); - let _ = tx.send_all(&mut stream).await; + let _ = utils::mpsc::send_all(&tx, iter).await; }); Streaming::new(rx) diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index 764d49ba7f..cd70d65a0f 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -26,13 +26,16 @@ use crate::{ test_utils::build_peer_manager, }; use futures::StreamExt; -use std::{convert::TryInto, sync::Arc}; +use std::{convert::TryInto, sync::Arc, time::Duration}; use tari_comms::{ - peer_manager::{node_id::NodeDistance, PeerFeatures}, + peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, protocol::rpc::{mock::RpcRequestMock, RpcStatusCode}, + runtime, test_utils::node_identity::{build_node_identity, ordered_node_identities_by_distance}, PeerManager, }; +use tari_test_utils::collect_recv; +use tari_utilities::ByteArray; fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { let peer_manager = build_peer_manager(); @@ -45,10 +48,8 @@ fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { // Unit tests for get_closer_peers request mod get_closer_peers { use super::*; - use tari_comms::peer_manager::{NodeId, Peer}; - use tari_utilities::ByteArray; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -66,7 +67,7 @@ mod get_closer_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_closest_peers() { let (service, mock, peer_manager) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -83,7 +84,7 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 10); let peers = results @@ -101,7 +102,7 @@ mod get_closer_peers { } } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); @@ -123,7 +124,7 @@ mod get_closer_peers { assert_eq!(results.len(), 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_skips_excluded_peers() { let (service, mock, peer_manager) = setup(); @@ -142,12 +143,12 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); let mut peers = results.into_iter().map(Result::unwrap).map(|r| r.peer.unwrap()); assert!(peers.all(|p| p.public_key != excluded_peer.public_key().as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_errors_if_maximum_n_exceeded() { let (service, mock, _) = setup(); let req = GetCloserPeersRequest { @@ -165,9 +166,10 @@ mod get_closer_peers { mod get_peers { use super::*; use crate::proto::rpc::GetPeersRequest; + use std::time::Duration; use tari_comms::{peer_manager::Peer, test_utils::node_identity::build_many_node_identities}; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -183,7 +185,7 @@ mod get_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_all_peers() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -200,7 +202,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 5); let peers = results @@ -214,7 +216,7 @@ mod get_peers { assert_eq!(peers.iter().filter(|p| p.features.is_node()).count(), 3); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_excludes_clients() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -231,7 +233,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 3); let peers = results @@ -244,7 +246,7 @@ mod get_peers { assert!(peers.iter().all(|p| p.features.is_node())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs index 856a94315a..ee99f8b560 100644 --- a/comms/dht/src/storage/connection.rs +++ b/comms/dht/src/storage/connection.rs @@ -123,16 +123,17 @@ impl DbConnection { mod test { use super::*; use diesel::{expression::sql_literal::sql, sql_types::Integer, RunQueryDsl}; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn connect_and_migrate() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); let output = conn.migrate().await.unwrap(); assert!(output.starts_with("Running migration")); } - #[tokio_macros::test_basic] + #[runtime::test] async fn memory_connections() { let id = random::string(8); let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs index ab9f52f78d..f5bf4f0596 100644 --- a/comms/dht/src/storage/error.rs +++ b/comms/dht/src/storage/error.rs @@ -40,4 +40,6 @@ pub enum StorageError { ResultError(#[from] diesel::result::Error), #[error("MessageFormatError: {0}")] MessageFormatError(#[from] MessageFormatError), + #[error("Unexpected result: {0}")] + UnexpectedResult(String), } diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 173d00e0ef..58ee06eb9c 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -255,9 +255,10 @@ impl StoreAndForwardDatabase { #[cfg(test)] mod test { use super::*; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -277,7 +278,7 @@ mod test { assert_eq!(messages[1].body_hash, msg2.body_hash); } - #[tokio_macros::test_basic] + #[runtime::test] async fn remove_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -304,7 +305,7 @@ mod test { assert_eq!(messages[0].id, msg2_id); } - #[tokio_macros::test_basic] + #[runtime::test] async fn truncate_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 95ce5e2500..d8de4fe048 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -153,7 +153,7 @@ where S: Service self.forward(&message).await?; } - // The message has been forwarded, but other middleware may be interested (i.e. StoreMiddleware) + // The message has been forwarded, but downstream middleware may be interested trace!( target: LOG_TARGET, "Passing message {} to next service (Trace: {})", @@ -205,8 +205,9 @@ where S: Service } let body = decryption_result - .clone() + .as_ref() .err() + .cloned() .expect("previous check that decryption failed"); let excluded_peers = vec![source_peer.node_id.clone()]; @@ -262,14 +263,13 @@ mod test { outbound::mock::create_outbound_service_mock, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; - use futures::{channel::mpsc, executor::block_on}; - use tari_comms::wrap_in_envelope_body; - use tokio::runtime::Runtime; + use tari_comms::{runtime, runtime::task, wrap_in_envelope_body}; + use tokio::sync::mpsc; - #[test] - fn decryption_succeeded() { + #[runtime::test] + async fn decryption_succeeded() { let spy = service_spy(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (oms_tx, _) = mpsc::channel(1); let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::()); @@ -280,18 +280,16 @@ mod test { Some(node_identity.public_key().clone()), inbound_msg, ); - block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); - assert!(oms_rx.try_next().is_err()); } - #[test] - fn decryption_failed() { - let mut rt = Runtime::new().unwrap(); + #[runtime::test] + async fn decryption_failed() { let spy = service_spy(); let (oms_requester, oms_mock) = create_outbound_service_mock(1); let oms_mock_state = oms_mock.get_state(); - rt.spawn(oms_mock.run()); + task::spawn(oms_mock.run()); let mut service = ForwardLayer::new(oms_requester, true).layer(spy.to_service::()); @@ -304,7 +302,7 @@ mod test { ); let header = inbound_msg.dht_header.clone(); let msg = DecryptedDhtMessage::failed(inbound_msg); - rt.block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); assert_eq!(oms_mock_state.call_count(), 1); diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 50b6ab7839..16e2760a1e 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -27,9 +27,9 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::channel::mpsc; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; +use tokio::sync::mpsc; use tower::layer::Layer; pub struct MessageHandlerLayer { diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 578fc1dcbc..641950e4f1 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -28,12 +28,13 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::{channel::mpsc, future::BoxFuture, task::Context}; +use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, pipeline::PipelineError, }; +use tokio::sync::mpsc; use tower::Service; #[derive(Clone)] diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index f3ba852118..c6224a7af7 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -41,7 +41,7 @@ use crate::{ }; use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; -use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; +use futures::{future, stream, StreamExt}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; @@ -53,6 +53,7 @@ use tari_comms::{ utils::signature, }; use tari_utilities::{convert::try_convert_all, ByteArray}; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::storeforward::handler"; @@ -103,6 +104,20 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, @@ -460,7 +475,8 @@ where S: Service public_key: CommsPublicKey, ) -> Result<(), StoreAndForwardError> { let msg_hash = Challenge::new().chain(body).finalize().to_vec(); - if dht_requester.insert_message_hash(msg_hash, public_key).await? { + let hit_count = dht_requester.add_message_to_dedup_cache(msg_hash, public_key).await?; + if hit_count > 1 { Err(StoreAndForwardError::DuplicateMessage) } else { Ok(()) @@ -567,14 +583,13 @@ mod test { }, }; use chrono::{Duration as OldDuration, Utc}; - use futures::channel::mpsc; use prost::Message; use std::time::Duration; - use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex; - use tari_test_utils::collect_stream; + use tari_test_utils::collect_recv; use tari_utilities::hex::Hex; - use tokio::{runtime::Handle, task, time::delay_for}; + use tokio::{runtime::Handle, sync::mpsc, task, time::sleep}; // TODO: unit tests for static functions (check_signature, etc) @@ -602,7 +617,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn request_stored_messages() { let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); @@ -662,7 +677,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); @@ -724,7 +739,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); let call = oms_mock_state.pop_call().unwrap(); @@ -750,7 +765,7 @@ mod test { assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); @@ -845,7 +860,7 @@ mod test { assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - let signals = collect_stream!( + let signals = collect_recv!( saf_response_signal_receiver, take = 1, timeout = Duration::from_secs(20) diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 5d06d85d56..7c8af0b731 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -36,12 +36,6 @@ use crate::{ DhtRequester, }; use chrono::{DateTime, NaiveDateTime, Utc}; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ @@ -51,7 +45,11 @@ use tari_comms::{ PeerManager, }; use tari_shutdown::ShutdownSignal; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::storeforward::actor"; /// The interval to initiate a database cleanup. @@ -167,13 +165,13 @@ pub struct StoreAndForwardService { dht_requester: DhtRequester, database: StoreAndForwardDatabase, peer_manager: Arc, - connection_events: Fuse, + connection_events: ConnectivityEventRx, outbound_requester: OutboundMessageRequester, - request_rx: Fuse>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, num_received_saf_responses: Option, num_online_peers: Option, - saf_response_signal_rx: Fuse>, + saf_response_signal_rx: mpsc::Receiver<()>, event_publisher: DhtEventSender, } @@ -196,13 +194,13 @@ impl StoreAndForwardService { database: StoreAndForwardDatabase::new(conn), peer_manager, dht_requester, - request_rx: request_rx.fuse(), - connection_events: connectivity.get_event_subscription().fuse(), + request_rx, + connection_events: connectivity.get_event_subscription(), outbound_requester, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, num_received_saf_responses: Some(0), num_online_peers: None, - saf_response_signal_rx: saf_response_signal_rx.fuse(), + saf_response_signal_rx, event_publisher, } } @@ -213,20 +211,15 @@ impl StoreAndForwardService { } async fn run(mut self) { - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("StoreAndForwardActor initialized without shutdown_signal"); - - let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - event = self.connection_events.select_next_some() => { + event = self.connection_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -234,20 +227,20 @@ impl StoreAndForwardService { } }, - _ = cleanup_ticker.select_next_some() => { + _ = cleanup_ticker.tick() => { if let Err(err) = self.cleanup().await { error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); } }, - _ = self.saf_response_signal_rx.select_next_some() => { + Some(_) = self.saf_response_signal_rx.recv() => { if let Some(n) = self.num_received_saf_responses { self.num_received_saf_responses = Some(n + 1); self.check_saf_response_threshold(); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); break; } diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 4393f36518..e9b88a37ad 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -122,16 +122,31 @@ where } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - Box::pin( - StoreTask::new( - self.next_service.clone(), - self.config.clone(), - Arc::clone(&self.peer_manager), - Arc::clone(&self.node_identity), - self.saf_requester.clone(), + if msg.is_duplicate() { + trace!( + target: LOG_TARGET, + "Passing duplicate message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + + let service = self.next_service.clone(); + Box::pin(async move { + let service = service.ready_oneshot().await?; + service.oneshot(msg).await + }) + } else { + Box::pin( + StoreTask::new( + self.next_service.clone(), + self.config.clone(), + Arc::clone(&self.peer_manager), + Arc::clone(&self.node_identity), + self.saf_requester.clone(), + ) + .handle(msg), ) - .handle(msg), - ) + } } } @@ -447,11 +462,11 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{runtime, wrap_in_envelope_body}; use tari_test_utils::async_assert_eventually; use tari_utilities::hex::Hex; - #[tokio_macros::test_basic] + #[runtime::test] async fn cleartext_message_no_origin() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -471,7 +486,7 @@ mod test { assert_eq!(messages.len(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_succeeded_no_store() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -499,7 +514,7 @@ mod test { assert_eq!(mock_state.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_should_store() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); @@ -538,7 +553,7 @@ mod test { assert!(duration.num_seconds() <= 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_banned_peer() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index ccc53c5a1e..4cfd99f209 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -25,26 +25,25 @@ use crate::{ actor::{DhtRequest, DhtRequester}, storage::DhtMetadataKey, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, RwLock, }, }; use tari_comms::peer_manager::Peer; -use tokio::task; +use tokio::{sync::mpsc, task}; pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { let (tx, rx) = mpsc::channel(buf_size); - (DhtRequester::new(tx), DhtActorMock::new(rx.fuse())) + (DhtRequester::new(tx), DhtActorMock::new(rx)) } #[derive(Default, Debug, Clone)] pub struct DhtMockState { - signature_cache_insert: Arc, + signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, settings: Arc>>>, @@ -52,16 +51,11 @@ pub struct DhtMockState { impl DhtMockState { pub fn new() -> Self { - Self { - signature_cache_insert: Arc::new(AtomicBool::new(false)), - call_count: Arc::new(AtomicUsize::new(0)), - select_peers: Arc::new(RwLock::new(Vec::new())), - settings: Arc::new(RwLock::new(HashMap::new())), - } + Default::default() } - pub fn set_signature_cache_insert(&self, v: bool) -> &Self { - self.signature_cache_insert.store(v, Ordering::SeqCst); + pub fn set_number_of_message_hits(&self, v: u32) -> &Self { + self.signature_cache_insert.store(v as usize, Ordering::SeqCst); self } @@ -80,12 +74,12 @@ impl DhtMockState { } pub struct DhtActorMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtMockState, } impl DhtActorMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtMockState::default(), @@ -101,7 +95,7 @@ impl DhtActorMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } @@ -111,9 +105,13 @@ impl DhtActorMock { self.state.inc_call_count(); match req { SendJoin => {}, - MsgHashCacheInsert(_, _, reply_tx) => { + MsgHashCacheInsert { reply_tx, .. } => { + let v = self.state.signature_cache_insert.load(Ordering::SeqCst); + reply_tx.send(v as u32).unwrap(); + }, + GetMsgHashHitCount(_, reply_tx) => { let v = self.state.signature_cache_insert.load(Ordering::SeqCst); - reply_tx.send(v).unwrap(); + reply_tx.send(v as u32).unwrap(); }, SelectPeers(_, reply_tx) => { let lock = self.state.select_peers.read().unwrap(); diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index 70575e2ae0..fbdf8e8284 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -24,7 +24,6 @@ use crate::{ discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester}, test_utils::make_peer, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use std::{ sync::{ @@ -35,15 +34,13 @@ use std::{ time::Duration, }; use tari_comms::peer_manager::Peer; +use tokio::sync::mpsc; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_dht_discovery_mock(buf_size: usize, timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { let (tx, rx) = mpsc::channel(buf_size); - ( - DhtDiscoveryRequester::new(tx, timeout), - DhtDiscoveryMock::new(rx.fuse()), - ) + (DhtDiscoveryRequester::new(tx, timeout), DhtDiscoveryMock::new(rx)) } #[derive(Debug, Clone)] @@ -75,12 +72,12 @@ impl DhtDiscoveryMockState { } pub struct DhtDiscoveryMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtDiscoveryMockState, } impl DhtDiscoveryMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtDiscoveryMockState::new(), @@ -92,7 +89,7 @@ impl DhtDiscoveryMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 0dd464c43a..72c0861a6d 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -23,7 +23,6 @@ use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; use chrono::Utc; use digest::Digest; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::sync::{ @@ -32,14 +31,17 @@ use std::sync::{ }; use tari_comms::types::Challenge; use tari_utilities::hex; -use tokio::{runtime, sync::RwLock}; +use tokio::{ + runtime, + sync::{mpsc, RwLock}, +}; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { let (tx, rx) = mpsc::channel(10); - let mock = StoreAndForwardMock::new(rx.fuse()); + let mock = StoreAndForwardMock::new(rx); let state = mock.get_shared_state(); runtime::Handle::current().spawn(mock.run()); (StoreAndForwardRequester::new(tx), state) @@ -90,12 +92,12 @@ impl StoreAndForwardMockState { } pub struct StoreAndForwardMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: StoreAndForwardMockState, } impl StoreAndForwardMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: StoreAndForwardMockState::new(), @@ -107,7 +109,7 @@ impl StoreAndForwardMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/tower_filter/predicate.rs b/comms/dht/src/tower_filter/predicate.rs deleted file mode 100644 index f86b9cc406..0000000000 --- a/comms/dht/src/tower_filter/predicate.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::future::Future; -use tari_comms::pipeline::PipelineError; - -/// Checks a request -pub trait Predicate { - /// The future returned by `check`. - type Future: Future>; - - /// Check whether the given request should be forwarded. - /// - /// If the future resolves with `Ok`, the request is forwarded to the inner service. - fn check(&mut self, request: &Request) -> Self::Future; -} - -impl Predicate for F -where - F: Fn(&T) -> U, - U: Future>, -{ - type Future = U; - - fn check(&mut self, request: &T) -> Self::Future { - self(request) - } -} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index a5aed09970..761bb1badd 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -46,6 +45,7 @@ use tari_comms_dht::{ DbConnectionUrl, Dht, DhtBuilder, + DhtConfig, }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::{ @@ -54,16 +54,20 @@ use tari_storage::{ }; use tari_test_utils::{ async_assert_eventually, - collect_stream, + collect_try_recv, paths::create_temporary_data_path, random, streams, unpack_enum, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc}, + time, +}; use tower::ServiceBuilder; struct TestNode { + name: String, comms: CommsNode, dht: Dht, inbound_messages: mpsc::Receiver, @@ -80,12 +84,16 @@ impl TestNode { self.comms.node_identity().to_peer() } + pub fn name(&self) -> &str { + &self.name + } + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { - time::timeout(timeout, self.inbound_messages.next()).await.ok()? + time::timeout(timeout, self.inbound_messages.recv()).await.ok()? } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -113,24 +121,36 @@ fn create_peer_storage() -> CommsDatabase { LMDBWrapper::new(Arc::new(peer_database)) } -async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { +async fn make_node>( + name: &str, + features: PeerFeatures, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let node_identity = make_node_identity(features); - make_node_with_node_identity(node_identity, seed_peer).await + make_node_with_node_identity(name, node_identity, dht_config, known_peers).await } -async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { +async fn make_node_with_node_identity>( + name: &str, + node_identity: Arc, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let (tx, inbound_messages) = mpsc::channel(10); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = setup_comms_dht( node_identity, create_peer_storage(), tx, - seed_peer.into_iter().collect(), + known_peers.into_iter().collect(), + dht_config, shutdown.to_signal(), ) .await; TestNode { + name: name.to_string(), comms, dht, inbound_messages, @@ -145,6 +165,7 @@ async fn setup_comms_dht( storage: CommsDatabase, inbound_tx: mpsc::Sender, peers: Vec, + dht_config: DhtConfig, shutdown_signal: ShutdownSignal, ) -> (CommsNode, Dht, MessagingEventSender) { // Create inbound and outbound channels @@ -168,11 +189,8 @@ async fn setup_comms_dht( comms.connectivity(), comms.shutdown_signal(), ) - .local_test() - .set_auto_store_and_forward_requests(false) + .with_config(dht_config) .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) - .with_discovery_timeout(Duration::from_secs(60)) - .with_num_neighbouring_nodes(8) .build() .await .unwrap(); @@ -205,17 +223,38 @@ async fn setup_comms_dht( (comms, dht, event_tx) } -#[tokio_macros::test] +fn dht_config() -> DhtConfig { + let mut config = DhtConfig::default_local_test(); + config.allow_test_addresses = true; + config.saf_auto_request = false; + config.discovery_request_timeout = Duration::from_secs(60); + config.num_neighbouring_nodes = 8; + config +} + +#[tokio::test] #[allow(non_snake_case)] async fn dht_join_propagation() { // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other // Node C knows no one - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A .comms @@ -262,19 +301,37 @@ async fn dht_join_propagation() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_discover_propagation() { // Create 4 nodes where A knows B, B knows A and C, C knows B and D, and D knows C // Node D knows no one - let node_D = make_node(PeerFeatures::COMMUNICATION_CLIENT, None).await; + let node_D = make_node("node_D", PeerFeatures::COMMUNICATION_CLIENT, dht_config(), None).await; // Node C knows about Node D - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", node_A.node_identity().node_id().short_str(), @@ -318,14 +375,20 @@ async fn dht_discover_propagation() { assert!(node_D_peer_manager.exists(node_A.node_identity().public_key()).await); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_store_forward() { let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}", node_A.node_identity().node_id().short_str(), @@ -370,9 +433,10 @@ async fn dht_store_forward() { .unwrap(); // Wait for node B to receive 2 propagation messages - collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); + collect_try_recv!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); - let mut node_C = make_node_with_node_identity(node_C_node_identity, Some(node_B.to_peer())).await; + let mut node_C = + make_node_with_node_identity("node_C", node_C_node_identity, dht_config(), Some(node_B.to_peer())).await; let mut node_C_dht_events = node_C.dht.subscribe_dht_events(); let mut node_C_msg_events = node_C.messaging_events.subscribe(); // Ask node B for messages @@ -389,8 +453,8 @@ async fn dht_store_forward() { .await .unwrap(); // Wait for node C to and receive a response from the SAF request - let event = collect_stream!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &*event.get(0).unwrap().as_ref()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -418,25 +482,47 @@ async fn dht_store_forward() { assert!(msgs.is_empty()); // Check that Node C emitted the StoreAndForwardMessagesReceived event when it went Online - let event = collect_stream!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &*event.get(0).unwrap().as_ref()); node_A.shutdown().await; node_B.shutdown().await; node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_dedup() { + let mut config = dht_config(); + // For this test we want to exactly measure the path of a message, so we disable repropagation of messages (i.e + // allow 1 occurrence) + config.dedup_allowed_message_occurrences = 1; // Node D knows no one - let mut node_D = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let mut node_D = make_node("node_D", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; // Node C knows about Node D - let mut node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let mut node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B and C - let mut node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let mut node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", @@ -482,8 +568,7 @@ async fn dht_propagate_dedup() { .dht .outbound_requester() .propagate( - // Node D is a client node, so an destination is required for domain messages - NodeDestination::Unknown, // NodeId(Box::new(node_D.node_identity().node_id().clone())), + NodeDestination::Unknown, OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), vec![], out_msg, @@ -496,6 +581,7 @@ async fn dht_propagate_dedup() { .await .expect("Node D expected an inbound message but it never arrived"); assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); let person = msg .decryption_result .unwrap() @@ -515,35 +601,150 @@ async fn dht_propagate_dedup() { node_D.shutdown().await; // Check the message flow BEFORE deduping - let received = filter_received(collect_stream!(node_A_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_A_messaging, timeout = Duration::from_secs(20))); // Expected race condition: If A->(B|C)->(C|B) before A->(C|B) then (C|B)->A if !received.is_empty() { assert_eq!(count_messages_received(&received, &[&node_B_id, &node_C_id]), 1); } - let received = filter_received(collect_stream!(node_B_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_B_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); // Expected race condition: If A->B->C before A->C then C->B does not happen assert!((1..=2).contains(&recv_count)); - let received = filter_received(collect_stream!(node_C_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); assert_eq!(recv_count, 2); assert_eq!(count_messages_received(&received, &[&node_D_id]), 0); - let received = filter_received(collect_stream!(node_D_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_D_messaging, timeout = Duration::from_secs(20))); assert_eq!(received.len(), 1); assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } -#[tokio_macros::test] +#[tokio::test] +#[allow(non_snake_case)] +async fn dht_repropagate() { + let mut config = dht_config(); + config.dedup_allowed_message_occurrences = 3; + let mut node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, config.clone(), []).await; + let mut node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), [ + node_C.to_peer() + ]) + .await; + let mut node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config, [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_B.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_A.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_B.to_peer()).await.unwrap(); + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + async fn connect_nodes(node1: &mut TestNode, node2: &mut TestNode) { + node1 + .comms + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + } + // Pre-connect nodes, this helps message passing be more deterministic + connect_nodes(&mut node_A, &mut node_B).await; + connect_nodes(&mut node_A, &mut node_C).await; + connect_nodes(&mut node_B, &mut node_C).await; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + let out_msg = OutboundDomainMessage::new(123, Person { + name: "Alan Turing".into(), + age: 41, + }); + node_A + .dht + .outbound_requester() + .propagate( + NodeDestination::Unknown, + OutboundEncryption::ClearText, + vec![], + out_msg.clone(), + ) + .await + .unwrap(); + + async fn receive_and_repropagate(node: &mut TestNode, out_msg: &OutboundDomainMessage) { + let msg = node + .next_inbound_message(Duration::from_secs(10)) + .await + .unwrap_or_else(|| panic!("{} expected an inbound message but it never arrived", node.name())); + log::info!("Received message {}", msg.tag); + + node.dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![]) + .with_destination(NodeDestination::Unknown) + .with_tag(msg.tag) + .finish(), + out_msg.clone(), + ) + .await + .unwrap() + .resolve() + .await + .unwrap(); + } + + // This relies on the DHT being set with .with_dedup_discard_hit_count(3) + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + receive_and_repropagate(&mut node_A, &out_msg).await; + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + receive_and_repropagate(&mut node_A, &out_msg).await; + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; +} + +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() { - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}", @@ -613,10 +814,10 @@ async fn dht_propagate_message_contents_not_malleable_ban() { let node_B_node_id = node_B.node_identity().node_id().clone(); // Node C should ban node B - let banned_node_id = streams::assert_in_stream( + let banned_node_id = streams::assert_in_broadcast( &mut connectivity_events, - |r| match &*r.unwrap() { - ConnectivityEvent::PeerBanned(node_id) => Some(node_id.clone()), + |r| match r { + ConnectivityEvent::PeerBanned(node_id) => Some(node_id), _ => None, }, Duration::from_secs(10), @@ -629,12 +830,9 @@ async fn dht_propagate_message_contents_not_malleable_ban() { node_C.shutdown().await; } -fn filter_received( - events: Vec, tokio::sync::broadcast::RecvError>>, -) -> Vec> { +fn filter_received(events: Vec>) -> Vec> { events .into_iter() - .map(Result::unwrap) .filter(|e| match &**e { MessagingEvent::MessageReceived(_, _) => true, _ => unreachable!(), diff --git a/comms/examples/stress/error.rs b/comms/examples/stress/error.rs index 5cb9be1cb3..e87aae514e 100644 --- a/comms/examples/stress/error.rs +++ b/comms/examples/stress/error.rs @@ -19,10 +19,10 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -use futures::channel::{mpsc::SendError, oneshot}; use std::io; use tari_comms::{ connectivity::ConnectivityError, + message::OutboundMessage, peer_manager::PeerManagerError, tor, CommsBuilderError, @@ -30,7 +30,11 @@ use tari_comms::{ }; use tari_crypto::tari_utilities::message_format::MessageFormatError; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc::error::SendError, oneshot}, + task, + time, +}; #[derive(Debug, Error)] pub enum Error { @@ -48,12 +52,12 @@ pub enum Error { ConnectivityError(#[from] ConnectivityError), #[error("Message format error: {0}")] MessageFormatError(#[from] MessageFormatError), - #[error("Failed to send message")] - SendError(#[from] SendError), + #[error("Failed to send message: {0}")] + SendError(#[from] SendError), #[error("JoinError: {0}")] JoinError(#[from] task::JoinError), #[error("Example did not exit cleanly: `{0}`")] - WaitTimeout(#[from] time::Elapsed), + WaitTimeout(#[from] time::error::Elapsed), #[error("IO error: {0}")] Io(#[from] io::Error), #[error("User quit")] @@ -63,5 +67,5 @@ pub enum Error { #[error("Unexpected EoF")] UnexpectedEof, #[error("Internal reply canceled")] - ReplyCanceled(#[from] oneshot::Canceled), + ReplyCanceled(#[from] oneshot::error::RecvError), } diff --git a/comms/examples/stress/node.rs b/comms/examples/stress/node.rs index 45ad0919ab..d060d18071 100644 --- a/comms/examples/stress/node.rs +++ b/comms/examples/stress/node.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::Error, STRESS_PROTOCOL_NAME, TOR_CONTROL_PORT_ADDR, TOR_SOCKS_ADDR}; -use futures::channel::mpsc; use rand::rngs::OsRng; use std::{convert, net::Ipv4Addr, path::Path, sync::Arc, time::Duration}; use tari_comms::{ @@ -43,7 +42,7 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; pub async fn create( node_identity: Option>, diff --git a/comms/examples/stress/service.rs b/comms/examples/stress/service.rs index 3e262cc38b..45e2bc0fd3 100644 --- a/comms/examples/stress/service.rs +++ b/comms/examples/stress/service.rs @@ -23,15 +23,7 @@ use super::error::Error; use crate::stress::{MAX_FRAME_SIZE, STRESS_PROTOCOL_NAME}; use bytes::{Buf, Bytes, BytesMut}; -use futures::{ - channel::{mpsc, oneshot}, - stream, - stream::Fuse, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::{stream, SinkExt, StreamExt}; use rand::{rngs::OsRng, RngCore}; use std::{ iter::repeat_with, @@ -43,12 +35,19 @@ use tari_comms::{ message::{InboundMessage, OutboundMessage}, peer_manager::{NodeId, Peer}, protocol::{ProtocolEvent, ProtocolNotification}, + utils, CommsNode, PeerConnection, Substream, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, task::JoinHandle, time}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot, RwLock}, + task, + task::JoinHandle, + time, +}; pub fn start_service( comms_node: CommsNode, @@ -70,6 +69,7 @@ pub fn start_service( (task::spawn(service.start()), request_tx) } +#[derive(Debug)] pub enum StressTestServiceRequest { BeginProtocol(Peer, StressProtocol, oneshot::Sender>), Shutdown, @@ -135,9 +135,9 @@ impl StressProtocol { } struct StressTestService { - request_rx: Fuse>, + request_rx: mpsc::Receiver, comms_node: CommsNode, - protocol_notif: Fuse>>, + protocol_notif: mpsc::Receiver>, shutdown: bool, inbound_rx: Arc>>, @@ -153,9 +153,9 @@ impl StressTestService { outbound_tx: mpsc::Sender, ) -> Self { Self { - request_rx: request_rx.fuse(), + request_rx, comms_node, - protocol_notif: protocol_notif.fuse(), + protocol_notif, shutdown: false, inbound_rx: Arc::new(RwLock::new(inbound_rx)), outbound_tx, @@ -163,23 +163,21 @@ impl StressTestService { } async fn start(mut self) -> Result<(), Error> { - let mut events = self.comms_node.subscribe_connectivity_events().fuse(); + let mut events = self.comms_node.subscribe_connectivity_events(); loop { - futures::select! { - event = events.select_next_some() => { - if let Ok(event) = event { - println!("{}", event); - } + tokio::select! { + Ok(event) = events.recv() => { + println!("{}", event); }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { if let Err(err) = self.handle_request(request).await { println!("Error: {}", err); } }, - notif = self.protocol_notif.select_next_some() => { + Some(notif) = self.protocol_notif.recv() => { self.handle_protocol_notification(notif).await; }, } @@ -431,7 +429,7 @@ async fn messaging_flood( peer: NodeId, protocol: StressProtocol, inbound_rx: Arc>>, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result<(), Error> { let start = Instant::now(); let mut counter = 1u32; @@ -441,18 +439,15 @@ async fn messaging_flood( protocol.num_messages * protocol.message_size / 1024 / 1024 ); let outbound_task = task::spawn(async move { - let mut iter = stream::iter( - repeat_with(|| { - counter += 1; - - println!("Send MSG {}", counter); - OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) - }) - .take(protocol.num_messages as usize) - .map(Ok), - ); - outbound_tx.send_all(&mut iter).await?; - time::delay_for(Duration::from_secs(5)).await; + let iter = repeat_with(|| { + counter += 1; + + println!("Send MSG {}", counter); + OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) + }) + .take(protocol.num_messages as usize); + utils::mpsc::send_all(&outbound_tx, iter).await?; + time::sleep(Duration::from_secs(5)).await; outbound_tx .send(OutboundMessage::new(peer.clone(), Bytes::from_static(&[0u8; 4]))) .await?; @@ -462,7 +457,7 @@ async fn messaging_flood( let inbound_task = task::spawn(async move { let mut inbound_rx = inbound_rx.write().await; let mut msgs = vec![]; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { let msg_id = decode_msg(msg.body); println!("GOT MSG {}", msg_id); if msgs.len() == protocol.num_messages as usize { @@ -497,6 +492,6 @@ fn generate_message(n: u32, size: usize) -> Bytes { fn decode_msg(msg: T) -> u32 { let mut buf = [0u8; 4]; - msg.bytes().copy_to_slice(&mut buf); + msg.chunk().copy_to_slice(&mut buf); u32::from_be_bytes(buf) } diff --git a/comms/examples/stress_test.rs b/comms/examples/stress_test.rs index 71c185aa3e..3a0c04f020 100644 --- a/comms/examples/stress_test.rs +++ b/comms/examples/stress_test.rs @@ -24,13 +24,13 @@ mod stress; use stress::{error::Error, prompt::user_prompt}; use crate::stress::{node, prompt::parse_from_short_str, service, service::StressTestServiceRequest}; -use futures::{channel::oneshot, future, future::Either, SinkExt}; +use futures::{future, future::Either}; use std::{env, net::Ipv4Addr, path::Path, process, sync::Arc, time::Duration}; use tari_crypto::tari_utilities::message_format::MessageFormat; use tempfile::Builder; -use tokio::time; +use tokio::{sync::oneshot, time}; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); match run().await { @@ -99,7 +99,7 @@ async fn run() -> Result<(), Error> { } println!("Stress test service started!"); - let (handle, mut requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); + let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); let mut last_peer = peer.as_ref().and_then(parse_from_short_str); diff --git a/comms/examples/tor.rs b/comms/examples/tor.rs index 734ef7718a..9186c69d43 100644 --- a/comms/examples/tor.rs +++ b/comms/examples/tor.rs @@ -1,7 +1,6 @@ use anyhow::anyhow; use bytes::Bytes; use chrono::Utc; -use futures::{channel::mpsc, SinkExt, StreamExt}; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{collections::HashMap, convert::identity, env, net::SocketAddr, path::Path, process, sync::Arc}; use tari_comms::{ @@ -21,7 +20,10 @@ use tari_storage::{ LMDBWrapper, }; use tempfile::Builder; -use tokio::{runtime, sync::broadcast}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, +}; // Tor example for tari_comms. // @@ -29,7 +31,7 @@ use tokio::{runtime, sync::broadcast}; type Error = anyhow::Error; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); if let Err(err) = run().await { @@ -56,7 +58,7 @@ async fn run() -> Result<(), Error> { println!("Starting comms nodes...",); let temp_dir1 = Builder::new().prefix("tor-example1").tempdir().unwrap(); - let (comms_node1, inbound_rx1, mut outbound_tx1) = setup_node_with_tor( + let (comms_node1, inbound_rx1, outbound_tx1) = setup_node_with_tor( control_port_addr.clone(), temp_dir1.as_ref(), (9098u16, "127.0.0.1:0".parse::().unwrap()), @@ -208,11 +210,11 @@ async fn setup_node_with_tor>( async fn start_ping_ponger( dest_node_id: NodeId, mut inbound_rx: mpsc::Receiver, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result { let mut inflight_pings = HashMap::new(); let mut counter = 0; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { counter += 1; let msg_str = String::from_utf8_lossy(&msg.body); diff --git a/comms/rpc_macros/Cargo.toml b/comms/rpc_macros/Cargo.toml index 3680ed81f9..5bc185bc90 100644 --- a/comms/rpc_macros/Cargo.toml +++ b/comms/rpc_macros/Cargo.toml @@ -13,16 +13,16 @@ edition = "2018" proc-macro = true [dependencies] +tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} + proc-macro2 = "1.0.24" quote = "1.0.7" syn = {version = "1.0.38", features = ["fold"]} -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} [dev-dependencies] tari_test_utils = {version="^0.9", path="../../infrastructure/test_utils"} futures = "0.3.5" -prost = "0.6.1" -tokio = "0.2.22" -tokio-macros = "0.2.5" +prost = "0.8.0" +tokio = {version = "1", features = ["macros"]} tower-service = "0.3.0" diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index f3e8cbffd1..5f44066f19 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -215,8 +215,8 @@ impl RpcCodeGenerator { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } }; diff --git a/comms/rpc_macros/tests/macro.rs b/comms/rpc_macros/tests/macro.rs index f3f05b4481..b41f3f9914 100644 --- a/comms/rpc_macros/tests/macro.rs +++ b/comms/rpc_macros/tests/macro.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::StreamExt; use prost::Message; use std::{collections::HashMap, ops::AddAssign, sync::Arc}; use tari_comms::{ @@ -34,7 +34,10 @@ use tari_comms::{ }; use tari_comms_rpc_macros::tari_rpc; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; use tower_service::Service; #[tari_rpc(protocol_name = b"/test/protocol/123", server_struct = TestServer, client_struct = TestClient)] @@ -80,7 +83,7 @@ impl Test for TestService { async fn server_streaming(&self, _: Request) -> Result, RpcStatus> { self.add_call("server_streaming").await; - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); tx.send(Ok(1)).await.unwrap(); Ok(Streaming::new(rx)) } @@ -101,7 +104,7 @@ fn it_sets_the_protocol_name() { assert_eq!(TestClient::PROTOCOL_NAME, b"/test/protocol/123"); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_the_correct_type() { let mut server = TestServer::new(TestService::default()); let resp = server @@ -112,7 +115,7 @@ async fn it_returns_the_correct_type() { assert_eq!(u32::decode(v).unwrap(), 12); } -#[tokio_macros::test] +#[tokio::test] async fn it_correctly_maps_the_method_nums() { let service = TestService::default(); let spy = service.state.clone(); @@ -135,7 +138,7 @@ async fn it_correctly_maps_the_method_nums() { assert_eq!(*spy.read().await.get("unit").unwrap(), 1); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_error_for_invalid_method_nums() { let service = TestService::default(); let mut server = TestServer::new(service); @@ -147,7 +150,7 @@ async fn it_returns_an_error_for_invalid_method_nums() { unpack_enum!(RpcStatusCode::UnsupportedMethod = err.status_code()); } -#[tokio_macros::test] +#[tokio::test] async fn it_generates_client_calls() { let (sock_client, sock_server) = MemorySocket::new_pair(); let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024))); diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index ee65e68476..3e36edb50b 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -145,7 +145,15 @@ impl BoundedExecutor { F::Output: Send + 'static, { let span = span!(Level::TRACE, "bounded_executor::waiting_time"); - let permit = self.semaphore.clone().acquire_owned().instrument(span).await; + // SAFETY: acquire_owned only fails if the semaphore is closed (i.e self.semaphore.close() is called) - this + // never happens in this implementation + let permit = self + .semaphore + .clone() + .acquire_owned() + .instrument(span) + .await + .expect("semaphore closed"); self.do_spawn(permit, future) } @@ -230,9 +238,9 @@ mod test { }, time::Duration, }; - use tokio::time::delay_for; + use tokio::time::sleep; - #[runtime::test_basic] + #[runtime::test] async fn spawn() { let flag = Arc::new(AtomicBool::new(false)); let flag_cloned = flag.clone(); @@ -241,7 +249,7 @@ mod test { // Spawn 1 let task1_fut = executor .spawn(async move { - delay_for(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)).await; flag_cloned.store(true, Ordering::SeqCst); }) .await; diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index 24c856f5ba..abd71e8952 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -46,11 +46,13 @@ use crate::{ CommsBuilder, Substream, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use log::*; use std::{iter, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::node"; diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index f5acc151a9..5da0d51793 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -51,10 +51,9 @@ use crate::{ tor, types::CommsDatabase, }; -use futures::channel::mpsc; use std::{fs::File, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; /// The `CommsBuilder` provides a simple builder API for getting Tari comms p2p messaging up and running. pub struct CommsBuilder { diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index d1ae9a0f9a..d4fe8cca97 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -42,19 +42,16 @@ use crate::{ CommsNode, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::stream::FuturesUnordered; use std::{collections::HashSet, convert::identity, hash::Hash, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::HashmapDatabase; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, task}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{broadcast, mpsc, oneshot}, + task, +}; async fn spawn_node( protocols: Protocols, @@ -109,7 +106,7 @@ async fn spawn_node( (comms_node, inbound_rx, outbound_tx, messaging_events_sender) } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_custom_protocols() { static TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test"); static ANOTHER_TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test-again"); @@ -161,9 +158,9 @@ async fn peer_to_peer_custom_protocols() { // Check that both nodes get the PeerConnected event. We subscribe after the nodes are initialized // so we miss those events. - let next_event = conn_man_events2.next().await.unwrap().unwrap(); + let next_event = conn_man_events2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn2) = &*next_event); - let next_event = conn_man_events1.next().await.unwrap().unwrap(); + let next_event = conn_man_events1.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn) = &*next_event); // Let's speak both our test protocols @@ -176,7 +173,7 @@ async fn peer_to_peer_custom_protocols() { negotiated_substream2.stream.write_all(ANOTHER_TEST_MSG).await.unwrap(); // Read TEST_PROTOCOL message to node 2 from node 1 - let negotiation = test_protocol_rx2.next().await.unwrap(); + let negotiation = test_protocol_rx2.recv().await.unwrap(); assert_eq!(negotiation.protocol, TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -185,7 +182,7 @@ async fn peer_to_peer_custom_protocols() { assert_eq!(buf, TEST_MSG); // Read ANOTHER_TEST_PROTOCOL message to node 1 from node 2 - let negotiation = another_test_protocol_rx1.next().await.unwrap(); + let negotiation = another_test_protocol_rx1.recv().await.unwrap(); assert_eq!(negotiation.protocol, ANOTHER_TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -193,18 +190,18 @@ async fn peer_to_peer_custom_protocols() { substream.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, ANOTHER_TEST_MSG); - shutdown.trigger().unwrap(); + shutdown.trigger(); comms_node1.wait_until_shutdown().await; comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging() { const NUM_MSGS: usize = 100; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, messaging_events2) = + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, messaging_events2) = spawn_node(Protocols::new(), shutdown.to_signal()).await; let mut messaging_events2 = messaging_events2.subscribe(); @@ -238,14 +235,14 @@ async fn peer_to_peer_messaging() { outbound_tx1.send(outbound_msg).await.unwrap(); } - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); let send_results = collect_stream!(replies, take = NUM_MSGS, timeout = Duration::from_secs(10)); send_results.into_iter().for_each(|r| { r.unwrap().unwrap(); }); - let events = collect_stream!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - events.into_iter().map(Result::unwrap).for_each(|m| { + let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + events.into_iter().for_each(|m| { unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); }); @@ -258,7 +255,7 @@ async fn peer_to_peer_messaging() { outbound_tx2.send(outbound_msg).await.unwrap(); } - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); // Check that we got all the messages let check_messages = |msgs: Vec| { @@ -279,13 +276,13 @@ async fn peer_to_peer_messaging() { comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging_simultaneous() { const NUM_MSGS: usize = 10; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; log::info!( "Peer1 = `{}`, Peer2 = `{}`", @@ -350,8 +347,8 @@ async fn peer_to_peer_messaging_simultaneous() { handle2.await.unwrap(); // Tasks are finished, let's see if all the messages made it though - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert!(has_unique_elements(messages1_to_2.into_iter().map(|m| m.body))); assert!(has_unique_elements(messages2_to_1.into_iter().map(|m| m.body))); diff --git a/comms/src/compat.rs b/comms/src/compat.rs index 254876f14d..67b53c7c91 100644 --- a/comms/src/compat.rs +++ b/comms/src/compat.rs @@ -27,8 +27,9 @@ use std::{ io, pin::Pin, - task::{self, Poll}, + task::{self, Context, Poll}, }; +use tokio::io::ReadBuf; /// `IoCompat` provides a compatibility shim between the `AsyncRead`/`AsyncWrite` traits provided by /// the `futures` library and those provided by the `tokio` library since they are different and @@ -47,16 +48,16 @@ impl IoCompat { impl tokio::io::AsyncRead for IoCompat where T: futures::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf.filled_mut()) } } impl futures::io::AsyncRead for IoCompat where T: tokio::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { + tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, &mut ReadBuf::new(buf)) } } diff --git a/comms/src/connection_manager/dial_state.rs b/comms/src/connection_manager/dial_state.rs index 0b07378747..07bbd1e631 100644 --- a/comms/src/connection_manager/dial_state.rs +++ b/comms/src/connection_manager/dial_state.rs @@ -24,8 +24,8 @@ use crate::{ connection_manager::{error::ConnectionManagerError, peer_connection::PeerConnection}, peer_manager::Peer, }; -use futures::channel::oneshot; use tari_shutdown::ShutdownSignal; +use tokio::sync::oneshot; /// The state of the dial request pub struct DialState { diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 538ffe89bd..8264286db3 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -39,22 +39,22 @@ use crate::{ types::CommsPublicKey, }; use futures::{ - channel::{mpsc, oneshot}, future, future::{BoxFuture, Either, FusedFuture}, pin_mut, - stream::{Fuse, FuturesUnordered}, - AsyncRead, - AsyncWrite, - AsyncWriteExt, + stream::FuturesUnordered, FutureExt, - SinkExt, - StreamExt, }; use log::*; use std::{collections::HashMap, sync::Arc, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; -use tokio::{task::JoinHandle, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + sync::{mpsc, oneshot}, + task::JoinHandle, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::dialer"; @@ -79,7 +79,7 @@ pub struct Dialer { transport: TTransport, noise_config: NoiseConfig, backoff: Arc, - request_rx: Fuse>, + request_rx: mpsc::Receiver, cancel_signals: HashMap, conn_man_notifier: mpsc::Sender, shutdown: Option, @@ -112,7 +112,7 @@ where transport, noise_config, backoff: Arc::new(backoff), - request_rx: request_rx.fuse(), + request_rx, cancel_signals: Default::default(), conn_man_notifier, shutdown: Some(shutdown), @@ -139,16 +139,20 @@ where .expect("Establisher initialized without a shutdown"); debug!(target: LOG_TARGET, "Connection dialer started"); loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(&mut pending_dials, request), - (dial_state, dial_result) = pending_dials.select_next_some() => { - self.handle_dial_result(dial_state, dial_result).await; - } - _ = shutdown => { + tokio::select! { + // Biased ordering is used because we already have the futures polled here in a fair order, and so wish to + // forgo the minor cost of the random ordering + biased; + + _ = &mut shutdown => { info!(target: LOG_TARGET, "Connection dialer shutting down because the shutdown signal was received"); self.cancel_all_dials(); break; } + Some((dial_state, dial_result)) = pending_dials.next() => { + self.handle_dial_result(dial_state, dial_result).await; + } + Some(request) = self.request_rx.recv() => self.handle_request(&mut pending_dials, request), } } } @@ -179,12 +183,7 @@ where self.cancel_signals.len() ); self.cancel_signals.drain().for_each(|(_, mut signal)| { - log_if_error_fmt!( - level: warn, - target: LOG_TARGET, - signal.trigger(), - "Shutdown trigger failed", - ); + signal.trigger(); }) } @@ -339,7 +338,7 @@ where } #[allow(clippy::too_many_arguments)] - #[tracing::instrument(skip(peer_manager, socket, conn_man_notifier, config, cancel_signal), err)] + #[tracing::instrument(skip(peer_manager, socket, conn_man_notifier, config, cancel_signal))] async fn perform_socket_upgrade_procedure( peer_manager: Arc, node_identity: Arc, @@ -352,7 +351,6 @@ where cancel_signal: ShutdownSignal, ) -> Result { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; - let mut muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) .await .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; @@ -448,9 +446,9 @@ where current_state.peer.node_id.short_str(), backoff_duration.as_secs() ); - let mut delay = time::delay_for(backoff_duration).fuse(); - let mut cancel_signal = current_state.get_cancel_signal(); - futures::select! { + let delay = time::sleep(backoff_duration).fuse(); + let cancel_signal = current_state.get_cancel_signal(); + tokio::select! { _ = delay => { debug!(target: LOG_TARGET, "[Attempt {}] Connecting to peer '{}'", current_state.num_attempts(), current_state.peer.node_id.short_str()); match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_byte).await { @@ -544,18 +542,13 @@ where // Try the next address continue; }, - Either::Right((cancel_result, _)) => { + // Canceled + Either::Right(_) => { debug!( target: LOG_TARGET, "Dial for peer '{}' cancelled", dial_state.peer.node_id.short_str() ); - log_if_error!( - level: warn, - target: LOG_TARGET, - cancel_result, - "Cancel channel error during dial: {}", - ); Err(ConnectionManagerError::DialCancelled) }, } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index f7bbbaa564..5645a57e62 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ + connection_manager::PeerConnectionRequest, noise, peer_manager::PeerManagerError, protocol::{IdentityProtocolError, ProtocolError}, }; -use futures::channel::mpsc; use thiserror::Error; -use tokio::{time, time::Elapsed}; +use tokio::{sync::mpsc, time::error::Elapsed}; #[derive(Debug, Error, Clone)] pub enum ConnectionManagerError { @@ -108,14 +108,14 @@ pub enum PeerConnectionError { #[error("Internal oneshot reply channel was unexpectedly cancelled")] InternalReplyCancelled, #[error("Failed to send internal request: {0}")] - InternalRequestSendFailed(#[from] mpsc::SendError), + InternalRequestSendFailed(#[from] mpsc::error::SendError), #[error("Protocol error: {0}")] ProtocolError(#[from] ProtocolError), #[error("Protocol negotiation timeout")] ProtocolNegotiationTimeout, } -impl From for PeerConnectionError { +impl From for PeerConnectionError { fn from(_: Elapsed) -> Self { PeerConnectionError::ProtocolNegotiationTimeout } diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 60ae3c2d12..21c3771610 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -32,7 +32,6 @@ use crate::{ bounded_executor::BoundedExecutor, connection_manager::{ liveness::LivenessSession, - types::OneshotTrigger, wire_mode::{WireMode, LIVENESS_WIRE_MODE}, }, multiaddr::Multiaddr, @@ -46,17 +45,7 @@ use crate::{ utils::multiaddr::multiaddr_to_socketaddr, PeerManager, }; -use futures::{ - channel::mpsc, - future, - AsyncRead, - AsyncReadExt, - AsyncWrite, - AsyncWriteExt, - FutureExt, - SinkExt, - StreamExt, -}; +use futures::{future, FutureExt}; use log::*; use std::{ convert::TryInto, @@ -69,8 +58,13 @@ use std::{ time::Duration, }; use tari_crypto::tari_utilities::hex::Hex; -use tari_shutdown::ShutdownSignal; -use tokio::time; +use tari_shutdown::{oneshot_trigger, oneshot_trigger::OneshotTrigger, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::listener"; @@ -118,7 +112,7 @@ where bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, - on_listening: OneshotTrigger::new(), + on_listening: oneshot_trigger::channel(), } } @@ -128,7 +122,7 @@ where // 'static lifetime as well as to flatten the oneshot result for ergonomics pub fn on_listening(&self) -> impl Future> + 'static { let signal = self.on_listening.to_signal(); - signal.map(|r| r.map_err(|_| ConnectionManagerError::ListenerOneshotCancelled)?) + signal.map(|r| r.ok_or(ConnectionManagerError::ListenerOneshotCancelled)?) } /// Set the supported protocols of this node to send to peers during the peer identity exchange @@ -147,31 +141,30 @@ where let mut shutdown_signal = self.shutdown_signal.clone(); match self.bind().await { - Ok((inbound, address)) => { + Ok((mut inbound, address)) => { info!(target: LOG_TARGET, "Listening for peer connections on '{}'", address); - self.on_listening.trigger(Ok(address)); - - let inbound = inbound.fuse(); - futures::pin_mut!(inbound); + self.on_listening.broadcast(Ok(address)); loop { - futures::select! { - inbound_result = inbound.select_next_some() => { + tokio::select! { + biased; + + _ = &mut shutdown_signal => { + info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); + break; + }, + Some(inbound_result) = inbound.next() => { if let Some((socket, peer_addr)) = log_if_error!(target: LOG_TARGET, inbound_result, "Inbound connection failed because '{error}'",) { self.spawn_listen_task(socket, peer_addr).await; } }, - _ = shutdown_signal => { - info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); - break; - }, } } }, Err(err) => { warn!(target: LOG_TARGET, "PeerListener was unable to start because '{}'", err); - self.on_listening.trigger(Err(err)); + self.on_listening.broadcast(Err(err)); }, } } @@ -238,7 +231,7 @@ where async fn spawn_listen_task(&self, mut socket: TTransport::Output, peer_addr: Multiaddr) { let node_identity = self.node_identity.clone(); let peer_manager = self.peer_manager.clone(); - let mut conn_man_notifier = self.conn_man_notifier.clone(); + let conn_man_notifier = self.conn_man_notifier.clone(); let noise_config = self.noise_config.clone(); let config = self.config.clone(); let our_supported_protocols = self.our_supported_protocols.clone(); @@ -318,7 +311,7 @@ where "No liveness sessions available or permitted for peer address '{}'", peer_addr ); - let _ = socket.close().await; + let _ = socket.shutdown().await; } }, Err(err) => { diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 3c06889307..75ee2db13f 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -20,15 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite, Future, StreamExt}; +use futures::StreamExt; +use std::future::Future; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; /// Max line length accepted by the liveness session. const MAX_LINE_LENGTH: usize = 50; pub struct LivenessSession { - framed: Framed, LinesCodec>, + framed: Framed, } impl LivenessSession @@ -36,7 +37,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin { pub fn new(socket: TSocket) -> Self { Self { - framed: Framed::new(IoCompat::new(socket), LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), + framed: Framed::new(socket, LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), } } @@ -52,13 +53,14 @@ mod test { use crate::{memsocket::MemorySocket, runtime}; use futures::SinkExt; use tokio::{time, time::Duration}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); let join_handle = runtime::current().spawn(liveness.run()); - let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); + let mut outbound = Framed::new(outbound, LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() } diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 928b53611f..0c6de18f59 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -36,20 +36,17 @@ use crate::{ transports::{TcpTransport, Transport}, PeerManager, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - AsyncRead, - AsyncWrite, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{fmt, sync::Arc}; use tari_shutdown::{Shutdown, ShutdownSignal}; use time::Duration; -use tokio::{sync::broadcast, task, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc, oneshot}, + task, + time, +}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::manager"; @@ -156,8 +153,8 @@ impl ListenerInfo { } pub struct ConnectionManager { - request_rx: Fuse>, - internal_event_rx: Fuse>, + request_rx: mpsc::Receiver, + internal_event_rx: mpsc::Receiver, dialer_tx: mpsc::Sender, dialer: Option>, listener: Option>, @@ -230,10 +227,10 @@ where Self { shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + request_rx, peer_manager, protocols: Protocols::new(), - internal_event_rx: internal_event_rx.fuse(), + internal_event_rx, dialer_tx, dialer: Some(dialer), listener: Some(listener), @@ -266,7 +263,7 @@ where .take() .expect("ConnectionManager initialized without a shutdown"); - // Runs the listeners, waiting for a + // Runs the listeners. Sockets are bound and ready once this resolves match self.run_listeners().await { Ok(info) => { self.listener_info = Some(info); @@ -293,16 +290,16 @@ where .join(", ") ); loop { - futures::select! { - event = self.internal_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_event_rx.recv() => { self.handle_event(event).await; }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - _ = shutdown => { + _ = &mut shutdown => { info!(target: LOG_TARGET, "ConnectionManager is shutting down because it received the shutdown signal"); break; } diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 6f0d90da5d..cceb6ae9bd 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -44,20 +44,21 @@ use crate::{ protocol::{ProtocolId, ProtocolNegotiation}, runtime, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{ fmt, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level, Span}; const LOG_TARGET: &str = "comms::connection_manager::peer_connection"; @@ -130,7 +131,7 @@ pub struct PeerConnection { peer_node_id: NodeId, peer_features: PeerFeatures, request_tx: mpsc::Sender, - address: Multiaddr, + address: Arc, direction: ConnectionDirection, started_at: Instant, substream_counter: SubstreamCounter, @@ -151,7 +152,7 @@ impl PeerConnection { request_tx, peer_node_id, peer_features, - address, + address: Arc::new(address), direction, started_at: Instant::now(), substream_counter, @@ -190,7 +191,7 @@ impl PeerConnection { self.substream_counter.get() } - #[tracing::instrument("peer_connection::open_substream", skip(self), err)] + #[tracing::instrument("peer_connection::open_substream", skip(self))] pub async fn open_substream( &mut self, protocol_id: &ProtocolId, @@ -208,7 +209,7 @@ impl PeerConnection { .map_err(|_| PeerConnectionError::InternalReplyCancelled)? } - #[tracing::instrument("peer_connection::open_framed_substream", skip(self), err)] + #[tracing::instrument("peer_connection::open_framed_substream", skip(self))] pub async fn open_framed_substream( &mut self, protocol_id: &ProtocolId, @@ -219,14 +220,14 @@ impl PeerConnection { } #[cfg(feature = "rpc")] - #[tracing::instrument("peer_connection::connect_rpc", skip(self), fields(peer_node_id = self.peer_node_id.to_string().as_str()), err)] + #[tracing::instrument("peer_connection::connect_rpc", skip(self), fields(peer_node_id = self.peer_node_id.to_string().as_str()))] pub async fn connect_rpc(&mut self) -> Result where T: From + NamedProtocolService { self.connect_rpc_using_builder(Default::default()).await } #[cfg(feature = "rpc")] - #[tracing::instrument("peer_connection::connect_rpc_with_builder", skip(self, builder), err)] + #[tracing::instrument("peer_connection::connect_rpc_with_builder", skip(self, builder))] pub async fn connect_rpc_using_builder(&mut self, builder: RpcClientBuilder) -> Result where T: From + NamedProtocolService { let protocol = ProtocolId::from_static(T::PROTOCOL_NAME); @@ -301,9 +302,9 @@ impl PartialEq for PeerConnection { struct PeerConnectionActor { id: ConnectionId, peer_node_id: NodeId, - request_rx: Fuse>, + request_rx: mpsc::Receiver, direction: ConnectionDirection, - incoming_substreams: Fuse, + incoming_substreams: IncomingSubstreams, control: Control, event_notifier: mpsc::Sender, our_supported_protocols: Vec, @@ -327,8 +328,8 @@ impl PeerConnectionActor { peer_node_id, direction, control: connection.get_yamux_control(), - incoming_substreams: connection.incoming().fuse(), - request_rx: request_rx.fuse(), + incoming_substreams: connection.incoming(), + request_rx, event_notifier, our_supported_protocols, their_supported_protocols, @@ -337,8 +338,8 @@ impl PeerConnectionActor { pub async fn run(mut self) { loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(request).await, + tokio::select! { + Some(request) = self.request_rx.recv() => self.handle_request(request).await, maybe_substream = self.incoming_substreams.next() => { match maybe_substream { @@ -362,7 +363,7 @@ impl PeerConnectionActor { } } } - self.request_rx.get_mut().close(); + self.request_rx.close(); } async fn handle_request(&mut self, request: PeerConnectionRequest) { @@ -396,7 +397,7 @@ impl PeerConnectionActor { } } - #[tracing::instrument(skip(self, stream), err, fields(comms.direction="inbound"))] + #[tracing::instrument(skip(self, stream),fields(comms.direction="inbound"))] async fn handle_incoming_substream(&mut self, mut stream: Substream) -> Result<(), PeerConnectionError> { let selected_protocol = ProtocolNegotiation::new(&mut stream) .negotiate_protocol_inbound(&self.our_supported_protocols) @@ -412,7 +413,7 @@ impl PeerConnectionActor { Ok(()) } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] async fn open_negotiated_protocol_stream( &mut self, protocol: ProtocolId, diff --git a/comms/src/connection_manager/requester.rs b/comms/src/connection_manager/requester.rs index 1f3f5cc887..0007e59228 100644 --- a/comms/src/connection_manager/requester.rs +++ b/comms/src/connection_manager/requester.rs @@ -25,12 +25,8 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, oneshot}; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] @@ -78,7 +74,7 @@ impl ConnectionManagerRequester { } /// Attempt to connect to a remote peer - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub async fn dial_peer(&mut self, node_id: NodeId) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.send_dial_peer(node_id, Some(reply_tx)).await?; @@ -97,7 +93,7 @@ impl ConnectionManagerRequester { } /// Send instruction to ConnectionManager to dial a peer and return the result on the given oneshot - #[tracing::instrument(skip(self, reply_tx), err)] + #[tracing::instrument(skip(self, reply_tx))] pub(crate) async fn send_dial_peer( &mut self, node_id: NodeId, @@ -124,7 +120,7 @@ impl ConnectionManagerRequester { } /// Send instruction to ConnectionManager to dial a peer without waiting for a result. - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub(crate) async fn send_dial_peer_no_reply(&mut self, node_id: NodeId) -> Result<(), ConnectionManagerError> { self.send_dial_peer(node_id, None).await?; Ok(()) diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index 586c0d5ec4..25e715bf50 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -36,20 +36,17 @@ use crate::{ test_utils::{node_identity::build_node_identity, test_node::build_peer_manager}, transports::MemoryTransport, }; -use futures::{ - channel::{mpsc, oneshot}, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; use multiaddr::Protocol; use std::{error::Error, time::Duration}; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::time::timeout; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot}, + time::timeout, +}; -#[runtime::test_basic] +#[runtime::test] async fn listen() -> Result<(), Box> { let (event_tx, _) = mpsc::channel(1); let mut shutdown = Shutdown::new(); @@ -61,7 +58,7 @@ async fn listen() -> Result<(), Box> { "/memory/0".parse()?, MemoryTransport, noise_config.clone(), - event_tx.clone(), + event_tx, peer_manager, node_identity, shutdown.to_signal(), @@ -72,12 +69,12 @@ async fn listen() -> Result<(), Box> { unpack_enum!(Protocol::Memory(port) = bind_addr.pop().unwrap()); assert!(port > 0); - shutdown.trigger().unwrap(); + shutdown.trigger(); Ok(()) } -#[runtime::test_basic] +#[runtime::test] async fn smoke() { let rt_handle = runtime::current(); // This test sets up Dialer and Listener components, uses the Dialer to dial the Listener, @@ -108,7 +105,7 @@ async fn smoke() { let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -148,11 +145,11 @@ async fn smoke() { } // Read PeerConnected events - we don't know which connection is which - unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.next().await.unwrap()); - unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.recv().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.recv().await.unwrap()); // Next event should be a NewInboundSubstream has been received - let listen_event = event_rx.next().await.unwrap(); + let listen_event = event_rx.recv().await.unwrap(); { unpack_enum!(ConnectionManagerEvent::NewInboundSubstream(node_id, proto, in_stream) = listen_event); assert_eq!(&*node_id, node_identity2.node_id()); @@ -165,7 +162,7 @@ async fn smoke() { conn1.disconnect().await.unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); let peer2 = peer_manager1.find_by_node_id(node_identity2.node_id()).await.unwrap(); let peer1 = peer_manager2.find_by_node_id(node_identity1.node_id()).await.unwrap(); @@ -176,7 +173,7 @@ async fn smoke() { timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn banned() { let rt_handle = runtime::current(); let (event_tx, mut event_rx) = mpsc::channel(10); @@ -209,7 +206,7 @@ async fn banned() { peer_manager1.add_peer(peer).await.unwrap(); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -241,10 +238,10 @@ async fn banned() { let err = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::IdentityProtocolError(_err) = err); - unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.recv().await.unwrap()); unpack_enum!(ConnectionManagerError::PeerBanned = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index bca876eff4..910280b4cf 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -41,19 +41,17 @@ use crate::{ }, transports::{MemoryTransport, TcpTransport}, }; -use futures::{ - channel::{mpsc, oneshot}, - future, - AsyncReadExt, - AsyncWriteExt, - StreamExt, -}; +use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{runtime::Handle, sync::broadcast}; +use tari_test_utils::{collect_try_recv, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + runtime::Handle, + sync::{broadcast, mpsc, oneshot}, +}; -#[runtime::test_basic] +#[runtime::test] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); @@ -83,10 +81,10 @@ async fn connect_to_nonexistent_peer() { unpack_enum!(ConnectionManagerError::PeerManagerError(err) = err); unpack_enum!(PeerManagerError::PeerNotFoundError = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -159,7 +157,7 @@ async fn dial_success() { assert_eq!(peer2.supported_protocols, [&IDENTITY_PROTOCOL, &TEST_PROTO]); assert_eq!(peer2.user_agent, "node2"); - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn_in) = &*event); assert_eq!(conn_in.peer_node_id(), node_identity1.node_id()); @@ -179,7 +177,7 @@ async fn dial_success() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx2.next().await.unwrap(); + let protocol_in = proto_rx2.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -189,7 +187,7 @@ async fn dial_success() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success_aux_tcp_listener() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -271,7 +269,7 @@ async fn dial_success_aux_tcp_listener() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx1.next().await.unwrap(); + let protocol_in = proto_rx1.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -281,7 +279,7 @@ async fn dial_success_aux_tcp_listener() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn simultaneous_dial_events() { let mut shutdown = Shutdown::new(); @@ -360,29 +358,22 @@ async fn simultaneous_dial_events() { _ => panic!("unexpected simultaneous dial result"), } - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); assert!(count_string_occurrences(&[event], &["PeerConnected", "PeerInboundConnectFailed"]) >= 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); drop(conn_man2); - let _events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); - - let _events2 = collect_stream!(subscription2, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let _events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); + let _events2 = collect_try_recv!(subscription2, timeout = Duration::from_secs(5)); // TODO: Investigate why two PeerDisconnected events are sometimes received // assert!(count_string_occurrences(&events1, &["PeerDisconnected"]) >= 1); // assert!(count_string_occurrences(&events2, &["PeerDisconnected"]) >= 1); } -#[tokio_macros::test_basic] +#[runtime::test] async fn dial_cancelled() { let mut shutdown = Shutdown::new(); @@ -429,13 +420,10 @@ async fn dial_cancelled() { let err = dial_result.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::DialCancelled = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); - let events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); assert_eq!(events1.len(), 1); unpack_enum!(ConnectionManagerEvent::PeerConnectFailed(node_id, err) = &*events1[0]); diff --git a/comms/src/connection_manager/types.rs b/comms/src/connection_manager/types.rs index ddb8f6de8f..c92b2b717a 100644 --- a/comms/src/connection_manager/types.rs +++ b/comms/src/connection_manager/types.rs @@ -20,11 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::oneshot, - future::{Fuse, Shared}, - FutureExt, -}; use std::fmt; /// Direction of the connection relative to this node @@ -47,29 +42,3 @@ impl fmt::Display for ConnectionDirection { write!(f, "{:?}", self) } } - -pub type OneshotSignal = Shared>>; -pub struct OneshotTrigger(Option>, OneshotSignal); - -impl OneshotTrigger { - pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self(Some(tx), rx.fuse().shared()) - } - - pub fn to_signal(&self) -> OneshotSignal { - self.1.clone() - } - - pub fn trigger(&mut self, item: T) { - if let Some(tx) = self.0.take() { - let _ = tx.send(item); - } - } -} - -impl Default for OneshotTrigger { - fn default() -> Self { - Self::new() - } -} diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index 35f37627c4..821253e84b 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -34,6 +34,7 @@ use crate::{ ConnectionManagerEvent, ConnectionManagerRequester, }, + connectivity::ConnectivityEventTx, peer_manager::NodeId, runtime::task, utils::datetime::format_duration, @@ -41,7 +42,6 @@ use crate::{ PeerConnection, PeerManager, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use nom::lib::std::collections::hash_map::Entry; use std::{ @@ -52,7 +52,7 @@ use std::{ time::{Duration, Instant}, }; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle, time}; +use tokio::{sync::mpsc, task::JoinHandle, time}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connectivity::manager"; @@ -72,7 +72,7 @@ const LOG_TARGET: &str = "comms::connectivity::manager"; pub struct ConnectivityManager { pub config: ConnectivityConfig, pub request_rx: mpsc::Receiver, - pub event_tx: broadcast::Sender>, + pub event_tx: ConnectivityEventTx, pub connection_manager: ConnectionManagerRequester, pub peer_manager: Arc, pub node_identity: Arc, @@ -84,7 +84,7 @@ impl ConnectivityManager { ConnectivityManagerActor { config: self.config, status: ConnectivityStatus::Initializing, - request_rx: self.request_rx.fuse(), + request_rx: self.request_rx, connection_manager: self.connection_manager, peer_manager: self.peer_manager.clone(), event_tx: self.event_tx, @@ -140,12 +140,12 @@ impl fmt::Display for ConnectivityStatus { pub struct ConnectivityManagerActor { config: ConnectivityConfig, status: ConnectivityStatus, - request_rx: Fuse>, + request_rx: mpsc::Receiver, connection_manager: ConnectionManagerRequester, node_identity: Arc, shutdown_signal: Option, peer_manager: Arc, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, connection_stats: HashMap, managed_peers: Vec, @@ -165,7 +165,7 @@ impl ConnectivityManagerActor { .take() .expect("ConnectivityManager initialized without a shutdown_signal"); - let mut connection_manager_events = self.connection_manager.get_event_subscription().fuse(); + let mut connection_manager_events = self.connection_manager.get_event_subscription(); let interval = self.config.connection_pool_refresh_interval; let mut ticker = time::interval_at( @@ -174,18 +174,17 @@ impl ConnectivityManagerActor { .expect("connection_pool_refresh_interval cause overflow") .into(), interval, - ) - .fuse(); + ); self.publish_event(ConnectivityEvent::ConnectivityStateInitialized); loop { - futures::select! { - req = self.request_rx.select_next_some() => { + tokio::select! { + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; }, - event = connection_manager_events.select_next_some() => { + event = connection_manager_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connection_manager_event(&event).await { error!(target:LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -193,13 +192,13 @@ impl ConnectivityManagerActor { } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_connection_pool().await { error!(target: LOG_TARGET, "Error when refreshing connection pools: {:?}", err); } }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "ConnectivityManager is shutting down because it received the shutdown signal"); self.disconnect_all().await; break; @@ -823,7 +822,7 @@ impl ConnectivityManagerActor { fn publish_event(&mut self, event: ConnectivityEvent) { // A send operation can only fail if there are no subscribers, so it is safe to ignore the error - let _ = self.event_tx.send(Arc::new(event)); + let _ = self.event_tx.send(event); } async fn ban_peer( @@ -863,7 +862,7 @@ impl ConnectivityManagerActor { fn delayed_close(conn: PeerConnection, delay: Duration) { task::spawn(async move { - time::delay_for(delay).await; + time::sleep(delay).await; debug!( target: LOG_TARGET, "Closing connection from peer `{}` after delay", diff --git a/comms/src/connectivity/requester.rs b/comms/src/connectivity/requester.rs index 740c8a6c81..8fb86320b4 100644 --- a/comms/src/connectivity/requester.rs +++ b/comms/src/connectivity/requester.rs @@ -31,23 +31,21 @@ use crate::{ peer_manager::NodeId, PeerConnection, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, - StreamExt, -}; use log::*; use std::{ fmt, - sync::Arc, time::{Duration, Instant}, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, broadcast::error::RecvError, mpsc, oneshot}, + time, +}; + const LOG_TARGET: &str = "comms::connectivity::requester"; use tracing; -pub type ConnectivityEventRx = broadcast::Receiver>; -pub type ConnectivityEventTx = broadcast::Sender>; +pub type ConnectivityEventRx = broadcast::Receiver; +pub type ConnectivityEventTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ConnectivityEvent { @@ -128,7 +126,7 @@ impl ConnectivityRequester { self.event_tx.clone() } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub async fn dial_peer(&mut self, peer: NodeId) -> Result { let mut num_cancels = 0; loop { @@ -264,24 +262,23 @@ impl ConnectivityRequester { let mut last_known_peer_count = status.num_connected_nodes(); loop { debug!(target: LOG_TARGET, "Waiting for connectivity event"); - let recv_result = time::timeout(remaining, connectivity_events.next()) + let recv_result = time::timeout(remaining, connectivity_events.recv()) .await - .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))? - .ok_or(ConnectivityError::ConnectivityEventStreamClosed)?; + .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; remaining = timeout .checked_sub(start.elapsed()) .ok_or(ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; match recv_result { - Ok(event) => match &*event { + Ok(event) => match event { ConnectivityEvent::ConnectivityStateOnline(_) => { info!(target: LOG_TARGET, "Connectivity is ONLINE."); break Ok(()); }, ConnectivityEvent::ConnectivityStateDegraded(n) => { warn!(target: LOG_TARGET, "Connectivity is DEGRADED ({} peer(s))", n); - last_known_peer_count = *n; + last_known_peer_count = n; }, ConnectivityEvent::ConnectivityStateOffline => { warn!( @@ -297,14 +294,14 @@ impl ConnectivityRequester { ); }, }, - Err(broadcast::RecvError::Closed) => { + Err(RecvError::Closed) => { error!( target: LOG_TARGET, "Connectivity event stream closed unexpectedly. System may be shutting down." ); break Err(ConnectivityError::ConnectivityEventStreamClosed); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagging behind on {} connectivity event(s)", n); // We lagged, so could have missed the state change. Check it explicitly. let status = self.get_connectivity_status().await?; diff --git a/comms/src/connectivity/selection.rs b/comms/src/connectivity/selection.rs index 931b78163d..c300d2d353 100644 --- a/comms/src/connectivity/selection.rs +++ b/comms/src/connectivity/selection.rs @@ -135,8 +135,8 @@ mod test { peer_manager::node_id::NodeDistance, test_utils::{mocks::create_dummy_peer_connection, node_id, node_identity::build_node_identity}, }; - use futures::channel::mpsc; use std::iter::repeat_with; + use tokio::sync::mpsc; fn create_pool_with_connections(n: usize) -> (ConnectionPool, Vec>) { let mut pool = ConnectionPool::new(); diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index a4fec1e896..948d083e94 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -28,6 +28,7 @@ use super::{ }; use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerEvent}, + connectivity::ConnectivityEventRx, peer_manager::{Peer, PeerFeatures}, runtime, runtime::task, @@ -39,18 +40,18 @@ use crate::{ NodeIdentity, PeerManager, }; -use futures::{channel::mpsc, future}; +use futures::future; use std::{sync::Arc, time::Duration}; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, streams, unpack_enum}; -use tokio::sync::broadcast; +use tari_test_utils::{collect_try_recv, streams, unpack_enum}; +use tokio::sync::{broadcast, mpsc}; #[allow(clippy::type_complexity)] fn setup_connectivity_manager( config: ConnectivityConfig, ) -> ( ConnectivityRequester, - broadcast::Receiver>, + ConnectivityEventRx, Arc, Arc, ConnectionManagerMockState, @@ -100,7 +101,7 @@ async fn add_test_peers(peer_manager: &PeerManager, n: usize) -> Vec { peers } -#[runtime::test_basic] +#[runtime::test] async fn connecting_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -117,15 +118,15 @@ async fn connecting_peers() { .map(|(_, _, conn, _)| conn) .collect::>(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // All connections succeeded for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } - let _events = collect_stream!(event_stream, take = 11, timeout = Duration::from_secs(10)); + let _events = collect_try_recv!(event_stream, take = 11, timeout = Duration::from_secs(10)); let connection_states = connectivity.get_all_connection_states().await.unwrap(); assert_eq!(connection_states.len(), 10); @@ -135,7 +136,7 @@ async fn connecting_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn add_many_managed_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -156,8 +157,8 @@ async fn add_many_managed_peers() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // First 5 succeeded for conn in &connections { @@ -172,10 +173,10 @@ async fn add_many_managed_peers() { )); } - let events = collect_stream!(event_stream, take = 9, timeout = Duration::from_secs(10)); + let events = collect_try_recv!(event_stream, take = 9, timeout = Duration::from_secs(10)); let n = events .iter() - .find_map(|event| match &**event.as_ref().unwrap() { + .find_map(|event| match event { ConnectivityEvent::ConnectivityStateOnline(n) => Some(n), ConnectivityEvent::ConnectivityStateDegraded(_) => None, ConnectivityEvent::PeerConnected(_) => None, @@ -205,7 +206,7 @@ async fn add_many_managed_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn online_then_offline() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -244,8 +245,8 @@ async fn online_then_offline() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); for conn in connections.iter().skip(1) { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); @@ -269,9 +270,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateDegraded(2) => Some(()), _ => None, }, @@ -289,9 +290,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateOffline => Some(()), _ => None, }, @@ -303,20 +304,20 @@ async fn online_then_offline() { assert!(is_offline); } -#[runtime::test_basic] +#[runtime::test] async fn ban_peer() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); let peer = add_test_peers(&peer_manager, 1).await.pop().unwrap(); let (conn, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); - let mut events = collect_stream!(event_stream, take = 2, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = &*events.remove(0).unwrap()); - unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 2, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = events.remove(0)); + unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = events.remove(0)); let conn = connectivity.get_connection(peer.node_id.clone()).await.unwrap(); assert!(conn.is_some()); @@ -329,13 +330,12 @@ async fn ban_peer() { // We can always expect a single PeerBanned because we do not publish a disconnected event from the connection // manager In a real system, peer disconnect and peer banned events may happen in any order and should always be // completely fine. - let event = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)) + let event = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)) .pop() - .unwrap() .unwrap(); - unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = &*event); - assert_eq!(node_id, &peer.node_id); + unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = event); + assert_eq!(node_id, peer.node_id); let peer = peer_manager.find_by_node_id(&peer.node_id).await.unwrap(); assert!(peer.is_banned()); @@ -344,7 +344,7 @@ async fn ban_peer() { assert!(conn.is_none()); } -#[runtime::test_basic] +#[runtime::test] async fn peer_selection() { let config = ConnectivityConfig { min_connectivity: 1.0, @@ -370,15 +370,15 @@ async fn peer_selection() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // 10 connections for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } // Wait for all peers to be connected (i.e. for the connection manager events to be received) - let mut _events = collect_stream!(event_stream, take = 12, timeout = Duration::from_secs(10)); + let mut _events = collect_try_recv!(event_stream, take = 12, timeout = Duration::from_secs(10)); let conns = connectivity .select_connections(ConnectivitySelection::random_nodes(10, vec![connections[0] diff --git a/comms/src/framing.rs b/comms/src/framing.rs index 1e6b67691e..06ccc00c30 100644 --- a/comms/src/framing.rs +++ b/comms/src/framing.rs @@ -20,17 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; /// Tari comms canonical framing -pub type CanonicalFraming = Framed, LengthDelimitedCodec>; +pub type CanonicalFraming = Framed; pub fn canonical(stream: T, max_frame_len: usize) -> CanonicalFraming where T: AsyncRead + AsyncWrite + Unpin { Framed::new( - IoCompat::new(stream), + stream, LengthDelimitedCodec::builder() .max_frame_length(max_frame_len) .new_codec(), diff --git a/comms/src/lib.rs b/comms/src/lib.rs index 4f3bb178d2..de48e316e5 100644 --- a/comms/src/lib.rs +++ b/comms/src/lib.rs @@ -32,21 +32,19 @@ pub use peer_manager::{NodeIdentity, PeerManager}; pub mod framing; -mod common; -pub use common::rate_limit; +pub mod rate_limit; mod multiplexing; pub use multiplexing::Substream; mod noise; mod proto; -mod runtime; pub mod backoff; pub mod bounded_executor; -pub mod compat; pub mod memsocket; pub mod protocol; +pub mod runtime; #[macro_use] pub mod message; pub mod net_address; diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index 3d102bf8dc..ed77fc6146 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -26,17 +26,21 @@ use bytes::{Buf, Bytes}; use futures::{ channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, - io::{AsyncRead, AsyncWrite, Error, ErrorKind, Result}, ready, stream::{FusedStream, Stream}, task::{Context, Poll}, }; use std::{ + cmp, collections::{hash_map::Entry, HashMap}, num::NonZeroU16, pin::Pin, sync::Mutex, }; +use tokio::{ + io, + io::{AsyncRead, AsyncWrite, ErrorKind, ReadBuf}, +}; lazy_static! { static ref SWITCHBOARD: Mutex = Mutex::new(SwitchBoard(HashMap::default(), 1)); @@ -114,6 +118,7 @@ pub fn release_memsocket_port(port: NonZeroU16) { /// use std::io::Result; /// /// use tari_comms::memsocket::{MemoryListener, MemorySocket}; +/// use tokio::io::*; /// use futures::prelude::*; /// /// async fn write_stormlight(mut stream: MemorySocket) -> Result<()> { @@ -170,7 +175,7 @@ impl MemoryListener { /// ``` /// /// [`local_addr`]: #method.local_addr - pub fn bind(port: u16) -> Result { + pub fn bind(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Get the port we should bind to. If 0 was given, use a random port @@ -262,11 +267,11 @@ impl MemoryListener { Incoming { inner: self } } - fn poll_accept(&mut self, context: &mut Context) -> Poll> { + fn poll_accept(&mut self, context: &mut Context) -> Poll> { match Pin::new(&mut self.incoming).poll_next(context) { Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)), Poll::Ready(None) => { - let err = Error::new(ErrorKind::Other, "MemoryListener unknown error"); + let err = io::Error::new(ErrorKind::Other, "MemoryListener unknown error"); Poll::Ready(Err(err)) }, Poll::Pending => Poll::Pending, @@ -283,7 +288,7 @@ pub struct Incoming<'a> { } impl<'a> Stream for Incoming<'a> { - type Item = Result; + type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { let socket = ready!(self.inner.poll_accept(context)?); @@ -302,6 +307,7 @@ impl<'a> Stream for Incoming<'a> { /// /// ```rust, no_run /// use futures::prelude::*; +/// use tokio::io::*; /// use tari_comms::memsocket::MemorySocket; /// /// # async fn run() -> ::std::io::Result<()> { @@ -371,7 +377,7 @@ impl MemorySocket { /// let socket = MemorySocket::connect(16)?; /// # Ok(())} /// ``` - pub fn connect(port: u16) -> Result { + pub fn connect(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Find port to connect to @@ -399,13 +405,13 @@ impl MemorySocket { impl AsyncRead for MemorySocket { /// Attempt to read from the `AsyncRead` into `buf`. - fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { if self.incoming.is_terminated() { if self.seen_eof { return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); } else { self.seen_eof = true; - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } } @@ -413,22 +419,23 @@ impl AsyncRead for MemorySocket { loop { // If we're already filled up the buffer then we can return - if bytes_read == buf.len() { - return Poll::Ready(Ok(bytes_read)); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } match self.current_buffer { // We have data to copy to buf Some(ref mut current_buffer) if !current_buffer.is_empty() => { - let bytes_to_read = ::std::cmp::min(buf.len() - bytes_read, current_buffer.len()); - debug_assert!(bytes_to_read > 0); - - buf[bytes_read..(bytes_read + bytes_to_read)] - .copy_from_slice(current_buffer.slice(0..bytes_to_read).as_ref()); + let bytes_to_read = cmp::min(buf.remaining(), current_buffer.len()); + if bytes_to_read > 0 { + buf.initialize_unfilled_to(bytes_to_read) + .copy_from_slice(¤t_buffer.slice(..bytes_to_read)); + buf.advance(bytes_to_read); - current_buffer.advance(bytes_to_read); + current_buffer.advance(bytes_to_read); - bytes_read += bytes_to_read; + bytes_read += bytes_to_read; + } }, // Either we've exhausted our current buffer or don't have one @@ -438,13 +445,13 @@ impl AsyncRead for MemorySocket { Poll::Pending => { // If we've read anything up to this point return the bytes read if bytes_read > 0 { - return Poll::Ready(Ok(bytes_read)); + return Poll::Ready(Ok(())); } else { return Poll::Pending; } }, Poll::Ready(Some(buf)) => Some(buf), - Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)), + Poll::Ready(None) => return Poll::Ready(Ok(())), } }; }, @@ -455,14 +462,14 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { let len = buf.len(); match self.outgoing.poll_ready(context) { Poll::Ready(Ok(())) => { if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -471,7 +478,7 @@ impl AsyncWrite for MemorySocket { }, Poll::Ready(Err(e)) => { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -484,12 +491,12 @@ impl AsyncWrite for MemorySocket { } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -499,15 +506,12 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use futures::{ - executor::block_on, - io::{AsyncReadExt, AsyncWriteExt}, - stream::StreamExt, - }; - use std::io::Result; + use crate::runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; #[test] - fn listener_bind() -> Result<()> { + fn listener_bind() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let listener = MemoryListener::bind(port)?; assert_eq!(listener.local_addr(), port); @@ -515,172 +519,187 @@ mod test { Ok(()) } - #[test] - fn simple_connect() -> Result<()> { + #[runtime::test] + async fn simple_connect() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let mut listener = MemoryListener::bind(port)?; let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); Ok(()) } - #[test] - fn listen_on_port_zero() -> Result<()> { + #[runtime::test] + async fn listen_on_port_zero() -> io::Result<()> { let mut listener = MemoryListener::bind(0)?; let listener_addr = listener.local_addr(); let mut dialer = MemorySocket::connect(listener_addr)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(listener_socket.write_all(b"bar"))?; - block_on(listener_socket.flush())?; + listener_socket.write_all(b"bar").await?; + listener_socket.flush().await?; let mut buf = [0; 3]; - block_on(dialer.read_exact(&mut buf))?; + dialer.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let mut listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + #[runtime::test] + async fn listener_correctly_frees_port_on_drop() { + async fn connect_on_port(port: u16) { + let mut listener = MemoryListener::bind(port).unwrap(); + let mut dialer = MemorySocket::connect(port).unwrap(); + let mut listener_socket = listener.incoming().next().await.unwrap().unwrap(); - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await.unwrap(); + dialer.flush().await.unwrap(); let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + let n = listener_socket.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, 3); assert_eq!(&buf, b"foo"); - - Ok(()) } let port = acquire_next_memsocket_port().into(); - connect_on_port(port)?; - connect_on_port(port)?; - - Ok(()) + connect_on_port(port).await; + connect_on_port(port).await; } - #[test] - fn simple_write_read() -> Result<()> { + #[runtime::test] + async fn simple_write_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"hello world"))?; - block_on(a.flush())?; + a.write_all(b"hello world").await?; + a.flush().await?; drop(a); let mut v = Vec::new(); - block_on(b.read_to_end(&mut v))?; + b.read_to_end(&mut v).await?; assert_eq!(v, b"hello world"); Ok(()) } - #[test] - fn partial_read() -> Result<()> { + #[runtime::test] + async fn partial_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; let mut buf = [0; 3]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn partial_read_write_both_sides() -> Result<()> { + #[runtime::test] + async fn partial_read_write_both_sides() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; - block_on(b.write_all(b"stormlight"))?; - block_on(b.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; + b.write_all(b"stormlight").await?; + b.flush().await?; let mut buf_a = [0; 5]; let mut buf_b = [0; 3]; - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"storm"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"foo"); - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"light"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"bar"); Ok(()) } - #[test] - fn many_small_writes() -> Result<()> { + #[runtime::test] + async fn many_small_writes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"words"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"of"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"radiance"))?; - block_on(a.flush())?; + a.write_all(b"words").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"of").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"radiance").await?; + a.flush().await?; drop(a); let mut buf = [0; 17]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"words of radiance"); Ok(()) } - #[test] - fn read_zero_bytes() -> Result<()> { + #[runtime::test] + async fn large_writes() -> io::Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + let large_data = vec![123u8; 1024]; + a.write_all(&large_data).await?; + a.flush().await?; + drop(a); + + let mut buf = Vec::new(); + b.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 1024); + + Ok(()) + } + + #[runtime::test] + async fn read_zero_bytes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 12]; - block_on(b.read_exact(&mut buf[0..0]))?; + b.read_exact(&mut buf[0..0]).await?; assert_eq!(buf, [0; 12]); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"way of kings"); Ok(()) } - #[test] - fn read_bytes_with_large_buffer() -> Result<()> { + #[runtime::test] + async fn read_bytes_with_large_buffer() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 20]; - let bytes_read = block_on(b.read(&mut buf))?; + let bytes_read = b.read(&mut buf).await?; assert_eq!(bytes_read, 12); assert_eq!(&buf[0..12], b"way of kings"); diff --git a/comms/src/message/outbound.rs b/comms/src/message/outbound.rs index 25f0899143..a08c9604ce 100644 --- a/comms/src/message/outbound.rs +++ b/comms/src/message/outbound.rs @@ -22,11 +22,11 @@ use crate::{message::MessageTag, peer_manager::NodeId, protocol::messaging::SendFailReason}; use bytes::Bytes; -use futures::channel::oneshot; use std::{ fmt, fmt::{Error, Formatter}, }; +use tokio::sync::oneshot; pub type MessagingReplyResult = Result<(), SendFailReason>; pub type MessagingReplyRx = oneshot::Receiver; diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 1ba104ce04..17558133f2 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -21,17 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{connection_manager::ConnectionDirection, runtime}; -use futures::{ - channel::mpsc, - io::{AsyncRead, AsyncWrite}, - stream::FusedStream, - task::Context, - SinkExt, - Stream, - StreamExt, -}; +use futures::{task::Context, Stream}; use std::{future::Future, io, pin::Pin, sync::Arc, task::Poll}; use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::{self, debug, error, event, Level}; use yamux::Mode; @@ -70,7 +67,7 @@ impl Yamux { config.set_receive_window(RECEIVE_WINDOW); let substream_counter = SubstreamCounter::new(); - let connection = yamux::Connection::new(socket, config, mode); + let connection = yamux::Connection::new(socket.compat(), config, mode); let control = Control::new(connection.control(), substream_counter.clone()); let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); @@ -88,12 +85,11 @@ impl Yamux { counter: SubstreamCounter, ) -> IncomingSubstreams where - TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, + TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static, { let shutdown = Shutdown::new(); let (incoming_tx, incoming_rx) = mpsc::channel(10); - let stream = yamux::into_stream(connection).boxed(); - let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); + let incoming = IncomingWorker::new(connection, incoming_tx, shutdown.to_signal()); runtime::task::spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } @@ -122,10 +118,6 @@ impl Yamux { pub(crate) fn substream_counter(&self) -> SubstreamCounter { self.substream_counter.clone() } - - pub fn is_terminated(&self) -> bool { - self.incoming.is_terminated() - } } #[derive(Clone)] @@ -146,7 +138,7 @@ impl Control { pub async fn open_stream(&mut self) -> Result { let stream = self.inner.open_stream().await?; Ok(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), }) } @@ -185,19 +177,13 @@ impl IncomingSubstreams { } } -impl FusedStream for IncomingSubstreams { - fn is_terminated(&self) -> bool { - self.inner.is_terminated() - } -} - impl Stream for IncomingSubstreams { type Item = Substream; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match futures::ready!(Pin::new(&mut self.inner).poll_next(cx)) { + match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(stream) => Poll::Ready(Some(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), })), None => Poll::Ready(None), @@ -213,17 +199,17 @@ impl Drop for IncomingSubstreams { #[derive(Debug)] pub struct Substream { - stream: yamux::Stream, + stream: Compat, counter_guard: CounterGuard, } -impl AsyncRead for Substream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { +impl tokio::io::AsyncRead for Substream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } } -impl AsyncWrite for Substream { +impl tokio::io::AsyncWrite for Substream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { Pin::new(&mut self.stream).poll_write(cx, buf) } @@ -232,23 +218,23 @@ impl AsyncWrite for Substream { Pin::new(&mut self.stream).poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) } } -struct IncomingWorker { - inner: S, +struct IncomingWorker { + connection: yamux::Connection, sender: mpsc::Sender, shutdown_signal: ShutdownSignal, } -impl IncomingWorker -where S: Stream> + Unpin +impl IncomingWorker +where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */ { - pub fn new(stream: S, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { + pub fn new(connection: yamux::Connection, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { Self { - inner: stream, + connection, sender, shutdown_signal, } @@ -256,37 +242,55 @@ where S: Stream> + Unpin #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))] pub async fn run(mut self) { - let mut mux_stream = self.inner.take_until(&mut self.shutdown_signal); - while let Some(result) = mux_stream.next().await { - match result { - Ok(stream) => { - event!(Level::TRACE, "yamux::stream received {}", stream); - if self.sender.send(stream).await.is_err() { - debug!( - target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" - ); - break; + loop { + tokio::select! { + biased; + + _ = &mut self.shutdown_signal => { + let mut control = self.connection.control(); + if let Err(err) = control.close().await { + error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); } - }, - Err(err) => { - event!( + break + } + + result = self.connection.next_stream() => { + match result { + Ok(Some(stream)) => { + event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() { + debug!( + target: LOG_TARGET, + "Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed" + ); + break; + } + }, + Ok(None) =>{ + debug!( + target: LOG_TARGET, + "Incoming peer substream completed. IncomingWorker exiting" + ); + break; + } + Err(err) => { + event!( Level::ERROR, "Incoming peer substream task received an error because '{}'", err ); error!( - target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err - ); - break; - }, + target: LOG_TARGET, + "Incoming peer substream task received an error because '{}'", err + ); + break; + }, + } + } } } debug!(target: LOG_TARGET, "Incoming peer substream task is shutting down"); - self.sender.close_channel(); } } @@ -321,15 +325,12 @@ mod test { runtime, runtime::task, }; - use futures::{ - future, - io::{AsyncReadExt, AsyncWriteExt}, - StreamExt, - }; use std::{io, time::Duration}; use tari_test_utils::collect_stream; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn open_substream() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; @@ -344,7 +345,7 @@ mod test { substream.write_all(msg).await.unwrap(); substream.flush().await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); }); let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) @@ -356,13 +357,16 @@ mod test { .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))?; let mut buf = Vec::new(); - let _ = future::select(substream.read_to_end(&mut buf), listener.next()).await; + tokio::select! { + _ = substream.read_to_end(&mut buf) => {}, + _ = listener.next() => {}, + }; assert_eq!(buf, msg); Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn substream_count() { const NUM_SUBSTREAMS: usize = 10; let (dialer, listener) = MemorySocket::new_pair(); @@ -396,7 +400,7 @@ mod test { assert_eq!(listener.substream_count(), 0); } - #[runtime::test_basic] + #[runtime::test] async fn close() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; @@ -425,7 +429,7 @@ mod test { assert_eq!(buf, msg); // Close the substream and then try to write to it - substream.close().await?; + substream.shutdown().await?; let result = substream.write_all(b"ignored message").await; match result { @@ -436,7 +440,7 @@ mod test { Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn send_big_message() -> io::Result<()> { #[allow(non_upper_case_globals)] static MiB: usize = 1 << 20; @@ -457,7 +461,7 @@ mod test { let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); assert_eq!(buf.len(), MSG_LEN); assert_eq!(buf, vec![0xAAu8; MSG_LEN]); @@ -476,7 +480,7 @@ mod test { let msg = vec![0xAAu8; MSG_LEN]; substream.write_all(msg.as_slice()).await?; - substream.close().await?; + substream.shutdown().await?; drop(substream); assert_eq!(incoming.substream_count(), 0); diff --git a/comms/src/noise/config.rs b/comms/src/noise/config.rs index 30cab07c48..23e1d34ec3 100644 --- a/comms/src/noise/config.rs +++ b/comms/src/noise/config.rs @@ -31,11 +31,11 @@ use crate::{ }, peer_manager::NodeIdentity, }; -use futures::{AsyncRead, AsyncWrite}; use log::*; use snow::{self, params::NoiseParams}; use std::sync::Arc; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncWrite}; const LOG_TARGET: &str = "comms::noise"; pub(super) const NOISE_IX_PARAMETER: &str = "Noise_IX_25519_ChaChaPoly_BLAKE2b"; @@ -60,7 +60,7 @@ impl NoiseConfig { /// Upgrades the given socket to using the noise protocol. The upgraded socket and the peer's static key /// is returned. - #[tracing::instrument(name = "noise::upgrade_socket", skip(self, socket), err)] + #[tracing::instrument(name = "noise::upgrade_socket", skip(self, socket))] pub async fn upgrade_socket( &self, socket: TSocket, @@ -96,10 +96,15 @@ impl NoiseConfig { #[cfg(test)] mod test { use super::*; - use crate::{memsocket::MemorySocket, peer_manager::PeerFeatures, test_utils::node_identity::build_node_identity}; - use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt}; + use crate::{ + memsocket::MemorySocket, + peer_manager::PeerFeatures, + runtime, + test_utils::node_identity::build_node_identity, + }; + use futures::{future, FutureExt}; use snow::params::{BaseChoice, CipherChoice, DHChoice, HandshakePattern, HashChoice}; - use tokio::runtime::Runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; fn check_noise_params(config: &NoiseConfig) { assert_eq!(config.parameters.hash, HashChoice::Blake2b); @@ -118,39 +123,35 @@ mod test { assert_eq!(config.node_identity.public_key(), node_identity.public_key()); } - #[test] - fn upgrade_socket() { - let mut rt = Runtime::new().unwrap(); - + #[runtime::test] + async fn upgrade_socket() { let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config1 = NoiseConfig::new(node_identity1.clone()); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config2 = NoiseConfig::new(node_identity2.clone()); - rt.block_on(async move { - let (in_socket, out_socket) = MemorySocket::new_pair(); - let (mut socket_in, mut socket_out) = future::join( - config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), - config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), - ) - .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) - .await; - - let in_pubkey = socket_in.get_remote_public_key().unwrap(); - let out_pubkey = socket_out.get_remote_public_key().unwrap(); - - assert_eq!(&in_pubkey, node_identity2.public_key()); - assert_eq!(&out_pubkey, node_identity1.public_key()); - - let sample = b"Children of time"; - socket_in.write_all(sample).await.unwrap(); - socket_in.flush().await.unwrap(); - socket_in.close().await.unwrap(); - - let mut read_buf = Vec::with_capacity(16); - socket_out.read_to_end(&mut read_buf).await.unwrap(); - assert_eq!(read_buf, sample); - }); + let (in_socket, out_socket) = MemorySocket::new_pair(); + let (mut socket_in, mut socket_out) = future::join( + config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), + config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), + ) + .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) + .await; + + let in_pubkey = socket_in.get_remote_public_key().unwrap(); + let out_pubkey = socket_out.get_remote_public_key().unwrap(); + + assert_eq!(&in_pubkey, node_identity2.public_key()); + assert_eq!(&out_pubkey, node_identity1.public_key()); + + let sample = b"Children of time"; + socket_in.write_all(sample).await.unwrap(); + socket_in.flush().await.unwrap(); + socket_in.shutdown().await.unwrap(); + + let mut read_buf = Vec::with_capacity(16); + socket_out.read_to_end(&mut read_buf).await.unwrap(); + assert_eq!(read_buf, sample); } } diff --git a/comms/src/noise/socket.rs b/comms/src/noise/socket.rs index eaf02a60b0..33d35f89ca 100644 --- a/comms/src/noise/socket.rs +++ b/comms/src/noise/socket.rs @@ -26,27 +26,27 @@ //! Noise Socket +use crate::types::CommsPublicKey; use futures::ready; use log::*; use snow::{error::StateProblem, HandshakeState, TransportState}; use std::{ + cmp, convert::TryInto, io, pin::Pin, task::{Context, Poll}, }; -// use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::types::CommsPublicKey; -use futures::{io::Error, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; const LOG_TARGET: &str = "comms::noise::socket"; -const MAX_PAYLOAD_LENGTH: usize = u16::max_value() as usize; // 65535 +const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize; // 65535 // The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because // encrypted messages include a tag along with the payload. -const MAX_WRITE_BUFFER_LENGTH: usize = u16::max_value() as usize - 16; // 65519 +const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16; // 65519 /// Collection of buffers used for buffering data during the various read/write states of a /// NoiseSocket @@ -223,7 +223,12 @@ where TSocket: AsyncRead, { loop { - let n = ready!(socket.as_mut().poll_read(&mut context, &mut buf[*offset..]))?; + let mut read_buf = ReadBuf::new(&mut buf[*offset..]); + let prev_rem = read_buf.remaining(); + ready!(socket.as_mut().poll_read(&mut context, &mut read_buf))?; + let n = prev_rem + .checked_sub(read_buf.remaining()) + .expect("buffer underflow: prev_rem < read_buf.remaining()"); trace!( target: LOG_TARGET, "poll_read_exact: read {}/{} bytes", @@ -320,7 +325,7 @@ where TSocket: AsyncRead + Unpin decrypted_len, ref mut offset, } => { - let bytes_to_copy = ::std::cmp::min(decrypted_len as usize - *offset, buf.len()); + let bytes_to_copy = cmp::min(decrypted_len as usize - *offset, buf.len()); buf[..bytes_to_copy] .copy_from_slice(&self.buffers.read_decrypted[*offset..(*offset + bytes_to_copy)]); trace!( @@ -351,8 +356,11 @@ where TSocket: AsyncRead + Unpin impl AsyncRead for NoiseSocket where TSocket: AsyncRead + Unpin { - fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut [u8]) -> Poll> { - self.get_mut().poll_read(context, buf) + fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { + let slice = buf.initialize_unfilled(); + let n = futures::ready!(self.get_mut().poll_read(context, slice))?; + buf.advance(n); + Poll::Ready(Ok(())) } } @@ -501,8 +509,8 @@ where TSocket: AsyncWrite + Unpin self.get_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.socket).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.socket).poll_shutdown(cx) } } @@ -531,7 +539,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin target: LOG_TARGET, "Noise handshake failed because '{:?}'. Closing socket.", err ); - self.socket.close().await?; + self.socket.shutdown().await?; Err(err) }, } @@ -644,7 +652,6 @@ mod test { use futures::future::join; use snow::{params::NoiseParams, Builder, Error, Keypair}; use std::io; - use tokio::runtime::Runtime; async fn build_test_connection( ) -> Result<((Keypair, Handshake), (Keypair, Handshake)), Error> { @@ -707,7 +714,7 @@ mod test { dialer_socket.write_all(b" ").await?; dialer_socket.write_all(b"archive").await?; dialer_socket.flush().await?; - dialer_socket.close().await?; + dialer_socket.shutdown().await?; let mut buf = Vec::new(); listener_socket.read_to_end(&mut buf).await?; @@ -745,51 +752,60 @@ mod test { Ok(()) } - #[test] - fn u16_max_writes() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn u16_max_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH + 1]; + a.write_all(&buf_send).await?; + a.flush().await?; - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await?; - assert_eq!(&buf_receive[..], &buf_send[..]); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH + 1]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - Ok(()) - }) + Ok(()) } - #[test] - fn unexpected_eof() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn larger_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH * 2 + 1024]; + a.write_all(&buf_send).await?; + a.flush().await?; - a.socket.close().await.unwrap(); - drop(a); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH * 2 + 1024]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await.unwrap(); - assert_eq!(&buf_receive[..], &buf_send[..]); + Ok(()) + } + + #[runtime::test] + async fn unexpected_eof() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let err = b.read_exact(&mut buf_receive).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - Ok(()) - }) + let buf_send = [1; MAX_PAYLOAD_LENGTH]; + a.write_all(&buf_send).await?; + a.flush().await?; + + a.socket.shutdown().await.unwrap(); + drop(a); + + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; + b.read_exact(&mut buf_receive).await.unwrap(); + assert_eq!(&buf_receive[..], &buf_send[..]); + + let err = b.read_exact(&mut buf_receive).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) } } diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index 3828c82ab5..d7df51a2ea 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -337,7 +337,7 @@ mod test { peer } - #[runtime::test_basic] + #[runtime::test] async fn get_broadcast_identities() { // Create peer manager with random peers let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); @@ -446,7 +446,7 @@ mod test { assert_ne!(identities1, identities2); } - #[runtime::test_basic] + #[runtime::test] async fn calc_region_threshold() { let n = 5; // Create peer manager with random peers @@ -514,7 +514,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn closest_peers() { let n = 5; // Create peer manager with random peers @@ -548,7 +548,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn add_or_update_online_peer() { let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); let mut peer = create_test_peer(false, PeerFeatures::COMMUNICATION_NODE); diff --git a/comms/src/pipeline/builder.rs b/comms/src/pipeline/builder.rs index 40a38d10a3..9ae90fabec 100644 --- a/comms/src/pipeline/builder.rs +++ b/comms/src/pipeline/builder.rs @@ -24,8 +24,8 @@ use crate::{ message::{InboundMessage, OutboundMessage}, pipeline::SinkService, }; -use futures::channel::mpsc; use thiserror::Error; +use tokio::sync::mpsc; use tower::Service; const DEFAULT_MAX_CONCURRENT_TASKS: usize = 50; @@ -99,9 +99,7 @@ where TOutSvc: Service + Clone + Send + 'static, TInSvc: Service + Clone + Send + 'static, { - fn build_outbound( - &mut self, - ) -> Result, TOutSvc>, PipelineBuilderError> { + fn build_outbound(&mut self) -> Result, PipelineBuilderError> { let (out_sender, out_receiver) = mpsc::channel(self.outbound_buffer_size); let in_receiver = self @@ -137,9 +135,9 @@ where } } -pub struct OutboundPipelineConfig { +pub struct OutboundPipelineConfig { /// Messages read from this stream are passed to the pipeline - pub in_receiver: TInStream, + pub in_receiver: mpsc::Receiver, /// Receiver of `OutboundMessage`s coming from the pipeline pub out_receiver: mpsc::Receiver, /// The pipeline (`tower::Service`) to run for each in_stream message @@ -149,7 +147,7 @@ pub struct OutboundPipelineConfig { pub struct Config { pub max_concurrent_inbound_tasks: usize, pub inbound: TInSvc, - pub outbound: OutboundPipelineConfig, TOutSvc>, + pub outbound: OutboundPipelineConfig, } #[derive(Debug, Error)] diff --git a/comms/src/pipeline/inbound.rs b/comms/src/pipeline/inbound.rs index 0b2116bc37..1f135640a7 100644 --- a/comms/src/pipeline/inbound.rs +++ b/comms/src/pipeline/inbound.rs @@ -21,10 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::bounded_executor::BoundedExecutor; -use futures::{future::FusedFuture, stream::FusedStream, Stream, StreamExt}; +use futures::future::FusedFuture; use log::*; use std::fmt::Display; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::inbound"; @@ -33,22 +34,26 @@ const LOG_TARGET: &str = "comms::pipeline::inbound"; /// The difference between this can ServiceExt::call_all is /// that ServicePipeline doesn't keep the result of the service /// call and that it spawns a task for each incoming item. -pub struct Inbound { +pub struct Inbound { executor: BoundedExecutor, service: TSvc, - stream: TStream, + stream: mpsc::Receiver, shutdown_signal: ShutdownSignal, } -impl Inbound +impl Inbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TSvc: Service + Clone + Send + 'static, + TMsg: Send + 'static, + TSvc: Service + Clone + Send + 'static, TSvc::Error: Display + Send, TSvc::Future: Send, { - pub fn new(executor: BoundedExecutor, stream: TStream, service: TSvc, shutdown_signal: ShutdownSignal) -> Self { + pub fn new( + executor: BoundedExecutor, + stream: mpsc::Receiver, + service: TSvc, + shutdown_signal: ShutdownSignal, + ) -> Self { Self { executor, service, @@ -59,7 +64,7 @@ where } pub async fn run(mut self) { - while let Some(item) = self.stream.next().await { + while let Some(item) = self.stream.recv().await { // Check if the shutdown signal has been triggered. // If there are messages in the stream, drop them. Otherwise the stream is empty, // it will return None and the while loop will end. @@ -100,21 +105,25 @@ where mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, future, stream}; + use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; - use tari_test_utils::collect_stream; - use tokio::{runtime::Handle, time}; + use tari_test_utils::collect_recv; + use tokio::{sync::mpsc, time}; use tower::service_fn; - #[runtime::test_basic] + #[runtime::test] async fn run() { let items = vec![1, 2, 3, 4, 5, 6]; - let stream = stream::iter(items.clone()).fuse(); + let (tx, mut stream) = mpsc::channel(items.len()); + for i in items.clone() { + tx.send(i).await.unwrap(); + } + stream.close(); - let (mut out_tx, mut out_rx) = mpsc::channel(items.len()); + let (out_tx, mut out_rx) = mpsc::channel(items.len()); - let executor = Handle::current(); + let executor = runtime::current(); let shutdown = Shutdown::new(); let pipeline = Inbound::new( BoundedExecutor::new(executor.clone(), 1), @@ -125,9 +134,10 @@ mod test { }), shutdown.to_signal(), ); + let spawned_task = executor.spawn(pipeline.run()); - let received = collect_stream!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); + let received = collect_recv!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); assert!(received.iter().all(|i| items.contains(i))); // Check that this task ends because the stream has closed diff --git a/comms/src/pipeline/mod.rs b/comms/src/pipeline/mod.rs index 1039374c65..a2a057354b 100644 --- a/comms/src/pipeline/mod.rs +++ b/comms/src/pipeline/mod.rs @@ -44,7 +44,4 @@ pub(crate) use inbound::Inbound; mod outbound; pub(crate) use outbound::Outbound; -mod translate_sink; -pub use translate_sink::TranslateSink; - pub type PipelineError = anyhow::Error; diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index c860166ad0..54facdb92b 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -25,34 +25,33 @@ use crate::{ pipeline::builder::OutboundPipelineConfig, protocol::messaging::MessagingRequest, }; -use futures::{channel::mpsc, future, future::Either, stream::FusedStream, SinkExt, Stream, StreamExt}; +use futures::future::Either; use log::*; use std::fmt::Display; -use tokio::runtime; +use tokio::{runtime, sync::mpsc}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::outbound"; -pub struct Outbound { +pub struct Outbound { /// Executor used to spawn a pipeline for each received item on the stream executor: runtime::Handle, /// Outbound pipeline configuration containing the pipeline and it's in and out streams - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, /// Request sender for Messaging messaging_request_tx: mpsc::Sender, } -impl Outbound +impl Outbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TPipeline: Service + Clone + Send + 'static, + TItem: Send + 'static, + TPipeline: Service + Clone + Send + 'static, TPipeline::Error: Display + Send, TPipeline::Future: Send, { pub fn new( executor: runtime::Handle, - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, messaging_request_tx: mpsc::Sender, ) -> Self { Self { @@ -64,10 +63,13 @@ where pub async fn run(mut self) { loop { - let either = future::select(self.config.in_receiver.next(), self.config.out_receiver.next()).await; + let either = tokio::select! { + next = self.config.in_receiver.recv() => Either::Left(next), + next = self.config.out_receiver.recv() => Either::Right(next) + }; match either { // Pipeline IN received a message. Spawn a new task for the pipeline - Either::Left((Some(msg), _)) => { + Either::Left(Some(msg)) => { let pipeline = self.config.pipeline.clone(); self.executor.spawn(async move { if let Err(err) = pipeline.oneshot(msg).await { @@ -76,7 +78,7 @@ where }); }, // Pipeline IN channel closed - Either::Left((None, _)) => { + Either::Left(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the in channel closed" @@ -84,7 +86,7 @@ where break; }, // Pipeline OUT received a message - Either::Right((Some(out_msg), _)) => { + Either::Right(Some(out_msg)) => { if self.messaging_request_tx.is_closed() { // MessagingRequest channel closed break; @@ -92,7 +94,7 @@ where self.send_messaging_request(out_msg).await; }, // Pipeline OUT channel closed - Either::Right((None, _)) => { + Either::Right(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the out channel closed" @@ -117,19 +119,22 @@ where #[cfg(test)] mod test { use super::*; - use crate::{pipeline::SinkService, runtime}; + use crate::{pipeline::SinkService, runtime, utils}; use bytes::Bytes; - use futures::stream; use std::time::Duration; - use tari_test_utils::{collect_stream, unpack_enum}; + use tari_test_utils::{collect_recv, unpack_enum}; use tokio::{runtime::Handle, time}; - #[runtime::test_basic] + #[runtime::test] async fn run() { const NUM_ITEMS: usize = 10; - let items = - (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))); - let stream = stream::iter(items).fuse(); + let (tx, in_receiver) = mpsc::channel(NUM_ITEMS); + utils::mpsc::send_all( + &tx, + (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))), + ) + .await + .unwrap(); let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); let (msg_tx, mut msg_rx) = mpsc::channel(NUM_ITEMS); let executor = Handle::current(); @@ -137,7 +142,7 @@ mod test { let pipeline = Outbound::new( executor.clone(), OutboundPipelineConfig { - in_receiver: stream, + in_receiver, out_receiver: out_rx, pipeline: SinkService::new(out_tx), }, @@ -146,7 +151,8 @@ mod test { let spawned_task = executor.spawn(pipeline.run()); - let requests = collect_stream!(msg_rx, take = NUM_ITEMS, timeout = Duration::from_millis(5)); + msg_rx.close(); + let requests = collect_recv!(msg_rx, timeout = Duration::from_millis(5)); for req in requests { unpack_enum!(MessagingRequest::SendMessage(_o) = req); } diff --git a/comms/src/pipeline/sink.rs b/comms/src/pipeline/sink.rs index a455aaf320..1b524e92a5 100644 --- a/comms/src/pipeline/sink.rs +++ b/comms/src/pipeline/sink.rs @@ -21,8 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::PipelineError; -use futures::{future::BoxFuture, task::Context, FutureExt, Sink, SinkExt}; -use std::{pin::Pin, task::Poll}; +use futures::{future::BoxFuture, task::Context, FutureExt}; +use std::task::Poll; use tower::Service; /// A service which forwards and messages it gets to the given Sink @@ -35,22 +35,24 @@ impl SinkService { } } -impl Service for SinkService -where - T: Send + 'static, - TSink: Sink + Unpin + Clone + Send + 'static, - TSink::Error: Into + Send + 'static, +impl Service for SinkService> +where T: Send + 'static { type Error = PipelineError; type Future = BoxFuture<'static, Result>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, item: T) -> Self::Future { - let mut sink = self.0.clone(); - async move { sink.send(item).await.map_err(Into::into) }.boxed() + let sink = self.0.clone(); + async move { + sink.send(item) + .await + .map_err(|_| anyhow::anyhow!("sink closed in sink service")) + } + .boxed() } } diff --git a/comms/src/pipeline/translate_sink.rs b/comms/src/pipeline/translate_sink.rs index 6a2bcad56a..606c038299 100644 --- a/comms/src/pipeline/translate_sink.rs +++ b/comms/src/pipeline/translate_sink.rs @@ -93,9 +93,10 @@ where F: FnMut(I) -> Option mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, SinkExt, StreamExt}; + use futures::{SinkExt, StreamExt}; + use tokio::sync::mpsc; - #[runtime::test_basic] + #[runtime::test] async fn check_translates() { let (tx, mut rx) = mpsc::channel(1); diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 2c4eba1db5..b468a946f5 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -20,26 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - compat::IoCompat, connection_manager::ConnectionDirection, message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, protocol::{NodeNetworkInfo, ProtocolError, ProtocolId, ProtocolNegotiation}, }; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use log::*; use prost::Message; use std::{io, time::Duration}; use thiserror::Error; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing; pub static IDENTITY_PROTOCOL: ProtocolId = ProtocolId::from_static(b"t/identity/1.0"); const LOG_TARGET: &str = "comms::protocol::identity"; -#[tracing::instrument(skip(socket, our_supported_protocols), err)] +#[tracing::instrument(skip(socket, our_supported_protocols))] pub async fn identity_exchange<'p, TSocket, P>( node_identity: &NodeIdentity, direction: ConnectionDirection, @@ -79,7 +81,7 @@ where debug_assert_eq!(proto, IDENTITY_PROTOCOL); // Create length-delimited frame codec - let framed = Framed::new(IoCompat::new(socket), LengthDelimitedCodec::new()); + let framed = Framed::new(socket, LengthDelimitedCodec::new()); let (mut sink, mut stream) = framed.split(); let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect(); @@ -136,8 +138,8 @@ pub enum IdentityProtocolError { ProtocolVersionMismatch, } -impl From for IdentityProtocolError { - fn from(_: time::Elapsed) -> Self { +impl From for IdentityProtocolError { + fn from(_: time::error::Elapsed) -> Self { IdentityProtocolError::Timeout } } @@ -172,7 +174,7 @@ mod test { }; use futures::{future, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn identity_exchange() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); @@ -221,7 +223,7 @@ mod test { assert_eq!(identity2.addresses, vec![node_identity2.public_address().to_vec()]); } - #[runtime::test_basic] + #[runtime::test] async fn fail_cases() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index 6f078cec23..91d0e786ba 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -20,10 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{connection_manager::PeerConnectionError, peer_manager::PeerManagerError, protocol::ProtocolError}; -use futures::channel::mpsc; +use crate::{ + connection_manager::PeerConnectionError, + message::OutboundMessage, + peer_manager::PeerManagerError, + protocol::ProtocolError, +}; use std::io; use thiserror::Error; +use tokio::sync::mpsc; #[derive(Debug, Error)] pub enum InboundMessagingError { @@ -46,7 +51,7 @@ pub enum MessagingProtocolError { #[error("IO Error: {0}")] Io(#[from] io::Error), #[error("Sender error: {0}")] - SenderError(#[from] mpsc::SendError), + SenderError(#[from] mpsc::error::SendError), #[error("Stream closed due to inactivity")] Inactivity, } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 241a152a5b..f216ddd04e 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -34,8 +34,8 @@ use crate::{ runtime, runtime::task, }; -use futures::channel::mpsc; use std::fmt; +use tokio::sync::mpsc; use tower::Service; /// Buffer size for inbound messages from _all_ peers. This should be large enough to buffer quite a few incoming diff --git a/comms/src/protocol/messaging/forward.rs b/comms/src/protocol/messaging/forward.rs new file mode 100644 index 0000000000..ce035e8fb8 --- /dev/null +++ b/comms/src/protocol/messaging/forward.rs @@ -0,0 +1,110 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Copied from futures rs + +use futures::{ + future::{FusedFuture, Future}, + ready, + stream::{Fuse, StreamExt}, + task::{Context, Poll}, + Sink, + Stream, + TryStream, +}; +use pin_project::pin_project; +use std::pin::Pin; + +/// Future for the [`forward`](super::StreamExt::forward) method. +#[pin_project(project = ForwardProj)] +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Forward { + #[pin] + sink: Option, + #[pin] + stream: Fuse, + buffered_item: Option, +} + +impl Forward +where St: TryStream +{ + pub(crate) fn new(stream: St, sink: Si) -> Self { + Self { + sink: Some(sink), + stream: stream.fuse(), + buffered_item: None, + } + } +} + +impl FusedFuture for Forward +where + Si: Sink, + St: Stream>, +{ + fn is_terminated(&self) -> bool { + self.sink.is_none() + } +} + +impl Future for Forward +where + Si: Sink, + St: Stream>, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let ForwardProj { + mut sink, + mut stream, + buffered_item, + } = self.project(); + let mut si = sink.as_mut().as_pin_mut().expect("polled `Forward` after completion"); + + loop { + // If we've got an item buffered already, we need to write it to the + // sink before we can do anything else + if buffered_item.is_some() { + ready!(si.as_mut().poll_ready(cx))?; + si.as_mut().start_send(buffered_item.take().unwrap())?; + } + + match stream.as_mut().poll_next(cx)? { + Poll::Ready(Some(item)) => { + *buffered_item = Some(item); + }, + Poll::Ready(None) => { + ready!(si.poll_close(cx))?; + sink.set(None); + return Poll::Ready(Ok(())); + }, + Poll::Pending => { + ready!(si.poll_flush(cx))?; + return Poll::Pending; + }, + } + } + } +} diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs index 643b07fc45..e65f81f8ec 100644 --- a/comms/src/protocol/messaging/inbound.rs +++ b/comms/src/protocol/messaging/inbound.rs @@ -21,15 +21,18 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - common::rate_limit::RateLimit, message::InboundMessage, peer_manager::NodeId, protocol::messaging::{MessagingEvent, MessagingProtocol}, + rate_limit::RateLimit, }; -use futures::{channel::mpsc, future::Either, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt}; use log::*; use std::{sync::Arc, time::Duration}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; @@ -61,7 +64,7 @@ impl InboundMessaging { } } - pub async fn run(mut self, socket: S) + pub async fn run(self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { let peer = &self.peer; debug!( @@ -70,48 +73,40 @@ impl InboundMessaging { peer.short_str() ); - let (mut sink, stream) = MessagingProtocol::framed(socket).split(); - - if let Err(err) = sink.close().await { - debug!( - target: LOG_TARGET, - "Error closing sink half for peer `{}`: {}", - peer.short_str(), - err - ); - } - let stream = stream.rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); + let stream = + MessagingProtocol::framed(socket).rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); - let mut stream = match self.inactivity_timeout { - Some(timeout) => Either::Left(tokio::stream::StreamExt::timeout(stream, timeout)), + let stream = match self.inactivity_timeout { + Some(timeout) => Either::Left(tokio_stream::StreamExt::timeout(stream, timeout)), None => Either::Right(stream.map(Ok)), }; + tokio::pin!(stream); while let Some(result) = stream.next().await { match result { Ok(Ok(raw_msg)) => { - let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.clone().freeze()); + let msg_len = raw_msg.len(); + let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); debug!( target: LOG_TARGET, "Received message {} from peer '{}' ({} bytes)", inbound_msg.tag, peer.short_str(), - raw_msg.len() + msg_len ); let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { + let tag = err.0.tag; warn!( target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", + "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", + tag, peer.short_str(), - err ); - if err.is_disconnected() { - break; - } + break; } let _ = self.messaging_events_tx.send(Arc::new(event)); diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 88fca6af05..1fa347b3a2 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -27,9 +27,9 @@ mod extension; pub use extension::MessagingProtocolExtension; mod error; +mod forward; mod inbound; mod outbound; - mod protocol; pub use protocol::{ MessagingEvent, diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 9d47895338..5ba1d36ec0 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -29,14 +29,13 @@ use crate::{ peer_manager::NodeId, protocol::messaging::protocol::MESSAGING_PROTOCOL, }; -use futures::{channel::mpsc, future::Either, SinkExt, StreamExt}; -use log::*; +use futures::{future::Either, StreamExt, TryStreamExt}; use std::{ io, time::{Duration, Instant}, }; -use tokio::stream as tokio_stream; -use tracing::{event, span, Instrument, Level}; +use tokio::sync::mpsc as tokiompsc; +use tracing::{debug, error, event, span, Instrument, Level}; const LOG_TARGET: &str = "comms::protocol::messaging::outbound"; /// The number of times to retry sending a failed message before publishing a SendMessageFailed event. @@ -46,8 +45,8 @@ const MAX_SEND_RETRIES: usize = 1; pub struct OutboundMessaging { connectivity: ConnectivityRequester, - request_rx: mpsc::UnboundedReceiver, - messaging_events_tx: mpsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, peer_node_id: NodeId, inactivity_timeout: Option, } @@ -55,8 +54,8 @@ pub struct OutboundMessaging { impl OutboundMessaging { pub fn new( connectivity: ConnectivityRequester, - messaging_events_tx: mpsc::Sender, - request_rx: mpsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, peer_node_id: NodeId, inactivity_timeout: Option, ) -> Self { @@ -82,7 +81,7 @@ impl OutboundMessaging { self.peer_node_id.short_str() ); let peer_node_id = self.peer_node_id.clone(); - let mut messaging_events_tx = self.messaging_events_tx.clone(); + let messaging_events_tx = self.messaging_events_tx.clone(); match self.run_inner().await { Ok(_) => { event!( @@ -107,9 +106,18 @@ impl OutboundMessaging { peer_node_id.short_str() ); }, - Err(err) => { - event!(Level::ERROR, "Outbound messaging substream failed:{}", err); - debug!(target: LOG_TARGET, "Outbound messaging substream failed: {}", err); + Err(err) => match err { + MessagingProtocolError::PeerDialFailed => { + debug!( + target: LOG_TARGET, + "Outbound messaging substream failed due to a dial fail. Most likely the peer is offline \ + or doesn't exist: {}", + err + ); + }, + _ => { + error!(target: LOG_TARGET, "Outbound messaging substream failed:{}", err); + }, }, } @@ -131,7 +139,6 @@ impl OutboundMessaging { break substream; }, Err(err) => { - event!(Level::ERROR, "Error establishing messaging protocol"); if attempts >= MAX_SEND_RETRIES { debug!( target: LOG_TARGET, @@ -265,7 +272,7 @@ impl OutboundMessaging { ); let substream = substream.stream; - let (sink, _) = MessagingProtocol::framed(substream).split(); + let framed = MessagingProtocol::framed(substream); let Self { request_rx, @@ -273,30 +280,30 @@ impl OutboundMessaging { .. } = self; + // Convert unbounded channel to a stream + let stream = futures::stream::unfold(request_rx, |mut rx| async move { + let v = rx.recv().await; + v.map(|v| (v, rx)) + }); + let stream = match inactivity_timeout { Some(timeout) => { - let s = tokio_stream::StreamExt::timeout(request_rx, timeout).map(|r| match r { - Ok(s) => Ok(s), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - MessagingProtocolError::Inactivity, - )), - }); + let s = tokio_stream::StreamExt::timeout(stream, timeout) + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, MessagingProtocolError::Inactivity)); Either::Left(s) }, - None => Either::Right(request_rx.map(Ok)), + None => Either::Right(stream.map(Ok)), }; - stream - .map(|msg| { - msg.map(|mut out_msg| { - event!(Level::DEBUG, "Message buffered for sending {}", out_msg); - out_msg.reply_success(); - out_msg.body - }) + let stream = stream.map(|msg| { + msg.map(|mut out_msg| { + event!(Level::DEBUG, "Message buffered for sending {}", out_msg); + out_msg.reply_success(); + out_msg.body }) - .forward(sink) - .await?; + }); + + super::forward::Forward::new(stream, framed).await?; debug!( target: LOG_TARGET, @@ -310,7 +317,7 @@ impl OutboundMessaging { // Close the request channel so that we can read all the remaining messages and flush them // to a failed event self.request_rx.close(); - while let Some(mut out_msg) = self.request_rx.next().await { + while let Some(mut out_msg) = self.request_rx.recv().await { out_msg.reply_fail(reason); let _ = self .messaging_events_tx diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index 1f6fe029ab..988b4ada21 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -22,7 +22,6 @@ use super::error::MessagingProtocolError; use crate::{ - compat::IoCompat, connectivity::{ConnectivityEvent, ConnectivityRequester}, framing, message::{InboundMessage, MessageTag, OutboundMessage}, @@ -36,7 +35,6 @@ use crate::{ runtime::task, }; use bytes::Bytes; -use futures::{channel::mpsc, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{ collections::{hash_map::Entry, HashMap}, @@ -46,7 +44,10 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; const LOG_TARGET: &str = "comms::protocol::messaging"; @@ -106,13 +107,13 @@ impl fmt::Display for MessagingEvent { pub struct MessagingProtocol { config: MessagingConfig, connectivity: ConnectivityRequester, - proto_notification: Fuse>>, + proto_notification: mpsc::Receiver>, active_queues: HashMap>, - request_rx: Fuse>, + request_rx: mpsc::Receiver, messaging_events_tx: MessagingEventSender, inbound_message_tx: mpsc::Sender, internal_messaging_event_tx: mpsc::Sender, - internal_messaging_event_rx: Fuse>, + internal_messaging_event_rx: mpsc::Receiver, shutdown_signal: ShutdownSignal, complete_trigger: Shutdown, } @@ -133,11 +134,11 @@ impl MessagingProtocol { Self { config, connectivity, - proto_notification: proto_notification.fuse(), - request_rx: request_rx.fuse(), + proto_notification, + request_rx, active_queues: Default::default(), messaging_events_tx, - internal_messaging_event_rx: internal_messaging_event_rx.fuse(), + internal_messaging_event_rx, internal_messaging_event_tx, inbound_message_tx, shutdown_signal, @@ -151,15 +152,15 @@ impl MessagingProtocol { pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut connectivity_events = self.connectivity.get_event_subscription(); loop { - futures::select! { - event = self.internal_messaging_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_messaging_event_rx.recv() => { self.handle_internal_messaging_event(event).await; }, - req = self.request_rx.select_next_some() => { + Some(req) = self.request_rx.recv() => { if let Err(err) = self.handle_request(req).await { error!( target: LOG_TARGET, @@ -169,17 +170,17 @@ impl MessagingProtocol { } }, - event = connectivity_events.select_next_some() => { + event = connectivity_events.recv() => { if let Ok(event) = event { self.handle_connectivity_event(&event); } } - notification = self.proto_notification.select_next_some() => { + Some(notification) = self.proto_notification.recv() => { self.handle_protocol_notification(notification).await; }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "MessagingProtocol is shutting down because the shutdown signal was triggered"); break; } @@ -188,7 +189,7 @@ impl MessagingProtocol { } #[inline] - pub fn framed(socket: TSubstream) -> Framed, LengthDelimitedCodec> + pub fn framed(socket: TSubstream) -> Framed where TSubstream: AsyncRead + AsyncWrite + Unpin { framing::canonical(socket, MAX_FRAME_LENGTH) } @@ -198,11 +199,9 @@ impl MessagingProtocol { #[allow(clippy::single_match)] match event { PeerConnectionWillClose(node_id, _) => { - // If the peer connection will close, cut off the pipe to send further messages. - // Any messages in the channel will be sent (hopefully) before the connection is disconnected. - if let Some(sender) = self.active_queues.remove(node_id) { - sender.close_channel(); - } + // If the peer connection will close, cut off the pipe to send further messages by dropping the sender. + // Any messages in the channel may be sent before the connection is disconnected. + let _ = self.active_queues.remove(node_id); }, _ => {}, } @@ -263,7 +262,7 @@ impl MessagingProtocol { let sender = Self::spawn_outbound_handler( self.connectivity.clone(), self.internal_messaging_event_tx.clone(), - peer_node_id.clone(), + peer_node_id, self.config.inactivity_timeout, ); break entry.insert(sender); @@ -273,7 +272,7 @@ impl MessagingProtocol { debug!(target: LOG_TARGET, "Sending message {}", out_msg); let tag = out_msg.tag; - match sender.send(out_msg).await { + match sender.send(out_msg) { Ok(_) => { debug!(target: LOG_TARGET, "Message ({}) dispatched to outbound handler", tag,); Ok(()) @@ -294,7 +293,7 @@ impl MessagingProtocol { peer_node_id: NodeId, inactivity_timeout: Option, ) -> mpsc::UnboundedSender { - let (msg_tx, msg_rx) = mpsc::unbounded(); + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); let outbound_messaging = OutboundMessaging::new(connectivity, events_tx, msg_rx, peer_node_id, inactivity_timeout); task::spawn(outbound_messaging.run()); diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index e954af31f2..5d066eb4c2 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -49,18 +49,16 @@ use crate::{ types::{CommsDatabase, CommsPublicKey}, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - SinkExt, - StreamExt, -}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::rngs::OsRng; use std::{io, sync::Arc, time::Duration}; use tari_crypto::keys::PublicKey; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, time}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1"); @@ -110,9 +108,9 @@ async fn spawn_messaging_protocol() -> ( ) } -#[runtime::test_basic] +#[runtime::test] async fn new_inbound_substream_handling() { - let (peer_manager, _, _, mut proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let expected_node_id = node_id::random(); @@ -148,7 +146,7 @@ async fn new_inbound_substream_handling() { framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); - let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.next()) + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() .unwrap(); @@ -156,19 +154,18 @@ async fn new_inbound_substream_handling() { assert_eq!(in_msg.body, TEST_MSG1); let expected_tag = in_msg.tag; - let event = time::timeout(Duration::from_secs(5), events_rx.next()) + let event = time::timeout(Duration::from_secs(5), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_request() { - let (_, node_identity, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, node_identity, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -192,9 +189,9 @@ async fn send_message_request() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_dial_failed() { - let (_, _, conn_manager_mock, _, mut request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); let out_msg = OutboundMessage::new(node_id, TEST_MSG1.clone()); @@ -202,7 +199,7 @@ async fn send_message_dial_failed() { // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); - let event = event_tx.next().await.unwrap().unwrap(); + let event = event_tx.recv().await.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); assert_eq!(out_msg.tag, expected_out_msg_tag); @@ -212,7 +209,7 @@ async fn send_message_dial_failed() { assert!(calls.iter().all(|evt| evt.starts_with("DialPeer"))); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_substream_bulk_failure() { const NUM_MSGS: usize = 10; let (_, node_identity, conn_manager_mock, _, mut request_tx, _, mut events_rx, _shutdown) = @@ -258,19 +255,18 @@ async fn send_message_substream_bulk_failure() { } // Check that the outbound handler closed - let event = time::timeout(Duration::from_secs(10), events_rx.next()) + let event = time::timeout(Duration::from_secs(10), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); assert_eq!(node_id, peer_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests() { const NUM_MSGS: usize = 100; - let (_, _, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -315,10 +311,10 @@ async fn many_concurrent_send_message_requests() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests_that_fail() { const NUM_MSGS: usize = 100; - let (_, _, _, _, mut request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, _, _, request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let node_id2 = node_id::random(); @@ -339,10 +335,9 @@ async fn many_concurrent_send_message_requests_that_fail() { } // Check that we got message success events - let events = collect_stream!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let events = collect_recv!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert_eq!(events.len(), NUM_MSGS); for event in events { - let event = event.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); // Assert that each tag is emitted only once @@ -357,7 +352,7 @@ async fn many_concurrent_send_message_requests_that_fail() { assert_eq!(msg_tags.len(), 0); } -#[runtime::test_basic] +#[runtime::test] async fn inactivity_timeout() { let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); let (inbound_msg_tx, mut inbound_msg_rx) = mpsc::channel(5); @@ -381,13 +376,13 @@ async fn inactivity_timeout() { let mut framed = MessagingProtocol::framed(socket_out); for _ in 0..5u8 { framed.send(Bytes::from_static(b"some message")).await.unwrap(); - time::delay_for(Duration::from_millis(1)).await; + time::sleep(Duration::from_millis(1)).await; } - time::delay_for(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(10)).await; let err = framed.send(Bytes::from_static(b"another message")).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - let _ = collect_stream!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); + let _ = collect_recv!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); } diff --git a/comms/src/protocol/negotiation.rs b/comms/src/protocol/negotiation.rs index 5178d790e6..326910dc99 100644 --- a/comms/src/protocol/negotiation.rs +++ b/comms/src/protocol/negotiation.rs @@ -23,9 +23,9 @@ use super::{ProtocolError, ProtocolId}; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use log::*; use std::convert::TryInto; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const LOG_TARGET: &str = "comms::connection_manager::protocol"; @@ -204,7 +204,7 @@ mod test { use futures::future; use tari_test_utils::unpack_enum; - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -229,7 +229,7 @@ mod test { assert_eq!(out_proto.unwrap(), ProtocolId::from_static(b"A")); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -254,7 +254,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolOutboundNegotiationFailed(_s) = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_max_rounds() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -279,7 +279,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -300,7 +300,7 @@ mod test { out_proto.unwrap(); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 936ef15f34..14d253196e 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -32,8 +32,8 @@ use crate::{ }, Substream, }; -use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +use tokio::sync::mpsc; pub type ProtocolNotificationTx = mpsc::Sender>; pub type ProtocolNotificationRx = mpsc::Receiver>; @@ -143,7 +143,6 @@ impl ProtocolExtension for Protocols { mod test { use super::*; use crate::runtime; - use futures::StreamExt; use tari_test_utils::unpack_enum; #[test] @@ -160,7 +159,7 @@ mod test { assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } - #[runtime::test_basic] + #[runtime::test] async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; @@ -172,12 +171,12 @@ mod test { .await .unwrap(); - let notification = rx.next().await.unwrap(); + let notification = rx.recv().await.unwrap(); unpack_enum!(ProtocolEvent::NewInboundSubstream(peer_id, _s) = notification.event); assert_eq!(peer_id, NodeId::new()); } - #[runtime::test_basic] + #[runtime::test] async fn notify_fail_not_registered() { let mut protocols = Protocols::<()>::new(); diff --git a/comms/src/protocol/rpc/body.rs b/comms/src/protocol/rpc/body.rs index 6079508729..e563d6483e 100644 --- a/comms/src/protocol/rpc/body.rs +++ b/comms/src/protocol/rpc/body.rs @@ -27,7 +27,6 @@ use crate::{ }; use bytes::BytesMut; use futures::{ - channel::mpsc, ready, stream::BoxStream, task::{Context, Poll}, @@ -37,6 +36,7 @@ use futures::{ use pin_project::pin_project; use prost::bytes::Buf; use std::{fmt, marker::PhantomData, pin::Pin}; +use tokio::sync::mpsc; pub trait IntoBody { fn into_body(self) -> Body; @@ -205,8 +205,8 @@ impl Buf for BodyBytes { self.0.as_ref().map(Buf::remaining).unwrap_or(0) } - fn bytes(&self) -> &[u8] { - self.0.as_ref().map(Buf::bytes).unwrap_or(&[]) + fn chunk(&self) -> &[u8] { + self.0.as_ref().map(Buf::chunk).unwrap_or(&[]) } fn advance(&mut self, cnt: usize) { @@ -227,7 +227,7 @@ impl Streaming { } pub fn empty() -> Self { - let (_, rx) = mpsc::channel(0); + let (_, rx) = mpsc::channel(1); Self { inner: rx } } @@ -240,7 +240,7 @@ impl Stream for Streaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(result) => { let result = result.map(|msg| msg.to_encoded_bytes().into()); Poll::Ready(Some(result)) @@ -275,7 +275,7 @@ impl Stream for ClientStreaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(Ok(resp)) => { // The streaming protocol dictates that an empty finish flag MUST be sent to indicate a terminated // stream. This empty response need not be emitted to downsteam consumers. @@ -298,7 +298,7 @@ mod test { use futures::{stream, StreamExt}; use prost::Message; - #[runtime::test_basic] + #[runtime::test] async fn single_body() { let mut body = Body::single(123u32.to_encoded_bytes()); let bytes = body.next().await.unwrap().unwrap(); @@ -306,7 +306,7 @@ mod test { assert_eq!(u32::decode(bytes).unwrap(), 123u32); } - #[runtime::test_basic] + #[runtime::test] async fn streaming_body() { let body = Body::streaming(stream::repeat(Bytes::new()).map(Ok).take(10)); let body = body.collect::>().await; diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 2806366f9f..6c994fcaf8 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -41,11 +41,8 @@ use crate::{ }; use bytes::Bytes; use futures::{ - channel::{mpsc, oneshot}, future::{BoxFuture, Either}, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, SinkExt, StreamExt, @@ -58,9 +55,15 @@ use std::{ fmt, future::Future, marker::PhantomData, + sync::Arc, time::{Duration, Instant}, }; -use tokio::time; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot, Mutex}, + time, +}; use tower::{Service, ServiceExt}; use tracing::{event, span, Instrument, Level}; @@ -82,14 +85,16 @@ impl RpcClient { TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (request_tx, request_rx) = mpsc::channel(1); - let connector = ClientConnector::new(request_tx); + let shutdown = Shutdown::new(); + let shutdown_signal = shutdown.to_signal(); + let connector = ClientConnector::new(request_tx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); let tracing_id = tracing::Span::current().id(); task::spawn({ let span = span!(Level::TRACE, "start_rpc_worker"); span.follows_from(tracing_id); - RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name) + RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name, shutdown_signal) .run() .instrument(span) }); @@ -110,7 +115,7 @@ impl RpcClient { let request = BaseRequest::new(method.into(), req_bytes.into()); let mut resp = self.call_inner(request).await?; - let resp = resp.next().await.ok_or(RpcError::ServerClosedRequest)??; + let resp = resp.recv().await.ok_or(RpcError::ServerClosedRequest)??; let resp = R::decode(resp.into_message())?; Ok(resp) @@ -132,8 +137,8 @@ impl RpcClient { } /// Close the RPC session. Any subsequent calls will error. - pub fn close(&mut self) { - self.connector.close() + pub async fn close(&mut self) { + self.connector.close().await; } pub fn is_connected(&self) -> bool { @@ -269,15 +274,20 @@ impl Default for RpcClientConfig { #[derive(Clone)] pub struct ClientConnector { inner: mpsc::Sender, + shutdown: Arc>, } impl ClientConnector { - pub(self) fn new(sender: mpsc::Sender) -> Self { - Self { inner: sender } + pub(self) fn new(sender: mpsc::Sender, shutdown: Shutdown) -> Self { + Self { + inner: sender, + shutdown: Arc::new(Mutex::new(shutdown)), + } } - pub fn close(&mut self) { - self.inner.close_channel(); + pub async fn close(&mut self) { + let mut lock = self.shutdown.lock().await; + lock.trigger(); } pub async fn get_last_request_latency(&mut self) -> Result, RpcError> { @@ -317,13 +327,13 @@ impl Service> for ClientConnector { type Future = BoxFuture<'static, Result>; type Response = mpsc::Receiver, RpcStatus>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready_unpin(cx).map_err(|_| RpcError::ClientClosed) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, request: BaseRequest) -> Self::Future { let (reply, reply_rx) = oneshot::channel(); - let mut inner = self.inner.clone(); + let inner = self.inner.clone(); async move { inner .send(ClientRequest::SendRequest { request, reply }) @@ -346,6 +356,7 @@ pub struct RpcClientWorker { ready_tx: Option>>, last_request_latency: Option, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, } impl RpcClientWorker @@ -357,6 +368,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send framed: CanonicalFraming, ready_tx: oneshot::Sender>, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, ) -> Self { Self { config, @@ -366,6 +378,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ready_tx: Some(ready_tx), last_request_latency: None, protocol_id, + shutdown_signal, } } @@ -405,26 +418,26 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send }, } - while let Some(req) = self.request_rx.next().await { - use ClientRequest::*; - match req { - SendRequest { request, reply } => { - if let Err(err) = self.do_request_response(request, reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; - } - }, - GetLastRequestLatency(reply) => { - let _ = reply.send(self.last_request_latency); - }, - SendPing(reply) => { - if let Err(err) = self.do_ping_pong(reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; + loop { + tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + break; + } + req = self.request_rx.recv() => { + match req { + Some(req) => { + if let Err(err) = self.handle_request(req).await { + error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + break; + } + } + None => break, } - }, + } } } + if let Err(err) = self.framed.close().await { debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); } @@ -436,6 +449,22 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ); } + async fn handle_request(&mut self, req: ClientRequest) -> Result<(), RpcError> { + use ClientRequest::*; + match req { + SendRequest { request, reply } => { + self.do_request_response(request, reply).await?; + }, + GetLastRequestLatency(reply) => { + let _ = reply.send(self.last_request_latency); + }, + SendPing(reply) => { + self.do_ping_pong(reply).await?; + }, + } + Ok(()) + } + async fn do_ping_pong(&mut self, reply: oneshot::Sender>) -> Result<(), RpcError> { let ack = proto::rpc::RpcRequest { flags: RpcMessageFlags::ACK.bits() as u32, @@ -482,7 +511,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send Ok(()) } - #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply), err)] + #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply))] async fn do_request_response( &mut self, request: BaseRequest, @@ -501,14 +530,30 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send debug!(target: LOG_TARGET, "Sending request: {}", req); let start = Instant::now(); + if reply.is_closed() { + event!(Level::WARN, "Client request was cancelled before request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled before request was sent" + ); + } self.framed.send(req.to_encoded_bytes().into()).await?; - let (mut response_tx, response_rx) = mpsc::channel(10); - if reply.send(response_rx).is_err() { - event!(Level::WARN, "Client request was cancelled"); - warn!(target: LOG_TARGET, "Client request was cancelled."); - response_tx.close_channel(); - // TODO: Should this not exit here? + let (response_tx, response_rx) = mpsc::channel(10); + if let Err(mut rx) = reply.send(response_rx) { + event!(Level::WARN, "Client request was cancelled after request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled after request was sent" + ); + rx.close(); + // RPC is strictly request/response + // If the client drops the RpcClient request at this point after the , we have two options: + // 1. Obey the protocol: receive the response + // 2. Close the RPC session and return an error (seems brittle and unexpected) + // Option 1 has the disadvantage when receiving large/many streamed responses. + // TODO: Detect if all handles to the client handles have been dropped. If so, + // immediately close the RPC session } loop { @@ -537,8 +582,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send start.elapsed() ); event!(Level::ERROR, "Response timed out"); - let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; - response_tx.close_channel(); + if !response_tx.is_closed() { + let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; + } + break; + }, + Err(RpcError::ClientClosed) => { + debug!( + target: LOG_TARGET, + "Request {} (method={}) was closed after {:.0?} (read_reply)", + request_id, + method, + start.elapsed() + ); + self.request_rx.close(); break; }, Err(err) => { @@ -564,7 +621,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let _ = response_tx.send(Ok(resp)).await; } if is_finished { - response_tx.close_channel(); break; } }, @@ -573,7 +629,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if !response_tx.is_closed() { let _ = response_tx.send(Err(err)).await; } - response_tx.close_channel(); break; }, Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) | @@ -598,7 +653,15 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send None => Either::Right(self.framed.next().map(Ok)), }; - match next_msg_fut.await { + let result = tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + return Err(RpcError::ClientClosed); + } + result = next_msg_fut => result, + }; + + match result { Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(RpcError::ServerClosedRequest), diff --git a/comms/src/protocol/rpc/client_pool.rs b/comms/src/protocol/rpc/client_pool.rs index 6829b41265..7cf99ed419 100644 --- a/comms/src/protocol/rpc/client_pool.rs +++ b/comms/src/protocol/rpc/client_pool.rs @@ -61,6 +61,11 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone let mut pool = self.pool.lock().await; pool.get_least_used_or_connect().await } + + pub async fn is_connected(&self) -> bool { + let pool = self.pool.lock().await; + pool.is_connected() + } } #[derive(Clone)] @@ -111,6 +116,10 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone } } + pub fn is_connected(&self) -> bool { + self.connection.is_connected() + } + pub(super) fn refresh_num_active_connections(&mut self) -> usize { self.prune(); self.clients.len() diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index 3abd62cef6..b39c15e6d7 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -22,10 +22,13 @@ use crate::{framing::CanonicalFraming, message::MessageExt, proto, protocol::rpc::error::HandshakeRejectReason}; use bytes::BytesMut; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use prost::{DecodeError, Message}; use std::{io, time::Duration}; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tracing::{debug, error, event, span, warn, Instrument, Level}; const LOG_TARGET: &str = "comms::rpc::handshake"; @@ -168,7 +171,7 @@ where T: AsyncRead + AsyncWrite + Unpin } #[tracing::instrument(name = "rpc::receive_handshake_reply", skip(self), err)] - async fn recv_next_frame(&mut self) -> Result>, time::Elapsed> { + async fn recv_next_frame(&mut self) -> Result>, time::error::Elapsed> { match self.timeout { Some(timeout) => time::timeout(timeout, self.framed.next()).await, None => Ok(self.framed.next().await), diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index d4e91fa8e4..2244979adf 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -80,6 +80,7 @@ pub mod __macro_reexports { }, Bytes, }; - pub use futures::{future, future::BoxFuture, AsyncRead, AsyncWrite}; + pub use futures::{future, future::BoxFuture}; + pub use tokio::io::{AsyncRead, AsyncWrite}; pub use tower::Service; } diff --git a/comms/src/protocol/rpc/server/error.rs b/comms/src/protocol/rpc/server/error.rs index 5078c6c588..6972cec60b 100644 --- a/comms/src/protocol/rpc/server/error.rs +++ b/comms/src/protocol/rpc/server/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::protocol::rpc::handshake::RpcHandshakeError; -use futures::channel::oneshot; use prost::DecodeError; use std::io; +use tokio::sync::oneshot; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -41,8 +41,8 @@ pub enum RpcServerError { ProtocolServiceNotFound(String), } -impl From for RpcServerError { - fn from(_: oneshot::Canceled) -> Self { +impl From for RpcServerError { + fn from(_: oneshot::error::RecvError) -> Self { RpcServerError::RequestCanceled } } diff --git a/comms/src/protocol/rpc/server/handle.rs b/comms/src/protocol/rpc/server/handle.rs index 89bf8dd3b9..972d91429d 100644 --- a/comms/src/protocol/rpc/server/handle.rs +++ b/comms/src/protocol/rpc/server/handle.rs @@ -21,10 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::RpcServerError; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug)] pub enum RpcServerRequest { diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 19741a0a1a..69659ba03b 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -42,6 +42,7 @@ use crate::{ ProtocolNotificationTx, }, test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, + utils, NodeIdentity, PeerConnection, PeerManager, @@ -49,7 +50,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{channel::mpsc, future::BoxFuture, stream, SinkExt}; +use futures::future::BoxFuture; use std::{ collections::HashMap, future, @@ -57,7 +58,7 @@ use std::{ task::{Context, Poll}, }; use tokio::{ - sync::{Mutex, RwLock}, + sync::{mpsc, Mutex, RwLock}, task, }; use tower::Service; @@ -139,9 +140,13 @@ pub trait RpcMock { { method_state.requests.write().await.push(request.into_message()); let resp = method_state.response.read().await.clone()?; - let (mut tx, rx) = mpsc::channel(resp.len()); - let mut resp = stream::iter(resp.into_iter().map(Ok).map(Ok)); - tx.send_all(&mut resp).await.unwrap(); + let (tx, rx) = mpsc::channel(resp.len()); + match utils::mpsc::send_all(&tx, resp.into_iter().map(Ok)).await { + Ok(_) => {}, + // This is done because tokio mpsc channels give the item back to you in the error, and our item doesn't + // impl Debug, so we can't use unwrap, expect etc + Err(_) => panic!("send error"), + } Ok(Streaming::new(rx)) } } @@ -234,7 +239,7 @@ where let peer_node_id = peer.node_id.clone(); let (_, our_conn_mock, peer_conn, _) = create_peer_connection_mock_pair(peer, self.our_node.to_peer()).await; - let mut protocol_tx = self.protocol_tx.clone(); + let protocol_tx = self.protocol_tx.clone(); task::spawn(async move { while let Some(substream) = our_conn_mock.next_incoming_substream().await { let proto_notif = ProtocolNotification::new( diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index b2b7dbf76f..88fdb7ee61 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -53,14 +53,19 @@ use crate::{ protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::SinkExt; use prost::Message; use std::{ borrow::Cow, future::Future, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; @@ -198,7 +203,7 @@ pub(super) struct PeerRpcServer { service: TSvc, protocol_notifications: Option>, comms_provider: TCommsProvider, - request_rx: Option>, + request_rx: mpsc::Receiver, } impl PeerRpcServer @@ -233,7 +238,7 @@ where service, protocol_notifications: Some(protocol_notifications), comms_provider, - request_rx: Some(request_rx), + request_rx, } } @@ -243,24 +248,19 @@ where .take() .expect("PeerRpcServer initialized without protocol_notifications"); - let mut requests = self - .request_rx - .take() - .expect("PeerRpcServer initialized without request_rx"); - loop { - futures::select! { - maybe_notif = protocol_notifs.next() => { - match maybe_notif { - Some(notif) => self.handle_protocol_notification(notif).await?, - // No more protocol notifications to come, so we're done - None => break, - } - } - - req = requests.select_next_some() => { + tokio::select! { + maybe_notif = protocol_notifs.recv() => { + match maybe_notif { + Some(notif) => self.handle_protocol_notification(notif).await?, + // No more protocol notifications to come, so we're done + None => break, + } + } + + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; - }, + }, } } diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 9d03c6535d..1d40988075 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -44,14 +44,15 @@ use crate::{ Bytes, }; use futures::{ - channel::mpsc, future::BoxFuture, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, }; use std::sync::Arc; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; use tower::Service; use tower_make::MakeService; @@ -329,7 +330,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn find_route() { let server = RpcServer::new(); let mut router = Router::new(server, HelloService).add_service(GoodbyeService); diff --git a/comms/src/protocol/rpc/test/client_pool.rs b/comms/src/protocol/rpc/test/client_pool.rs index e1eb957d5f..d95e1d22b6 100644 --- a/comms/src/protocol/rpc/test/client_pool.rs +++ b/comms/src/protocol/rpc/test/client_pool.rs @@ -39,13 +39,13 @@ use crate::{ runtime::task, test_utils::mocks::{new_peer_connection_mock_pair, PeerConnectionMockState}, }; -use futures::{channel::mpsc, SinkExt}; use tari_shutdown::Shutdown; use tari_test_utils::{async_assert_eventually, unpack_enum}; +use tokio::sync::mpsc; async fn setup(num_concurrent_sessions: usize) -> (PeerConnection, PeerConnectionMockState, Shutdown) { let (conn1, conn1_state, conn2, conn2_state) = new_peer_connection_mock_pair().await; - let (mut notif_tx, notif_rx) = mpsc::channel(1); + let (notif_tx, notif_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); @@ -148,15 +148,15 @@ mod lazy_pool { async fn it_prunes_disconnected_sessions() { let (conn, mock_state, _shutdown) = setup(2).await; let mut pool = LazyPool::::new(conn, 2, Default::default()); - let mut conn1 = pool.get_least_used_or_connect().await.unwrap(); + let mut client1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); - let _conn2 = pool.get_least_used_or_connect().await.unwrap(); + let _client2 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 2); - conn1.close(); - drop(conn1); + client1.close().await; + drop(client1); async_assert_eventually!(mock_state.num_open_substreams(), expect = 1); assert_eq!(pool.refresh_num_active_connections(), 1); - let _conn3 = pool.get_least_used_or_connect().await.unwrap(); + let _client3 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(pool.refresh_num_active_connections(), 2); assert_eq!(mock_state.num_open_substreams(), 2); } diff --git a/comms/src/protocol/rpc/test/comms_integration.rs b/comms/src/protocol/rpc/test/comms_integration.rs index 9d23088f07..f43f921081 100644 --- a/comms/src/protocol/rpc/test/comms_integration.rs +++ b/comms/src/protocol/rpc/test/comms_integration.rs @@ -37,7 +37,7 @@ use crate::{ use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn run_service() { let node_identity1 = build_node_identity(Default::default()); let rpc_service = MockRpcService::new(); diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index 0e190473dd..445099e7c3 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -26,12 +26,16 @@ use crate::{ rpc::{NamedProtocolService, Request, Response, RpcError, RpcServerError, RpcStatus, Streaming}, ProtocolId, }, + utils, }; use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, time}; +use tokio::{ + sync::{mpsc, RwLock}, + task, + time, +}; #[async_trait] // #[tari_rpc(protocol_name = "/tari/greeting/1.0", server_struct = GreetingServer, client_struct = GreetingClient)] @@ -91,20 +95,11 @@ impl GreetingRpc for GreetingService { } async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); let num = *request.message(); let greetings = self.greetings[..num as usize].to_vec(); task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - match tx.send_all(&mut stream).await { - Ok(_) => {}, - Err(_err) => { - // Log error - }, - } + let _ = utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await; }); Ok(Streaming::new(rx)) @@ -118,7 +113,7 @@ impl GreetingRpc for GreetingService { } async fn streaming_error2(&self, _: Request<()>) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(2); + let (tx, rx) = mpsc::channel(2); tx.send(Ok("This is ok".to_string())).await.unwrap(); tx.send(Err(RpcStatus::bad_request("This is a problem"))).await.unwrap(); @@ -151,7 +146,7 @@ impl SlowGreetingService { impl GreetingRpc for SlowGreetingService { async fn say_hello(&self, _: Request) -> Result, RpcStatus> { let delay = *self.delay.read().await; - time::delay_for(delay).await; + time::sleep(delay).await; Ok(Response::new(SayHelloResponse { greeting: "took a while to load".to_string(), })) @@ -376,8 +371,8 @@ impl GreetingClient { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } } diff --git a/comms/src/protocol/rpc/test/handshake.rs b/comms/src/protocol/rpc/test/handshake.rs index cdd79746f2..9a21628012 100644 --- a/comms/src/protocol/rpc/test/handshake.rs +++ b/comms/src/protocol/rpc/test/handshake.rs @@ -33,7 +33,7 @@ use crate::{ }; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn it_performs_the_handshake() { let (client, server) = MemorySocket::new_pair(); @@ -51,7 +51,7 @@ async fn it_performs_the_handshake() { assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); } -#[runtime::test_basic] +#[runtime::test] async fn it_rejects_the_handshake() { let (client, server) = MemorySocket::new_pair(); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index a762ac4c9c..3149e794e3 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -52,12 +52,15 @@ use crate::{ test_utils::node_identity::build_node_identity, NodeIdentity, }; -use futures::{channel::mpsc, future, future::Either, SinkExt, StreamExt}; +use futures::{future, future::Either, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; pub(super) async fn setup_service( service_impl: T, @@ -85,7 +88,7 @@ pub(super) async fn setup_service( futures::pin_mut!(fut); match future::select(shutdown_signal, fut).await { - Either::Left((r, _)) => r.unwrap(), + Either::Left(_) => {}, Either::Right((r, _)) => r.unwrap(), } } @@ -97,7 +100,7 @@ pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, ) -> (MemorySocket, task::JoinHandle<()>, Arc, Shutdown) { - let (mut notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; + let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (inbound, outbound) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -114,7 +117,7 @@ pub(super) async fn setup( (outbound, server_hnd, node_identity, shutdown) } -#[runtime::test_basic] +#[runtime::test] async fn request_response_errors_and_streaming() { let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; @@ -171,7 +174,7 @@ async fn request_response_errors_and_streaming() { let pk_hex = client.get_public_key_hex().await.unwrap(); assert_eq!(pk_hex, node_identity.public_key().to_hex()); - client.close(); + client.close().await; let err = client .say_hello(SayHelloRequest { @@ -181,13 +184,20 @@ async fn request_response_errors_and_streaming() { .await .unwrap_err(); - unpack_enum!(RpcError::ClientClosed = err); + match err { + // Because of the race between closing the request stream and sending on that stream in the above call + // We can either get "this client was closed" or "the request you made was cancelled". + // If we delay some small time, we'll always get the former (but arbitrary delays cause flakiness and should be + // avoided) + RpcError::ClientClosed | RpcError::RequestCancelled => {}, + err => panic!("Unexpected error {:?}", err), + } - shutdown.trigger().unwrap(); + shutdown.trigger(); server_hnd.await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn concurrent_requests() { let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; @@ -227,7 +237,7 @@ async fn concurrent_requests() { assert_eq!(spawned2.await.unwrap(), GreetingService::DEFAULT_GREETINGS[..5]); } -#[runtime::test_basic] +#[runtime::test] async fn response_too_big() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -248,7 +258,7 @@ async fn response_too_big() { let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn ping_latency() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -261,11 +271,11 @@ async fn ping_latency() { assert!(latency.as_secs() < 5); } -#[runtime::test_basic] +#[runtime::test] async fn server_shutdown_before_connect() { let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; let framed = framing::canonical(socket, 1024); - shutdown.trigger().unwrap(); + shutdown.trigger(); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( @@ -274,7 +284,7 @@ async fn server_shutdown_before_connect() { )); } -#[runtime::test_basic] +#[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; @@ -298,9 +308,9 @@ async fn timeout() { assert_eq!(resp.greeting, "took a while to load"); } -#[runtime::test_basic] +#[runtime::test] async fn unknown_protocol() { - let (mut notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; + let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; let (inbound, socket) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -324,7 +334,7 @@ async fn unknown_protocol() { )); } -#[runtime::test_basic] +#[runtime::test] async fn rejected_no_sessions_available() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; let framed = framing::canonical(socket, 1024); diff --git a/comms/src/common/rate_limit.rs b/comms/src/rate_limit.rs similarity index 79% rename from comms/src/common/rate_limit.rs rename to comms/src/rate_limit.rs index 1705397d8e..3a06c70040 100644 --- a/comms/src/common/rate_limit.rs +++ b/comms/src/rate_limit.rs @@ -26,7 +26,7 @@ // This is slightly changed from the libra rate limiter implementation -use futures::{stream::Fuse, FutureExt, Stream, StreamExt}; +use futures::FutureExt; use pin_project::pin_project; use std::{ future::Future, @@ -36,10 +36,11 @@ use std::{ time::Duration, }; use tokio::{ - sync::{OwnedSemaphorePermit, Semaphore}, + sync::{AcquireError, OwnedSemaphorePermit, Semaphore}, time, time::Interval, }; +use tokio_stream::Stream; pub trait RateLimit: Stream { /// Consumes the stream and returns a rate-limited stream that only polls the underlying stream @@ -60,12 +61,13 @@ pub struct RateLimiter { stream: T, /// An interval stream that "restocks" the permits #[pin] - interval: Fuse, + interval: Interval, /// The maximum permits to issue capacity: usize, /// A semaphore that holds the permits permits: Arc, - permit_future: Option + Send>>>, + #[allow(clippy::type_complexity)] + permit_future: Option> + Send>>>, permit_acquired: bool, } @@ -75,7 +77,7 @@ impl RateLimiter { stream, capacity, - interval: time::interval(restock_interval).fuse(), + interval: time::interval(restock_interval), // `interval` starts immediately, so we can start with zero permits permits: Arc::new(Semaphore::new(0)), permit_future: None, @@ -89,7 +91,7 @@ impl Stream for RateLimiter { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // "Restock" permits once interval is ready - if let Poll::Ready(Some(_)) = self.as_mut().project().interval.poll_next(cx) { + if self.as_mut().project().interval.poll_tick(cx).is_ready() { self.permits .add_permits(self.capacity - self.permits.available_permits()); } @@ -103,6 +105,8 @@ impl Stream for RateLimiter { } // Wait until a permit is acquired + // `unwrap()` is safe because acquire_owned only panics if the semaphore has closed, but we never close it + // for the lifetime of this instance let permit = futures::ready!(self .as_mut() .project() @@ -110,7 +114,8 @@ impl Stream for RateLimiter { .as_mut() .unwrap() .as_mut() - .poll(cx)); + .poll(cx)) + .unwrap(); // Don't release the permit on drop, `interval` will restock permits permit.forget(); let this = self.as_mut().project(); @@ -130,45 +135,54 @@ impl Stream for RateLimiter { mod test { use super::*; use crate::runtime; - use futures::{future::Either, stream}; + use futures::{stream, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn rate_limit() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } assert_eq!(count, 10); } - #[runtime::test_basic] + #[runtime::test] async fn rate_limit_restock() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } // Test that at least 1 restock happens. diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index a6f7e615f7..48752c05d0 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -25,10 +25,7 @@ use tokio::runtime; // Re-export pub use tokio::{runtime::Handle, task}; -#[cfg(test)] -pub use tokio_macros::test; -#[cfg(test)] -pub use tokio_macros::test_basic; +pub use tokio::test; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] diff --git a/comms/src/socks/client.rs b/comms/src/socks/client.rs index f155f2f420..7580c1dbcf 100644 --- a/comms/src/socks/client.rs +++ b/comms/src/socks/client.rs @@ -23,7 +23,6 @@ // Acknowledgement to @sticnarf for tokio-socks on which this code is based use super::error::SocksError; use data_encoding::BASE32; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use multiaddr::{Multiaddr, Protocol}; use std::{ borrow::Cow, @@ -31,6 +30,7 @@ use std::{ fmt::Formatter, net::{Ipv4Addr, Ipv6Addr}, }; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub type Result = std::result::Result; @@ -104,7 +104,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Connects to a address through a SOCKS5 proxy and returns the 'upgraded' socket. This consumes the /// `Socks5Client` as once connected, the socks protocol does not recognise any further commands. - #[tracing::instrument(name = "socks::connect", skip(self), err)] + #[tracing::instrument(name = "socks::connect", skip(self))] pub async fn connect(mut self, address: &Multiaddr) -> Result<(TSocket, Multiaddr)> { let address = self.execute_command(Command::Connect, address).await?; Ok((self.protocol.socket, address)) @@ -112,7 +112,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Requests the tor proxy to resolve a DNS address is resolved into an IP address. /// This operation only works with the tor SOCKS proxy. - #[tracing::instrument(name = "socks:tor_resolve", skip(self), err)] + #[tracing::instrument(name = "socks:tor_resolve", skip(self))] pub async fn tor_resolve(&mut self, address: &Multiaddr) -> Result { // Tor resolve does not return the port back let (dns, rest) = multiaddr_split_first(&address); @@ -126,7 +126,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Requests the tor proxy to reverse resolve an IP address into a DNS address if it is able. /// This operation only works with the tor SOCKS proxy. - #[tracing::instrument(name = "socks::tor_resolve_ptr", skip(self), err)] + #[tracing::instrument(name = "socks::tor_resolve_ptr", skip(self))] pub async fn tor_resolve_ptr(&mut self, address: &Multiaddr) -> Result { self.execute_command(Command::TorResolvePtr, address).await } diff --git a/comms/src/test_utils/mocks/connection_manager.rs b/comms/src/test_utils/mocks/connection_manager.rs index cc489af60e..ece7224f44 100644 --- a/comms/src/test_utils/mocks/connection_manager.rs +++ b/comms/src/test_utils/mocks/connection_manager.rs @@ -31,7 +31,6 @@ use crate::{ peer_manager::NodeId, runtime::task, }; -use futures::{channel::mpsc, lock::Mutex, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -39,14 +38,14 @@ use std::{ Arc, }, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, Mutex}; pub fn create_connection_manager_mock() -> (ConnectionManagerRequester, ConnectionManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectionManagerRequester::new(tx, event_tx.clone()), - ConnectionManagerMock::new(rx.fuse(), event_tx), + ConnectionManagerMock::new(rx, event_tx), ) } @@ -97,13 +96,13 @@ impl ConnectionManagerMockState { } pub struct ConnectionManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectionManagerMockState, } impl ConnectionManagerMock { pub fn new( - receiver: Fuse>, + receiver: mpsc::Receiver, event_tx: broadcast::Sender>, ) -> Self { Self { @@ -121,7 +120,7 @@ impl ConnectionManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/connectivity_manager.rs b/comms/src/test_utils/mocks/connectivity_manager.rs index 122a60127b..78394f428b 100644 --- a/comms/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/src/test_utils/mocks/connectivity_manager.rs @@ -22,32 +22,36 @@ use crate::{ connection_manager::{ConnectionManagerError, PeerConnection}, - connectivity::{ConnectivityEvent, ConnectivityRequest, ConnectivityRequester, ConnectivityStatus}, + connectivity::{ + ConnectivityEvent, + ConnectivityEventTx, + ConnectivityRequest, + ConnectivityRequester, + ConnectivityStatus, + }, peer_manager::NodeId, runtime::task, }; -use futures::{ - channel::{mpsc, oneshot}, - lock::Mutex, - stream::Fuse, - StreamExt, -}; +use futures::lock::Mutex; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; pub fn create_connectivity_mock() -> (ConnectivityRequester, ConnectivityManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectivityRequester::new(tx, event_tx.clone()), - ConnectivityManagerMock::new(rx.fuse(), event_tx), + ConnectivityManagerMock::new(rx, event_tx), ) } #[derive(Debug, Clone)] pub struct ConnectivityManagerMockState { inner: Arc>, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, } #[derive(Debug, Default)] @@ -61,7 +65,7 @@ struct State { } impl ConnectivityManagerMockState { - pub fn new(event_tx: broadcast::Sender>) -> Self { + pub fn new(event_tx: ConnectivityEventTx) -> Self { Self { event_tx, inner: Default::default(), @@ -132,7 +136,7 @@ impl ConnectivityManagerMockState { count, self.call_count().await ); - time::delay_for(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(100)).await; } } @@ -156,9 +160,8 @@ impl ConnectivityManagerMockState { .await } - #[allow(dead_code)] pub fn publish_event(&self, event: ConnectivityEvent) { - self.event_tx.send(Arc::new(event)).unwrap(); + self.event_tx.send(event).unwrap(); } pub(self) async fn with_state(&self, f: F) -> R @@ -169,15 +172,12 @@ impl ConnectivityManagerMockState { } pub struct ConnectivityManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectivityManagerMockState, } impl ConnectivityManagerMock { - pub fn new( - receiver: Fuse>, - event_tx: broadcast::Sender>, - ) -> Self { + pub fn new(receiver: mpsc::Receiver, event_tx: ConnectivityEventTx) -> Self { Self { receiver, state: ConnectivityManagerMockState::new(event_tx), @@ -195,7 +195,7 @@ impl ConnectivityManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index 13b0c77bbf..bccdec9575 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -34,15 +34,18 @@ use crate::{ peer_manager::{NodeId, Peer, PeerFeatures}, test_utils::{node_identity::build_node_identity, transport}, }; -use futures::{channel::mpsc, lock::Mutex, StreamExt}; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use tokio::runtime::Handle; +use tokio::{ + runtime::Handle, + sync::{mpsc, Mutex}, +}; +use tokio_stream::StreamExt; pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::Receiver) { - let (tx, rx) = mpsc::channel(0); + let (tx, rx) = mpsc::channel(1); ( PeerConnection::new( 1, @@ -114,7 +117,7 @@ pub async fn new_peer_connection_mock_pair() -> ( create_peer_connection_mock_pair(peer1, peer2).await } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct PeerConnectionMockState { call_count: Arc, mux_control: Arc>, @@ -181,7 +184,7 @@ impl PeerConnectionMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index a150a765f8..3e5d7229a0 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -29,12 +29,14 @@ use crate::{ protocol::Protocols, transports::Transport, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_shutdown::ShutdownSignal; use tari_storage::HashmapDatabase; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; #[derive(Clone, Debug)] pub struct TestNodeConfig { diff --git a/comms/src/tor/control_client/client.rs b/comms/src/tor/control_client/client.rs index 5d14c770b7..573bce60c0 100644 --- a/comms/src/tor/control_client/client.rs +++ b/comms/src/tor/control_client/client.rs @@ -34,10 +34,13 @@ use crate::{ tor::control_client::{event::TorControlEvent, monitor::spawn_monitor}, transports::{TcpTransport, Transport}, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{borrow::Cow, fmt, fmt::Display, num::NonZeroU16}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; +use tokio_stream::wrappers::BroadcastStream; /// Client for the Tor control port. /// @@ -80,8 +83,8 @@ impl TorControlPortClient { &self.event_tx } - pub fn get_event_stream(&self) -> broadcast::Receiver { - self.event_tx.subscribe() + pub fn get_event_stream(&self) -> BroadcastStream { + BroadcastStream::new(self.event_tx.subscribe()) } /// Authenticate with the tor control port @@ -232,8 +235,7 @@ impl TorControlPortClient { } async fn receive_line(&mut self) -> Result { - let line = self.output_stream.next().await.ok_or(TorClientError::UnexpectedEof)?; - + let line = self.output_stream.recv().await.ok_or(TorClientError::UnexpectedEof)?; Ok(line) } } @@ -273,9 +275,11 @@ mod test { runtime, tor::control_client::{test_server, test_server::canned_responses, types::PrivateKey}, }; - use futures::{future, AsyncWriteExt}; + use futures::future; use std::net::SocketAddr; use tari_test_utils::unpack_enum; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; async fn setup_test() -> (TorControlPortClient, test_server::State) { let (_, mock_state, socket) = test_server::spawn().await; @@ -298,7 +302,7 @@ mod test { let _out_sock = result_out.unwrap(); let (mut in_sock, _) = result_in.unwrap().unwrap(); in_sock.write(b"test123").await.unwrap(); - in_sock.close().await.unwrap(); + in_sock.shutdown().await.unwrap(); } #[runtime::test] diff --git a/comms/src/tor/control_client/monitor.rs b/comms/src/tor/control_client/monitor.rs index a5191b466d..72eb6b88ef 100644 --- a/comms/src/tor/control_client/monitor.rs +++ b/comms/src/tor/control_client/monitor.rs @@ -21,11 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{event::TorControlEvent, parsers, response::ResponseLine, LOG_TARGET}; -use crate::{compat::IoCompat, runtime::task}; -use futures::{channel::mpsc, future, future::Either, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use crate::runtime::task; +use futures::{future::Either, SinkExt, Stream, StreamExt}; use log::*; use std::fmt; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LinesCodec}; pub fn spawn_monitor( @@ -36,16 +39,19 @@ pub fn spawn_monitor( where TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut responses_tx, responses_rx) = mpsc::channel(100); + let (responses_tx, responses_rx) = mpsc::channel(100); task::spawn(async move { - let framed = Framed::new(IoCompat::new(socket), LinesCodec::new()); + let framed = Framed::new(socket, LinesCodec::new()); let (mut sink, mut stream) = framed.split(); loop { - let either = future::select(cmd_rx.next(), stream.next()).await; + let either = tokio::select! { + next = cmd_rx.recv() => Either::Left(next), + next = stream.next() => Either::Right(next), + }; match either { // Received a command to send to the control server - Either::Left((Some(line), _)) => { + Either::Left(Some(line)) => { trace!(target: LOG_TARGET, "Writing command of length '{}'", line.len()); if let Err(err) = sink.send(line).await { error!( @@ -56,7 +62,7 @@ where } }, // Command stream ended - Either::Left((None, _)) => { + Either::Left(None) => { debug!( target: LOG_TARGET, "Tor control server command receiver closed. Monitor is exiting." @@ -65,7 +71,7 @@ where }, // Received a line from the control server - Either::Right((Some(Ok(line)), _)) => { + Either::Right(Some(Ok(line))) => { trace!(target: LOG_TARGET, "Read line of length '{}'", line.len()); match parsers::response_line(&line) { Ok(mut line) => { @@ -95,7 +101,7 @@ where }, // Error receiving a line from the control server - Either::Right((Some(Err(err)), _)) => { + Either::Right(Some(Err(err))) => { error!( target: LOG_TARGET, "Line framing error when reading from tor control server: '{:?}'. Monitor is exiting.", err @@ -103,7 +109,7 @@ where break; }, // The control server disconnected - Either::Right((None, _)) => { + Either::Right(None) => { cmd_rx.close(); debug!( target: LOG_TARGET, diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 5a5e1b3b7c..1741cfc721 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -20,13 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - compat::IoCompat, - memsocket::MemorySocket, - multiaddr::Multiaddr, - runtime, - test_utils::transport::build_connected_sockets, -}; +use crate::{memsocket::MemorySocket, multiaddr::Multiaddr, runtime, test_utils::transport::build_connected_sockets}; use futures::{lock::Mutex, stream, SinkExt, StreamExt}; use std::sync::Arc; use tokio_util::codec::{Framed, LinesCodec}; @@ -82,7 +76,7 @@ impl TorControlPortTestServer { } pub async fn run(self) { - let mut framed = Framed::new(IoCompat::new(self.socket), LinesCodec::new()); + let mut framed = Framed::new(self.socket, LinesCodec::new()); let state = self.state; while let Some(msg) = framed.next().await { state.request_lines.lock().await.push(msg.unwrap()); diff --git a/comms/src/tor/hidden_service/controller.rs b/comms/src/tor/hidden_service/controller.rs index e19818df58..74b89808fb 100644 --- a/comms/src/tor/hidden_service/controller.rs +++ b/comms/src/tor/hidden_service/controller.rs @@ -214,7 +214,7 @@ impl HiddenServiceController { "Failed to reestablish connection with tor control server because '{:?}'", err ); warn!(target: LOG_TARGET, "Will attempt again in 5 seconds..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; }, Either::Right(_) => { diff --git a/comms/src/transports/dns/tor.rs b/comms/src/transports/dns/tor.rs index 737971af9a..e198efdc0b 100644 --- a/comms/src/transports/dns/tor.rs +++ b/comms/src/transports/dns/tor.rs @@ -69,7 +69,7 @@ impl DnsResolver for TorDnsResolver { let resolved = match client.tor_resolve(&addr).await { Ok(a) => a, Err(err) => { - error!(target: LOG_TARGET, "{}", err); + error!(target: LOG_TARGET, "Error resolving address: {}", err); return Err(err.into()); }, }; diff --git a/comms/src/transports/memory.rs b/comms/src/transports/memory.rs index 074e728274..8e3a1d7a91 100644 --- a/comms/src/transports/memory.rs +++ b/comms/src/transports/memory.rs @@ -129,7 +129,8 @@ impl Stream for Listener { mod test { use super::*; use crate::runtime; - use futures::{future::join, stream::StreamExt, AsyncReadExt, AsyncWriteExt}; + use futures::{future::join, stream::StreamExt}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[runtime::test] async fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { diff --git a/comms/src/transports/mod.rs b/comms/src/transports/mod.rs index 0c04e50a72..75b16db975 100644 --- a/comms/src/transports/mod.rs +++ b/comms/src/transports/mod.rs @@ -24,8 +24,8 @@ // Copyright (c) The Libra Core Contributors // SPDX-License-Identifier: Apache-2.0 -use futures::Stream; use multiaddr::Multiaddr; +use tokio_stream::Stream; mod dns; mod helpers; @@ -37,7 +37,7 @@ mod socks; pub use socks::{SocksConfig, SocksTransport}; mod tcp; -pub use tcp::{TcpSocket, TcpTransport}; +pub use tcp::TcpTransport; mod tcp_with_tor; pub use tcp_with_tor::TcpWithTorTransport; diff --git a/comms/src/transports/socks.rs b/comms/src/transports/socks.rs index 2027f34561..7dc87ef0e4 100644 --- a/comms/src/transports/socks.rs +++ b/comms/src/transports/socks.rs @@ -24,12 +24,13 @@ use crate::{ multiaddr::Multiaddr, socks, socks::Socks5Client, - transports::{dns::SystemDnsResolver, tcp::TcpTransport, TcpSocket, Transport}, + transports::{dns::SystemDnsResolver, tcp::TcpTransport, Transport}, }; -use std::{io, time::Duration}; +use std::io; +use tokio::net::TcpStream; -/// SO_KEEPALIVE setting for the SOCKS TCP connection -const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); +// /// SO_KEEPALIVE setting for the SOCKS TCP connection +// const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); #[derive(Clone, Debug)] pub struct SocksConfig { @@ -57,7 +58,7 @@ impl SocksTransport { pub fn create_socks_tcp_transport() -> TcpTransport { let mut tcp_transport = TcpTransport::new(); tcp_transport.set_nodelay(true); - tcp_transport.set_keepalive(Some(SOCKS_SO_KEEPALIVE)); + // .set_keepalive(Some(SOCKS_SO_KEEPALIVE)) tcp_transport.set_dns_resolver(SystemDnsResolver); tcp_transport } @@ -66,7 +67,7 @@ impl SocksTransport { tcp: TcpTransport, socks_config: SocksConfig, dest_addr: Multiaddr, - ) -> io::Result { + ) -> io::Result { // Create a new connection to the SOCKS proxy let socks_conn = tcp.dial(socks_config.proxy_address).await?; let mut client = Socks5Client::new(socks_conn); diff --git a/comms/src/transports/tcp.rs b/comms/src/transports/tcp.rs index 8112cca203..6b47e7c357 100644 --- a/comms/src/transports/tcp.rs +++ b/comms/src/transports/tcp.rs @@ -25,44 +25,42 @@ use crate::{ transports::dns::{DnsResolverRef, SystemDnsResolver}, utils::multiaddr::socketaddr_to_multiaddr, }; -use futures::{io::Error, ready, AsyncRead, AsyncWrite, Future, FutureExt, Stream}; +use futures::{ready, FutureExt}; use multiaddr::Multiaddr; use std::{ + future::Future, io, pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}, - net::{TcpListener, TcpStream}, }; +use tokio::net::{TcpListener, TcpStream}; +use tokio_stream::Stream; /// Transport implementation for TCP #[derive(Clone)] pub struct TcpTransport { - recv_buffer_size: Option, - send_buffer_size: Option, + // recv_buffer_size: Option, + // send_buffer_size: Option, ttl: Option, - #[allow(clippy::option_option)] - keepalive: Option>, + // #[allow(clippy::option_option)] + // keepalive: Option>, nodelay: Option, dns_resolver: DnsResolverRef, } impl TcpTransport { - #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] - setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); - - #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] - setter_mut!(set_send_buffer_size, send_buffer_size, Option); + // #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] + // setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); + // + // #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] + // setter_mut!(set_send_buffer_size, send_buffer_size, Option); #[doc("Sets `IP_TTL` i.e. the TTL of packets sent from this socket.")] setter_mut!(set_ttl, ttl, Option); - #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] - setter_mut!(set_keepalive, keepalive, Option>); + // #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] + // setter_mut!(set_keepalive, keepalive, Option>); #[doc("Sets `TCP_NODELAY` i.e disable Nagle's algorithm if set to true.")] setter_mut!(set_nodelay, nodelay, Option); @@ -81,9 +79,10 @@ impl TcpTransport { /// Apply socket options to `TcpStream`. fn configure(&self, socket: &TcpStream) -> io::Result<()> { - if let Some(keepalive) = self.keepalive { - socket.set_keepalive(keepalive)?; - } + // https://github.com/rust-lang/rust/issues/69774 + // if let Some(keepalive) = self.keepalive { + // socket.set_keepalive(keepalive)?; + // } if let Some(ttl) = self.ttl { socket.set_ttl(ttl)?; @@ -93,13 +92,13 @@ impl TcpTransport { socket.set_nodelay(nodelay)?; } - if let Some(recv_buffer_size) = self.recv_buffer_size { - socket.set_recv_buffer_size(recv_buffer_size)?; - } - - if let Some(send_buffer_size) = self.send_buffer_size { - socket.set_send_buffer_size(send_buffer_size)?; - } + // if let Some(recv_buffer_size) = self.recv_buffer_size { + // socket.set_recv_buffer_size(recv_buffer_size)?; + // } + // + // if let Some(send_buffer_size) = self.send_buffer_size { + // socket.set_send_buffer_size(send_buffer_size)?; + // } Ok(()) } @@ -108,10 +107,10 @@ impl TcpTransport { impl Default for TcpTransport { fn default() -> Self { Self { - recv_buffer_size: None, - send_buffer_size: None, + // recv_buffer_size: None, + // send_buffer_size: None, ttl: None, - keepalive: None, + // keepalive: None, nodelay: None, dns_resolver: Arc::new(SystemDnsResolver), } @@ -122,7 +121,7 @@ impl Default for TcpTransport { impl Transport for TcpTransport { type Error = io::Error; type Listener = TcpInbound; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { let socket_addr = self @@ -161,12 +160,12 @@ impl TcpOutbound { impl Future for TcpOutbound where F: Future> + Unpin { - type Output = io::Result; + type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let socket = ready!(Pin::new(&mut self.future).poll(cx))?; - self.config.configure(&socket)?; - Poll::Ready(Ok(TcpSocket::new(socket))) + let stream = ready!(Pin::new(&mut self.future).poll(cx))?; + self.config.configure(&stream)?; + Poll::Ready(Ok(stream)) } } @@ -184,52 +183,14 @@ impl TcpInbound { } impl Stream for TcpInbound { - type Item = io::Result<(TcpSocket, Multiaddr)>; + type Item = io::Result<(TcpStream, Multiaddr)>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let (socket, addr) = ready!(self.listener.poll_accept(cx))?; // Configure each socket self.config.configure(&socket)?; let peer_addr = socketaddr_to_multiaddr(&addr); - Poll::Ready(Some(Ok((TcpSocket::new(socket), peer_addr)))) - } -} - -/// TcpSocket is a wrapper struct for tokio `TcpStream` and implements -/// `futures-rs` AsyncRead/Write -pub struct TcpSocket { - inner: TcpStream, -} - -impl TcpSocket { - pub fn new(stream: TcpStream) -> Self { - Self { inner: stream } - } -} - -impl AsyncWrite for TcpSocket { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -impl AsyncRead for TcpSocket { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl From for TcpSocket { - fn from(stream: TcpStream) -> Self { - Self { inner: stream } + Poll::Ready(Some(Ok((socket, peer_addr)))) } } @@ -240,16 +201,15 @@ mod test { #[test] fn configure() { let mut tcp = TcpTransport::new(); - tcp.set_send_buffer_size(123) - .set_recv_buffer_size(456) - .set_nodelay(true) - .set_ttl(789) - .set_keepalive(Some(Duration::from_millis(100))); - - assert_eq!(tcp.send_buffer_size, Some(123)); - assert_eq!(tcp.recv_buffer_size, Some(456)); + // tcp.set_send_buffer_size(123) + // .set_recv_buffer_size(456) + tcp.set_nodelay(true).set_ttl(789); + // .set_keepalive(Some(Duration::from_millis(100))); + + // assert_eq!(tcp.send_buffer_size, Some(123)); + // assert_eq!(tcp.recv_buffer_size, Some(456)); assert_eq!(tcp.nodelay, Some(true)); assert_eq!(tcp.ttl, Some(789)); - assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); + // assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); } } diff --git a/comms/src/transports/tcp_with_tor.rs b/comms/src/transports/tcp_with_tor.rs index de54e17bb5..be800d9bfb 100644 --- a/comms/src/transports/tcp_with_tor.rs +++ b/comms/src/transports/tcp_with_tor.rs @@ -21,16 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::Transport; -use crate::transports::{ - dns::TorDnsResolver, - helpers::is_onion_address, - SocksConfig, - SocksTransport, - TcpSocket, - TcpTransport, -}; +use crate::transports::{dns::TorDnsResolver, helpers::is_onion_address, SocksConfig, SocksTransport, TcpTransport}; use multiaddr::Multiaddr; use std::io; +use tokio::net::TcpStream; /// Transport implementation for TCP with Tor support #[derive(Clone, Default)] @@ -69,7 +63,7 @@ impl TcpWithTorTransport { impl Transport for TcpWithTorTransport { type Error = io::Error; type Listener = ::Listener; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { self.tcp_transport.listen(addr).await diff --git a/comms/src/utils/mod.rs b/comms/src/utils/mod.rs index 697a758c1f..e543e644bb 100644 --- a/comms/src/utils/mod.rs +++ b/comms/src/utils/mod.rs @@ -22,5 +22,6 @@ pub mod cidr; pub mod datetime; +pub mod mpsc; pub mod multiaddr; pub mod signature; diff --git a/comms/src/common/mod.rs b/comms/src/utils/mpsc.rs similarity index 84% rename from comms/src/common/mod.rs rename to comms/src/utils/mpsc.rs index 2ac0cd0a6f..8ded39967f 100644 --- a/comms/src/common/mod.rs +++ b/comms/src/utils/mpsc.rs @@ -1,4 +1,4 @@ -// Copyright 2020, The Tari Project +// Copyright 2021, The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -20,4 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod rate_limit; +use tokio::sync::mpsc; + +pub async fn send_all>( + sender: &mpsc::Sender, + iter: I, +) -> Result<(), mpsc::error::SendError> { + for item in iter { + sender.send(item).await?; + } + Ok(()) +} diff --git a/comms/tests/greeting_service.rs b/comms/tests/greeting_service.rs index c13ae842e4..e70b6e16ff 100644 --- a/comms/tests/greeting_service.rs +++ b/comms/tests/greeting_service.rs @@ -23,14 +23,14 @@ #![cfg(feature = "rpc")] use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{cmp, time::Duration}; use tari_comms::{ async_trait, protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, }; use tari_comms_rpc_macros::tari_rpc; -use tokio::{task, time}; +use tokio::{sync::mpsc, task, time}; #[tari_rpc(protocol_name = b"t/greeting/1", server_struct = GreetingServer, client_struct = GreetingClient)] pub trait GreetingRpc: Send + Sync + 'static { @@ -85,15 +85,9 @@ impl GreetingRpc for GreetingService { async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { let num = *request.message(); - let (mut tx, rx) = mpsc::channel(num as usize); + let (tx, rx) = mpsc::channel(num as usize); let greetings = self.greetings[..cmp::min(num as usize + 1, self.greetings.len())].to_vec(); - task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - tx.send_all(&mut stream).await.unwrap(); - }); + task::spawn(async move { utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await }); Ok(Streaming::new(rx)) } @@ -113,7 +107,7 @@ impl GreetingRpc for GreetingService { item_size, num_items, } = request.into_message(); - let (mut tx, rx) = mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let t = std::time::Instant::now(); task::spawn(async move { let item = iter::repeat(0u8).take(item_size as usize).collect::>(); @@ -136,7 +130,7 @@ impl GreetingRpc for GreetingService { } async fn slow_response(&self, request: Request) -> Result, RpcStatus> { - time::delay_for(Duration::from_secs(request.into_message())).await; + time::sleep(Duration::from_secs(request.into_message())).await; Ok(Response::new(())) } } diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs index 0376a7dddc..933e158398 100644 --- a/comms/tests/rpc_stress.rs +++ b/comms/tests/rpc_stress.rs @@ -155,6 +155,7 @@ async fn run_stress_test(test_params: Params) { } future::join_all(tasks).await.into_iter().for_each(Result::unwrap); + log::info!("Stress test took {:.2?}", time.elapsed()); } @@ -259,7 +260,7 @@ async fn high_contention_high_concurrency() { .await; } -#[tokio_macros::test] +#[tokio::test] async fn run() { // let _ = env_logger::try_init(); log_timing("quick", quick()).await; diff --git a/comms/tests/substream_stress.rs b/comms/tests/substream_stress.rs index cbae6e8e52..a72eff03f3 100644 --- a/comms/tests/substream_stress.rs +++ b/comms/tests/substream_stress.rs @@ -23,7 +23,7 @@ mod helpers; use helpers::create_comms; -use futures::{channel::mpsc, future, SinkExt, StreamExt}; +use futures::{future, SinkExt, StreamExt}; use std::time::Duration; use tari_comms::{ framing, @@ -35,7 +35,7 @@ use tari_comms::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::unpack_enum; -use tokio::{task, time::Instant}; +use tokio::{sync::mpsc, task, time::Instant}; const PROTOCOL_NAME: &[u8] = b"test/dummy/protocol"; @@ -79,7 +79,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s task::spawn({ let sample = sample.clone(); async move { - while let Some(event) = notif_rx.next().await { + while let Some(event) = notif_rx.recv().await { unpack_enum!(ProtocolEvent::NewInboundSubstream(_n, remote_substream) = event.event); let mut remote_substream = framing::canonical(remote_substream, frame_size); @@ -150,7 +150,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s println!("avg t = {}ms", avg); } -#[tokio_macros::test] +#[tokio::test] async fn many_at_frame_limit() { const NUM_SUBSTREAMS: usize = 20; const NUM_ITERATIONS_PER_STREAM: usize = 100; diff --git a/infrastructure/shutdown/Cargo.toml b/infrastructure/shutdown/Cargo.toml index 1102ec037a..176bace6e2 100644 --- a/infrastructure/shutdown/Cargo.toml +++ b/infrastructure/shutdown/Cargo.toml @@ -12,7 +12,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "^0.3.1" +futures = "^0.3" [dev-dependencies] -tokio = {version="^0.2", features=["rt-core"]} +tokio = {version="1", default-features = false, features = ["rt", "macros"]} diff --git a/infrastructure/shutdown/src/lib.rs b/infrastructure/shutdown/src/lib.rs index f0054bd843..5afd56a239 100644 --- a/infrastructure/shutdown/src/lib.rs +++ b/infrastructure/shutdown/src/lib.rs @@ -26,14 +26,16 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use futures::{ - channel::{oneshot, oneshot::Canceled}, - future::{Fuse, FusedFuture, Shared}, + +pub mod oneshot_trigger; + +use crate::oneshot_trigger::OneshotSignal; +use futures::future::FusedFuture; +use std::{ + future::Future, + pin::Pin, task::{Context, Poll}, - Future, - FutureExt, }; -use std::pin::Pin; /// Trigger for shutdowns. /// @@ -42,71 +44,69 @@ use std::pin::Pin; /// /// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as /// long as required by the application. -pub struct Shutdown { - trigger: Option>, - signal: ShutdownSignal, - on_triggered: Option>, -} - +pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>); impl Shutdown { - /// Create a new Shutdown pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self { - trigger: Some(tx), - signal: rx.fuse().shared(), - on_triggered: None, - } + Self(oneshot_trigger::OneshotTrigger::new()) } - /// Set the on_triggered callback - pub fn on_triggered(&mut self, on_trigger: F) -> &mut Self - where F: FnOnce() + Send + Sync + 'static { - self.on_triggered = Some(Box::new(on_trigger)); - self + pub fn trigger(&mut self) { + self.0.broadcast(()); + } + + pub fn is_triggered(&self) -> bool { + self.0.is_used() } - /// Convert this into a ShutdownSignal without consuming the - /// struct. pub fn to_signal(&self) -> ShutdownSignal { - self.signal.clone() + self.0.to_signal().into() } +} - /// Trigger any listening signals - pub fn trigger(&mut self) -> Result<(), ShutdownError> { - match self.trigger.take() { - Some(trigger) => { - trigger.send(()).map_err(|_| ShutdownError)?; +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} - if let Some(on_triggered) = self.on_triggered.take() { - on_triggered(); - } +/// Receiver end of a shutdown signal. Once received the consumer should shut down. +#[derive(Debug, Clone)] +pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>); - Ok(()) - }, - None => Ok(()), - } +impl ShutdownSignal { + pub fn is_triggered(&self) -> bool { + self.0.is_terminated() } - pub fn is_triggered(&self) -> bool { - self.trigger.is_none() + /// Wait for the shutdown signal to trigger. + pub fn wait(&mut self) -> &mut Self { + self } } -impl Drop for Shutdown { - fn drop(&mut self) { - let _ = self.trigger(); +impl Future for ShutdownSignal { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.0).poll(cx) { + // Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } } } -impl Default for Shutdown { - fn default() -> Self { - Self::new() +impl FusedFuture for ShutdownSignal { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } -/// Receiver end of a shutdown signal. Once received the consumer should shut down. -pub type ShutdownSignal = Shared>>; +impl From> for ShutdownSignal { + fn from(inner: OneshotSignal<()>) -> Self { + Self(inner) + } +} #[derive(Debug, Clone, Default)] pub struct OptionalShutdownSignal(Option); @@ -137,11 +137,11 @@ impl OptionalShutdownSignal { } impl Future for OptionalShutdownSignal { - type Output = Result<(), Canceled>; + type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.as_mut() { - Some(inner) => inner.poll_unpin(cx), + Some(inner) => Pin::new(inner).poll(cx), None => Poll::Pending, } } @@ -165,73 +165,50 @@ impl FusedFuture for OptionalShutdownSignal { } } -#[derive(Debug)] -pub struct ShutdownError; - #[cfg(test)] mod test { use super::*; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; - use tokio::runtime::Runtime; - - #[test] - fn trigger() { - let rt = Runtime::new().unwrap(); + use tokio::task; + + #[tokio::test] + async fn trigger() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); assert!(!shutdown.is_triggered()); - rt.spawn(async move { - signal.await.unwrap(); + let fut = task::spawn(async move { + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + assert!(shutdown.is_triggered()); // Shutdown::trigger is idempotent - shutdown.trigger().unwrap(); + shutdown.trigger(); assert!(shutdown.is_triggered()); + fut.await.unwrap(); } - #[test] - fn signal_clone() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn signal_clone() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + fut.await.unwrap(); } - #[test] - fn drop_trigger() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn drop_trigger() { let shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); drop(shutdown); - } - - #[test] - fn on_trigger() { - let rt = Runtime::new().unwrap(); - let spy = Arc::new(AtomicBool::new(false)); - let spy_clone = Arc::clone(&spy); - let mut shutdown = Shutdown::new(); - shutdown.on_triggered(move || { - spy_clone.store(true, Ordering::SeqCst); - }); - let signal = shutdown.to_signal(); - rt.spawn(async move { - signal.await.unwrap(); - }); - shutdown.trigger().unwrap(); - assert!(spy.load(Ordering::SeqCst)); + fut.await.unwrap(); } } diff --git a/infrastructure/shutdown/src/oneshot_trigger.rs b/infrastructure/shutdown/src/oneshot_trigger.rs new file mode 100644 index 0000000000..7d2ee5b46a --- /dev/null +++ b/infrastructure/shutdown/src/oneshot_trigger.rs @@ -0,0 +1,106 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{ + channel::{oneshot, oneshot::Receiver}, + future::{FusedFuture, Shared}, + FutureExt, +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +pub fn channel() -> OneshotTrigger { + OneshotTrigger::new() +} + +pub struct OneshotTrigger { + sender: Option>, + signal: OneshotSignal, +} + +impl OneshotTrigger { + pub fn new() -> Self { + let (tx, rx) = oneshot::channel(); + Self { + sender: Some(tx), + signal: rx.shared().into(), + } + } + + pub fn to_signal(&self) -> OneshotSignal { + self.signal.clone() + } + + pub fn broadcast(&mut self, item: T) { + if let Some(tx) = self.sender.take() { + let _ = tx.send(item); + } + } + + pub fn is_used(&self) -> bool { + self.sender.is_none() + } +} + +impl Default for OneshotTrigger { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OneshotSignal { + inner: Shared>, +} + +impl From>> for OneshotSignal { + fn from(inner: Shared>) -> Self { + Self { inner } + } +} + +impl Future for OneshotSignal { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.is_terminated() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(Some(v)), + // Channel canceled + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl FusedFuture for OneshotSignal { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} diff --git a/infrastructure/storage/Cargo.toml b/infrastructure/storage/Cargo.toml index 5db74c5097..3915bde787 100644 --- a/infrastructure/storage/Cargo.toml +++ b/infrastructure/storage/Cargo.toml @@ -13,13 +13,13 @@ edition = "2018" bincode = "1.1" log = "0.4.0" lmdb-zero = "0.4.4" -thiserror = "1.0.20" +thiserror = "1.0.26" rmp = "0.8.7" rmp-serde = "0.13.7" serde = "1.0.80" serde_derive = "1.0.80" tari_utilities = "^0.3" -bytes = "0.4.12" +bytes = "0.5" [dev-dependencies] rand = "0.8" diff --git a/infrastructure/test_utils/Cargo.toml b/infrastructure/test_utils/Cargo.toml index bee18392c9..c2ea538ffb 100644 --- a/infrastructure/test_utils/Cargo.toml +++ b/infrastructure/test_utils/Cargo.toml @@ -10,9 +10,10 @@ license = "BSD-3-Clause" [dependencies] tari_shutdown = {version="*", path="../shutdown"} + futures-test = { version = "^0.3.1" } futures = {version= "^0.3.1"} rand = "0.8" -tokio = {version= "0.2.10", features=["rt-threaded", "time", "io-driver"]} +tokio = {version= "1.10", features=["rt-multi-thread", "time"]} lazy_static = "1.3.0" tempfile = "3.1.0" diff --git a/infrastructure/test_utils/src/futures/async_assert_eventually.rs b/infrastructure/test_utils/src/futures/async_assert_eventually.rs index cddfb29b37..f9a1f3ee9d 100644 --- a/infrastructure/test_utils/src/futures/async_assert_eventually.rs +++ b/infrastructure/test_utils/src/futures/async_assert_eventually.rs @@ -46,7 +46,7 @@ macro_rules! async_assert_eventually { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; value = $check_expr; } }}; @@ -82,7 +82,7 @@ macro_rules! async_assert { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; } }}; ($check_expr:expr$(,)?) => {{ diff --git a/infrastructure/test_utils/src/runtime.rs b/infrastructure/test_utils/src/runtime.rs index 7b8c5faa57..22ee5962a6 100644 --- a/infrastructure/test_utils/src/runtime.rs +++ b/infrastructure/test_utils/src/runtime.rs @@ -26,12 +26,8 @@ use tari_shutdown::Shutdown; use tokio::{runtime, runtime::Runtime, task, task::JoinError}; pub fn create_runtime() -> Runtime { - tokio::runtime::Builder::new() - .threaded_scheduler() - .enable_io() - .enable_time() - .max_threads(8) - .core_threads(4) + tokio::runtime::Builder::new_multi_thread() + .enable_all() .build() .expect("Could not create runtime") } @@ -49,7 +45,7 @@ where F: Future + Send + 'static { /// Create a runtime and report if it panics. If there are tasks still running after the panic, this /// will carry on running forever. -// #[deprecated(note = "use tokio_macros::test instead")] +// #[deprecated(note = "use tokio::test instead")] pub fn test_async(f: F) where F: FnOnce(&mut TestRuntime) { let mut rt = TestRuntime::from(create_runtime()); diff --git a/infrastructure/test_utils/src/streams/mod.rs b/infrastructure/test_utils/src/streams/mod.rs index a70e588f7e..177bea14b1 100644 --- a/infrastructure/test_utils/src/streams/mod.rs +++ b/infrastructure/test_utils/src/streams/mod.rs @@ -20,8 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{Stream, StreamExt}; -use std::{collections::HashMap, hash::Hash, time::Duration}; +use futures::{stream, Stream, StreamExt}; +use std::{borrow::BorrowMut, collections::HashMap, hash::Hash, time::Duration}; +use tokio::sync::{broadcast, mpsc}; #[allow(dead_code)] #[allow(clippy::mutable_key_type)] // Note: Clippy Breaks with Interior Mutability Error @@ -54,7 +55,6 @@ where #[macro_export] macro_rules! collect_stream { ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::{Stream, StreamExt}; use tokio::time; // Evaluate $stream once, NOT in the loop 🐛🚨 @@ -62,14 +62,17 @@ macro_rules! collect_stream { let mut items = Vec::new(); loop { - if let Some(item) = time::timeout($timeout, stream.next()).await.expect( - format!( - "Timeout before stream could collect {} item(s). Got {} item(s).", - $take, - items.len() + if let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next(stream)) + .await + .expect( + format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + ) + .as_str(), ) - .as_str(), - ) { + { items.push(item); if items.len() == $take { break items; @@ -80,11 +83,86 @@ macro_rules! collect_stream { } }}; ($stream:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::StreamExt; use tokio::time; + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next($stream)) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + + let mut items = Vec::new(); + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv ended early", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, stream.recv()) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_try_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + let mut items = Vec::new(); - while let Some(item) = time::timeout($timeout, $stream.next()) + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv returned unexpected result", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Ok(item) = time::timeout($timeout, stream.recv()) .await .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) { @@ -102,9 +180,9 @@ macro_rules! collect_stream { /// # use std::time::Duration; /// # use tari_test_utils::collect_stream_count; /// -/// let mut rt = Runtime::new().unwrap(); +/// let rt = Runtime::new().unwrap(); /// let mut stream = stream::iter(vec![1,2,2,3,2]); -/// assert_eq!(rt.block_on(async { collect_stream_count!(stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); +/// assert_eq!(rt.block_on(async { collect_stream_count!(&mut stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); /// ``` #[macro_export] macro_rules! collect_stream_count { @@ -139,3 +217,56 @@ where } } } + +pub async fn assert_in_mpsc(rx: &mut mpsc::Receiver, mut predicate: P, timeout: Duration) -> R +where P: FnMut(T) -> Option { + loop { + if let Some(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the mpsc stream ended"); + } + } +} + +pub async fn assert_in_broadcast(rx: &mut broadcast::Receiver, mut predicate: P, timeout: Duration) -> R +where + P: FnMut(T) -> Option, + T: Clone, +{ + loop { + if let Ok(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the broadcast channel ended"); + } + } +} + +pub fn convert_mpsc_to_stream(rx: &mut mpsc::Receiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_unbounded_mpsc_to_stream(rx: &mut mpsc::UnboundedReceiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_broadcast_to_stream<'a, T, S>(rx: S) -> impl Stream> + 'a +where + T: Clone + Send + 'static, + S: BorrowMut> + 'a, +{ + stream::unfold(rx, |mut rx| async move { + Some(rx.borrow_mut().recv().await).map(|t| (t, rx)) + }) +} diff --git a/integration_tests/features/Mempool.feature b/integration_tests/features/Mempool.feature index 36113cd9b9..a77d837a08 100644 --- a/integration_tests/features/Mempool.feature +++ b/integration_tests/features/Mempool.feature @@ -86,6 +86,7 @@ Feature: Mempool # Collects 7 coinbases into one wallet, send 7 transactions # Stronger chain # + Given I do not expect all automated transactions to succeed Given I have a seed node SEED_A And I have a base node NODE_A1 connected to seed SEED_A And I have wallet WALLET_A1 connected to seed node SEED_A @@ -198,4 +199,4 @@ Feature: Mempool When I submit transaction TX1 to BN1 Then I wait until base node BN1 has 1 unconfirmed transactions in its mempool When I mine 1 blocks on BN1 - Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool \ No newline at end of file + Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool diff --git a/integration_tests/features/Reorgs.feature b/integration_tests/features/Reorgs.feature index 008df39e8e..d249f05015 100644 --- a/integration_tests/features/Reorgs.feature +++ b/integration_tests/features/Reorgs.feature @@ -96,6 +96,7 @@ Feature: Reorgs @critical @reorg Scenario: Zero-conf reorg with spending + Given I do not expect all automated transactions to succeed Given I have a base node NODE1 connected to all seed nodes Given I have a base node NODE2 connected to node NODE1 When I mine 14 blocks on NODE1 @@ -142,6 +143,7 @@ Feature: Reorgs # Chain 1a: # Mine X1 blocks (orphan_storage_capacity default set to 10) # + Given I do not expect all automated transactions to succeed Given I have a seed node SEED_A1 # Add multiple base nodes to ensure more robust comms And I have a base node NODE_A1 connected to seed SEED_A1 diff --git a/integration_tests/features/StressTest.feature b/integration_tests/features/StressTest.feature index 0425f45644..1e975b03f5 100644 --- a/integration_tests/features/StressTest.feature +++ b/integration_tests/features/StressTest.feature @@ -12,18 +12,18 @@ Feature: Stress Test And I have stress-test wallet WALLET_B connected to the seed node NODE2 with broadcast monitoring timeout # There need to be at least as many mature coinbase UTXOs in the wallet coin splits required for the number of transactions When I merge mine blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip When I wait for wallet WALLET_A to have at least 5100000000 uT Then I coin split tari in wallet WALLET_A to produce UTXOs of 5000 uT each with fee_per_gram 20 uT When I merge mine 3 blocks via PROXY When I merge mine blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip Then wallet WALLET_A detects all transactions as Mined_Confirmed When I send transactions of 1111 uT each from wallet WALLET_A to wallet WALLET_B at fee_per_gram 20 # Mine enough blocks for the first block of transactions to be confirmed. When I merge mine 4 blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip # Now wait until all transactions are detected as confirmed in WALLET_A, continue to mine blocks if transactions # are not found to be confirmed as sometimes the previous mining occurs faster than transactions are submitted # to the mempool diff --git a/integration_tests/features/WalletFFI.feature b/integration_tests/features/WalletFFI.feature index 11b905051e..9432989270 100644 --- a/integration_tests/features/WalletFFI.feature +++ b/integration_tests/features/WalletFFI.feature @@ -1,129 +1,136 @@ @wallet-ffi Feature: Wallet FFI + # Increase heap memory available to nodejs if frequent crashing occurs with + # error being be similar to this: `0x1a32cd5 V8_Fatal(char const*, ...)` - Scenario: As a client I want to send Tari to a Public Key - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - Scenario: As a client I want to specify a custom fee when I send tari - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - Scenario: As a client I want to receive Tari via my Public Key while I am online - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + # It's just calling the encrypt function, we don't test if it's actually encrypted + Scenario: As a client I want to be able to protect my wallet with a passphrase + Given I have a base node BASE + And I have a ffi wallet FFI_WALLET connected to base node BASE + And I set passphrase PASSPHRASE of ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - @long-running @broken - Scenario: As a client I want to receive Tari via my Public Key sent while I am offline when I come back online + Scenario: As a client I want to see my whoami info Given I have a base node BASE - And I have wallet SENDER connected to base node BASE - And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks - Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE - And I stop wallet FFI_WALLET + Then I want to get public key of ffi wallet FFI_WALLET + And I want to get emoji id of ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET + + Scenario: As a client I want to be able to restore my ffi wallet from seed words + Given I have a base node BASE + And I have wallet SPECTATOR connected to base node BASE + And I have mining node MINER connected to base node BASE and wallet SPECTATOR + And mining node MINER mines 10 blocks + Then I wait for wallet SPECTATOR to have at least 1000000 uT + Then I recover wallet SPECTATOR into ffi wallet FFI_WALLET from seed words on node BASE + And I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET + + Scenario: As a client I want to set the base node + Given I have a base node BASE1 + Given I have a base node BASE2 + And I have a ffi wallet FFI_WALLET connected to base node BASE1 + And I set base node BASE2 for ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET + And I stop node BASE1 And I wait 5 seconds - And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 + And I restart ffi wallet FFI_WALLET + # Possibly check SAF messages, no way to get current connected base node peer from the library itself afaik + # Good idea just to add a fn to do this to the library. + # Then I wait for ffi wallet FFI_WALLET to receive 1 SAF message And I wait 5 seconds - And I start wallet FFI_WALLET - And wallet SENDER detects all transactions are at least Broadcast - And mining node MINER mines 10 blocks - Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET - @long-running - Scenario: As a client I want to retrieve a list of transactions I have made and received + Scenario: As a client I want to cancel a transaction Given I have a base node BASE And I have wallet SENDER connected to base node BASE And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks + And mining node MINER mines 10 blocks Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 And wallet SENDER detects all transactions are at least Broadcast And mining node MINER mines 10 blocks Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT - And Check callbacks for finished inbound tx on ffi wallet FFI_WALLET And I have wallet RECEIVER connected to base node BASE + And I stop wallet RECEIVER And I send 1000000 uT from ffi wallet FFI_WALLET to wallet RECEIVER at fee 100 - And ffi wallet FFI_WALLET has 1 broadcast transaction - And mining node MINER mines 4 blocks - Then I wait for wallet RECEIVER to have at least 1000000 uT - And Check callbacks for finished outbound tx on ffi wallet FFI_WALLET - And I have 1 received and 1 send transaction in ffi wallet FFI_WALLET - And I start STXO validation on wallet FFI_WALLET - And I start UTXO validation on wallet FFI_WALLET - - # It's just calling the encrypt function, we don't test if it's actually encrypted - Scenario: As a client I want to be able to protect my wallet with a passphrase - Given I have a base node BASE - And I have a ffi wallet FFI_WALLET connected to base node BASE - And I set passphrase PASSPHRASE of ffi wallet FFI_WALLET + Then I wait for ffi wallet FFI_WALLET to have 1 pending outbound transaction + Then I cancel all outbound transactions on ffi wallet FFI_WALLET and it will cancel 1 transaction + And I stop ffi wallet FFI_WALLET Scenario: As a client I want to manage contacts Given I have a base node BASE And I have a ffi wallet FFI_WALLET connected to base node BASE And I have wallet WALLET connected to base node BASE + And I wait 5 seconds And I add contact with alias ALIAS and pubkey WALLET to ffi wallet FFI_WALLET Then I have contact with alias ALIAS and pubkey WALLET in ffi wallet FFI_WALLET When I remove contact with alias ALIAS from ffi wallet FFI_WALLET Then I don't have contact with alias ALIAS in ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - Scenario: As a client I want to set the base node (should be persisted) - Given I have a base node BASE1 - Given I have a base node BASE2 - And I have a ffi wallet FFI_WALLET connected to base node BASE1 - And I set base node BASE2 for ffi wallet FFI_WALLET - Then BASE2 is connected to FFI_WALLET - And I stop wallet FFI_WALLET - And I wait 5 seconds - And I start wallet FFI_WALLET - Then BASE2 is connected to FFI_WALLET - - Scenario: As a client I want to see my public_key, emoji ID, address (whoami) - Given I have a base node BASE - And I have a ffi wallet FFI_WALLET connected to base node BASE - Then I want to get public key of ffi wallet FFI_WALLET - And I want to get emoji id of ffi wallet FFI_WALLET - - Scenario: As a client I want to get my balance - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - @long-running - Scenario: As a client I want to cancel a transaction + Scenario: As a client I want to retrieve a list of transactions I have made and received Given I have a base node BASE And I have wallet SENDER connected to base node BASE And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks + And mining node MINER mines 10 blocks Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 - And wallet SENDER detects all transactions are at least Broadcast And mining node MINER mines 10 blocks Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT And I have wallet RECEIVER connected to base node BASE - And I stop wallet RECEIVER And I send 1000000 uT from ffi wallet FFI_WALLET to wallet RECEIVER at fee 100 - Then I wait for ffi wallet FFI_WALLET to have 1 pending outbound transaction - Then I cancel all transactions on ffi wallet FFI_WALLET and it will cancel 1 transaction + And mining node MINER mines 10 blocks + Then I wait for wallet RECEIVER to have at least 1000000 uT + And I have 1 received and 1 send transaction in ffi wallet FFI_WALLET + And I start STXO validation on ffi wallet FFI_WALLET + And I start UTXO validation on ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - @long-running - Scenario: As a client I want to be able to restore my wallet from seed words + Scenario: As a client I want to receive Tari via my Public Key sent while I am offline when I come back online Given I have a base node BASE - And I have wallet WALLET connected to base node BASE - And I have mining node MINER connected to base node BASE and wallet WALLET - And mining node MINER mines 4 blocks - Then I wait for wallet WALLET to have at least 1000000 uT - Then I recover wallet WALLET into ffi wallet FFI_WALLET from seed words on node BASE - And I wait for recovery of wallet FFI_WALLET to finish - And I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I have wallet SENDER connected to base node BASE + And I have mining node MINER connected to base node BASE and wallet SENDER + And mining node MINER mines 10 blocks + Then I wait for wallet SENDER to have at least 1000000 uT + And I have a ffi wallet FFI_WALLET connected to base node BASE + And I stop ffi wallet FFI_WALLET + And I wait 10 seconds + And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 + And I wait 5 seconds + And I restart ffi wallet FFI_WALLET + Then I wait for ffi wallet FFI_WALLET to receive 1 transaction + Then I wait for ffi wallet FFI_WALLET to receive 1 finalization + # Assume tx will be mined to reduce time taken for test, balance is tested in later scenarios. + # And mining node MINER mines 10 blocks + # Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET + + # Scenario: As a client I want to get my balance + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + + #Scenario: As a client I want to send Tari to a Public Key + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want to be able to initiate TXO and TX validation with the specifed base node. + #Scenario: As a client I want to specify a custom fee when I send tari # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about the progress of sending and receiving a transaction + #Scenario: As a client I want to receive Tari via my Public Key while I am online # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about my connection status to the specifed Base Node + # Scenario: As a client I want to be able to initiate TXO and TX validation with the specifed base node. + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + + # Scenario: As a client I want feedback about the progress of sending and receiving a transaction + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about the wallet restoration process + # Scenario: As a client I want feedback about my connection status to the specifed Base Node + + # Scenario: As a client I want feedback about the wallet restoration process # As a client I want to be able to restore my wallet from seed words - Scenario: As a client I want async feedback about TXO and TX validation processes -# It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + # Scenario: As a client I want feedback about TXO and TX validation processes + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" diff --git a/integration_tests/features/support/steps.js b/integration_tests/features/support/steps.js index fcedbaeeff..ef93992b5e 100644 --- a/integration_tests/features/support/steps.js +++ b/integration_tests/features/support/steps.js @@ -37,6 +37,22 @@ Given("I have {int} seed nodes", { timeout: 20 * 1000 }, async function (n) { await Promise.all(promises); }); +Given( + /I do not expect all automated transactions to succeed/, + { timeout: 20 * 1000 }, + async function () { + this.checkAutoTransactions = false; + } +); + +Given( + /I expect all automated transactions to succeed/, + { timeout: 20 * 1000 }, + async function () { + this.checkAutoTransactions = true; + } +); + Given( /I have a base node (.*) connected to all seed nodes/, { timeout: 20 * 1000 }, @@ -942,21 +958,6 @@ Then( } ); -Then( - "all nodes are at current tip height", - { timeout: 1200 * 1000 }, - async function () { - const height = parseInt(this.tipHeight); - console.log("Wait for all nodes to reach height of", height); - await this.forEachClientAsync(async (client, name) => { - await waitFor(async () => client.getTipHeight(), height, 1200 * 1000); - const currTip = await client.getTipHeight(); - console.log(`Node ${name} is at tip: ${currTip} (expected ${height})`); - expect(currTip).to.equal(height); - }); - } -); - Then( /all nodes are at the same height as node (.*)/, { timeout: 1200 * 1000 }, @@ -1106,20 +1107,24 @@ When(/I spend outputs (.*) via (.*)/, async function (inputs, node) { expect(this.lastResult.result).to.equal("ACCEPTED"); }); -Then(/(.*) has (.*) in (.*) state/, async function (node, txn, pool) { - const client = this.getClient(node); - const sig = this.transactions[txn].body.kernels[0].excess_sig; - await waitFor( - async () => await client.transactionStateResult(sig), - pool, - 1200 * 1000 - ); - this.lastResult = await this.getClient(node).transactionState( - this.transactions[txn].body.kernels[0].excess_sig - ); - console.log(`Node ${node} response is: ${this.lastResult.result}`); - expect(this.lastResult.result).to.equal(pool); -}); +Then( + /(.*) has (.*) in (.*) state/, + { timeout: 21 * 60 * 1000 }, + async function (node, txn, pool) { + const client = this.getClient(node); + const sig = this.transactions[txn].body.kernels[0].excess_sig; + await waitForPredicate( + async () => (await client.transactionStateResult(sig)) === pool, + 20 * 60 * 1000, + 1000 + ); + this.lastResult = await this.getClient(node).transactionState( + this.transactions[txn].body.kernels[0].excess_sig + ); + console.log(`Node ${node} response is: ${this.lastResult.result}`); + expect(this.lastResult.result).to.equal(pool); + } +); // The number is rounded down. E.g. if 1% can fail out of 17, that is 16.83 have to succeed. // It's means at least 16 have to succeed. @@ -1186,11 +1191,16 @@ When( /I mine a block on (.*) with coinbase (.*)/, { timeout: 600 * 1000 }, async function (name, coinbaseName) { + const tipHeight = await this.getClient(name).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await this.mineBlock(name, 0, (candidate) => { this.addOutput(coinbaseName, candidate.originalTemplate.coinbase); return candidate; }); - this.tipHeight += 1; } ); @@ -1198,11 +1208,25 @@ When( /I mine (\d+) custom weight blocks on (.*) with weight (\d+)/, { timeout: -1 }, async function (numBlocks, name, weight) { + const tipHeight = await this.getClient(name).getTipHeight(); for (let i = 0; i < numBlocks; i++) { + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + i + 1 + ); + expect(autoTransactionResult).to.equal(true); // If a block cannot be mined quickly enough (or the process has frozen), timeout. - await withTimeout(60 * 1000, this.mineBlock(name, parseInt(weight))); + await withTimeout( + 60 * 1000, + this.mineBlock(name, parseInt(weight), (candidate) => { + this.addTransactionOutput( + tipHeight + i + 1 + 2, + candidate.originalTemplate.coinbase + ); + return candidate; + }) + ); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1244,10 +1268,24 @@ When( /I mine (\d+) blocks on (.*)/, { timeout: -1 }, async function (numBlocks, name) { + const tipHeight = await this.getClient(name).getTipHeight(); for (let i = 0; i < numBlocks; i++) { - await withTimeout(60 * 1000, this.mineBlock(name, 0)); + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + i + 1 + ); + expect(autoTransactionResult).to.equal(true); + await withTimeout( + 60 * 1000, + this.mineBlock(name, 0, (candidate) => { + this.addTransactionOutput( + tipHeight + i + 1 + 2, + candidate.originalTemplate.coinbase + ); + return candidate; + }) + ); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1257,7 +1295,13 @@ When( async function (numBlocks, walletName, nodeName) { const nodeClient = this.getClient(nodeName); const walletClient = await this.getWallet(walletName).connectClient(); + const tipHeight = await this.getClient(nodeName).getTipHeight(); for (let i = 0; i < numBlocks; i++) { + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + i + ); + expect(autoTransactionResult).to.equal(true); await nodeClient.mineBlock(walletClient); } } @@ -1270,7 +1314,6 @@ When( for (let i = 0; i < numBlocks; i++) { await this.mergeMineBlock(mmProxy); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1282,7 +1325,8 @@ When( /I co-mine (.*) blocks via merge mining proxy (.*) and base node (.*) with wallet (.*)/, { timeout: 1200 * 1000 }, async function (numBlocks, mmProxy, node, wallet) { - this.lastResult = this.tipHeight; + let tipHeight = await this.getClient(node).getTipHeight(); + this.lastResult = tipHeight; const baseNodeMiningPromise = await this.baseNodeMineBlocksUntilHeightIncreasedBy( node, @@ -1295,13 +1339,13 @@ When( ); await Promise.all([baseNodeMiningPromise, mergeMiningPromise]).then( ([res1, res2]) => { - this.tipHeight = Math.max(res1, res2); - this.lastResult = this.tipHeight - this.lastResult; + tipHeight = Math.max(res1, res2); + this.lastResult = tipHeight - this.lastResult; console.log( "Co-mining", numBlocks, "blocks concluded, tip at", - this.tipHeight + tipHeight ); } ); @@ -1312,7 +1356,6 @@ When( /I co-mine (.*) blocks via merge mining proxy (.*) and mining node (.*)/, { timeout: 6000 * 1000 }, async function (numBlocks, mmProxy, miner) { - this.lastResult = this.tipHeight; const sha3MiningPromise = this.sha3MineBlocksUntilHeightIncreasedBy( miner, numBlocks, @@ -1324,13 +1367,14 @@ When( ); await Promise.all([sha3MiningPromise, mergeMiningPromise]).then( ([res1, res2]) => { - this.tipHeight = Math.max(res1, res2); - this.lastResult = this.tipHeight - this.lastResult; console.log( "Co-mining", numBlocks, - "blocks concluded, tip at", - this.tipHeight + "blocks concluded, tips at [", + res1, + ",", + res2, + "]" ); } ); @@ -1340,10 +1384,20 @@ When( When( /I mine but do not submit a block (.*) on (.*)/, async function (blockName, nodeName) { + const tipHeight = await this.getClient(nodeName).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await this.mineBlock( nodeName, null, (block) => { + this.addTransactionOutput( + tipHeight + 2, + block.originalTemplate.coinbase + ); this.saveBlock(blockName, block); return false; }, @@ -1362,7 +1416,15 @@ When( const client = this.getClient(node); const template = client.getPreviousBlockTemplate(atHeight); const candidate = await client.getMinedCandidateBlock(0, template); - + let autoTransactionResult = await this.createTransactions( + node, + parseInt(atHeight) + ); + expect(autoTransactionResult).to.equal(true); + this.addTransactionOutput( + parseInt(atHeight) + 1, + candidate.originalTemplate.coinbase + ); await client.submitBlock( candidate.template, (block) => { @@ -2577,8 +2639,13 @@ Then( if (await walletClient.isTransactionMinedConfirmed(txIds[i])) { return true; } else { + const tipHeight = await this.getClient(nodeName).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await nodeClient.mineBlock(walletClient); - this.tipHeight += 1; return false; } }, @@ -2632,7 +2699,6 @@ Then( return true; } else { await this.mergeMineBlock(mmProxy); - this.tipHeight += 1; return false; } }, @@ -3379,29 +3445,6 @@ When( } ); -When( - "I have a ffi wallet {word} connected to base node {word}", - { timeout: 20 * 1000 }, - async function (name, node) { - let wallet = await this.createAndAddFFIWallet(name); - let peer = this.nodes[node].peerAddress().split("::"); - await wallet.addBaseNodePeer(peer[0], peer[1]); - } -); - -Then( - "I want to get public key of ffi wallet {word}", - { timeout: 20 * 1000 }, - async function (name) { - let wallet = this.getWallet(name); - let public_key = await wallet.getPublicKey(); - expect(public_key.length).to.be.equal( - 64, - `Public key has wrong length : ${public_key}` - ); - } -); - Then( /I wait until base node (.*) has (.*) unconfirmed transactions in its mempool/, { timeout: 180 * 1000 }, @@ -3429,57 +3472,120 @@ Then( ); Then( - "I want to get emoji id of ffi wallet {word}", + /node (.*) lists heights (\d+) to (\d+)/, + async function (node, first, last) { + const client = this.getClient(node); + const start = first; + const end = last; + let heights = []; + + for (let i = start; i <= end; i++) { + heights.push(i); + } + const blocks = await client.getBlocks(heights); + const results = blocks.map((result) => + parseInt(result.block.header.height) + ); + let i = 0; // for ordering check + for (let height = start; height <= end; height++) { + expect(results[i]).equal(height); + i++; + } + } +); + +Then( + "I wait for recovery of wallet {word} to finish", + { timeout: 600 * 1000 }, + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + while (wallet.recoveryInProgress) { + await sleep(1000); + } + expect(wallet.recoveryProgress[1]).to.be.greaterThan(0); + expect(wallet.recoveryProgress[0]).to.be.equal(wallet.recoveryProgress[1]); + } +); + +When( + "I have {int} base nodes with pruning horizon {int} force syncing on node {word}", + { timeout: 190 * 1000 }, + async function (nodes_count, horizon, force_sync_to) { + const promises = []; + const force_sync_address = this.getNode(force_sync_to).peerAddress(); + for (let i = 0; i < nodes_count; i++) { + const base_node = this.createNode(`BaseNode${i}`, { + pruningHorizon: horizon, + }); + base_node.setPeerSeeds([force_sync_address]); + base_node.setForceSyncPeers([force_sync_address]); + promises.push( + base_node.startNew().then(() => this.addNode(`BaseNode${i}`, base_node)) + ); + } + await Promise.all(promises); + } +); + +//region FFI +When( + "I have ffi wallet {word} connected to base node {word}", { timeout: 20 * 1000 }, - async function (name) { + async function (name, node) { + let wallet = await this.createAndAddFFIWallet(name); + let peer = this.nodes[node].peerAddress().split("::"); + wallet.addBaseNodePeer(peer[0], peer[1]); + } +); + +Then( + "I want to get public key of ffi wallet {word}", + { timeout: 20 * 1000 }, + function (name) { let wallet = this.getWallet(name); - let emoji_id = await wallet.getEmojiId(); - expect(emoji_id.length).to.be.equal( - 22 * 3, // 22 emojis, 3 bytes per one emoji - `Emoji id has wrong length : ${emoji_id}` + let public_key = wallet.identify(); + expect(public_key.length).to.be.equal( + 64, + `Public key has wrong length : ${public_key}` ); } ); Then( - "I wait for ffi wallet {word} to have at least {int} uT", - { timeout: 60 * 1000 }, - async function (name, amount) { + "I want to get emoji id of ffi wallet {word}", + { timeout: 20 * 1000 }, + async function (name) { let wallet = this.getWallet(name); - let retries = 1; - let balance = 0; - const retries_limit = 12; - while (retries <= retries_limit) { - balance = await wallet.getBalance(); - if (balance >= amount) { - break; - } - await sleep(5000); - ++retries; - } - expect(balance, "Balance is not enough").to.be.greaterThanOrEqual(amount); + let emoji_id = wallet.identifyEmoji(); + console.log(emoji_id); + expect(emoji_id.length).to.be.equal( + 22 * 3, // 22 emojis, 3 bytes per one emoji + `Emoji id has wrong length : ${emoji_id}` + ); } ); When( "I send {int} uT from ffi wallet {word} to wallet {word} at fee {int}", { timeout: 20 * 1000 }, - async function (amount, sender, receiver, fee) { - await this.getWallet(sender).sendTransaction( - await this.getWalletPubkey(receiver), + function (amount, sender, receiver, fee) { + let ffi_wallet = this.getWallet(sender); + let result = ffi_wallet.sendTransaction( + this.getWalletPubkey(receiver), amount, fee, `Send from ffi ${sender} to ${receiver} at fee ${fee}` ); + console.log(result); } ); When( "I set passphrase {word} of ffi wallet {word}", { timeout: 20 * 1000 }, - async function (passphrase, name) { + function (passphrase, name) { let wallet = this.getWallet(name); - await wallet.applyEncryption(passphrase); + wallet.applyEncryption(passphrase); } ); @@ -3488,17 +3594,29 @@ Then( { timeout: 120 * 1000 }, async function (received, send, name) { let wallet = this.getWallet(name); - let [outbound, inbound] = await wallet.getCompletedTransactions(); - let retries = 1; - const retries_limit = 23; - while ( - (inbound != received || outbound != send) && - retries <= retries_limit - ) { - await sleep(5000); - [outbound, inbound] = await wallet.getCompletedTransactions(); - ++retries; + let completed = wallet.getCompletedTxs(); + let inbound = 0; + let outbound = 0; + let length = completed.getLength(); + let inboundTxs = wallet.getInboundTxs(); + inbound += inboundTxs.getLength(); + inboundTxs.destroy(); + let outboundTxs = wallet.getOutboundTxs(); + outbound += outboundTxs.getLength(); + outboundTxs.destroy(); + for (let i = 0; i < length; i++) { + { + let tx = completed.getAt(i); + if (tx.isOutbound()) { + outbound++; + } else { + inbound++; + } + tx.destroy(); + } } + completed.destroy(); + expect(outbound, "Outbound transaction count mismatch").to.be.equal(send); expect(inbound, "Inbound transaction count mismatch").to.be.equal(received); } @@ -3526,70 +3644,86 @@ Then( When( "I add contact with alias {word} and pubkey {word} to ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name, ffi_wallet_name) { + function (alias, wallet_name, ffi_wallet_name) { let ffi_wallet = this.getWallet(ffi_wallet_name); - await ffi_wallet.addContact(alias, await this.getWalletPubkey(wallet_name)); + ffi_wallet.addContact(alias, this.getWalletPubkey(wallet_name)); } ); Then( "I have contact with alias {word} and pubkey {word} in ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name, ffi_wallet_name) { + function (alias, wallet_name, ffi_wallet_name) { + let wallet = this.getWalletPubkey(wallet_name); let ffi_wallet = this.getWallet(ffi_wallet_name); - expect(await this.getWalletPubkey(wallet_name)).to.be.equal( - await ffi_wallet.getContact(alias) - ); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + let found = false; + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let hex = contact.getPubkeyHex(); + if (wallet === hex) { + found = true; + } + contact.destroy(); + } + } + contacts.destroy(); + expect(found).to.be.equal(true); } ); When( "I remove contact with alias {word} from ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, walllet_name) { - let wallet = this.getWallet(walllet_name); - await wallet.removeContact(alias); + function (alias, wallet_name) { + let ffi_wallet = this.getWallet(wallet_name); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let calias = contact.getAlias(); + if (alias === calias) { + ffi_wallet.removeContact(contact); + } + contact.destroy(); + } + } + contacts.destroy(); } ); Then( "I don't have contact with alias {word} in ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name) { - let wallet = this.getWallet(wallet_name); - expect(await wallet.getContact("alias")).to.be.undefined; - } -); - -Then( - /node (.*) lists heights (\d+) to (\d+)/, - async function (node, first, last) { - const client = this.getClient(node); - const start = first; - const end = last; - let heights = []; - - for (let i = start; i <= end; i++) { - heights.push(i); - } - const blocks = await client.getBlocks(heights); - const results = blocks.map((result) => - parseInt(result.block.header.height) - ); - let i = 0; // for ordering check - for (let height = start; height <= end; height++) { - expect(results[i]).equal(height); - i++; + function (alias, wallet_name) { + let ffi_wallet = this.getWallet(wallet_name); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + let found = false; + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let calias = contact.getAlias(); + if (alias === calias) { + found = true; + } + contact.destroy(); + } } + contacts.destroy(); + expect(found).to.be.equal(false); } ); When( "I set base node {word} for ffi wallet {word}", - async function (node, wallet_name) { + function (node, wallet_name) { let wallet = this.getWallet(wallet_name); let peer = this.nodes[node].peerAddress().split("::"); - await wallet.addBaseNodePeer(peer[0], peer[1]); + wallet.addBaseNodePeer(peer[0], peer[1]); } ); @@ -3598,26 +3732,48 @@ Then( { timeout: 120 * 1000 }, async function (wallet_name, count) { let wallet = this.getWallet(wallet_name); - let broadcast = await wallet.getOutboundTransactionsCount(); + let broadcast = wallet.getOutboundTransactions(); + let length = broadcast.getLength(); + broadcast.destroy(); let retries = 1; const retries_limit = 24; - while (broadcast != count && retries <= retries_limit) { + while (length != count && retries <= retries_limit) { await sleep(5000); - broadcast = await wallet.getOutboundTransactionsCount(); + broadcast = wallet.getOutboundTransactions(); + length = broadcast.getLength(); + broadcast.destroy(); ++retries; } - expect(broadcast, "Number of pending messages mismatch").to.be.equal(count); + expect(length, "Number of pending messages mismatch").to.be.equal(count); } ); Then( - "I cancel all transactions on ffi wallet {word} and it will cancel {int} transaction", + "I cancel all outbound transactions on ffi wallet {word} and it will cancel {int} transaction", async function (wallet_name, count) { const wallet = this.getWallet(wallet_name); - expect( - await wallet.cancelAllOutboundTransactions(), - "Number of cancelled transactions" - ).to.be.equal(count); + let txs = wallet.getOutboundTransactions(); + let cancelled = 0; + for (let i = 0; i < txs.getLength(); i++) { + let tx = txs.getAt(i); + let cancellation = wallet.cancelPendingTransaction(tx.getTransactionID()); + tx.destroy(); + if (cancellation) { + cancelled++; + } + } + txs.destroy(); + expect(cancelled).to.be.equal(count); + } +); + +Given( + /I have a ffi wallet (.*) connected to base node (.*)/, + { timeout: 20 * 1000 }, + async function (walletName, nodeName) { + let ffi_wallet = await this.createAndAddFFIWallet(walletName, null); + let peer = this.nodes[nodeName].peerAddress().split("::"); + ffi_wallet.addBaseNodePeer(peer[0], peer[1]); } ); @@ -3634,78 +3790,207 @@ Then( seed_words_text ); let peer = this.nodes[node].peerAddress().split("::"); - await ffi_wallet.addBaseNodePeer(peer[0], peer[1]); - await ffi_wallet.startRecovery(peer[0]); + ffi_wallet.addBaseNodePeer(peer[0], peer[1]); + ffi_wallet.startRecovery(peer[0]); } ); Then( - "I wait for recovery of wallet {word} to finish", - { timeout: 600 * 1000 }, + "Check callbacks for finished inbound tx on ffi wallet {word}", async function (wallet_name) { const wallet = this.getWallet(wallet_name); - while (wallet.recoveryInProgress) { - await sleep(1000); + expect(wallet.receivedTransaction).to.be.greaterThanOrEqual(1); + expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); + wallet.clearCallbackCounters(); + } +); + +Then( + "Check callbacks for finished outbound tx on ffi wallet {word}", + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + expect(wallet.receivedTransactionReply).to.be.greaterThanOrEqual(1); + expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); + wallet.clearCallbackCounters(); + } +); + +Then( + /I wait for ffi wallet (.*) to receive (.*) transaction/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + wallet_name + " to receive " + amount + " transaction(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().received >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().received >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); } - expect(wallet.recoveryProgress[1]).to.be.greaterThan(0); - expect(wallet.recoveryProgress[0]).to.be.equal(wallet.recoveryProgress[1]); + expect(wallet.getCounters().received >= amount).to.equal(true); } ); -Then("I start STXO validation on wallet {word}", async function (wallet_name) { - const wallet = this.getWallet(wallet_name); - await wallet.startStxoValidation(); - while (!wallet.stxo_validation_complete) { - await sleep(1000); +Then( + /I wait for ffi wallet (.*) to receive (.*) finalization/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + + wallet_name + + " to receive " + + amount + + " transaction finalization(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().finalized >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().finalized >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); + } + expect(wallet.getCounters().finalized >= amount).to.equal(true); } - expect(wallet.stxo_validation_result).to.be.equal(0); -}); +); -Then("I start UTXO validation on wallet {word}", async function (wallet_name) { - const wallet = this.getWallet(wallet_name); - await wallet.startUtxoValidation(); - while (!wallet.utxo_validation_complete) { - await sleep(1000); +Then( + /I wait for ffi wallet (.*) to receive (.*) SAF message/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + + wallet_name + + " to receive " + + amount + + " SAF messages(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().saf >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().saf >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); + } + expect(wallet.getCounters().saf >= amount).to.equal(true); } - expect(wallet.utxo_validation_result).to.be.equal(0); -}); +); Then( - "Check callbacks for finished inbound tx on ffi wallet {word}", - async function (wallet_name) { + /I wait for ffi wallet (.*) to have at least (.*) uT/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + wallet_name + " balance to be at least " + amount + " uT" + ); + + let count = 0; + + while (!(wallet.getBalance().available >= amount)) { + await sleep(1000); + count++; + if (count > 700) { + break; + } + } + + let balance = wallet.getBalance().available; + + if (!(balance >= amount)) { + console.log("Balance not adequate!"); + } else { + console.log(wallet.getBalance()); + } + expect(balance >= amount).to.equal(true); + } +); + +Then( + "I wait for recovery of ffi wallet {word} to finish", + { timeout: 600 * 1000 }, + function (wallet_name) { const wallet = this.getWallet(wallet_name); - expect(wallet.receivedTransaction).to.be.greaterThanOrEqual(1); - expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); - wallet.clearCallbackCounters(); + while (!wallet.recoveryFinished) { + sleep(1000).then(); + } } ); +When(/I start ffi wallet (.*)/, async function (walletName) { + let wallet = this.getWallet(walletName); + await wallet.startNew(null, null); +}); + +When(/I restart ffi wallet (.*)/, async function (walletName) { + let wallet = this.getWallet(walletName); + await wallet.restart(); +}); + +When(/I stop ffi wallet (.*)/, function (walletName) { + let wallet = this.getWallet(walletName); + wallet.stop(); + wallet.resetCounters(); +}); + Then( - "Check callbacks for finished outbound tx on ffi wallet {word}", + "I start STXO validation on ffi wallet {word}", async function (wallet_name) { const wallet = this.getWallet(wallet_name); - expect(wallet.receivedTransactionReply).to.be.greaterThanOrEqual(1); - expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); - wallet.clearCallbackCounters(); + await wallet.startStxoValidation(); + while (!wallet.getStxoValidationStatus().stxo_validation_complete) { + await sleep(1000); + } } ); -When( - "I have {int} base nodes with pruning horizon {int} force syncing on node {word}", - { timeout: 190 * 1000 }, - async function (nodes_count, horizon, force_sync_to) { - const promises = []; - const force_sync_address = this.getNode(force_sync_to).peerAddress(); - for (let i = 0; i < nodes_count; i++) { - const base_node = this.createNode(`BaseNode${i}`, { - pruningHorizon: horizon, - }); - base_node.setPeerSeeds([force_sync_address]); - base_node.setForceSyncPeers([force_sync_address]); - promises.push( - base_node.startNew().then(() => this.addNode(`BaseNode${i}`, base_node)) - ); +Then( + "I start UTXO validation on ffi wallet {word}", + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + await wallet.startUtxoValidation(); + while (!wallet.getUtxoValidationStatus().utxo_validation_complete) { + await sleep(1000); } - await Promise.all(promises); } ); +//endregion diff --git a/integration_tests/features/support/world.js b/integration_tests/features/support/world.js index 6ca3d0f699..91e8b45f16 100644 --- a/integration_tests/features/support/world.js +++ b/integration_tests/features/support/world.js @@ -5,6 +5,7 @@ const MergeMiningProxyProcess = require("../../helpers/mergeMiningProxyProcess") const WalletProcess = require("../../helpers/walletProcess"); const WalletFFIClient = require("../../helpers/walletFFIClient"); const MiningNodeProcess = require("../../helpers/miningNodeProcess"); +const TransactionBuilder = require("../../helpers/transactionBuilder"); const glob = require("glob"); const fs = require("fs"); const archiver = require("archiver"); @@ -12,7 +13,7 @@ class CustomWorld { constructor({ attach, parameters }) { // this.variable = 0; this.attach = attach; - + this.checkAutoTransactions = true; this.seeds = {}; this.nodes = {}; this.proxies = {}; @@ -23,6 +24,7 @@ class CustomWorld { this.clients = {}; this.headers = {}; this.outputs = {}; + this.transactionOutputs = {}; this.testrun = `run${Date.now()}`; this.lastResult = null; this.blocks = {}; @@ -30,7 +32,6 @@ class CustomWorld { this.peers = {}; this.transactionsMap = new Map(); this.resultStack = []; - this.tipHeight = 0; this.logFilePathBaseNode = parameters.logFilePathBaseNode || "./log4rs/base_node.yml"; this.logFilePathProxy = parameters.logFilePathProxy || "./log4rs/proxy.yml"; @@ -106,11 +107,11 @@ class CustomWorld { this.walletPubkeys[name] = walletInfo.public_key; } - async createAndAddFFIWallet(name, seed_words) { + async createAndAddFFIWallet(name, seed_words = null, passphrase = null) { const wallet = new WalletFFIClient(name); - await wallet.startNew(seed_words); + await wallet.startNew(seed_words, passphrase); this.walletsFFI[name] = wallet; - this.walletPubkeys[name] = await wallet.getPublicKey(); + this.walletPubkeys[name] = wallet.identify(); return wallet; } @@ -126,6 +127,47 @@ class CustomWorld { this.outputs[name] = output; } + addTransactionOutput(spendHeight, output) { + if (this.transactionOutputs[spendHeight] == null) { + this.transactionOutputs[spendHeight] = [output]; + } else { + this.transactionOutputs[spendHeight].push(output); + } + } + + async createTransactions(name, height) { + let result = true; + const txInputs = this.transactionOutputs[height]; + if (txInputs == null) { + return result; + } + let i = 0; + for (const input of txInputs) { + const txn = new TransactionBuilder(); + txn.addInput(input); + const txOutput = txn.addOutput(txn.getSpendableAmount()); + this.addTransactionOutput(height + 1, txOutput); + const completedTx = txn.build(); + const submitResult = await this.getClient(name).submitTransaction( + completedTx + ); + if (this.checkAutoTransactions && submitResult.result != "ACCEPTED") { + result = false; + } + if (submitResult.result == "ACCEPTED") { + i++; + } + if (i > 9) { + //this is to make sure the blocks stay relatively empty so that the tests don't take too long + break; + } + } + console.log( + `Created ${i} transactions for node: ${name} at height: ${height}` + ); + return result; + } + async mineBlock(name, weight, beforeSubmit, onError) { await this.clients[name].mineBlockWithoutWallet( beforeSubmit, diff --git a/integration_tests/helpers/ffi/byteVector.js b/integration_tests/helpers/ffi/byteVector.js index 51f5d338bd..245cb4320e 100644 --- a/integration_tests/helpers/ffi/byteVector.js +++ b/integration_tests/helpers/ffi/byteVector.js @@ -1,28 +1,51 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class ByteVector { #byte_vector_ptr; - constructor(byte_vector_ptr) { - this.#byte_vector_ptr = byte_vector_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#byte_vector_ptr) { + this.destroy(); + this.#byte_vector_ptr = ptr; + } else { + this.#byte_vector_ptr = ptr; + } } - static async fromBuffer(buffer) { - let buf = Buffer.from(buffer, "utf-8"); // get the bytes + fromBytes(input) { + let buf = Buffer.from(input, "utf-8"); // ensure encoding is utf=8, js default is utf-16 let len = buf.length; // get the length - return new ByteVector(await WalletFFI.byteVectorCreate(buf, len)); + let result = new ByteVector(); + result.pointerAssign(InterfaceFFI.byteVectorCreate(buf, len)); + return result; + } + + getBytes() { + let result = []; + for (let i = 0; i < this.getLength(); i++) { + result.push(this.getAt(i)); + } + return result; } getLength() { - return WalletFFI.byteVectorGetLength(this.#byte_vector_ptr); + return InterfaceFFI.byteVectorGetLength(this.#byte_vector_ptr); } getAt(position) { - return WalletFFI.byteVectorGetAt(this.#byte_vector_ptr, position); + return InterfaceFFI.byteVectorGetAt(this.#byte_vector_ptr, position); + } + + getPtr() { + return this.#byte_vector_ptr; } destroy() { - return WalletFFI.byteVectorDestroy(this.#byte_vector_ptr); + if (this.#byte_vector_ptr) { + InterfaceFFI.byteVectorDestroy(this.#byte_vector_ptr); + this.#byte_vector_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/commsConfig.js b/integration_tests/helpers/ffi/commsConfig.js new file mode 100644 index 0000000000..9bb9ddcb7a --- /dev/null +++ b/integration_tests/helpers/ffi/commsConfig.js @@ -0,0 +1,43 @@ +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); + +class CommsConfig { + #comms_config_ptr; + + constructor( + public_address, + transport_ptr, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network + ) { + let sanitize_address = utf8.encode(public_address); + let sanitize_db_name = utf8.encode(database_name); + let sanitize_db_path = utf8.encode(datastore_path); + let sanitize_network = utf8.encode(network); + this.#comms_config_ptr = InterfaceFFI.commsConfigCreate( + sanitize_address, + transport_ptr, + sanitize_db_name, + sanitize_db_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + sanitize_network + ); + } + + getPtr() { + return this.#comms_config_ptr; + } + + destroy() { + if (this.#comms_config_ptr) { + InterfaceFFI.commsConfigDestroy(this.#comms_config_ptr); + this.#comms_config_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = CommsConfig; diff --git a/integration_tests/helpers/ffi/completedTransaction.js b/integration_tests/helpers/ffi/completedTransaction.js index a7a21c28cd..cc23f22ecf 100644 --- a/integration_tests/helpers/ffi/completedTransaction.js +++ b/integration_tests/helpers/ffi/completedTransaction.js @@ -1,23 +1,104 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class CompletedTransaction { #tari_completed_transaction_ptr; - constructor(tari_completed_transaction_ptr) { - this.#tari_completed_transaction_ptr = tari_completed_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_completed_transaction_ptr) { + this.destroy(); + this.#tari_completed_transaction_ptr = ptr; + } else { + this.#tari_completed_transaction_ptr = ptr; + } + } + + getPtr() { + return this.#tari_completed_transaction_ptr; } isOutbound() { - return WalletFFI.completedTransactionIsOutbound( + return InterfaceFFI.completedTransactionIsOutbound( this.#tari_completed_transaction_ptr ); } - destroy() { - return WalletFFI.completedTransactionDestroy( + getDestinationPublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.completedTransactionGetDestinationPublicKey( + this.#tari_completed_transaction_ptr + ) + ); + return result; + } + + getSourcePublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.completedTransactionGetSourcePublicKey( + this.#tari_completed_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.completedTransactionGetAmount( + this.#tari_completed_transaction_ptr + ); + } + + getFee() { + return InterfaceFFI.completedTransactionGetFee( + this.#tari_completed_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.completedTransactionGetMessage( + this.#tari_completed_transaction_ptr + ); + } + + getStatus() { + return InterfaceFFI.completedTransactionGetStatus( + this.#tari_completed_transaction_ptr + ); + } + + getTransactionID() { + return InterfaceFFI.completedTransactionGetTransactionId( + this.#tari_completed_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.completedTransactionGetTimestamp( + this.#tari_completed_transaction_ptr + ); + } + + isValid() { + return InterfaceFFI.completedTransactionIsValid( + this.#tari_completed_transaction_ptr + ); + } + + getConfirmations() { + return InterfaceFFI.completedTransactionGetConfirmations( this.#tari_completed_transaction_ptr ); } + + destroy() { + if (this.#tari_completed_transaction_ptr) { + InterfaceFFI.completedTransactionDestroy( + this.#tari_completed_transaction_ptr + ); + this.#tari_completed_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = CompletedTransaction; diff --git a/integration_tests/helpers/ffi/completedTransactions.js b/integration_tests/helpers/ffi/completedTransactions.js index d2d4c96156..2b8387bb72 100644 --- a/integration_tests/helpers/ffi/completedTransactions.js +++ b/integration_tests/helpers/ffi/completedTransactions.js @@ -1,38 +1,37 @@ const CompletedTransaction = require("./completedTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class CompletedTransactions { #tari_completed_transactions_ptr; - constructor(tari_completed_transactions_ptr) { - this.#tari_completed_transactions_ptr = tari_completed_transactions_ptr; - } - - static async fromWallet(wallet) { - return new CompletedTransactions( - await WalletFFI.walletGetCompletedTransactions(wallet) - ); + constructor(ptr) { + this.#tari_completed_transactions_ptr = ptr; } getLength() { - return WalletFFI.completedTransactionsGetLength( + return InterfaceFFI.completedTransactionsGetLength( this.#tari_completed_transactions_ptr ); } - async getAt(position) { - return new CompletedTransaction( - await WalletFFI.completedTransactionsGetAt( + getAt(position) { + let result = new CompletedTransaction(); + result.pointerAssign( + InterfaceFFI.completedTransactionsGetAt( this.#tari_completed_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.completedTransactionsDestroy( - this.#tari_completed_transactions_ptr - ); + if (this.#tari_completed_transactions_ptr) { + InterfaceFFI.completedTransactionsDestroy( + this.#tari_completed_transactions_ptr + ); + this.#tari_completed_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/contact.js b/integration_tests/helpers/ffi/contact.js index 184c684a2b..ea72376e75 100644 --- a/integration_tests/helpers/ffi/contact.js +++ b/integration_tests/helpers/ffi/contact.js @@ -1,32 +1,52 @@ const PublicKey = require("./publicKey"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class Contact { #tari_contact_ptr; - constructor(tari_contact_ptr) { - this.#tari_contact_ptr = tari_contact_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_contact_ptr) { + this.destroy(); + this.#tari_contact_ptr = ptr; + } else { + this.#tari_contact_ptr = ptr; + } } getPtr() { return this.#tari_contact_ptr; } - async getAlias() { - const alias = await WalletFFI.contactGetAlias(this.#tari_contact_ptr); + getAlias() { + const alias = InterfaceFFI.contactGetAlias(this.#tari_contact_ptr); const result = alias.readCString(); - await WalletFFI.stringDestroy(alias); + InterfaceFFI.stringDestroy(alias); return result; } - async getPubkey() { - return new PublicKey( - await WalletFFI.contactGetPublicKey(this.#tari_contact_ptr) + getPubkey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.contactGetPublicKey(this.#tari_contact_ptr) ); + return result; + } + + getPubkeyHex() { + let result = ""; + let pk = new PublicKey(); + pk.pointerAssign(InterfaceFFI.contactGetPublicKey(this.#tari_contact_ptr)); + result = pk.getHex(); + pk.destroy(); + return result; } destroy() { - return WalletFFI.contactDestroy(this.#tari_contact_ptr); + if (this.#tari_contact_ptr) { + InterfaceFFI.contactDestroy(this.#tari_contact_ptr); + this.#tari_contact_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/contacts.js b/integration_tests/helpers/ffi/contacts.js index d8803874ab..1f7db81fcc 100644 --- a/integration_tests/helpers/ffi/contacts.js +++ b/integration_tests/helpers/ffi/contacts.js @@ -1,29 +1,30 @@ const Contact = require("./contact"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class Contacts { #tari_contacts_ptr; - constructor(tari_contacts_ptr) { - this.#tari_contacts_ptr = tari_contacts_ptr; - } - - static async fromWallet(wallet) { - return new Contacts(await WalletFFI.walletGetContacts(wallet)); + constructor(ptr) { + this.#tari_contacts_ptr = ptr; } getLength() { - return WalletFFI.contactsGetLength(this.#tari_contacts_ptr); + return InterfaceFFI.contactsGetLength(this.#tari_contacts_ptr); } - async getAt(position) { - return new Contact( - await WalletFFI.contactsGetAt(this.#tari_contacts_ptr, position) + getAt(position) { + let result = new Contact(); + result.pointerAssign( + InterfaceFFI.contactsGetAt(this.#tari_contacts_ptr, position) ); + return result; } destroy() { - return WalletFFI.contactsDestroy(this.#tari_contacts_ptr); + if (this.#tari_contacts_ptr) { + InterfaceFFI.contactsDestroy(this.#tari_contacts_ptr); + this.#tari_contacts_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/emojiSet.js b/integration_tests/helpers/ffi/emojiSet.js new file mode 100644 index 0000000000..f94e2ae746 --- /dev/null +++ b/integration_tests/helpers/ffi/emojiSet.js @@ -0,0 +1,36 @@ +const InterfaceFFI = require("./ffiInterface"); + +class EmojiSet { + #emoji_set_ptr; + + constructor() { + this.#emoji_set_ptr = InterfaceFFI.getEmojiSet(); + } + + getLength() { + return InterfaceFFI.emojiSetGetLength(this.#emoji_set_ptr); + } + + getAt(position) { + return InterfaceFFI.emojiSetGetAt(this.#emoji_set_ptr, position); + } + + list() { + let set = []; + for (let i = 0; i < this.getLength(); i++) { + let item = this.getAt(i); + set.push(Buffer.from(item.getBytes(), "utf-8").toString()); + item.destroy(); + } + return set; + } + + destroy() { + if (this.#emoji_set_ptr) { + InterfaceFFI.byteVectorDestroy(this.#emoji_set_ptr); + this.#emoji_set_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = EmojiSet; diff --git a/integration_tests/helpers/ffi/ffiInterface.js b/integration_tests/helpers/ffi/ffiInterface.js new file mode 100644 index 0000000000..e39632cd05 --- /dev/null +++ b/integration_tests/helpers/ffi/ffiInterface.js @@ -0,0 +1,1473 @@ +/** + * This library was AUTO-GENERATED. Do not modify manually! + */ + +const { expect } = require("chai"); +const ffi = require("ffi-napi"); +const ref = require("ref-napi"); +const dateFormat = require("dateformat"); +const { spawn } = require("child_process"); +const fs = require("fs"); + +class InterfaceFFI { + //region Compile + static compile() { + return new Promise((resolve, _reject) => { + const cmd = "cargo"; + const args = [ + "build", + "--release", + "--package", + "tari_wallet_ffi", + "-Z", + "unstable-options", + "--out-dir", + process.cwd() + "/temp/out", + ]; + const baseDir = `./temp/base_nodes/${dateFormat( + new Date(), + "yyyymmddHHMM" + )}/WalletFFI-compile`; + if (!fs.existsSync(baseDir)) { + fs.mkdirSync(baseDir, { recursive: true }); + fs.mkdirSync(baseDir + "/log", { recursive: true }); + } + const ps = spawn(cmd, args, { + cwd: baseDir, + env: { ...process.env }, + }); + ps.on("close", (_code) => { + resolve(ps); + }); + ps.stderr.on("data", (data) => { + console.log("stderr : ", data.toString()); + }); + ps.on("error", (error) => { + console.log("error : ", error.toString()); + }); + expect(ps.error).to.be.an("undefined"); + this.#ps = ps; + }); + } + //endregion + + //region Interface + static #fn; + + static #loaded = false; + static #ps = null; + + static async Init() { + if (this.#loaded) { + return; + } + + this.#loaded = true; + await this.compile(); + const outputProcess = `${process.cwd()}/temp/out/${ + process.platform === "win32" ? "" : "lib" + }tari_wallet_ffi`; + + // Load the library + this.#fn = ffi.Library(outputProcess, { + transport_memory_create: ["pointer", ["void"]], + transport_tcp_create: ["pointer", ["string", "int*"]], + transport_tor_create: [ + "pointer", + ["string", "pointer", "ushort", "string", "string", "int*"], + ], + transport_memory_get_address: ["char*", ["pointer", "int*"]], + transport_type_destroy: ["void", ["pointer"]], + string_destroy: ["void", ["string"]], + byte_vector_create: ["pointer", ["uchar*", "uint", "int*"]], + byte_vector_get_at: ["uchar", ["pointer", "uint", "int*"]], + byte_vector_get_length: ["uint", ["pointer", "int*"]], + byte_vector_destroy: ["void", ["pointer"]], + public_key_create: ["pointer", ["pointer", "int*"]], + public_key_get_bytes: ["pointer", ["pointer", "int*"]], + public_key_from_private_key: ["pointer", ["pointer", "int*"]], + public_key_from_hex: ["pointer", ["string", "int*"]], + public_key_destroy: ["void", ["pointer"]], + public_key_to_emoji_id: ["char*", ["pointer", "int*"]], + emoji_id_to_public_key: ["pointer", ["string", "int*"]], + private_key_create: ["pointer", ["pointer", "int*"]], + private_key_generate: ["pointer", ["void"]], + private_key_get_bytes: ["pointer", ["pointer", "int*"]], + private_key_from_hex: ["pointer", ["string", "int*"]], + private_key_destroy: ["void", ["pointer"]], + seed_words_create: ["pointer", ["void"]], + seed_words_get_length: ["uint", ["pointer", "int*"]], + seed_words_get_at: ["char*", ["pointer", "uint", "int*"]], + seed_words_push_word: ["uchar", ["pointer", "string", "int*"]], + seed_words_destroy: ["void", ["pointer"]], + contact_create: ["pointer", ["string", "pointer", "int*"]], + contact_get_alias: ["char*", ["pointer", "int*"]], + contact_get_public_key: ["pointer", ["pointer", "int*"]], + contact_destroy: ["void", ["pointer"]], + contacts_get_length: ["uint", ["pointer", "int*"]], + contacts_get_at: ["pointer", ["pointer", "uint", "int*"]], + contacts_destroy: ["void", ["pointer"]], + completed_transaction_get_destination_public_key: [ + "pointer", + ["pointer", "int*"], + ], + completed_transaction_get_source_public_key: [ + "pointer", + ["pointer", "int*"], + ], + completed_transaction_get_amount: ["uint64", ["pointer", "int*"]], + completed_transaction_get_fee: ["uint64", ["pointer", "int*"]], + completed_transaction_get_message: ["char*", ["pointer", "int*"]], + completed_transaction_get_status: ["int", ["pointer", "int*"]], + completed_transaction_get_transaction_id: ["uint64", ["pointer", "int*"]], + completed_transaction_get_timestamp: ["uint64", ["pointer", "int*"]], + completed_transaction_is_valid: ["bool", ["pointer", "int*"]], + completed_transaction_is_outbound: ["bool", ["pointer", "int*"]], + completed_transaction_get_confirmations: ["uint64", ["pointer", "int*"]], + completed_transaction_destroy: ["void", ["pointer"]], + //completed_transaction_get_excess: [ + //this.tari_excess_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + //completed_transaction_get_public_nonce: [ + // this.tari_excess_public_nonce_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + //completed_transaction_get_signature: [ + // this.tari_excess_signature_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + // excess_destroy: ["void", [this.tari_excess_ptr]], + // nonce_destroy: ["void", [this.tari_excess_public_nonce_ptr]], + // signature_destroy: ["void", [this.tari_excess_signature_ptr]], + completed_transactions_get_length: ["uint", ["pointer", "int*"]], + completed_transactions_get_at: ["pointer", ["pointer", "uint", "int*"]], + completed_transactions_destroy: ["void", ["pointer"]], + pending_outbound_transaction_get_transaction_id: [ + "uint64", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_destination_public_key: [ + "pointer", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_amount: ["uint64", ["pointer", "int*"]], + pending_outbound_transaction_get_fee: ["uint64", ["pointer", "int*"]], + pending_outbound_transaction_get_message: ["char*", ["pointer", "int*"]], + pending_outbound_transaction_get_timestamp: [ + "uint64", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_status: ["int", ["pointer", "int*"]], + pending_outbound_transaction_destroy: ["void", ["pointer"]], + pending_outbound_transactions_get_length: ["uint", ["pointer", "int*"]], + pending_outbound_transactions_get_at: [ + "pointer", + ["pointer", "uint", "int*"], + ], + pending_outbound_transactions_destroy: ["void", ["pointer"]], + pending_inbound_transaction_get_transaction_id: [ + "uint64", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_source_public_key: [ + "pointer", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_message: ["char*", ["pointer", "int*"]], + pending_inbound_transaction_get_amount: ["uint64", ["pointer", "int*"]], + pending_inbound_transaction_get_timestamp: [ + "uint64", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_status: ["int", ["pointer", "int*"]], + pending_inbound_transaction_destroy: ["void", ["pointer"]], + pending_inbound_transactions_get_length: ["uint", ["pointer", "int*"]], + pending_inbound_transactions_get_at: [ + "pointer", + ["pointer", "uint", "int*"], + ], + pending_inbound_transactions_destroy: ["void", ["pointer"]], + comms_config_create: [ + "pointer", + [ + "string", + "pointer", + "string", + "string", + "uint64", + "uint64", + "string", + "int*", + ], + ], + comms_config_destroy: ["void", ["pointer"]], + wallet_create: [ + "pointer", + [ + "pointer", + "string", + "uint", + "uint", + "string", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "bool*", + "int*", + ], + ], + wallet_sign_message: ["char*", ["pointer", "string", "int*"]], + wallet_verify_message_signature: [ + "bool", + ["pointer", "pointer", "string", "string", "int*"], + ], + wallet_add_base_node_peer: [ + "bool", + ["pointer", "pointer", "string", "int*"], + ], + wallet_upsert_contact: ["bool", ["pointer", "pointer", "int*"]], + wallet_remove_contact: ["bool", ["pointer", "pointer", "int*"]], + wallet_get_available_balance: ["uint64", ["pointer", "int*"]], + wallet_get_pending_incoming_balance: ["uint64", ["pointer", "int*"]], + wallet_get_pending_outgoing_balance: ["uint64", ["pointer", "int*"]], + wallet_get_fee_estimate: [ + "uint64", + ["pointer", "uint64", "uint64", "uint64", "uint64", "int*"], + ], + wallet_get_num_confirmations_required: ["uint64", ["pointer", "int*"]], + wallet_set_num_confirmations_required: [ + "void", + ["pointer", "uint64", "int*"], + ], + wallet_send_transaction: [ + "uint64", + ["pointer", "pointer", "uint64", "uint64", "string", "int*"], + ], + wallet_get_contacts: ["pointer", ["pointer", "int*"]], + wallet_get_completed_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_pending_outbound_transactions: [ + "pointer", + ["pointer", "int*"], + ], + wallet_get_public_key: ["pointer", ["pointer", "int*"]], + wallet_get_pending_inbound_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_cancelled_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_completed_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_pending_outbound_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_pending_inbound_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_cancelled_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_import_utxo: [ + "uint64", + ["pointer", "uint64", "pointer", "pointer", "string", "int*"], + ], + wallet_start_utxo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_stxo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_invalid_txo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_transaction_validation: ["uint64", ["pointer", "int*"]], + wallet_restart_transaction_broadcast: ["bool", ["pointer", "int*"]], + wallet_set_low_power_mode: ["void", ["pointer", "int*"]], + wallet_set_normal_power_mode: ["void", ["pointer", "int*"]], + wallet_cancel_pending_transaction: [ + "bool", + ["pointer", "uint64", "int*"], + ], + wallet_coin_split: [ + "uint64", + ["pointer", "uint64", "uint64", "uint64", "string", "uint64", "int*"], + ], + wallet_get_seed_words: ["pointer", ["pointer", "int*"]], + wallet_apply_encryption: ["void", ["pointer", "string", "int*"]], + wallet_remove_encryption: ["void", ["pointer", "int*"]], + wallet_set_key_value: ["bool", ["pointer", "string", "string", "int*"]], + wallet_get_value: ["char*", ["pointer", "string", "int*"]], + wallet_clear_value: ["bool", ["pointer", "string", "int*"]], + wallet_is_recovery_in_progress: ["bool", ["pointer", "int*"]], + wallet_start_recovery: [ + "bool", + ["pointer", "pointer", "pointer", "int*"], + ], + wallet_destroy: ["void", ["pointer"]], + file_partial_backup: ["void", ["string", "string", "int*"]], + log_debug_message: ["void", ["string"]], + get_emoji_set: ["pointer", ["void"]], + emoji_set_destroy: ["void", ["pointer"]], + emoji_set_get_at: ["pointer", ["pointer", "uint", "int*"]], + emoji_set_get_length: ["uint", ["pointer", "int*"]], + }); + } + //endregion + + static checkErrorResult(error, error_name) { + expect(error.deref()).to.equal(0, `Error in ${error_name}`); + } + + //region Helpers + static initError() { + let error = Buffer.alloc(4); + error.writeInt32LE(-1, 0); + error.type = ref.types.int; + return error; + } + + static initBool() { + let boolean = ref.alloc(ref.types.bool); + return boolean; + } + + static filePartialBackup(original_file_path, backup_file_path) { + let error = this.initError(); + let result = this.#fn.file_partial_backup( + original_file_path, + backup_file_path, + error + ); + this.checkErrorResult(error, `filePartialBackup`); + return result; + } + + static logDebugMessage(msg) { + this.#fn.log_debug_message(msg); + } + //endregion + + //region String + static stringDestroy(s) { + this.#fn.string_destroy(s); + } + //endregion + + // region ByteVector + static byteVectorCreate(byte_array, element_count) { + let error = this.initError(); + let result = this.#fn.byte_vector_create(byte_array, element_count, error); + this.checkErrorResult(error, `byteVectorCreate`); + return result; + } + + static byteVectorGetAt(ptr, i) { + let error = this.initError(); + let result = this.#fn.byte_vector_get_at(ptr, i, error); + this.checkErrorResult(error, `byteVectorGetAt`); + return result; + } + + static byteVectorGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.byte_vector_get_length(ptr, error); + this.checkErrorResult(error, `byteVectorGetLength`); + return result; + } + + static byteVectorDestroy(ptr) { + this.#fn.byte_vector_destroy(ptr); + } + //endregion + + //region PrivateKey + static privateKeyCreate(ptr) { + let error = this.initError(); + let result = this.#fn.private_key_create(ptr, error); + this.checkErrorResult(error, `privateKeyCreate`); + return result; + } + + static privateKeyGenerate() { + return this.#fn.private_key_generate(); + } + + static privateKeyGetBytes(ptr) { + let error = this.initError(); + let result = this.#fn.private_key_get_bytes(ptr, error); + this.checkErrorResult(error, "privateKeyGetBytes"); + return result; + } + + static privateKeyFromHex(hex) { + let error = this.initError(); + let result = this.#fn.private_key_from_hex(hex, error); + this.checkErrorResult(error, "privateKeyFromHex"); + return result; + } + + static privateKeyDestroy(ptr) { + this.#fn.private_key_destroy(ptr); + } + + //endregion + + //region PublicKey + static publicKeyCreate(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_create(ptr, error); + this.checkErrorResult(error, `publicKeyCreate`); + return result; + } + + static publicKeyGetBytes(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_get_bytes(ptr, error); + this.checkErrorResult(error, `publicKeyGetBytes`); + return result; + } + + static publicKeyFromPrivateKey(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_from_private_key(ptr, error); + this.checkErrorResult(error, `publicKeyFromPrivateKey`); + return result; + } + + static publicKeyFromHex(hex) { + let error = this.initError(); + let result = this.#fn.public_key_from_hex(hex, error); + this.checkErrorResult(error, `publicKeyFromHex`); + return result; + } + + static emojiIdToPublicKey(emoji) { + let error = this.initError(); + let result = this.#fn.emoji_id_to_public_key(emoji, error); + this.checkErrorResult(error, `emojiIdToPublicKey`); + return result; + } + + static publicKeyToEmojiId(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_to_emoji_id(ptr, error); + this.checkErrorResult(error, `publicKeyToEmojiId`); + return result; + } + + static publicKeyDestroy(ptr) { + this.#fn.public_key_destroy(ptr); + } + //endregion + + //region TransportType + static transportMemoryCreate() { + return this.#fn.transport_memory_create(); + } + + static transportTcpCreate(listener_address) { + let error = this.initError(); + let result = this.#fn.transport_tcp_create(listener_address, error); + this.checkErrorResult(error, `transportTcpCreate`); + return result; + } + + static transportTorCreate( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password + ) { + let error = this.initError(); + let result = this.#fn.transport_tor_create( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password, + error + ); + this.checkErrorResult(error, `transportTorCreate`); + return result; + } + + static transportMemoryGetAddress(transport) { + let error = this.initError(); + let result = this.#fn.transport_memory_get_address(transport, error); + this.checkErrorResult(error, `transportMemoryGetAddress`); + return result; + } + + static transportTypeDestroy(transport) { + this.#fn.transport_type_destroy(transport); + } + //endregion + + //region EmojiSet + static getEmojiSet() { + return this.#fn.this.#fn.get_emoji_set(); + } + + static emojiSetDestroy(ptr) { + this.#fn.emoji_set_destroy(ptr); + } + + static emojiSetGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.emoji_set_get_at(ptr, position, error); + this.checkErrorResult(error, `emojiSetGetAt`); + return result; + } + + static emojiSetGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.emoji_set_get_length(ptr, error); + this.checkErrorResult(error, `emojiSetGetLength`); + return result; + } + //endregion + + //region SeedWords + static seedWordsCreate() { + return this.#fn.seed_words_create(); + } + + static seedWordsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.seed_words_get_length(ptr, error); + this.checkErrorResult(error, `emojiSetGetLength`); + return result; + } + + static seedWordsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.seed_words_get_at(ptr, position, error); + this.checkErrorResult(error, `seedWordsGetAt`); + return result; + } + + static seedWordsPushWord(ptr, word) { + let error = this.initError(); + let result = this.#fn.seed_words_push_word(ptr, word, error); + this.checkErrorResult(error, `seedWordsPushWord`); + return result; + } + + static seedWordsDestroy(ptr) { + this.#fn.seed_words_destroy(ptr); + } + //endregion + + //region CommsConfig + static commsConfigCreate( + public_address, + transport, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network + ) { + let error = this.initError(); + let result = this.#fn.comms_config_create( + public_address, + transport, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network, + error + ); + this.checkErrorResult(error, `commsConfigCreate`); + return result; + } + + static commsConfigDestroy(ptr) { + this.#fn.comms_config_destroy(ptr); + } + //endregion + + //region Contact + static contactCreate(alias, public_key) { + let error = this.initError(); + let result = this.#fn.contact_create(alias, public_key, error); + this.checkErrorResult(error, `contactCreate`); + return result; + } + + static contactGetAlias(ptr) { + let error = this.initError(); + let result = this.#fn.contact_get_alias(ptr, error); + this.checkErrorResult(error, `contactGetAlias`); + return result; + } + + static contactGetPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.contact_get_public_key(ptr, error); + this.checkErrorResult(error, `contactGetPublicKey`); + return result; + } + + static contactDestroy(ptr) { + this.#fn.contact_destroy(ptr); + } + //endregion + + //region Contacts (List) + static contactsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.contacts_get_length(ptr, error); + this.checkErrorResult(error, `contactsGetLength`); + return result; + } + + static contactsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.contacts_get_at(ptr, position, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static contactsDestroy(ptr) { + this.#fn.contacts_destroy(ptr); + } + //endregion + + //region CompletedTransaction + static completedTransactionGetDestinationPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_destination_public_key( + ptr, + error + ); + this.checkErrorResult(error, `completedTransactionGetDestinationPublicKey`); + return result; + } + + static completedTransactionGetSourcePublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_source_public_key( + ptr, + error + ); + this.checkErrorResult(error, `completedTransactionGetSourcePublicKey`); + return result; + } + + static completedTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `completedTransactionGetAmount`); + return result; + } + + static completedTransactionGetFee(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_fee(ptr, error); + this.checkErrorResult(error, `completedTransactionGetFee`); + return result; + } + + static completedTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_message(ptr, error); + this.checkErrorResult(error, `completedTransactionGetMessage`); + return result; + } + + static completedTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_status(ptr, error); + this.checkErrorResult(error, `completedTransactionGetStatus`); + return result; + } + + static completedTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_transaction_id(ptr, error); + this.checkErrorResult(error, `completedTransactionGetTransactionId`); + return result; + } + + static completedTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_timestamp(ptr, error); + this.checkErrorResult(error, `completedTransactionGetTimestamp`); + return result; + } + + static completedTransactionIsValid(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_is_valid(ptr, error); + this.checkErrorResult(error, `completedTransactionIsValid`); + return result; + } + + static completedTransactionIsOutbound(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_is_outbound(ptr, error); + this.checkErrorResult(error, `completedTransactionGetConfirmations`); + return result; + } + + static completedTransactionGetConfirmations(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_confirmations(ptr, error); + this.checkErrorResult(error, `completedTransactionGetConfirmations`); + return result; + } + + static completedTransactionDestroy(ptr) { + this.#fn.completed_transaction_destroy(ptr); + } + + //endregion + + /* + //Flagged as design flaw in the FFI lib + + static completedTransactionGetExcess(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_excess.async( + transaction, + this.error, + this.checkAsyncRes(resolve, reject, "completedTransactionGetExcess") + ) + ); + } + + static completedTransactionGetPublicNonce(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_public_nonce.async( + transaction, + this.error, + this.checkAsyncRes( + resolve, + reject, + "completedTransactionGetPublicNonce" + ) + ) + ); + } + + static completedTransactionGetSignature(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_signature.async( + transaction, + this.error, + this.checkAsyncRes(resolve, reject, "completedTransactionGetSignature") + ) + ); + } + + static excessDestroy(excess) { + return new Promise((resolve, reject) => + this.#fn.excess_destroy.async( + excess, + this.checkAsyncRes(resolve, reject, "excessDestroy") + ) + ); + } + + static nonceDestroy(nonce) { + return new Promise((resolve, reject) => + this.#fn.nonce_destroy.async( + nonce, + this.checkAsyncRes(resolve, reject, "nonceDestroy") + ) + ); + } + + static signatureDestroy(signature) { + return new Promise((resolve, reject) => + this.#fn.signature_destroy.async( + signature, + this.checkAsyncRes(resolve, reject, "signatureDestroy") + ) + ); + } + */ + + //region CompletedTransactions (List) + static completedTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transactions_get_length(ptr, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static completedTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.completed_transactions_get_at(ptr, position, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static completedTransactionsDestroy(transactions) { + this.#fn.completed_transactions_destroy(transactions); + } + //endregion + + //region PendingOutboundTransaction + static pendingOutboundTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_transaction_id( + ptr, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionGetTransactionId`); + return result; + } + + static pendingOutboundTransactionGetDestinationPublicKey(ptr) { + let error = this.initError(); + let result = + this.#fn.pending_outbound_transaction_get_destination_public_key( + ptr, + error + ); + this.checkErrorResult( + error, + `pendingOutboundTransactionGetDestinationPublicKey` + ); + return result; + } + + static pendingOutboundTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetAmount`); + return result; + } + + static pendingOutboundTransactionGetFee(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_fee(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetFee`); + return result; + } + + static pendingOutboundTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_message(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetMessage`); + return result; + } + + static pendingOutboundTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_timestamp( + ptr, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionGetTimestamp`); + return result; + } + + static pendingOutboundTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_status(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetStatus`); + return result; + } + + static pendingOutboundTransactionDestroy(ptr) { + this.#fn.pending_outbound_transaction_destroy(ptr); + } + //endregion + + //region PendingOutboundTransactions (List) + static pendingOutboundTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transactions_get_length(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionsGetLength`); + return result; + } + + static pendingOutboundTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transactions_get_at( + ptr, + position, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionsGetAt`); + return result; + } + + static pendingOutboundTransactionsDestroy(ptr) { + this.#fn.pending_outbound_transactions_destroy(ptr); + } + //endregion + + //region PendingInboundTransaction + static pendingInboundTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_transaction_id( + ptr, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionGetTransactionId`); + return result; + } + + static pendingInboundTransactionGetSourcePublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_source_public_key( + ptr, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionGetSourcePublicKey`); + return result; + } + + static pendingInboundTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_message(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetMessage`); + return result; + } + + static pendingInboundTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetAmount`); + return result; + } + + static pendingInboundTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_timestamp(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetTimestamp`); + return result; + } + + static pendingInboundTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_status(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetStatus`); + return result; + } + + static pendingInboundTransactionDestroy(ptr) { + this.#fn.pending_inbound_transaction_destroy(ptr); + } + //endregion + + //region PendingInboundTransactions (List) + static pendingInboundTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transactions_get_length(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionsGetLength`); + return result; + } + + static pendingInboundTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transactions_get_at( + ptr, + position, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionsGetAt`); + return result; + } + + static pendingInboundTransactionsDestroy(ptr) { + this.#fn.pending_inbound_transactions_destroy(ptr); + } + //endregion + + //region Wallet + + //region Callbacks + static createCallbackReceivedTransaction(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackReceivedTransactionReply(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackReceivedFinalizedTransaction(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionBroadcast(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionMined(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionMinedUnconfirmed(fn) { + return ffi.Callback("void", ["pointer", "uint64"], fn); + } + + static createCallbackDirectSendResult(fn) { + return ffi.Callback("void", ["uint64", "bool"], fn); + } + + static createCallbackStoreAndForwardSendResult(fn) { + return ffi.Callback("void", ["uint64", "bool"], fn); + } + + static createCallbackTransactionCancellation(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + static createCallbackUtxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackStxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackInvalidTxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackTransactionValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackSafMessageReceived(fn) { + return ffi.Callback("void", ["void"], fn); + } + static createRecoveryProgressCallback(fn) { + return ffi.Callback("void", ["uchar", "uint64", "uint64"], fn); + } + //endregion + + static walletCreate( + config, + log_path, + num_rolling_log_files, + size_per_log_file_bytes, + passphrase, + seed_words, + callback_received_transaction, + callback_received_transaction_reply, + callback_received_finalized_transaction, + callback_transaction_broadcast, + callback_transaction_mined, + callback_transaction_mined_unconfirmed, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, + callback_utxo_validation_complete, + callback_stxo_validation_complete, + callback_invalid_txo_validation_complete, + callback_transaction_validation_complete, + callback_saf_message_received + ) { + let error = this.initError(); + let recovery_in_progress = this.initBool(); + + let result = this.#fn.wallet_create( + config, + log_path, + num_rolling_log_files, + size_per_log_file_bytes, + passphrase, + seed_words, + callback_received_transaction, + callback_received_transaction_reply, + callback_received_finalized_transaction, + callback_transaction_broadcast, + callback_transaction_mined, + callback_transaction_mined_unconfirmed, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, + callback_utxo_validation_complete, + callback_stxo_validation_complete, + callback_invalid_txo_validation_complete, + callback_transaction_validation_complete, + callback_saf_message_received, + recovery_in_progress, + error + ); + this.checkErrorResult(error, `walletCreate`); + if (recovery_in_progress) { + console.log("Wallet recovery is in progress"); + } + return result; + } + + static walletGetPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_public_key(ptr, error); + this.checkErrorResult(error, `walletGetPublicKey`); + return result; + } + + static walletSignMessage(ptr, msg) { + let error = this.initError(); + let result = this.#fn.wallet_sign_message(ptr, msg, error); + this.checkErrorResult(error, `walletSignMessage`); + return result; + } + + static walletVerifyMessageSignature(ptr, public_key_ptr, hex_sig_nonce, msg) { + let error = this.initError(); + let result = this.#fn.wallet_verify_message_signature( + ptr, + public_key_ptr, + hex_sig_nonce, + msg, + error + ); + this.checkErrorResult(error, `walletVerifyMessageSignature`); + return result; + } + + static walletAddBaseNodePeer(ptr, public_key_ptr, address) { + let error = this.initError(); + let result = this.#fn.wallet_add_base_node_peer( + ptr, + public_key_ptr, + address, + error + ); + this.checkErrorResult(error, `walletAddBaseNodePeer`); + return result; + } + + static walletUpsertContact(ptr, contact_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_upsert_contact(ptr, contact_ptr, error); + this.checkErrorResult(error, `walletUpsertContact`); + return result; + } + + static walletRemoveContact(ptr, contact_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_remove_contact(ptr, contact_ptr, error); + this.checkErrorResult(error, `walletRemoveContact`); + return result; + } + + static walletGetAvailableBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_available_balance(ptr, error); + this.checkErrorResult(error, `walletGetAvailableBalance`); + return result; + } + + static walletGetPendingIncomingBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_incoming_balance(ptr, error); + this.checkErrorResult(error, `walletGetPendingIncomingBalance`); + return result; + } + + static walletGetPendingOutgoingBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outgoing_balance(ptr, error); + this.checkErrorResult(error, `walletGetPendingOutgoingBalance`); + return result; + } + + static walletGetFeeEstimate( + ptr, + amount, + fee_per_gram, + num_kernels, + num_outputs + ) { + let error = this.initError(); + let result = this.#fn.wallet_get_fee_estimate( + ptr, + amount, + fee_per_gram, + num_kernels, + num_outputs, + error + ); + this.checkErrorResult(error, `walletGetFeeEstimate`); + return result; + } + + static walletGetNumConfirmationsRequired(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_num_confirmations_required(ptr, error); + this.checkErrorResult(error, `walletGetNumConfirmationsRequired`); + return result; + } + + static walletSetNumConfirmationsRequired(ptr, num) { + let error = this.initError(); + this.#fn.wallet_set_num_confirmations_required(ptr, num, error); + this.checkErrorResult(error, `walletSetNumConfirmationsRequired`); + } + + static walletSendTransaction( + ptr, + destination, + amount, + fee_per_gram, + message + ) { + let error = this.initError(); + let result = this.#fn.wallet_send_transaction( + ptr, + destination, + amount, + fee_per_gram, + message, + error + ); + this.checkErrorResult(error, `walletSendTransaction`); + return result; + } + + static walletGetContacts(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_contacts(ptr, error); + this.checkErrorResult(error, `walletGetContacts`); + return result; + } + + static walletGetCompletedTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_completed_transactions(ptr, error); + this.checkErrorResult(error, `walletGetCompletedTransactions`); + return result; + } + + static walletGetPendingOutboundTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outbound_transactions(ptr, error); + this.checkErrorResult(error, `walletGetPendingOutboundTransactions`); + return result; + } + + static walletGetPendingInboundTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_inbound_transactions(ptr, error); + this.checkErrorResult(error, `walletGetPendingInboundTransactions`); + return result; + } + + static walletGetCancelledTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_cancelled_transactions(ptr, error); + this.checkErrorResult(error, `walletGetCancelledTransactions`); + return result; + } + + static walletGetCompletedTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_completed_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetCompletedTransactionById`); + return result; + } + + static walletGetPendingOutboundTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outbound_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetPendingOutboundTransactionById`); + return result; + } + + static walletGetPendingInboundTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_inbound_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetPendingInboundTransactionById`); + return result; + } + + static walletGetCancelledTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_cancelled_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetCancelledTransactionById`); + return result; + } + + static walletImportUtxo( + ptr, + amount, + spending_key_ptr, + source_public_key_ptr, + message + ) { + let error = this.initError(); + let result = this.#fn.wallet_import_utxo( + ptr, + amount, + spending_key_ptr, + source_public_key_ptr, + message, + error + ); + this.checkErrorResult(error, `walletImportUtxo`); + return result; + } + + static walletStartUtxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_utxo_validation(ptr, error); + this.checkErrorResult(error, `walletStartUtxoValidation`); + return result; + } + + static walletStartStxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_stxo_validation(ptr, error); + this.checkErrorResult(error, `walletStartStxoValidation`); + return result; + } + + static walletStartInvalidTxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_invalid_txo_validation(ptr, error); + this.checkErrorResult(error, `walletStartInvalidUtxoValidation`); + return result; + } + + static walletStartTransactionValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_transaction_validation(ptr, error); + this.checkErrorResult(error, `walletStartTransactionValidation`); + return result; + } + + static walletRestartTransactionBroadcast(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_restart_transaction_broadcast(ptr, error); + this.checkErrorResult(error, `walletRestartTransactionBroadcast`); + return result; + } + + static walletSetLowPowerMode(ptr) { + let error = this.initError(); + this.#fn.wallet_set_low_power_mode(ptr, error); + this.checkErrorResult(error, `walletSetLowPowerMode`); + } + + static walletSetNormalPowerMode(ptr) { + let error = this.initError(); + this.#fn.wallet_set_normal_power_mode(ptr, error); + this.checkErrorResult(error, `walletSetNormalPowerMode`); + } + + static walletCancelPendingTransaction(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_cancel_pending_transaction( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletCancelPendingTransaction`); + return result; + } + + static walletCoinSplit(ptr, amount, count, fee, msg, lock_height) { + let error = this.initError(); + let result = this.#fn.wallet_coin_split( + ptr, + amount, + count, + fee, + msg, + lock_height, + error + ); + this.checkErrorResult(error, `walletCoinSplit`); + return result; + } + + static walletGetSeedWords(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_seed_words(ptr, error); + this.checkErrorResult(error, `walletGetSeedWords`); + return result; + } + + static walletApplyEncryption(ptr, passphrase) { + let error = this.initError(); + this.#fn.wallet_apply_encryption(ptr, passphrase, error); + this.checkErrorResult(error, `walletApplyEncryption`); + } + + static walletRemoveEncryption(ptr) { + let error = this.initError(); + this.#fn.wallet_remove_encryption(ptr, error); + this.checkErrorResult(error, `walletRemoveEncryption`); + } + + static walletSetKeyValue(ptr, key_ptr, value) { + let error = this.initError(); + let result = this.#fn.wallet_set_key_value(ptr, key_ptr, value, error); + this.checkErrorResult(error, `walletSetKeyValue`); + return result; + } + + static walletGetValue(ptr, key_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_value(ptr, key_ptr, error); + this.checkErrorResult(error, `walletGetValue`); + return result; + } + + static walletClearValue(ptr, key_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_clear_value(ptr, key_ptr, error); + this.checkErrorResult(error, `walletClearValue`); + return result; + } + + static walletIsRecoveryInProgress(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_is_recovery_in_progress(ptr, error); + this.checkErrorResult(error, `walletIsRecoveryInProgress`); + return result; + } + + static walletStartRecovery( + ptr, + base_node_public_key_ptr, + recovery_progress_callback + ) { + let error = this.initError(); + let result = this.#fn.wallet_start_recovery( + ptr, + base_node_public_key_ptr, + recovery_progress_callback, + error + ); + this.checkErrorResult(error, `walletStartRecovery`); + return result; + } + + static walletDestroy(ptr) { + this.#fn.wallet_destroy(ptr); + } + //endregion +} +module.exports = InterfaceFFI; diff --git a/integration_tests/helpers/ffi/pendingInboundTransaction.js b/integration_tests/helpers/ffi/pendingInboundTransaction.js index 32cae202fb..dc2071e24b 100644 --- a/integration_tests/helpers/ffi/pendingInboundTransaction.js +++ b/integration_tests/helpers/ffi/pendingInboundTransaction.js @@ -1,24 +1,70 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class PendingInboundTransaction { #tari_pending_inbound_transaction_ptr; - constructor(tari_pending_inbound_transaction_ptr) { - this.#tari_pending_inbound_transaction_ptr = - tari_pending_inbound_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_pending_inbound_transaction_ptr) { + this.destroy(); + this.#tari_pending_inbound_transaction_ptr = ptr; + } else { + this.#tari_pending_inbound_transaction_ptr = ptr; + } + } + + getPtr() { + return this.#tari_pending_inbound_transaction_ptr; + } + + getSourcePublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.pendingInboundTransactionGetSourcePublicKey( + this.#tari_pending_inbound_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.pendingInboundTransactionGetAmount( + this.#tari_pending_inbound_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.pendingInboundTransactionGetMessage( + this.#tari_pending_inbound_transaction_ptr + ); } getStatus() { - return WalletFFI.pendingInboundTransactionGetStatus( + return InterfaceFFI.pendingInboundTransactionGetStatus( this.#tari_pending_inbound_transaction_ptr ); } - destroy() { - return WalletFFI.pendingInboundTransactionDestroy( + getTransactionID() { + return InterfaceFFI.pendingInboundTransactionGetTransactionId( + this.#tari_pending_inbound_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.pendingInboundTransactionGetTimestamp( this.#tari_pending_inbound_transaction_ptr ); } + + destroy() { + if (this.#tari_pending_inbound_transaction_ptr) { + InterfaceFFI.pendingInboundTransactionDestroy( + this.#tari_pending_inbound_transaction_ptr + ); + this.#tari_pending_inbound_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = PendingInboundTransaction; diff --git a/integration_tests/helpers/ffi/pendingInboundTransactions.js b/integration_tests/helpers/ffi/pendingInboundTransactions.js index 6246b03429..2500b41a04 100644 --- a/integration_tests/helpers/ffi/pendingInboundTransactions.js +++ b/integration_tests/helpers/ffi/pendingInboundTransactions.js @@ -1,39 +1,37 @@ const PendingInboundTransaction = require("./pendingInboundTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class PendingInboundTransactions { #tari_pending_inbound_transactions_ptr; - constructor(tari_pending_inbound_transactions_ptr) { - this.#tari_pending_inbound_transactions_ptr = - tari_pending_inbound_transactions_ptr; - } - - static async fromWallet(wallet) { - return new PendingInboundTransactions( - await WalletFFI.walletGetPendingInboundTransactions(wallet) - ); + constructor(ptr) { + this.#tari_pending_inbound_transactions_ptr = ptr; } getLength() { - return WalletFFI.pendingInboundTransactionsGetLength( + return InterfaceFFI.pendingInboundTransactionsGetLength( this.#tari_pending_inbound_transactions_ptr ); } - async getAt(position) { - return new PendingInboundTransaction( - await WalletFFI.pendingInboundTransactionsGetAt( + getAt(position) { + let result = new PendingInboundTransaction(); + result.pointerAssign( + InterfaceFFI.pendingInboundTransactionsGetAt( this.#tari_pending_inbound_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.pendingInboundTransactionsDestroy( - this.#tari_pending_inbound_transactions_ptr - ); + if (this.#tari_pending_inbound_transactions_ptr) { + InterfaceFFI.pendingInboundTransactionsDestroy( + this.#tari_pending_inbound_transactions_ptr + ); + this.#tari_pending_inbound_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/pendingOutboundTransaction.js b/integration_tests/helpers/ffi/pendingOutboundTransaction.js index eed2d722bb..0fc2ca47b9 100644 --- a/integration_tests/helpers/ffi/pendingOutboundTransaction.js +++ b/integration_tests/helpers/ffi/pendingOutboundTransaction.js @@ -1,30 +1,76 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class PendingOutboundTransaction { #tari_pending_outbound_transaction_ptr; - constructor(tari_pending_outbound_transaction_ptr) { - this.#tari_pending_outbound_transaction_ptr = - tari_pending_outbound_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_pending_outbound_transaction_ptr) { + this.#tari_pending_outbound_transaction_ptr = ptr; + this.destroy(); + } else { + this.#tari_pending_outbound_transaction_ptr = ptr; + } } - getTransactionId() { - return WalletFFI.pendingOutboundTransactionGetTransactionId( + getPtr() { + return this.#tari_pending_outbound_transaction_ptr; + } + + getDestinationPublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.pendingOutboundTransactionGetDestinationPublicKey( + this.#tari_pending_outbound_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.pendingOutboundTransactionGetAmount( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getFee() { + return InterfaceFFI.pendingOutboundTransactionGetFee( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.pendingOutboundTransactionGetMessage( this.#tari_pending_outbound_transaction_ptr ); } getStatus() { - return WalletFFI.pendingOutboundTransactionGetStatus( + return InterfaceFFI.pendingOutboundTransactionGetStatus( this.#tari_pending_outbound_transaction_ptr ); } - destroy() { - return WalletFFI.pendingOutboundTransactionDestroy( + getTransactionID() { + return InterfaceFFI.pendingOutboundTransactionGetTransactionId( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.pendingOutboundTransactionGetTimestamp( this.#tari_pending_outbound_transaction_ptr ); } + + destroy() { + if (this.#tari_pending_outbound_transaction_ptr) { + InterfaceFFI.pendingOutboundTransactionDestroy( + this.#tari_pending_outbound_transaction_ptr + ); + this.#tari_pending_outbound_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = PendingOutboundTransaction; diff --git a/integration_tests/helpers/ffi/pendingOutboundTransactions.js b/integration_tests/helpers/ffi/pendingOutboundTransactions.js index 28e408563d..45de06033b 100644 --- a/integration_tests/helpers/ffi/pendingOutboundTransactions.js +++ b/integration_tests/helpers/ffi/pendingOutboundTransactions.js @@ -1,39 +1,37 @@ const PendingOutboundTransaction = require("./pendingOutboundTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class PendingOutboundTransactions { #tari_pending_outbound_transactions_ptr; - constructor(tari_pending_outbound_transactions_ptr) { - this.#tari_pending_outbound_transactions_ptr = - tari_pending_outbound_transactions_ptr; - } - - static async fromWallet(wallet) { - return new PendingOutboundTransactions( - await WalletFFI.walletGetPendingOutboundTransactions(wallet) - ); + constructor(ptr) { + this.#tari_pending_outbound_transactions_ptr = ptr; } getLength() { - return WalletFFI.pendingOutboundTransactionsGetLength( + return InterfaceFFI.pendingOutboundTransactionsGetLength( this.#tari_pending_outbound_transactions_ptr ); } - async getAt(position) { - return new PendingOutboundTransaction( - await WalletFFI.pendingOutboundTransactionsGetAt( + getAt(position) { + let result = new PendingOutboundTransaction(); + result.pointerAssign( + InterfaceFFI.pendingOutboundTransactionsGetAt( this.#tari_pending_outbound_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.pendingOutboundTransactionsDestroy( - this.#tari_pending_outbound_transactions_ptr - ); + if (this.#tari_pending_outbound_transactions_ptr) { + InterfaceFFI.pendingOutboundTransactionsDestroy( + this.#tari_pending_outbound_transactions_ptr + ); + this.#tari_pending_outbound_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/privateKey.js b/integration_tests/helpers/ffi/privateKey.js new file mode 100644 index 0000000000..7115ab8a1d --- /dev/null +++ b/integration_tests/helpers/ffi/privateKey.js @@ -0,0 +1,67 @@ +const InterfaceFFI = require("./ffiInterface"); +const ByteVector = require("./byteVector"); +const utf8 = require("utf8"); + +class PrivateKey { + #tari_private_key_ptr; + + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_private_key_ptr) { + this.#tari_private_key_ptr = ptr; + this.destroy(); + } else { + this.#tari_private_key_ptr = ptr; + } + } + + generate() { + this.#tari_private_key_ptr = InterfaceFFI.privateKeyGenerate(); + } + + fromHexString(hex) { + let sanitize = utf8.encode(hex); // Make sure it's not UTF-16 encoded (JS default) + let result = new PrivateKey(); + result.pointerAssign(InterfaceFFI.privateKeyFromHex(sanitize)); + return result; + } + + fromByteVector(byte_vector) { + let result = new PrivateKey(); + result.pointerAssign(InterfaceFFI.privateKeyCreate(byte_vector.getPtr())); + return result; + } + + getPtr() { + return this.#tari_private_key_ptr; + } + + getBytes() { + let result = new ByteVector(); + result.pointerAssign( + InterfaceFFI.privateKeyGetBytes(this.#tari_private_key_ptr) + ); + return result; + } + + getHex() { + const bytes = this.getBytes(); + const length = bytes.getLength(); + let byte_array = new Uint8Array(length); + for (let i = 0; i < length; i++) { + byte_array[i] = bytes.getAt(i); + } + bytes.destroy(); + let buffer = Buffer.from(byte_array, 0); + return buffer.toString("hex"); + } + + destroy() { + if (this.#tari_private_key_ptr) { + InterfaceFFI.privateKeyDestroy(this.#tari_private_key_ptr); + this.#tari_private_key_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = PrivateKey; diff --git a/integration_tests/helpers/ffi/publicKey.js b/integration_tests/helpers/ffi/publicKey.js index 1165aa193d..7e1476c3d5 100644 --- a/integration_tests/helpers/ffi/publicKey.js +++ b/integration_tests/helpers/ffi/publicKey.js @@ -1,64 +1,82 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); const ByteVector = require("./byteVector"); const utf8 = require("utf8"); class PublicKey { #tari_public_key_ptr; - constructor(public_key) { - this.#tari_public_key_ptr = public_key; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_public_key_ptr) { + this.destroy(); + this.#tari_public_key_ptr = ptr; + } else { + this.#tari_public_key_ptr = ptr; + } } - static fromPubkey(public_key) { - return new PublicKey(public_key); + fromPrivateKey(key) { + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyFromPrivateKey(key.getPtr())); + return result; } - static async fromWallet(wallet) { - return new PublicKey(await WalletFFI.walletGetPublicKey(wallet)); + static fromHexString(hex) { + let sanitize = utf8.encode(hex); // Make sure it's not UTF-16 encoded (JS default) + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyFromHex(sanitize)); + return result; } - static async fromString(public_key_hex) { - let sanitize = utf8.encode(public_key_hex); // Make sure it's not UTF-16 encoded (JS default) - return new PublicKey(await WalletFFI.publicKeyFromHex(sanitize)); + fromEmojiID(emoji) { + let sanitize = utf8.encode(emoji); // Make sure it's not UTF-16 encoded (JS default) + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.emojiIdToPublicKey(sanitize)); + return result; } - static async fromBytes(bytes) { - return new PublicKey(await WalletFFI.publicKeyCreate(bytes)); + fromByteVector(byte_vector) { + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyCreate(byte_vector.getPtr())); + return result; } getPtr() { return this.#tari_public_key_ptr; } - async getBytes() { - return new ByteVector( - await WalletFFI.publicKeyGetBytes(this.#tari_public_key_ptr) + getBytes() { + let result = new ByteVector(); + result.pointerAssign( + InterfaceFFI.publicKeyGetBytes(this.#tari_public_key_ptr) ); + return result; } - async getHex() { - const bytes = await this.getBytes(); - const length = await bytes.getLength(); + getHex() { + const bytes = this.getBytes(); + const length = bytes.getLength(); let byte_array = new Uint8Array(length); - for (let i = 0; i < length; ++i) { - byte_array[i] = await bytes.getAt(i); + for (let i = 0; i < length; i++) { + byte_array[i] = bytes.getAt(i); } - await bytes.destroy(); + bytes.destroy(); let buffer = Buffer.from(byte_array, 0); return buffer.toString("hex"); } - async getEmojiId() { - const emoji_id = await WalletFFI.publicKeyToEmojiId( - this.#tari_public_key_ptr - ); + getEmojiId() { + const emoji_id = InterfaceFFI.publicKeyToEmojiId(this.#tari_public_key_ptr); const result = emoji_id.readCString(); - await WalletFFI.stringDestroy(emoji_id); + InterfaceFFI.stringDestroy(emoji_id); return result; } destroy() { - return WalletFFI.publicKeyDestroy(this.#tari_public_key_ptr); + if (this.#tari_public_key_ptr) { + InterfaceFFI.publicKeyDestroy(this.#tari_public_key_ptr); + this.#tari_public_key_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/seedWords.js b/integration_tests/helpers/ffi/seedWords.js index 86c05cab48..e191bc38a9 100644 --- a/integration_tests/helpers/ffi/seedWords.js +++ b/integration_tests/helpers/ffi/seedWords.js @@ -1,45 +1,55 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); class SeedWords { #tari_seed_words_ptr; - constructor(tari_seed_words_ptr) { - this.#tari_seed_words_ptr = tari_seed_words_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_seed_words_ptr) { + this.destroy(); + this.#tari_seed_words_ptr = ptr; + } else { + this.#tari_seed_words_ptr = ptr; + } } - static async fromString(seed_words_text) { - const seed_words = await WalletFFI.seedWordsCreate(); + static fromText(seed_words_text) { + const seed_words = new SeedWords(); + seed_words.pointerAssign(InterfaceFFI.seedWordsCreate()); const seed_words_list = seed_words_text.split(" "); for (const seed_word of seed_words_list) { - await WalletFFI.seedWordsPushWord(seed_words, seed_word); + InterfaceFFI.seedWordsPushWord( + seed_words.getPtr(), + utf8.encode(seed_word) + ); } - return new SeedWords(seed_words); - } - - static async fromWallet(wallet) { - return new SeedWords(await WalletFFI.walletGetSeedWords(wallet)); + return seed_words; } getLength() { - return WalletFFI.seedWordsGetLength(this.#tari_seed_words_ptr); + return InterfaceFFI.seedWordsGetLength(this.#tari_seed_words_ptr); } getPtr() { return this.#tari_seed_words_ptr; } - async getAt(position) { - const seed_word = await WalletFFI.seedWordsGetAt( + getAt(position) { + const seed_word = InterfaceFFI.seedWordsGetAt( this.#tari_seed_words_ptr, position ); const result = seed_word.readCString(); - await WalletFFI.stringDestroy(seed_word); + InterfaceFFI.stringDestroy(seed_word); return result; } destroy() { - return WalletFFI.seedWordsDestroy(this.#tari_seed_words_ptr); + if (this.#tari_seed_words_ptr) { + InterfaceFFI.seedWordsDestroy(this.#tari_seed_words_ptr); + this.#tari_seed_words_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/transportType.js b/integration_tests/helpers/ffi/transportType.js new file mode 100644 index 0000000000..0826c423b5 --- /dev/null +++ b/integration_tests/helpers/ffi/transportType.js @@ -0,0 +1,85 @@ +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); + +class TransportType { + #tari_transport_type_ptr; + #type = "None"; + + pointerAssign(ptr, type) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_transport_type_ptr) { + this.destroy(); + this.#tari_transport_type_ptr = ptr; + this.#type = type; + } else { + this.#tari_transport_type_ptr = ptr; + this.#type = type; + } + } + + getPtr() { + return this.#tari_transport_type_ptr; + } + + getType() { + return this.#type; + } + + static createMemory() { + let result = new TransportType(); + result.pointerAssign(InterfaceFFI.transportMemoryCreate(), "Memory"); + return result; + } + + static createTCP(listener_address) { + let sanitize = utf8.encode(listener_address); // Make sure it's not UTF-16 encoded (JS default) + let result = new TransportType(); + result.pointerAssign(InterfaceFFI.transportTcpCreate(sanitize), "TCP"); + return result; + } + + static createTor( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password + ) { + let sanitize_address = utf8.encode(control_server_address); + let sanitize_username = utf8.encode(socks_username); + let sanitize_password = utf8.encode(socks_password); + let result = new TransportType(); + result.pointerAssign( + InterfaceFFI.transportTorCreate( + sanitize_address, + tor_cookie.getPtr(), + tor_port, + sanitize_username, + sanitize_password + ), + "Tor" + ); + return result; + } + + getAddress() { + if (this.#type === "Memory") { + let c_address = InterfaceFFI.transportMemoryGetAddress(this.getPtr()); + let result = c_address.readCString(); + InterfaceFFI.stringDestroy(c_address); + return result; + } else { + return "N/A"; + } + } + + destroy() { + this.#type = "None"; + if (this.#tari_transport_type_ptr) { + InterfaceFFI.transportTypeDestroy(this.#tari_transport_type_ptr); + this.#tari_transport_type_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = TransportType; diff --git a/integration_tests/helpers/ffi/wallet.js b/integration_tests/helpers/ffi/wallet.js new file mode 100644 index 0000000000..fea21fe682 --- /dev/null +++ b/integration_tests/helpers/ffi/wallet.js @@ -0,0 +1,449 @@ +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); +const CompletedTransaction = require("./completedTransaction"); +const CompletedTransactions = require("./completedTransactions"); +const PendingInboundTransaction = require("./pendingInboundTransaction"); +const PendingInboundTransactions = require("./pendingInboundTransactions"); +const PendingOutboundTransactions = require("./pendingOutboundTransactions"); +const Contact = require("./contact"); +const Contacts = require("./contacts"); + +const utf8 = require("utf8"); + +class Wallet { + #wallet_ptr; + #log_path = ""; + receivedTransaction = 0; + receivedTransactionReply = 0; + transactionBroadcast = 0; + transactionMined = 0; + saf_messages = 0; + + utxo_validation_complete = false; + utxo_validation_result = 0; + stxo_validation_complete = false; + stxo_validation_result = 0; + + getUtxoValidationStatus() { + return { + utxo_validation_complete: this.utxo_validation_complete, + utxo_validation_result: this.utxo_validation_result, + }; + } + + getStxoValidationStatus() { + return { + stxo_validation_complete: this.stxo_validation_complete, + stxo_validation_result: this.stxo_validation_result, + }; + } + + clearCallbackCounters() { + this.receivedTransaction = + this.receivedTransactionReply = + this.transactionBroadcast = + this.transactionMined = + this.saf_messages = + this.cancelled = + this.minedunconfirmed = + this.finalized = + 0; + } + + getCounters() { + return { + received: this.receivedTransaction, + replyreceived: this.receivedTransactionReply, + broadcast: this.transactionBroadcast, + finalized: this.finalized, + minedunconfirmed: this.minedunconfirmed, + cancelled: this.cancelled, + mined: this.transactionMined, + saf: this.saf_messages, + }; + } + + constructor( + comms_config_ptr, + log_path, + passphrase, + seed_words_ptr, + num_rolling_log_file = 50, + log_size_bytes = 102400 + ) { + this.receivedTransaction = 0; + this.receivedTransactionReply = 0; + this.transactionBroadcast = 0; + this.transactionMined = 0; + this.saf_messages = 0; + this.cancelled = 0; + this.minedunconfirmed = 0; + this.finalized = 0; + this.recoveryFinished = true; + let sanitize = null; + let words = null; + if (passphrase) { + sanitize = utf8.encode(passphrase); + } + if (seed_words_ptr) { + words = seed_words_ptr; + } + this.#log_path = log_path; + this.#wallet_ptr = InterfaceFFI.walletCreate( + comms_config_ptr, + utf8.encode(this.#log_path), //`${this.baseDir}/log/wallet.log`, + num_rolling_log_file, + log_size_bytes, + sanitize, + words, + this.#callback_received_transaction, + this.#callback_received_transaction_reply, + this.#callback_received_finalized_transaction, + this.#callback_transaction_broadcast, + this.#callback_transaction_mined, + this.#callback_transaction_mined_unconfirmed, + this.#callback_direct_send_result, + this.#callback_store_and_forward_send_result, + this.#callback_transaction_cancellation, + this.#callback_utxo_validation_complete, + this.#callback_stxo_validation_complete, + this.#callback_invalid_txo_validation_complete, + this.#callback_transaction_validation_complete, + this.#callback_saf_message_received + ); + } + + //region Callbacks + #onReceivedTransaction = (ptr) => { + // refer to outer scope in callback function otherwise this is null + let tx = new PendingInboundTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received Transaction with txID ${tx.getTransactionID()}` + ); + tx.destroy(); + this.receivedTransaction += 1; + }; + + #onReceivedTransactionReply = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received reply for Transaction with txID ${tx.getTransactionID()}.` + ); + tx.destroy(); + this.receivedTransactionReply += 1; + }; + + #onReceivedFinalizedTransaction = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received finalization for Transaction with txID ${tx.getTransactionID()}.` + ); + tx.destroy(); + this.finalized += 1; + }; + + #onTransactionBroadcast = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was broadcast.` + ); + tx.destroy(); + this.transactionBroadcast += 1; + }; + + #onTransactionMined = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was mined.` + ); + tx.destroy(); + this.transactionMined += 1; + }; + + #onTransactionMinedUnconfirmed = (ptr, confirmations) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} is mined unconfirmed with ${confirmations} confirmations.` + ); + tx.destroy(); + this.minedunconfirmed += 1; + }; + + #onTransactionCancellation = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was cancelled` + ); + tx.destroy(); + this.cancelled += 1; + }; + + #onDirectSendResult = (id, success) => { + console.log( + `${new Date().toISOString()} callbackDirectSendResult(${id},${success})` + ); + }; + + #onStoreAndForwardSendResult = (id, success) => { + console.log( + `${new Date().toISOString()} callbackStoreAndForwardSendResult(${id},${success})` + ); + }; + + #onUtxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackUtxoValidationComplete(${request_key},${validation_results})` + ); + this.utxo_validation_complete = true; + this.utxo_validation_result = validation_results; + }; + + #onStxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackStxoValidationComplete(${request_key},${validation_results})` + ); + this.stxo_validation_complete = true; + this.stxo_validation_result = validation_results; + }; + + #onInvalidTxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackInvalidTxoValidationComplete(${request_key},${validation_results})` + ); + //this.invalidtxo_validation_complete = true; + //this.invalidtxo_validation_result = validation_results; + }; + + #onTransactionValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackTransactionValidationComplete(${request_key},${validation_results})` + ); + //this.transaction_validation_complete = true; + //this.transaction_validation_result = validation_results; + }; + + #onSafMessageReceived = () => { + console.log(`${new Date().toISOString()} callbackSafMessageReceived()`); + this.saf_messages += 1; + }; + + #onRecoveryProgress = (a, b, c) => { + console.log( + `${new Date().toISOString()} recoveryProgressCallback(${a},${b},${c})` + ); + if (a === 4) { + console.log(`Recovery completed, funds recovered: ${c} uT`); + } + }; + + #callback_received_transaction = + InterfaceFFI.createCallbackReceivedTransaction(this.#onReceivedTransaction); + #callback_received_transaction_reply = + InterfaceFFI.createCallbackReceivedTransactionReply( + this.#onReceivedTransactionReply + ); + #callback_received_finalized_transaction = + InterfaceFFI.createCallbackReceivedFinalizedTransaction( + this.#onReceivedFinalizedTransaction + ); + #callback_transaction_broadcast = + InterfaceFFI.createCallbackTransactionBroadcast( + this.#onTransactionBroadcast + ); + #callback_transaction_mined = InterfaceFFI.createCallbackTransactionMined( + this.#onTransactionMined + ); + #callback_transaction_mined_unconfirmed = + InterfaceFFI.createCallbackTransactionMinedUnconfirmed( + this.#onTransactionMinedUnconfirmed + ); + #callback_direct_send_result = InterfaceFFI.createCallbackDirectSendResult( + this.#onDirectSendResult + ); + #callback_store_and_forward_send_result = + InterfaceFFI.createCallbackStoreAndForwardSendResult( + this.#onStoreAndForwardSendResult + ); + #callback_transaction_cancellation = + InterfaceFFI.createCallbackTransactionCancellation( + this.#onTransactionCancellation + ); + #callback_utxo_validation_complete = + InterfaceFFI.createCallbackUtxoValidationComplete( + this.#onUtxoValidationComplete + ); + #callback_stxo_validation_complete = + InterfaceFFI.createCallbackStxoValidationComplete( + this.#onStxoValidationComplete + ); + #callback_invalid_txo_validation_complete = + InterfaceFFI.createCallbackInvalidTxoValidationComplete( + this.#onInvalidTxoValidationComplete + ); + #callback_transaction_validation_complete = + InterfaceFFI.createCallbackTransactionValidationComplete( + this.#onTransactionValidationComplete + ); + #callback_saf_message_received = + InterfaceFFI.createCallbackSafMessageReceived(this.#onSafMessageReceived); + #recoveryProgressCallback = InterfaceFFI.createRecoveryProgressCallback( + this.#onRecoveryProgress + ); + //endregion + + startRecovery(base_node_pubkey) { + let node_pubkey = PublicKey.fromHexString(utf8.encode(base_node_pubkey)); + InterfaceFFI.walletStartRecovery( + this.#wallet_ptr, + node_pubkey.getPtr(), + this.#recoveryProgressCallback + ); + node_pubkey.destroy(); + } + + recoveryInProgress() { + return InterfaceFFI.walletIsRecoveryInProgress(this.#wallet_ptr); + } + + getPublicKey() { + let ptr = InterfaceFFI.walletGetPublicKey(this.#wallet_ptr); + let pk = new PublicKey(); + pk.pointerAssign(ptr); + let result = pk.getHex(); + pk.destroy(); + return result; + } + + getEmojiId() { + let ptr = InterfaceFFI.walletGetPublicKey(this.#wallet_ptr); + let pk = new PublicKey(); + pk.pointerAssign(ptr); + let result = pk.getEmojiId(); + pk.destroy(); + return result; + } + + getBalance() { + let available = InterfaceFFI.walletGetAvailableBalance(this.#wallet_ptr); + let pendingIncoming = InterfaceFFI.walletGetPendingIncomingBalance( + this.#wallet_ptr + ); + let pendingOutgoing = InterfaceFFI.walletGetPendingOutgoingBalance( + this.#wallet_ptr + ); + return { + pendingIn: pendingIncoming, + pendingOut: pendingOutgoing, + available: available, + }; + } + + addBaseNodePeer(public_key_hex, address) { + let public_key = PublicKey.fromHexString(utf8.encode(public_key_hex)); + let result = InterfaceFFI.walletAddBaseNodePeer( + this.#wallet_ptr, + public_key.getPtr(), + utf8.encode(address) + ); + public_key.destroy(); + return result; + } + + sendTransaction(destination, amount, fee_per_gram, message) { + let dest_public_key = PublicKey.fromHexString(utf8.encode(destination)); + let result = InterfaceFFI.walletSendTransaction( + this.#wallet_ptr, + dest_public_key.getPtr(), + amount, + fee_per_gram, + utf8.encode(message) + ); + dest_public_key.destroy(); + return result; + } + + applyEncryption(passphrase) { + InterfaceFFI.walletApplyEncryption( + this.#wallet_ptr, + utf8.encode(passphrase) + ); + } + + getCompletedTransactions() { + let list_ptr = InterfaceFFI.walletGetCompletedTransactions( + this.#wallet_ptr + ); + return new CompletedTransactions(list_ptr); + } + + getInboundTransactions() { + let list_ptr = InterfaceFFI.walletGetPendingInboundTransactions( + this.#wallet_ptr + ); + return new PendingInboundTransactions(list_ptr); + } + + getOutboundTransactions() { + let list_ptr = InterfaceFFI.walletGetPendingOutboundTransactions( + this.#wallet_ptr + ); + return new PendingOutboundTransactions(list_ptr); + } + + getContacts() { + let list_ptr = InterfaceFFI.walletGetContacts(this.#wallet_ptr); + return new Contacts(list_ptr); + } + + addContact(alias, pubkey_hex) { + let public_key = PublicKey.fromHexString(utf8.encode(pubkey_hex)); + let contact = new Contact(); + contact.pointerAssign( + InterfaceFFI.contactCreate(utf8.encode(alias), public_key.getPtr()) + ); + let result = InterfaceFFI.walletUpsertContact( + this.#wallet_ptr, + contact.getPtr() + ); + contact.destroy(); + public_key.destroy(); + return result; + } + + removeContact(contact) { + let result = InterfaceFFI.walletRemoveContact( + this.#wallet_ptr, + contact.getPtr() + ); + contact.destroy(); + return result; + } + + cancelPendingTransaction(tx_id) { + return InterfaceFFI.walletCancelPendingTransaction(this.#wallet_ptr, tx_id); + } + + startUtxoValidation() { + return InterfaceFFI.walletStartUtxoValidation(this.#wallet_ptr); + } + + startStxoValidation() { + return InterfaceFFI.walletStartStxoValidation(this.#wallet_ptr); + } + + destroy() { + if (this.#wallet_ptr) { + InterfaceFFI.walletDestroy(this.#wallet_ptr); + this.#wallet_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = Wallet; diff --git a/integration_tests/helpers/walletFFIClient.js b/integration_tests/helpers/walletFFIClient.js index 60596c8e34..8835f03dee 100644 --- a/integration_tests/helpers/walletFFIClient.js +++ b/integration_tests/helpers/walletFFIClient.js @@ -1,440 +1,187 @@ -const WalletFFI = require("./ffi/walletFFI"); +const SeedWords = require("./ffi/seedWords"); +const TransportType = require("./ffi/transportType"); +const CommsConfig = require("./ffi/commsConfig"); +const Wallet = require("./ffi/wallet"); const { getFreePort } = require("./util"); const dateFormat = require("dateformat"); -const { expect } = require("chai"); -const PublicKey = require("./ffi/publicKey"); -const CompletedTransactions = require("./ffi/completedTransactions"); -const PendingOutboundTransactions = require("./ffi/pendingOutboundTransactions"); -const Contact = require("./ffi/contact"); -const Contacts = require("./ffi/contacts"); -const SeedWords = require("./ffi/seedWords"); +const InterfaceFFI = require("./ffi/ffiInterface"); class WalletFFIClient { #name; #wallet; #comms_config; + #transport; + #seed_words; + #pass_phrase; #port; - #callback_received_transaction; - #callback_received_transaction_reply; - #callback_received_finalized_transaction; - #callback_transaction_broadcast; - #callback_transaction_mined; - #callback_transaction_mined_unconfirmed; - #callback_direct_send_result; - #callback_store_and_forward_send_result; - #callback_transaction_cancellation; - #callback_utxo_validation_complete; - #callback_stxo_validation_complete; - #callback_invalid_txo_validation_complete; - #callback_transaction_validation_complete; - #callback_saf_message_received; - #recovery_progress_callback; - - #callbackReceivedTransaction = (..._args) => { - console.log(`${new Date().toISOString()} callbackReceivedTransaction`); - this.receivedTransaction += 1; - }; - #callbackReceivedTransactionReply = (..._args) => { - console.log(`${new Date().toISOString()} callbackReceivedTransactionReply`); - this.receivedTransactionReply += 1; - }; - #callbackReceivedFinalizedTransaction = (..._args) => { - console.log( - `${new Date().toISOString()} callbackReceivedFinalizedTransaction` - ); - }; - #callbackTransactionBroadcast = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionBroadcast`); - this.transactionBroadcast += 1; - }; - #callbackTransactionMined = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionMined`); - this.transactionMined += 1; - }; - #callbackTransactionMinedUnconfirmed = (..._args) => { - console.log( - `${new Date().toISOString()} callbackTransactionMinedUnconfirmed` - ); - }; - #callbackDirectSendResult = (..._args) => { - console.log(`${new Date().toISOString()} callbackDirectSendResult`); - }; - #callbackStoreAndForwardSendResult = (..._args) => { - console.log( - `${new Date().toISOString()} callbackStoreAndForwardSendResult` - ); - }; - #callbackTransactionCancellation = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionCancellation`); - }; - #callbackUtxoValidationComplete = (_request_key, validation_results) => { - console.log(`${new Date().toISOString()} callbackUtxoValidationComplete`); - this.utxo_validation_complete = true; - this.utxo_validation_result = validation_results; - }; - #callbackStxoValidationComplete = (_request_key, validation_results) => { - console.log(`${new Date().toISOString()} callbackStxoValidationComplete`); - this.stxo_validation_complete = true; - this.stxo_validation_result = validation_results; - }; - #callbackInvalidTxoValidationComplete = (..._args) => { - console.log( - `${new Date().toISOString()} callbackInvalidTxoValidationComplete` - ); - }; - #callbackTransactionValidationComplete = (..._args) => { - console.log( - `${new Date().toISOString()} callbackTransactionValidationComplete` - ); - }; - #callbackSafMessageReceived = (..._args) => { - console.log(`${new Date().toISOString()} callbackSafMessageReceived`); - }; - #recoveryProgressCallback = (a, b, c) => { - console.log(`${new Date().toISOString()} recoveryProgressCallback`); - if (a == 3) - // Progress - this.recoveryProgress = [b, c]; - if (a == 4) - // Completed - this.recoveryInProgress = false; - }; - - clearCallbackCounters() { - this.receivedTransaction = - this.receivedTransactionReply = - this.transactionBroadcast = - this.transactionMined = - 0; - } + baseDir = ""; constructor(name) { - this.#wallet = null; this.#name = name; - this.baseDir = ""; - this.clearCallbackCounters(); - - // Create the ffi callbacks - this.#callback_received_transaction = - WalletFFI.createCallbackReceivedTransaction( - this.#callbackReceivedTransaction - ); - this.#callback_received_transaction_reply = - WalletFFI.createCallbackReceivedTransactionReply( - this.#callbackReceivedTransactionReply - ); - this.#callback_received_finalized_transaction = - WalletFFI.createCallbackReceivedFinalizedTransaction( - this.#callbackReceivedFinalizedTransaction - ); - this.#callback_transaction_broadcast = - WalletFFI.createCallbackTransactionBroadcast( - this.#callbackTransactionBroadcast - ); - this.#callback_transaction_mined = WalletFFI.createCallbackTransactionMined( - this.#callbackTransactionMined - ); - this.#callback_transaction_mined_unconfirmed = - WalletFFI.createCallbackTransactionMinedUnconfirmed( - this.#callbackTransactionMinedUnconfirmed - ); - this.#callback_direct_send_result = - WalletFFI.createCallbackDirectSendResult(this.#callbackDirectSendResult); - this.#callback_store_and_forward_send_result = - WalletFFI.createCallbackStoreAndForwardSendResult( - this.#callbackStoreAndForwardSendResult - ); - this.#callback_transaction_cancellation = - WalletFFI.createCallbackTransactionCancellation( - this.#callbackTransactionCancellation - ); - this.#callback_utxo_validation_complete = - WalletFFI.createCallbackUtxoValidationComplete( - this.#callbackUtxoValidationComplete - ); - this.#callback_stxo_validation_complete = - WalletFFI.createCallbackStxoValidationComplete( - this.#callbackStxoValidationComplete - ); - this.#callback_invalid_txo_validation_complete = - WalletFFI.createCallbackInvalidTxoValidationComplete( - this.#callbackInvalidTxoValidationComplete - ); - this.#callback_transaction_validation_complete = - WalletFFI.createCallbackTransactionValidationComplete( - this.#callbackTransactionValidationComplete - ); - this.#callback_saf_message_received = - WalletFFI.createCallbackSafMessageReceived( - this.#callbackSafMessageReceived - ); - this.#recovery_progress_callback = WalletFFI.createRecoveryProgressCallback( - this.#recoveryProgressCallback - ); } static async Init() { - await WalletFFI.Init(); + await InterfaceFFI.Init(); } - async startNew(seed_words_text) { + async startNew(seed_words_text, pass_phrase) { this.#port = await getFreePort(19000, 25000); const name = `WalletFFI${this.#port}-${this.#name}`; this.baseDir = `./temp/base_nodes/${dateFormat( new Date(), "yyyymmddHHMM" )}/${name}`; - const tcp = await WalletFFI.transportTcpCreate( - `/ip4/0.0.0.0/tcp/${this.#port}` - ); - this.#comms_config = await WalletFFI.commsConfigCreate( + this.#transport = TransportType.createTCP(`/ip4/0.0.0.0/tcp/${this.#port}`); + this.#comms_config = new CommsConfig( `/ip4/0.0.0.0/tcp/${this.#port}`, - tcp, + this.#transport.getPtr(), "wallet.dat", this.baseDir, 30, 600, "localnet" ); - await this.start(seed_words_text); + this.#start(seed_words_text, pass_phrase); } - async start(seed_words_text) { - let seed_words; - let seed_words_ptr = WalletFFI.NULL; - if (seed_words_text) { - seed_words = await SeedWords.fromString(seed_words_text); - seed_words_ptr = seed_words.getPtr(); - } - this.#wallet = await WalletFFI.walletCreate( - this.#comms_config, - `${this.baseDir}/log/wallet.log`, - 50, - 102400, - WalletFFI.NULL, - seed_words_ptr, - this.#callback_received_transaction, - this.#callback_received_transaction_reply, - this.#callback_received_finalized_transaction, - this.#callback_transaction_broadcast, - this.#callback_transaction_mined, - this.#callback_transaction_mined_unconfirmed, - this.#callback_direct_send_result, - this.#callback_store_and_forward_send_result, - this.#callback_transaction_cancellation, - this.#callback_utxo_validation_complete, - this.#callback_stxo_validation_complete, - this.#callback_invalid_txo_validation_complete, - this.#callback_transaction_validation_complete, - this.#callback_saf_message_received + async restart(seed_words_text, pass_phrase) { + this.#transport = TransportType.createTCP(`/ip4/0.0.0.0/tcp/${this.#port}`); + this.#comms_config = new CommsConfig( + `/ip4/0.0.0.0/tcp/${this.#port}`, + this.#transport.getPtr(), + "wallet.dat", + this.baseDir, + 30, + 600, + "localnet" ); - if (seed_words) await seed_words.destroy(); + this.#start(seed_words_text, pass_phrase); } - async startRecovery(base_node_pubkey) { - const node_pubkey = await PublicKey.fromString(base_node_pubkey); - expect( - await WalletFFI.walletStartRecovery( - this.#wallet, - node_pubkey.getPtr(), - this.#recovery_progress_callback - ) - ).to.be.true; - node_pubkey.destroy(); - this.recoveryInProgress = true; + getStxoValidationStatus() { + return this.#wallet.getStxoValidationStatus(); } - recoveryInProgress() { - return this.recoveryInProgress; + getUtxoValidationStatus() { + return this.#wallet.getUtxoValidationStatus(); + } + identify() { + return this.#wallet.getPublicKey(); } - async stop() { - await WalletFFI.walletDestroy(this.#wallet); + identifyEmoji() { + return this.#wallet.getEmojiId(); } - async getPublicKey() { - const public_key = await PublicKey.fromWallet(this.#wallet); - const public_key_hex = public_key.getHex(); - public_key.destroy(); - return public_key_hex; + getBalance() { + return this.#wallet.getBalance(); } - async getEmojiId() { - const public_key = await PublicKey.fromWallet(this.#wallet); - const emoji_id = await public_key.getEmojiId(); - public_key.destroy(); - return emoji_id; + addBaseNodePeer(public_key_hex, address) { + return this.#wallet.addBaseNodePeer(public_key_hex, address); } - async getBalance() { - return await WalletFFI.walletGetAvailableBalance(this.#wallet); + addContact(alias, pubkey_hex) { + return this.#wallet.addContact(alias, pubkey_hex); } - async addBaseNodePeer(public_key_hex, address) { - const public_key = await PublicKey.fromString(public_key_hex); - expect( - await WalletFFI.walletAddBaseNodePeer( - this.#wallet, - public_key.getPtr(), - address - ) - ).to.be.true; - await public_key.destroy(); + getContactList() { + return this.#wallet.getContacts(); } - async sendTransaction(destination, amount, fee_per_gram, message) { - const dest_public_key = await PublicKey.fromString(destination); - const result = await WalletFFI.walletSendTransaction( - this.#wallet, - dest_public_key.getPtr(), - amount, - fee_per_gram, - message - ); - await dest_public_key.destroy(); - return result; + getCompletedTxs() { + return this.#wallet.getCompletedTransactions(); } - async applyEncryption(passphrase) { - await WalletFFI.walletApplyEncryption(this.#wallet, passphrase); + getInboundTxs() { + return this.#wallet.getInboundTransactions(); } - async getCompletedTransactions() { - const txs = await CompletedTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - let outbound = 0; - let inbound = 0; - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - if (await tx.isOutbound()) { - ++outbound; - } else { - ++inbound; - } - tx.destroy(); - } - txs.destroy(); - return [outbound, inbound]; + getOutboundTxs() { + return this.#wallet.getOutboundTransactions(); } - async getBroadcastTransactionsCount() { - let broadcast_tx_cnt = 0; - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - const status = await tx.getStatus(); - tx.destroy(); - if (status === 1) { - // Broadcast - broadcast_tx_cnt += 1; - } - } - await txs.destroy(); - return broadcast_tx_cnt; + removeContact(contact) { + return this.#wallet.removeContact(contact); } - async getOutboundTransactionsCount() { - let outbound_tx_cnt = 0; - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - const status = await tx.getStatus(); - if (status === 4) { - // Pending - outbound_tx_cnt += 1; - } - tx.destroy(); - } - await txs.destroy(); - return outbound_tx_cnt; + startRecovery(base_node_pubkey) { + this.#wallet.startRecovery(base_node_pubkey); } - async addContact(alias, pubkey_hex) { - const public_key = await PublicKey.fromString(pubkey_hex); - const contact = new Contact( - await WalletFFI.contactCreate(alias, public_key.getPtr()) - ); - public_key.destroy(); - expect(await WalletFFI.walletUpsertContact(this.#wallet, contact.getPtr())) - .to.be.true; - contact.destroy(); + checkRecoveryInProgress() { + return this.#wallet.recoveryInProgress(); } - async #findContact(lookup_alias) { - const contacts = await Contacts.fromWallet(this.#wallet); - const length = await contacts.getLength(); - let contact; - for (let i = 0; i < length; ++i) { - contact = await contacts.getAt(i); - const alias = await contact.getAlias(); - const found = alias === lookup_alias; - if (found) { - break; - } - contact.destroy(); - contact = undefined; - } - contacts.destroy(); - return contact; + applyEncryption(passphrase) { + this.#wallet.applyEncryption(passphrase); } - async getContact(alias) { - const contact = await this.#findContact(alias); - if (contact) { - const pubkey = await contact.getPubkey(); - const pubkey_hex = pubkey.getHex(); - pubkey.destroy(); - contact.destroy(); - return pubkey_hex; - } + startStxoValidation() { + this.#wallet.startStxoValidation(); } - async removeContact(alias) { - const contact = await this.#findContact(alias); - if (contact) { - expect( - await WalletFFI.walletRemoveContact(this.#wallet, contact.getPtr()) - ).to.be.true; - contact.destroy(); - } + startUtxoValidation() { + this.#wallet.startUtxoValidation(); + } + + getCounters() { + return this.#wallet.getCounters(); + } + resetCounters() { + this.#wallet.clearCallbackCounters(); } - async identify() { - return { - public_key: await this.getPublicKey(), - }; + sendTransaction(destination, amount, fee_per_gram, message) { + return this.#wallet.sendTransaction( + destination, + amount, + fee_per_gram, + message + ); } - async cancelAllOutboundTransactions() { - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - let cancelled = 0; - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - if ( - await WalletFFI.walletCancelPendingTransaction( - this.#wallet, - await tx.getTransactionId() - ) - ) { - ++cancelled; - } - tx.destroy(); + #start( + seed_words_text, + pass_phrase, + rolling_log_files = 50, + byte_size_per_log = 102400 + ) { + this.#pass_phrase = pass_phrase; + if (seed_words_text) { + let seed_words = SeedWords.fromText(seed_words_text); + this.#seed_words = seed_words; } - txs.destroy(); - return cancelled; + + let log_path = `${this.baseDir}/log/wallet.log`; + this.#wallet = new Wallet( + this.#comms_config.getPtr(), + log_path, + this.#pass_phrase, + this.#seed_words ? this.#seed_words.getPtr() : null, + rolling_log_files, + byte_size_per_log + ); } - startUtxoValidation() { - this.utxo_validation_complete = false; - return WalletFFI.walletStartUtxoValidation(this.#wallet); + getOutboundTransactions() { + return this.#wallet.getOutboundTransactions(); } - startStxoValidation() { - this.stxo_validation_complete = false; - return WalletFFI.walletStartStxoValidation(this.#wallet); + cancelPendingTransaction(tx_id) { + return this.#wallet.cancelPendingTransaction(tx_id); + } + + stop() { + if (this.#wallet) { + this.#wallet.destroy(); + } + if (this.#comms_config) { + this.#comms_config.destroy(); + } + if (this.#seed_words) { + this.#seed_words.destroy(); + } } }