diff --git a/Cargo.lock b/Cargo.lock index 250a52339e..46ad019662 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e61f2b7f93d2c7d2b08263acaa4a363b3e276806c68af6134c44f523bf1aacd" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -28,40 +19,29 @@ dependencies = [ [[package]] name = "aead" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e3e798aa0c8239776f54415bc06f3d74b1850f3f830b45c35cfc80556973f70" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" dependencies = [ "generic-array", ] -[[package]] -name = "aes" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2bc6d3f370b5666245ff421e231cba4353df936e26986d2918e61a8fd6aef6" -dependencies = [ - "aes-soft 0.5.0", - "aesni 0.8.0", - "block-cipher", -] - [[package]] name = "aes" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "884391ef1066acaa41e766ba8f596341b96e93ce34f9a43e7d24bf0a0eaf0561" dependencies = [ - "aes-soft 0.6.4", - "aesni 0.10.0", + "aes-soft", + "aesni", "cipher 0.2.5", ] [[package]] name = "aes" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495ee669413bfbe9e8cace80f4d3d78e6d8c8d99579f97fb93bde351b185f2d4" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -85,29 +65,18 @@ dependencies = [ [[package]] name = "aes-gcm" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a930fd487faaa92a30afa92cc9dd1526a5cff67124abbbb1c617ce070f4dcf" +checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" dependencies = [ - "aead 0.4.2", - "aes 0.7.4", + "aead 0.4.3", + "aes 0.7.5", "cipher 0.3.0", "ctr 0.8.0", - "ghash 0.4.3", + "ghash 0.4.4", "subtle", ] -[[package]] -name = "aes-soft" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63dd91889c49327ad7ef3b500fd1109dbd3c509a03db0d4a9ce413b79f575cb6" -dependencies = [ - "block-cipher", - "byteorder", - "opaque-debug", -] - [[package]] name = "aes-soft" version = "0.6.4" @@ -118,16 +87,6 @@ dependencies = [ "opaque-debug", ] -[[package]] -name = "aesni" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6fe808308bb07d393e2ea47780043ec47683fcf19cf5efc8ca51c50cc8c68a" -dependencies = [ - "block-cipher", - "opaque-debug", -] - [[package]] name = "aesni" version = "0.10.0" @@ -206,9 +165,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" dependencies = [ "async-stream-impl", "futures-core", @@ -216,9 +175,9 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -259,21 +218,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" -[[package]] -name = "backtrace" -version = "0.3.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a905d892734eea339e896738c14b9afce22b5318f64b951e70bf3844419b01" -dependencies = [ - "addr2line", - "cc", - "cfg-if 1.0.0", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base58-monero" version = "0.3.0" @@ -302,12 +246,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "base64" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" - [[package]] name = "base64" version = "0.12.3" @@ -337,7 +275,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -360,7 +298,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "which", + "which 3.1.1", ] [[package]] @@ -389,9 +327,9 @@ checksum = "3e54f7b7a46d7b183eb41e2d82965261fa8a1597c68b50aced268ee1fc70272d" [[package]] name = "blake2" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a5720225ef5daecf08657f23791354e1685a8c91a4c60c7f3d3b2892f978f4" +checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" dependencies = [ "crypto-mac", "digest", @@ -435,12 +373,12 @@ checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" [[package]] name = "blowfish" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f06850ba969bc59388b2cc0a4f186fc6d9d37208863b15b84ae3866ac90ac06" +checksum = "32fa6a061124e37baba002e496d203e23ba3d7b73750be82dbfbc92913048a5b" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -459,7 +397,7 @@ dependencies = [ "lazy_static 1.4.0", "memchr", "regex-automata", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -496,30 +434,20 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" -[[package]] -name = "bytes" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" -dependencies = [ - "byteorder", - "iovec", -] - [[package]] name = "bytes" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" -dependencies = [ - "serde 1.0.129", -] [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +dependencies = [ + "serde 1.0.130", +] [[package]] name = "c_linked_list" @@ -550,12 +478,12 @@ dependencies = [ [[package]] name = "cast5" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ed1e6b53a3de8bafcce4b88867893c234e57f91686a4726d8e803771f0b55b" +checksum = "1285caf81ea1f1ece6b24414c521e625ad0ec94d880625c20f2e65d8d3f78823" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -571,7 +499,7 @@ dependencies = [ "log 0.4.14", "proc-macro2 1.0.28", "quote 1.0.9", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "syn 1.0.75", "tempfile", @@ -595,11 +523,11 @@ dependencies = [ [[package]] name = "cfb-mode" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fa76b7293f89734378d27057d169dc68077ad34b21dbcabf1c0a646a9462592" +checksum = "1d6975e91054798d325f85f50115056d7deccf6817fe7f947c438ee45b119632" dependencies = [ - "stream-cipher", + "cipher 0.2.5", ] [[package]] @@ -616,9 +544,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chacha20" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8756167ea0aca10e066cdbe7813bd71d2f24e69b0bc7b50509590cef2ce0b9" +checksum = "f08493fa7707effc63254c66c6ea908675912493cd67952eda23c09fae2610b1" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -628,11 +556,11 @@ dependencies = [ [[package]] name = "chacha20poly1305" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "175a11316f33592cf2b71416ee65283730b5b7849813c4891d02a12906ed9acc" +checksum = "b6547abe025f4027edacd9edaa357aded014eecec42a5070d9b885c3c334aba2" dependencies = [ - "aead 0.4.2", + "aead 0.4.3", "chacha20", "cipher 0.3.0", "poly1305", @@ -654,7 +582,7 @@ dependencies = [ "libc", "num-integer", "num-traits 0.2.14", - "serde 1.0.129", + "serde 1.0.130", "time", "winapi 0.3.9", ] @@ -677,7 +605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6316c62053228eddd526a5e6deb6344c80bf2bc1e9786e7f90b3083e73197c1" dependencies = [ "bitstring", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -763,7 +691,7 @@ dependencies = [ "lazy_static 1.4.0", "nom 4.2.3", "rust-ini", - "serde 1.0.129", + "serde 1.0.130", "serde-hjson", "serde_json", "toml 0.4.10", @@ -788,9 +716,9 @@ checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" [[package]] name = "cpufeatures" -version = "0.1.5" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" dependencies = [ "libc", ] @@ -827,7 +755,7 @@ dependencies = [ "clap", "criterion-plot", "csv", - "itertools", + "itertools 0.8.2", "lazy_static 1.4.0", "libc", "num-traits 0.2.14", @@ -836,7 +764,7 @@ dependencies = [ "rand_xoshiro", "rayon", "rayon-core", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "tinytemplate", @@ -851,7 +779,7 @@ checksum = "76f9212ddf2f4a9eb2d401635190600656a1f88a932ef53d06e7fa4c7e02fb8e" dependencies = [ "byteorder", "cast", - "itertools", + "itertools 0.8.2", ] [[package]] @@ -1014,7 +942,7 @@ dependencies = [ "csv-core", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -1068,7 +996,7 @@ dependencies = [ "byteorder", "digest", "rand_core 0.5.1", - "serde 1.0.129", + "serde 1.0.130", "subtle", "zeroize", ] @@ -1083,7 +1011,7 @@ dependencies = [ "digest", "packed_simd_2", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "subtle-ng", "zeroize", ] @@ -1178,12 +1106,12 @@ dependencies = [ [[package]] name = "des" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e084b5048dec677e6c9f27d7abc551dde7d127cf4127fea82323c98a30d7fa0d" +checksum = "b24e7c748888aa2fa8bce21d8c64a52efc810663285315ac7476f7197a982fae" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -1279,7 +1207,7 @@ dependencies = [ "curve25519-dalek", "ed25519", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "sha2", "zeroize", ] @@ -1599,19 +1527,6 @@ dependencies = [ "pin-utils", ] -[[package]] -name = "futures-test-preview" -version = "0.3.0-alpha.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0813d833a213f893d1f07ccd9d49d3de100181d146053b2097b8b934d7945eb" -dependencies = [ - "futures-core-preview", - "futures-executor-preview", - "futures-io-preview", - "futures-util-preview", - "pin-utils", -] - [[package]] name = "futures-timer" version = "0.3.0" @@ -1731,20 +1646,14 @@ dependencies = [ [[package]] name = "ghash" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b442c439366184de619215247d24e908912b175e824a530253845ac4c251a5c1" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" dependencies = [ "opaque-debug", - "polyval 0.5.2", + "polyval 0.5.3", ] -[[package]] -name = "gimli" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" - [[package]] name = "git2" version = "0.8.0" @@ -1792,7 +1701,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7f3675cfef6a30c8031cf9e6493ebdc3bb3272a3fea3923c4210d1830e6a472" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -1847,7 +1756,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -1868,7 +1777,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] @@ -1955,7 +1864,7 @@ version = "0.14.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13f67199e765030fa08fe0bd581af683f0d5bc04ea09c2b1102012c5fb90e7fd" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -1973,6 +1882,18 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.12", + "pin-project-lite 0.2.7", + "tokio 1.10.1", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.4.3" @@ -1992,7 +1913,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "hyper 0.14.12", "native-tls", "tokio 1.10.1", @@ -2037,7 +1958,7 @@ dependencies = [ "byteorder", "color_quant", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits 0.2.14", ] @@ -2090,6 +2011,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -2112,7 +2042,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "436f3455a8a4e9c7b14de9f1206198ee5d0bdc2db1b560339d2141093d7dd389" dependencies = [ "hyper 0.10.16", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", ] @@ -2162,9 +2092,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.100" +version = "0.2.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1fa8cddc8fbbee11227ef194b5317ed014b8acbf15139bd716a18ad3fe99ec5" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libgit2-sys" @@ -2289,9 +2219,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" dependencies = [ "scopeguard", ] @@ -2312,7 +2242,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -2335,7 +2265,7 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.5.3", "serde_derive", "serde_yaml", @@ -2359,9 +2289,9 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.7.0", "serde_json", "serde_yaml", @@ -2512,17 +2442,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -2555,21 +2474,39 @@ dependencies = [ "fixed-hash", "hex", "hex-literal", - "serde 1.0.129", + "serde 1.0.130", "serde-big-array", "thiserror", "tiny-keccak", ] +[[package]] +name = "multiaddr" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48ee4ea82141951ac6379f964f71b20876d43712bea8faf6dd1a375e08a46499" +dependencies = [ + "arrayref", + "bs58", + "byteorder", + "data-encoding", + "multihash", + "percent-encoding 2.1.0", + "serde 1.0.130", + "static_assertions", + "unsigned-varint", + "url 2.2.2", +] + [[package]] name = "multihash" -version = "0.13.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac63698b887d2d929306ea48b63760431ff8a24fac40ddb22f9c7f49fb7cab" +checksum = "752a61cd890ff691b4411423d23816d5866dd5621e4d1c5687a53b94b5a979d8" dependencies = [ "generic-array", "multihash-derive", - "unsigned-varint 0.5.1", + "unsigned-varint", ] [[package]] @@ -2629,9 +2566,12 @@ checksum = "d36047f46c69ef97b60e7b069a26ce9a15cd8a7852eddb6991ea94a83ba36a78" [[package]] name = "nibble_vec" -version = "0.0.4" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d77f3db4bce033f4d04db08079b2ef1c3d02b44e86f25d08886fafa7756ffa" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] [[package]] name = "nix" @@ -2693,10 +2633,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" dependencies = [ "num-bigint 0.3.2", - "num-complex", + "num-complex 0.3.1", "num-integer", "num-iter", - "num-rational", + "num-rational 0.3.2", + "num-traits 0.2.14", +] + +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint 0.4.1", + "num-complex 0.4.0", + "num-integer", + "num-iter", + "num-rational 0.4.0", "num-traits 0.2.14", ] @@ -2722,6 +2676,17 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-bigint" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76e97c412795abf6c24ba30055a8f20642ea57ca12875220b854cfa501bf1e48" +dependencies = [ + "autocfg 1.0.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-bigint-dig" version = "0.6.1" @@ -2736,7 +2701,7 @@ dependencies = [ "num-iter", "num-traits 0.2.14", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "smallvec", "zeroize", ] @@ -2750,6 +2715,15 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -2804,6 +2778,18 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg 1.0.1", + "num-bigint 0.4.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-traits" version = "0.1.43" @@ -2832,15 +2818,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee2766204889d09937d00bfbb7fec56bb2a199e2ade963cab19185d8a6104c7c" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.8.0" @@ -2904,14 +2881,14 @@ checksum = "e1cf9b1c4e9a6c4de793c632496fa490bdc0e1eea73f0c91394f7b6990935d22" dependencies = [ "async-trait", "crossbeam-channel 0.5.1", - "futures 0.3.15", + "futures 0.3.16", "js-sys", "lazy_static 1.4.0", "percent-encoding 2.1.0", - "pin-project 1.0.7", + "pin-project 1.0.8", "rand 0.8.4", "thiserror", - "tokio 1.9.0", + "tokio 1.10.1", "tokio-stream", ] @@ -2927,7 +2904,7 @@ dependencies = [ "opentelemetry-semantic-conventions", "thiserror", "thrift", - "tokio 1.9.0", + "tokio 1.10.1", ] [[package]] @@ -2967,24 +2944,6 @@ dependencies = [ "libm 0.1.4", ] -[[package]] -name = "parity-multiaddr" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bfda2e46fc5e14122649e2645645a81ee5844e0fb2e727ef560cc71a8b2d801" -dependencies = [ - "arrayref", - "bs58", - "byteorder", - "data-encoding", - "multihash", - "percent-encoding 2.1.0", - "serde 1.0.129", - "static_assertions", - "unsigned-varint 0.6.0", - "url 2.2.2", -] - [[package]] name = "parking_lot" version = "0.10.2" @@ -2997,13 +2956,13 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", - "lock_api 0.4.4", - "parking_lot_core 0.8.3", + "lock_api 0.4.5", + "parking_lot_core 0.8.5", ] [[package]] @@ -3022,9 +2981,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" +checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" dependencies = [ "cfg-if 1.0.0", "instant", @@ -3090,11 +3049,11 @@ dependencies = [ [[package]] name = "pgp" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501f8c2834bc16a23ae40932b9f924c6c5fc1d7cd1cc3536a532f37e81f603ed" +checksum = "856124b4d0a95badd3e1ad353edd7157fc6c6995767b78ef62848f3b296405ff" dependencies = [ - "aes 0.5.0", + "aes 0.6.0", "base64 0.12.3", "bitfield", "block-modes", @@ -3105,6 +3064,7 @@ dependencies = [ "cast5", "cfb-mode", "chrono", + "cipher 0.2.5", "circular", "clear_on_drop", "crc24", @@ -3203,9 +3163,9 @@ checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "poly1305" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fcffab1f78ebbdf4b93b68c1ffebc24037eedf271edaca795732b24e5e4e349" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" dependencies = [ "cpufeatures", "opaque-debug", @@ -3225,9 +3185,9 @@ dependencies = [ [[package]] name = "polyval" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6ba6a405ef63530d6cb12802014b22f9c5751bd17cdcddbe9e46d5c8ae83287" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" dependencies = [ "cfg-if 1.0.0", "cpufeatures", @@ -3307,40 +3267,40 @@ dependencies = [ [[package]] name = "prost" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce49aefe0a6144a45de32927c77bd2859a5f7677b55f220ae5b744e87389c212" +checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost-derive", ] [[package]] name = "prost-build" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b10678c913ecbd69350e8535c3aef91a8676c0773fc1d7b95cdd196d7f2f26" +checksum = "355f634b43cdd80724ee7848f95770e7e70eefa6dcf14fea676216573b8fd603" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "heck", - "itertools", + "itertools 0.10.1", "log 0.4.14", "multimap", "petgraph", "prost", "prost-types", "tempfile", - "which", + "which 4.2.2", ] [[package]] name = "prost-derive" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537aa19b95acde10a12fec4301466386f757403de4cd4e5b4fa78fb5ecb18f72" +checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.1", "proc-macro2 1.0.28", "quote 1.0.9", "syn 1.0.75", @@ -3348,11 +3308,11 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1834f67c0697c001304b75be76f67add9c89742eda3a085ad8ee0bb38c3417aa" +checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost", ] @@ -3398,9 +3358,9 @@ dependencies = [ [[package]] name = "radix_trie" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3681b28cd95acfb0560ea9441f82d6a4504fa3b15b97bd7b6e952131820e95" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" dependencies = [ "endian-type", "nibble_vec", @@ -3417,7 +3377,6 @@ dependencies = [ "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc 0.2.0", - "rand_pcg", ] [[package]] @@ -3517,15 +3476,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "rand_pcg" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "rand_xoshiro" version = "0.1.0" @@ -3666,8 +3616,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", - "serde_json", + "serde 1.0.130", "serde_urlencoded", "tokio 0.2.25", "tokio-tls", @@ -3685,7 +3634,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246e9f61b9bb77df069a947682be06e31ac43ea37862e244a69f177694ea6d22" dependencies = [ "base64 0.13.0", - "bytes 1.0.1", + "bytes 1.1.0", "encoding_rs", "futures-core", "futures-util", @@ -3701,7 +3650,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "serde_urlencoded", "tokio 1.10.1", @@ -3757,7 +3706,7 @@ checksum = "011e1d58446e9fa3af7cdc1fb91295b10621d3ac4cb3a85cc86385ee9ca50cd3" dependencies = [ "byteorder", "rmp", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -3798,12 +3747,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e52c148ef37f8c375d49d5a73aa70713125b7f19095948a923f80afdeb22ec2" -[[package]] -name = "rustc-demangle" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -3836,11 +3779,11 @@ dependencies = [ [[package]] name = "rustls" -version = "0.17.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d4a31f5d68413404705d6982529b0e11a9aacd4839d1d6222ee3b8cb4015e1" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ - "base64 0.11.0", + "base64 0.13.0", "log 0.4.14", "ring", "sct", @@ -3931,22 +3874,23 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467" +checksum = "5b9bd29cdffb8875b04f71c51058f940cf4e390bbfd2ce669c4f22cd70b492a5" dependencies = [ "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", + "num 0.4.0", "security-framework-sys", ] [[package]] name = "security-framework-sys" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4effb91b4b8b6fb7732e670b6cee160278ff8e6bf485c7805d9e319d76e284" +checksum = "19133a286e494cc3311c165c4676ccb1fd47bed45b55f9d71fbd784ad4cea6f8" dependencies = [ "core-foundation-sys", "libc", @@ -3984,9 +3928,9 @@ checksum = "9dad3f759919b92c3068c696c15c3d17238234498bbdcc80f2c469606f948ac8" [[package]] name = "serde" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1f72836d2aa753853178eda473a3b9d8e4eefdaf20523b919677e6de489f8f1" +checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" dependencies = [ "serde_derive", ] @@ -3997,7 +3941,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_derive", ] @@ -4021,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a663f873dedc4eac1a559d4c6bc0d0b2c34dc5ac4702e105014b8281489e44f" dependencies = [ "ordered-float 1.1.1", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4031,14 +3975,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" dependencies = [ "ordered-float 2.7.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_derive" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e57ae87ad533d9a56427558b516d0adac283614e347abf85b0dc0cbbf0a249f3" +checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -4047,13 +3991,13 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" +checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950" dependencies = [ "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4085,26 +4029,26 @@ dependencies = [ "form_urlencoded", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_yaml" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" +checksum = "ad104641f3c958dab30eb3010e834c2622d1f3f4c530fef1dee20ad9485f3c09" dependencies = [ "dtoa", "indexmap", - "serde 1.0.129", + "serde 1.0.130", "yaml-rust", ] [[package]] name = "sha-1" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4115,9 +4059,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" +checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4208,7 +4152,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6142f7c25e94f6fd25a32c3348ec230df9109b463f59c8c7acc4bd34936babb7" dependencies = [ - "aes-gcm 0.9.3", + "aes-gcm 0.9.4", "blake2", "chacha20poly1305", "rand 0.8.4", @@ -4261,16 +4205,6 @@ dependencies = [ "futures 0.1.31", ] -[[package]] -name = "stream-cipher" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c80e15f898d8d8f25db24c253ea615cc14acf418ff307822995814e7d42cfa89" -dependencies = [ - "block-cipher", - "generic-array", -] - [[package]] name = "strsim" version = "0.8.0" @@ -4431,7 +4365,6 @@ dependencies = [ "tari_common_types", "tari_comms", "tari_core", - "tari_crypto", "tari_wallet", "tonic", "tonic-build", @@ -4458,7 +4391,7 @@ dependencies = [ "tari_p2p", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4490,9 +4423,8 @@ dependencies = [ "tari_p2p", "tari_service_framework", "tari_shutdown", - "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4512,7 +4444,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "sha3", "subtle-ng", @@ -4530,12 +4462,12 @@ dependencies = [ "git2", "log 0.4.14", "log4rs 1.0.0", + "multiaddr", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", "path-clean", "prost-build", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "structopt", @@ -4553,10 +4485,11 @@ name = "tari_common_types" version = "0.9.5" dependencies = [ "futures 0.3.16", + "lazy_static 1.4.0", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "tari_crypto", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4567,7 +4500,7 @@ dependencies = [ "async-trait", "bitflags 1.3.2", "blake2", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "cidr", "clear_on_drop", @@ -4578,15 +4511,15 @@ dependencies = [ "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", + "multiaddr", "nom 5.1.2", "openssl", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "snow", @@ -4598,10 +4531,10 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-util 0.3.1", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.3.1", "tower-make", "tracing", "tracing-futures", @@ -4614,7 +4547,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bitflags 1.3.2", - "bytes 0.4.12", + "bytes 0.5.6", "chacha20", "chrono", "clap", @@ -4623,7 +4556,7 @@ dependencies = [ "digest", "env_logger 0.7.1", "futures 0.3.16", - "futures-test-preview", + "futures-test", "futures-util", "lazy_static 1.4.0", "libsqlite3-sys", @@ -4634,7 +4567,7 @@ dependencies = [ "prost", "prost-types", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_repr", "tari_common", @@ -4647,10 +4580,10 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-test", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-test 0.4.2", + "tower 0.3.1", "tower-test", "ttl_cache", ] @@ -4666,8 +4599,7 @@ dependencies = [ "syn 1.0.75", "tari_comms", "tari_test_utils", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tower-service", ] @@ -4702,7 +4634,7 @@ dependencies = [ "tari_shutdown", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4719,7 +4651,7 @@ dependencies = [ "bincode", "bitflags 1.3.2", "blake2", - "bytes 0.4.12", + "bytes 0.5.6", "chrono", "config", "croaring", @@ -4728,17 +4660,18 @@ dependencies = [ "fs2", "futures 0.3.16", "hex", + "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", "monero", "newtype-ops", - "num", + "num 0.3.1", "num-format", "prost", "prost-types", "rand 0.8.4", "randomx-rs", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "strum_macros 0.17.1", @@ -4756,8 +4689,7 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tracing", "tracing-attributes", "tracing-futures", @@ -4781,7 +4713,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "sha3", @@ -4806,7 +4738,7 @@ version = "0.9.5" dependencies = [ "digest", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "sha2", @@ -4820,7 +4752,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bincode", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "config", "derive-error", @@ -4828,12 +4760,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.8.4", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -4843,8 +4775,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tracing", "tracing-futures", @@ -4868,7 +4799,7 @@ dependencies = [ "prost-types", "rand 0.8.4", "reqwest 0.11.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "tari_app_grpc", @@ -4878,7 +4809,7 @@ dependencies = [ "tari_crypto", "thiserror", "time", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4893,7 +4824,7 @@ dependencies = [ "digest", "log 0.4.14", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_crypto", "tari_infra_derive", @@ -4922,7 +4853,7 @@ dependencies = [ "rand 0.8.4", "reqwest 0.10.10", "semver 1.0.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "stream-cancel", "tari_common", @@ -4936,9 +4867,9 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tower 0.3.1", "tower-service", "trust-dns-client", ] @@ -4955,9 +4886,8 @@ dependencies = [ "tari_shutdown", "tari_test_utils", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", "tower-service", ] @@ -4966,7 +4896,7 @@ name = "tari_shutdown" version = "0.9.5" dependencies = [ "futures 0.3.16", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4974,14 +4904,14 @@ name = "tari_storage" version = "0.9.5" dependencies = [ "bincode", - "bytes 0.4.12", + "bytes 0.5.6", "env_logger 0.6.2", "lmdb-zero", "log 0.4.14", "rand 0.8.4", "rmp", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "tari_utilities", "thiserror", @@ -4993,7 +4923,7 @@ version = "0.0.1" dependencies = [ "hex", "libc", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_app_grpc", "tari_common", @@ -5017,12 +4947,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.7.3", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -5031,8 +4961,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tonic-build", "tracing", @@ -5051,7 +4980,7 @@ dependencies = [ "rand 0.8.4", "tari_shutdown", "tempfile", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5067,7 +4996,7 @@ dependencies = [ "clear_on_drop", "newtype-ops", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "thiserror", ] @@ -5094,7 +5023,7 @@ dependencies = [ "log4rs 1.0.0", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_common_types", "tari_comms", @@ -5110,9 +5039,8 @@ dependencies = [ "tempfile", "thiserror", "time", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", ] [[package]] @@ -5140,7 +5068,7 @@ dependencies = [ "tari_wallet", "tempfile", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5171,12 +5099,12 @@ name = "test_faucet" version = "0.9.5" dependencies = [ "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_core", "tari_crypto", "tari_utilities", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5190,18 +5118,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "283d5230e63df9608ac7d9691adc1dfb6e701225436eb64d0b9a7f0a5a04f6ec" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "fa3884228611f5cd3608e2d409bf7dce832e4eb3135e3f11addbd7e41bd68e71" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5276,7 +5204,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_json", ] @@ -5306,16 +5234,10 @@ dependencies = [ "futures-core", "iovec", "lazy_static 1.4.0", - "libc", "memchr", "mio 0.6.23", - "mio-uds", - "num_cpus", "pin-project-lite 0.1.12", - "signal-hook-registry", "slab", - "tokio-macros", - "winapi 0.3.9", ] [[package]] @@ -5325,20 +5247,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92036be488bb6594459f2e03b60e42df6f937fe6ca5c5ffdcb539c6b84dc40f5" dependencies = [ "autocfg 1.0.1", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", "mio 0.7.13", "num_cpus", + "once_cell", "pin-project-lite 0.2.7", + "signal-hook-registry", + "tokio-macros", "winapi 0.3.9", ] +[[package]] +name = "tokio-io-timeout" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90c49f106be240de154571dd31fbe48acb10ba6c6dd6f6517ad603abffa42de9" +dependencies = [ + "pin-project-lite 0.2.7", + "tokio 1.10.1", +] + [[package]] name = "tokio-macros" -version = "0.2.6" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5355,6 +5290,17 @@ dependencies = [ "tokio 1.10.1", ] +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio 1.10.1", + "webpki", +] + [[package]] name = "tokio-stream" version = "0.1.7" @@ -5363,7 +5309,8 @@ checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" dependencies = [ "futures-core", "pin-project-lite 0.2.7", - "tokio 1.9.0", + "tokio 1.10.1", + "tokio-util 0.6.7", ] [[package]] @@ -5377,6 +5324,19 @@ dependencies = [ "tokio 0.2.25", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes 1.1.0", + "futures-core", + "tokio 1.10.1", + "tokio-stream", +] + [[package]] name = "tokio-tls" version = "0.3.1" @@ -5407,8 +5367,9 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", + "futures-io", "futures-sink", "log 0.4.14", "pin-project-lite 0.2.7", @@ -5421,7 +5382,7 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -5430,34 +5391,35 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "tonic" -version = "0.2.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4afef9ce97ea39593992cf3fa00ff33b1ad5eb07665b31355df63a690e38c736" +checksum = "796c5e1cd49905e65dd8e700d4cb1dffcbfdb4fc9d017de08c1a537afd83627c" dependencies = [ "async-stream", "async-trait", - "base64 0.11.0", - "bytes 0.5.6", + "base64 0.13.0", + "bytes 1.1.0", "futures-core", "futures-util", + "h2 0.3.4", "http", - "http-body 0.3.1", - "hyper 0.13.10", + "http-body 0.4.3", + "hyper 0.14.12", + "hyper-timeout", "percent-encoding 2.1.0", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "prost-derive", - "tokio 0.2.25", - "tokio-util 0.3.1", - "tower", - "tower-balance", - "tower-load", - "tower-make", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.4.8", + "tower-layer", "tower-service", "tracing", "tracing-futures", @@ -5465,9 +5427,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d8d21cb568e802d77055ab7fcd43f0992206de5028de95c8d3a41118d32e8e" +checksum = "12b52d07035516c2b74337d2ac7746075e7dcae7643816c1b12c5ff8a7484c08" dependencies = [ "proc-macro2 1.0.28", "prost-build", @@ -5494,23 +5456,21 @@ dependencies = [ ] [[package]] -name = "tower-balance" -version = "0.3.0" +name = "tower" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a792277613b7052448851efcf98a2c433e6f1d01460832dc60bef676bc275d4c" +checksum = "f60422bc7fefa2f3ec70359b8ff1caff59d785877eb70595904605bcc412470f" dependencies = [ "futures-core", "futures-util", "indexmap", - "pin-project 0.4.28", - "rand 0.7.3", + "pin-project 1.0.8", + "rand 0.8.4", "slab", - "tokio 0.2.25", - "tower-discover", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", "tower-layer", - "tower-load", - "tower-make", - "tower-ready-cache", "tower-service", "tracing", ] @@ -5592,21 +5552,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" dependencies = [ - "tokio 0.2.25", - "tower-service", -] - -[[package]] -name = "tower-ready-cache" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eabb6620e5481267e2ec832c780b31cad0c15dcb14ed825df5076b26b591e1f" -dependencies = [ - "futures-core", - "futures-util", - "indexmap", - "log 0.4.14", - "tokio 0.2.25", "tower-service", ] @@ -5638,7 +5583,7 @@ dependencies = [ "futures-util", "pin-project 0.4.28", "tokio 0.2.25", - "tokio-test", + "tokio-test 0.2.1", "tower-layer", "tower-service", ] @@ -5740,7 +5685,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb65ea441fbb84f9f6748fd496cf7f63ec9af5bca94dd86456978d055e8eb28b" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "tracing-core", ] @@ -5755,7 +5700,7 @@ dependencies = [ "lazy_static 1.4.0", "matchers", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sharded-slab", "smallvec", @@ -5774,47 +5719,54 @@ checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" [[package]] name = "trust-dns-client" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e935ae5a26a2745fb5a6b95f0e206e1cfb7f00066892d2cf78a8fee87bc2e0c6" +checksum = "37532ce92c75c6174b1d51ed612e26c5fde66ef3f29aa10dbd84e7c5d9a0c27b" dependencies = [ "cfg-if 1.0.0", "chrono", "data-encoding", - "futures 0.3.16", + "futures-channel", + "futures-util", "lazy_static 1.4.0", "log 0.4.14", "radix_trie", - "rand 0.7.3", + "rand 0.8.4", "ring", "rustls", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "trust-dns-proto", "webpki", ] [[package]] name = "trust-dns-proto" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cad71a0c0d68ab9941d2fb6e82f8fb2e86d9945b94e1661dd0aaea2b88215a9" +checksum = "4cd23117e93ea0e776abfd8a07c9e389d7ecd3377827858f21bd795ebdfefa36" dependencies = [ "async-trait", - "backtrace", "cfg-if 1.0.0", "data-encoding", "enum-as-inner", - "futures 0.3.16", + "futures-channel", + "futures-io", + "futures-util", "idna 0.2.3", + "ipnet", "lazy_static 1.4.0", "log 0.4.14", - "rand 0.7.3", + "rand 0.8.4", "ring", + "rustls", "smallvec", "thiserror", - "tokio 0.2.25", + "tinyvec", + "tokio 1.10.1", + "tokio-rustls", "url 2.2.2", + "webpki", ] [[package]] @@ -5856,12 +5808,12 @@ dependencies = [ [[package]] name = "twofish" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a30db256d7388f6e08efa0a8e9e62ee34dd1af59706c76c9e8c97c2a500f12" +checksum = "0028f5982f23ecc9a1bc3008ead4c664f843ed5d78acd3d213b99ff50c441bc2" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -5988,15 +5940,9 @@ dependencies = [ [[package]] name = "unsigned-varint" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fdeedbf205afadfe39ae559b75c3240f24e257d0ca27e85f85cb82aa19ac35" - -[[package]] -name = "unsigned-varint" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35581ff83d4101e58b582e607120c7f5ffb17e632a980b1f38334d76b36908b2" +checksum = "5f8d425fafb8cd76bc3f22aace4af471d3156301d7508f2107e98fbeae10bc7f" [[package]] name = "untrusted" @@ -6097,7 +6043,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce9b1b516211d33767048e5d47fa2a381ed8b76fc48d2ce4aa39877f9f183e0" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "wasm-bindgen-macro", ] @@ -6187,6 +6133,17 @@ dependencies = [ "libc", ] +[[package]] +name = "which" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea187a8ef279bc014ec368c27a920da2024d2a711109bfbe3440585d5cf27ad9" +dependencies = [ + "either", + "lazy_static 1.4.0", + "libc", +] + [[package]] name = "winapi" version = "0.2.8" @@ -6278,7 +6235,7 @@ dependencies = [ "futures 0.3.16", "log 0.4.14", "nohash-hasher", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "rand 0.8.4", "static_assertions", ] diff --git a/applications/tari_app_grpc/Cargo.toml b/applications/tari_app_grpc/Cargo.toml index bec7e05720..fac860a5f1 100644 --- a/applications/tari_app_grpc/Cargo.toml +++ b/applications/tari_app_grpc/Cargo.toml @@ -10,14 +10,16 @@ edition = "2018" [dependencies] tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_core = { path = "../../base_layer/core"} -tari_wallet = { path = "../../base_layer/wallet"} -tari_crypto = "0.11.1" +tari_wallet = { path = "../../base_layer/wallet", optional = true} tari_comms = { path = "../../comms"} chrono = "0.4.6" -prost = "0.6" -prost-types = "0.6.1" -tonic = "0.2" +prost = "0.8" +prost-types = "0.8" +tonic = "0.5.2" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" + +[features] +wallet = ["tari_wallet"] diff --git a/applications/tari_app_grpc/src/conversions/block_header.rs b/applications/tari_app_grpc/src/conversions/block_header.rs index 3b660f21e1..c89bae5b78 100644 --- a/applications/tari_app_grpc/src/conversions/block_header.rs +++ b/applications/tari_app_grpc/src/conversions/block_header.rs @@ -25,8 +25,12 @@ use crate::{ tari_rpc as grpc, }; use std::convert::TryFrom; -use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork, transactions::types::BlindingFactor}; -use tari_crypto::tari_utilities::{ByteArray, Hashable}; +use tari_core::{ + blocks::BlockHeader, + crypto::tari_utilities::{ByteArray, Hashable}, + proof_of_work::ProofOfWork, + transactions::types::BlindingFactor, +}; impl From for grpc::BlockHeader { fn from(h: BlockHeader) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/com_signature.rs b/applications/tari_app_grpc/src/conversions/com_signature.rs index e10e48ffe8..1ffb26f08e 100644 --- a/applications/tari_app_grpc/src/conversions/com_signature.rs +++ b/applications/tari_app_grpc/src/conversions/com_signature.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; use tari_core::transactions::types::{ComSignature, Commitment, PrivateKey}; diff --git a/applications/tari_app_grpc/src/conversions/mod.rs b/applications/tari_app_grpc/src/conversions/mod.rs index e404d52d1a..f48a1b876d 100644 --- a/applications/tari_app_grpc/src/conversions/mod.rs +++ b/applications/tari_app_grpc/src/conversions/mod.rs @@ -59,7 +59,7 @@ pub use self::{ use crate::{tari_rpc as grpc, tari_rpc::BlockGroupRequest}; use prost_types::Timestamp; -use tari_crypto::tari_utilities::epoch_time::EpochTime; +use tari_core::crypto::tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `EpochTime` to a `prost::Timestamp` pub fn datetime_to_timestamp(datetime: EpochTime) -> Timestamp { diff --git a/applications/tari_app_grpc/src/conversions/new_block_template.rs b/applications/tari_app_grpc/src/conversions/new_block_template.rs index 7d87900cd1..f4e01677f1 100644 --- a/applications/tari_app_grpc/src/conversions/new_block_template.rs +++ b/applications/tari_app_grpc/src/conversions/new_block_template.rs @@ -24,10 +24,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; use tari_core::{ blocks::{NewBlockHeaderTemplate, NewBlockTemplate}, + crypto::tari_utilities::ByteArray, proof_of_work::ProofOfWork, transactions::types::BlindingFactor, }; -use tari_crypto::tari_utilities::ByteArray; impl From for grpc::NewBlockTemplate { fn from(block: NewBlockTemplate) -> Self { let header = grpc::NewBlockHeaderTemplate { diff --git a/applications/tari_app_grpc/src/conversions/peer.rs b/applications/tari_app_grpc/src/conversions/peer.rs index f04d3fd8dd..f3bd151d0d 100644 --- a/applications/tari_app_grpc/src/conversions/peer.rs +++ b/applications/tari_app_grpc/src/conversions/peer.rs @@ -22,7 +22,7 @@ use crate::{conversions::datetime_to_timestamp, tari_rpc as grpc}; use tari_comms::{connectivity::ConnectivityStatus, net_address::MutliaddrWithStats, peer_manager::Peer}; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; impl From for grpc::Peer { fn from(peer: Peer) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/signature.rs b/applications/tari_app_grpc/src/conversions/signature.rs index d9883a338e..8468491a25 100644 --- a/applications/tari_app_grpc/src/conversions/signature.rs +++ b/applications/tari_app_grpc/src/conversions/signature.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; use tari_core::transactions::types::{PrivateKey, PublicKey, Signature}; diff --git a/applications/tari_app_grpc/src/conversions/transaction.rs b/applications/tari_app_grpc/src/conversions/transaction.rs index 97dcaf8795..cb9e91f63a 100644 --- a/applications/tari_app_grpc/src/conversions/transaction.rs +++ b/applications/tari_app_grpc/src/conversions/transaction.rs @@ -22,9 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::transaction::Transaction; -use tari_crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}; -use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; +use tari_core::{ + crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}, + transactions::transaction::Transaction, +}; impl From for grpc::Transaction { fn from(source: Transaction) -> Self { @@ -53,38 +54,44 @@ impl TryFrom for Transaction { } } -impl From for grpc::TransactionStatus { - fn from(status: models::TransactionStatus) -> Self { - use models::TransactionStatus::*; - match status { - Completed => grpc::TransactionStatus::Completed, - Broadcast => grpc::TransactionStatus::Broadcast, - MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, - MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, - Imported => grpc::TransactionStatus::Imported, - Pending => grpc::TransactionStatus::Pending, - Coinbase => grpc::TransactionStatus::Coinbase, +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; + + impl From for grpc::TransactionStatus { + fn from(status: models::TransactionStatus) -> Self { + use models::TransactionStatus::*; + match status { + Completed => grpc::TransactionStatus::Completed, + Broadcast => grpc::TransactionStatus::Broadcast, + MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, + MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, + Imported => grpc::TransactionStatus::Imported, + Pending => grpc::TransactionStatus::Pending, + Coinbase => grpc::TransactionStatus::Coinbase, + } } } -} -impl From for grpc::TransactionDirection { - fn from(status: models::TransactionDirection) -> Self { - use models::TransactionDirection::*; - match status { - Unknown => grpc::TransactionDirection::Unknown, - Inbound => grpc::TransactionDirection::Inbound, - Outbound => grpc::TransactionDirection::Outbound, + impl From for grpc::TransactionDirection { + fn from(status: models::TransactionDirection) -> Self { + use models::TransactionDirection::*; + match status { + Unknown => grpc::TransactionDirection::Unknown, + Inbound => grpc::TransactionDirection::Inbound, + Outbound => grpc::TransactionDirection::Outbound, + } } } -} -impl grpc::TransactionInfo { - pub fn not_found(tx_id: TxId) -> Self { - Self { - tx_id, - status: grpc::TransactionStatus::NotFound as i32, - ..Default::default() + impl grpc::TransactionInfo { + pub fn not_found(tx_id: TxId) -> Self { + Self { + tx_id, + status: grpc::TransactionStatus::NotFound as i32, + ..Default::default() + } } } } diff --git a/applications/tari_app_grpc/src/conversions/transaction_input.rs b/applications/tari_app_grpc/src/conversions/transaction_input.rs index ed5793a2f2..692af44856 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_input.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_input.rs @@ -22,13 +22,15 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - transaction::TransactionInput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::{ByteArray, Hashable}, +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::{ + transaction::TransactionInput, + types::{Commitment, PublicKey}, + }, }; impl TryFrom for TransactionInput { diff --git a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs index e394a6bce5..ac7b4c540d 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs @@ -22,12 +22,14 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::{KernelFeatures, TransactionKernel}, - types::Commitment, +use tari_core::{ + crypto::tari_utilities::{ByteArray, Hashable}, + transactions::{ + tari_amount::MicroTari, + transaction::{KernelFeatures, TransactionKernel}, + types::Commitment, + }, }; -use tari_crypto::tari_utilities::{ByteArray, Hashable}; impl TryFrom for TransactionKernel { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/transaction_output.rs b/applications/tari_app_grpc/src/conversions/transaction_output.rs index b9556b2940..05b13f41e3 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_output.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_output.rs @@ -22,14 +22,16 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - bullet_rangeproofs::BulletRangeProof, - transaction::TransactionOutput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::TariScript, - tari_utilities::{ByteArray, Hashable}, +use tari_core::{ + crypto::{ + script::TariScript, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::{ + bullet_rangeproofs::BulletRangeProof, + transaction::TransactionOutput, + types::{Commitment, PublicKey}, + }, }; impl TryFrom for TransactionOutput { diff --git a/applications/tari_app_grpc/src/conversions/unblinded_output.rs b/applications/tari_app_grpc/src/conversions/unblinded_output.rs index 94ac4c178d..fbf1276fd3 100644 --- a/applications/tari_app_grpc/src/conversions/unblinded_output.rs +++ b/applications/tari_app_grpc/src/conversions/unblinded_output.rs @@ -22,14 +22,16 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::UnblindedOutput, - types::{PrivateKey, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::ByteArray, +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::ByteArray, + }, + transactions::{ + tari_amount::MicroTari, + transaction::UnblindedOutput, + types::{PrivateKey, PublicKey}, + }, }; impl From for grpc::UnblindedOutput { diff --git a/applications/tari_app_utilities/Cargo.toml b/applications/tari_app_utilities/Cargo.toml index 333af5f959..fb0f1ef840 100644 --- a/applications/tari_app_utilities/Cargo.toml +++ b/applications/tari_app_utilities/Cargo.toml @@ -9,21 +9,21 @@ tari_comms = { path = "../../comms"} tari_crypto = "0.11.1" tari_common = { path = "../../common" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", optional = true } config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} qrcode = { version = "0.12" } dirs-next = "1.0.2" serde_json = "1.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version="^1.10", features = ["signal"] } structopt = { version = "0.3.13", default_features = false } strum = "^0.19" strum_macros = "^0.19" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" [dependencies.tari_core] path = "../../base_layer/core" @@ -33,3 +33,7 @@ features = ["transactions"] [build-dependencies] tari_common = { path = "../../common", features = ["build", "static-application-info"] } + +[features] +# TODO: This crate is supposed to hold common logic. Move code from this feature into the crate that is more specific to the wallet +wallet = ["tari_wallet"] diff --git a/applications/tari_app_utilities/src/utilities.rs b/applications/tari_app_utilities/src/utilities.rs index d963f212a7..52655eae8d 100644 --- a/applications/tari_app_utilities/src/utilities.rs +++ b/applications/tari_app_utilities/src/utilities.rs @@ -35,13 +35,8 @@ use tari_comms::{ types::CommsPublicKey, utils::multiaddr::multiaddr_to_socketaddr, }; -use tari_core::tari_utilities::hex::Hex; +use tari_core::{tari_utilities::hex::Hex, transactions::emoji::EmojiId}; use tari_p2p::transport::{TorConfig, TransportType}; -use tari_wallet::{ - error::{WalletError, WalletStorageError}, - output_manager_service::error::OutputManagerError, - util::emoji::EmojiId, -}; use thiserror::Error; use tokio::{runtime, runtime::Runtime}; @@ -107,20 +102,6 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - -impl From for ExitCodes { - fn from(err: OutputManagerError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - impl From for ExitCodes { fn from(err: ConnectivityError) -> Self { error!(target: LOG_TARGET, "{}", err); @@ -135,13 +116,36 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletStorageError) -> Self { - use WalletStorageError::*; - match err { - NoPasswordError => ExitCodes::NoPassword, - IncorrectPassword => ExitCodes::IncorrectPassword, - e => ExitCodes::WalletError(e.to_string()), +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{ + error::{WalletError, WalletStorageError}, + output_manager_service::error::OutputManagerError, + }; + + impl From for ExitCodes { + fn from(err: WalletError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: OutputManagerError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: WalletStorageError) -> Self { + use WalletStorageError::*; + match err { + NoPasswordError => ExitCodes::NoPassword, + IncorrectPassword => ExitCodes::IncorrectPassword, + e => ExitCodes::WalletError(e.to_string()), + } } } } @@ -259,26 +263,22 @@ pub fn convert_socks_authentication(auth: SocksAuthentication) -> socks::Authent /// ## Returns /// A result containing the runtime on success, string indicating the error on failure pub fn setup_runtime(config: &GlobalConfig) -> Result { - info!( - target: LOG_TARGET, - "Configuring the node to run on up to {} core threads and {} mining threads.", - config.max_threads.unwrap_or(512), - config.num_mining_threads - ); + let mut builder = runtime::Builder::new_multi_thread(); - let mut builder = runtime::Builder::new(); - - if let Some(max_threads) = config.max_threads { - // Ensure that there are always enough threads for mining. - // e.g if the user sets max_threads = 2, mining_threads = 5 then 7 threads are available in total - builder.max_threads(max_threads + config.num_mining_threads); - } if let Some(core_threads) = config.core_threads { - builder.core_threads(core_threads); + info!( + target: LOG_TARGET, + "Configuring the node to run on up to {} core threads.", + config + .core_threads + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "".to_string()), + ); + builder.worker_threads(core_threads); } builder - .threaded_scheduler() .enable_all() .build() .map_err(|e| format!("There was an error while building the node runtime. {}", e.to_string())) diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index db832f8a9a..ef9b0f48ff 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -11,30 +11,29 @@ edition = "2018" tari_app_grpc = { path = "../tari_app_grpc" } tari_app_utilities = { path = "../tari_app_utilities" } tari_common = { path = "../../common" } -tari_comms = { path = "../../comms", features = ["rpc"]} -tari_comms_dht = { path = "../../comms/dht"} -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_comms = { path = "../../comms", features = ["rpc"] } +tari_comms_dht = { path = "../../comms/dht" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_mmr = { path = "../../base_layer/mmr" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_service_framework = { path = "../../base_layer/service_framework"} -tari_shutdown = { path = "../../infrastructure/shutdown"} -tari_wallet = { path = "../../base_layer/wallet" } +tari_service_framework = { path = "../../base_layer/service_framework" } +tari_shutdown = { path = "../../infrastructure/shutdown" } anyhow = "1.0.32" bincode = "1.3.1" chrono = "0.4" config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"] } log = { version = "0.4.8", features = ["std"] } regex = "1" rustyline = "6.0" rustyline-derive = "0.3" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version = "^1.10", features = ["signal"] } strum = "^0.19" strum_macros = "0.18.0" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" tracing-subscriber = "0.2.20" @@ -44,7 +43,7 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} [features] -avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_wallet/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] +avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] safe = [] diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index cd10870e6d..a2a58dccb2 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -59,7 +59,6 @@ use tari_p2p::{ }; use tari_service_framework::{ServiceHandles, StackBuilder}; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "c::bn::initialization"; /// The minimum buffer size for the base node pubsub_connector channel @@ -84,8 +83,7 @@ where B: BlockchainBackend + 'static fs::create_dir_all(&config.peer_db_path)?; let buf_size = cmp::max(BASE_NODE_BUFFER_MIN_SIZE, config.buffer_size_base_node); - let (publisher, peer_message_subscriptions) = - pubsub_connector(runtime::Handle::current(), buf_size, config.buffer_rate_limit_base_node); + let (publisher, peer_message_subscriptions) = pubsub_connector(buf_size, config.buffer_rate_limit_base_node); let peer_message_subscriptions = Arc::new(peer_message_subscriptions); let node_config = BaseNodeServiceConfig { diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index dc36f64f59..e8519e2e74 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -71,10 +71,8 @@ impl BaseNodeContext { pub async fn run(self) { info!(target: LOG_TARGET, "Tari base node has STARTED"); - if let Err(e) = self.state_machine().shutdown_signal().await { - warn!(target: LOG_TARGET, "Error shutting down Base Node State Machine: {}", e); - } - info!(target: LOG_TARGET, "Initiating communications stack shutdown"); + self.state_machine().shutdown_signal().wait().await; + info!(target: LOG_TARGET, "Waiting for communications stack shutdown"); self.base_node_comms.wait_until_shutdown().await; info!(target: LOG_TARGET, "Communications stack has shutdown"); diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 114cee7ad2..255adb9fec 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -53,11 +53,13 @@ use tari_core::{ mempool::service::LocalMempoolService, proof_of_work::PowAlgorithm, tari_utilities::{hex::Hex, message_format::MessageFormat}, - transactions::types::{Commitment, HashOutput, Signature}, + transactions::{ + emoji::EmojiId, + types::{Commitment, HashOutput, Signature}, + }, }; use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::Hashable}; use tari_p2p::auto_update::SoftwareUpdaterHandle; -use tari_wallet::util::emoji::EmojiId; use tokio::{runtime, sync::watch}; pub enum StatusOutput { @@ -101,7 +103,7 @@ impl CommandHandler { } pub fn status(&self, output: StatusOutput) { - let mut state_info = self.state_machine_info.clone(); + let state_info = self.state_machine_info.clone(); let mut node = self.node_service.clone(); let mut mempool = self.mempool_service.clone(); let peer_manager = self.peer_manager.clone(); @@ -115,8 +117,7 @@ impl CommandHandler { let version = format!("v{}", consts::APP_VERSION_NUMBER); status_line.add_field("", version); - let state = state_info.recv().await.unwrap(); - status_line.add_field("State", state.state_info.short_desc()); + status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); @@ -189,18 +190,8 @@ impl CommandHandler { /// Function to process the get-state-info command pub fn state_info(&self) { - let mut channel = self.state_machine_info.clone(); - self.executor.spawn(async move { - match channel.recv().await { - None => { - info!( - target: LOG_TARGET, - "Error communicating with state machine, channel could have been closed" - ); - }, - Some(data) => println!("Current state machine state:\n{}", data), - }; - }); + let watch = self.state_machine_info.clone(); + println!("Current state machine state:\n{}", *watch.borrow()); } /// Check for updates diff --git a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs index 00474a5cb4..f6ee5645c2 100644 --- a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs +++ b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs @@ -26,6 +26,7 @@ use crate::{ helpers::{mean, median}, }, }; +use futures::{channel::mpsc, SinkExt}; use log::*; use std::{ cmp, @@ -40,7 +41,6 @@ use tari_comms::{Bytes, CommsNode}; use tari_core::{ base_node::{ comms_interface::{Broadcast, CommsInterfaceError}, - state_machine_service::states::BlockSyncInfo, LocalNodeCommsInterface, StateMachineHandle, }, @@ -54,7 +54,7 @@ use tari_core::{ }; use tari_crypto::tari_utilities::{message_format::MessageFormat, Hashable}; use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "tari::base_node::grpc"; @@ -995,32 +995,25 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer { ) -> Result, Status> { debug!(target: LOG_TARGET, "Incoming GRPC request for BN sync data"); - let mut channel = self.state_machine_handle.get_status_info_watch(); - - let mut sync_info: Option = None; - - if let Some(info) = channel.recv().await { - sync_info = info.state_info.get_block_sync_info(); - } - - let mut response = tari_rpc::SyncInfoResponse { - tip_height: 0, - local_height: 0, - peer_node_id: vec![], - }; - - if let Some(info) = sync_info { - let node_ids = info - .sync_peers - .iter() - .map(|x| x.to_string().as_bytes().to_vec()) - .collect(); - response = tari_rpc::SyncInfoResponse { - tip_height: info.tip_height, - local_height: info.local_height, - peer_node_id: node_ids, - }; - } + let response = self + .state_machine_handle + .get_status_info_watch() + .borrow() + .state_info + .get_block_sync_info() + .map(|info| { + let node_ids = info + .sync_peers + .iter() + .map(|x| x.to_string().as_bytes().to_vec()) + .collect(); + tari_rpc::SyncInfoResponse { + tip_height: info.tip_height, + local_height: info.local_height, + peer_node_id: node_ids, + } + }) + .unwrap_or_default(); debug!(target: LOG_TARGET, "Sending SyncData response to client"); Ok(Response::new(response)) diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index 8ef2c9b68a..cb9daf1904 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -96,7 +96,7 @@ mod status_line; mod utils; use crate::command_handler::{CommandHandler, StatusOutput}; -use futures::{future::Fuse, pin_mut, FutureExt}; +use futures::{pin_mut, FutureExt}; use log::*; use opentelemetry::{self, global, KeyValue}; use parser::Parser; @@ -119,7 +119,7 @@ use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ runtime, task, - time::{self, Delay}, + time::{self}, }; use tonic::transport::Server; use tracing_subscriber::{layer::SubscriberExt, Registry}; @@ -145,7 +145,7 @@ fn main_inner() -> Result<(), ExitCodes> { debug!(target: LOG_TARGET, "Using configuration: {:?}", node_config); // Set up the Tokio runtime - let mut rt = setup_runtime(&node_config).map_err(|e| { + let rt = setup_runtime(&node_config).map_err(|e| { error!(target: LOG_TARGET, "{}", e); ExitCodes::UnknownError })?; @@ -320,26 +320,28 @@ async fn read_command(mut rustyline: Editor) -> Result<(String, Editor

Fuse { +fn status_interval(start_time: Instant) -> time::Sleep { let duration = match start_time.elapsed().as_secs() { 0..=120 => Duration::from_secs(5), _ => Duration::from_secs(30), }; - time::delay_for(duration).fuse() + time::sleep(duration) } async fn status_loop(command_handler: Arc, shutdown: Shutdown) { let start_time = Instant::now(); let mut shutdown_signal = shutdown.to_signal(); loop { - let mut interval = status_interval(start_time); - futures::select! { + let interval = status_interval(start_time); + tokio::select! { + biased; + _ = shutdown_signal.wait() => { + break; + } + _ = interval => { command_handler.status(StatusOutput::Log); }, - _ = shutdown_signal => { - break; - } } } } @@ -368,9 +370,9 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { let start_time = Instant::now(); let mut software_update_notif = command_handler.get_software_updater().new_update_notifier().clone(); loop { - let mut interval = status_interval(start_time); - futures::select! { - res = read_command_fut => { + let interval = status_interval(start_time); + tokio::select! { + res = &mut read_command_fut => { match res { Ok((line, mut rustyline)) => { if let Some(p) = rustyline.helper_mut().as_deref_mut() { @@ -387,8 +389,8 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { } } }, - resp = software_update_notif.recv().fuse() => { - if let Some(Some(update)) = resp { + Ok(_) = software_update_notif.changed() => { + if let Some(ref update) = *software_update_notif.borrow() { println!( "Version {} of the {} is available: {} (sha: {})", update.version(), @@ -401,7 +403,7 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { _ = interval => { command_handler.status(StatusOutput::Full); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { break; } } diff --git a/applications/tari_console_wallet/Cargo.toml b/applications/tari_console_wallet/Cargo.toml index 5163910d37..f1bfea2fdf 100644 --- a/applications/tari_console_wallet/Cargo.toml +++ b/applications/tari_console_wallet/Cargo.toml @@ -8,18 +8,18 @@ edition = "2018" tari_wallet = { path = "../../base_layer/wallet" } tari_crypto = "0.11.1" tari_common = { path = "../../common" } -tari_app_utilities = { path = "../tari_app_utilities"} +tari_app_utilities = { path = "../tari_app_utilities", features = ["wallet"]} tari_comms = { path = "../../comms"} tari_comms_dht = { path = "../../comms/dht"} tari_p2p = { path = "../../base_layer/p2p" } -tari_app_grpc = { path = "../tari_app_grpc" } +tari_app_grpc = { path = "../tari_app_grpc", features = ["wallet"] } tari_shutdown = { path = "../../infrastructure/shutdown" } tari_key_manager = { path = "../../base_layer/key_manager" } bitflags = "1.2.1" chrono = { version = "0.4.6", features = ["serde"]} chrono-english = "0.1" -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} crossterm = { version = "0.17"} rand = "0.8" unicode-width = "0.1" @@ -31,9 +31,9 @@ rpassword = "5.0" rustyline = "6.0" strum = "^0.19" strum_macros = "^0.19" -tokio = { version="0.2.10", features = ["signal"] } -thiserror = "1.0.20" -tonic = "0.2" +tokio = { version="^1.10", features = ["signal"] } +thiserror = "1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" diff --git a/applications/tari_console_wallet/src/automation/commands.rs b/applications/tari_console_wallet/src/automation/commands.rs index 608cc8a675..8266e1a4db 100644 --- a/applications/tari_console_wallet/src/automation/commands.rs +++ b/applications/tari_console_wallet/src/automation/commands.rs @@ -26,7 +26,7 @@ use crate::{ utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, }; use chrono::{DateTime, Utc}; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{ fs::File, @@ -45,6 +45,7 @@ use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; use tari_core::{ tari_utilities::hex::Hex, transactions::{ + emoji::EmojiId, tari_amount::{uT, MicroTari, Tari}, transaction::UnblindedOutput, types::PublicKey, @@ -54,12 +55,11 @@ use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; use tari_wallet::{ output_manager_service::{handle::OutputManagerHandle, TxId}, transaction_service::handle::{TransactionEvent, TransactionServiceHandle}, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ - sync::mpsc, - time::{delay_for, timeout}, + sync::{broadcast, mpsc}, + time::{sleep, timeout}, }; pub const LOG_TARGET: &str = "wallet::automation::commands"; @@ -175,21 +175,22 @@ pub async fn coin_split( Ok(tx_id) } -async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result { - let mut connectivity = connectivity_requester.get_event_subscription().fuse(); +async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result<(), CommandError> { + let mut connectivity = connectivity_requester.get_event_subscription(); print!("Waiting for connectivity... "); - let mut timeout = delay_for(Duration::from_secs(30)).fuse(); + let timeout = sleep(Duration::from_secs(30)); + tokio::pin!(timeout); + let mut timeout = timeout.fuse(); loop { - futures::select! { - result = connectivity.select_next_some() => { - if let Ok(msg) = result { - if let ConnectivityEvent::PeerConnected(_) = (*msg).clone() { - println!("✅"); - return Ok(true); - } + tokio::select! { + // Wait for the first base node connection + Ok(ConnectivityEvent::PeerConnected(conn)) = connectivity.recv() => { + if conn.peer_features().is_node() { + println!("✅"); + return Ok(()); } }, - () = timeout => { + () = &mut timeout => { println!("❌"); return Err(CommandError::Comms("Timed out".to_string())); } @@ -311,7 +312,7 @@ pub async fn make_it_rain( target: LOG_TARGET, "make-it-rain delaying for {:?} ms - scheduled to start at {}", delay_ms, start_time ); - delay_for(Duration::from_millis(delay_ms)).await; + sleep(Duration::from_millis(delay_ms)).await; let num_txs = (txps * duration as f64) as usize; let started_at = Utc::now(); @@ -352,10 +353,10 @@ pub async fn make_it_rain( let target_ms = (i as f64 / (txps / 1000.0)) as i64; if target_ms - actual_ms > 0 { // Maximum delay between Txs set to 120 s - delay_for(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; + sleep(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; } let delayed_for = Instant::now(); - let mut sender_clone = sender.clone(); + let sender_clone = sender.clone(); tokio::task::spawn(async move { let spawn_start = Instant::now(); // Send transaction @@ -432,7 +433,7 @@ pub async fn monitor_transactions( tx_ids: Vec, wait_stage: TransactionStage, ) -> Vec { - let mut event_stream = transaction_service.get_event_stream_fused(); + let mut event_stream = transaction_service.get_event_stream(); let mut results = Vec::new(); debug!(target: LOG_TARGET, "monitor transactions wait_stage: {:?}", wait_stage); println!( @@ -442,104 +443,102 @@ pub async fn monitor_transactions( ); loop { - match event_stream.next().await { - Some(event_result) => match event_result { - Ok(event) => match &*event { - TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx direct send event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } - } - }, - TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx store and forward event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } + match event_stream.recv().await { + Ok(event) => match &*event { + TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx direct send event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); - if wait_stage == TransactionStage::Negotiated { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Negotiated, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx store and forward event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); - if wait_stage == TransactionStage::Broadcast { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Broadcast, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); + if wait_stage == TransactionStage::Negotiated { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Negotiated, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations - ); - if wait_stage == TransactionStage::MinedUnconfirmed { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::MinedUnconfirmed, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); + if wait_stage == TransactionStage::Broadcast { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Broadcast, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); - if wait_stage == TransactionStage::Mined { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Mined, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations + ); + if wait_stage == TransactionStage::MinedUnconfirmed { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::MinedUnconfirmed, + }); + if results.len() == tx_ids.len() { + break; } - }, - _ => {}, + } }, - Err(e) => { - eprintln!("RecvError in monitor_transactions: {:?}", e); - break; + TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); + if wait_stage == TransactionStage::Mined { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Mined, + }); + if results.len() == tx_ids.len() { + break; + } + } }, + _ => {}, }, - None => { - warn!( + // All event senders have gone (i.e. we take it that the node is shutting down) + Err(broadcast::error::RecvError::Closed) => { + debug!( target: LOG_TARGET, - "`None` result in event in monitor_transactions loop" + "All Transaction event senders have gone. Exiting `monitor_transactions` loop." ); break; }, + Err(err) => { + warn!(target: LOG_TARGET, "monitor_transactions: {}", err); + }, } } @@ -578,7 +577,8 @@ pub async fn command_runner( }, DiscoverPeer => { if !online { - online = wait_for_comms(&connectivity_requester).await?; + wait_for_comms(&connectivity_requester).await?; + online = true; } discover_peer(dht_service.clone(), parsed.args).await? }, diff --git a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs index b7e53ae6a2..1e98102851 100644 --- a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -1,4 +1,4 @@ -use futures::future; +use futures::{channel::mpsc, future, SinkExt}; use log::*; use std::convert::TryFrom; use tari_app_grpc::{ @@ -41,7 +41,7 @@ use tari_wallet::{ transaction_service::{handle::TransactionServiceHandle, storage::models}, WalletSqlite, }; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "wallet::ui::grpc"; diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index b5c0c9d805..69c6eadfd8 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -128,9 +128,15 @@ pub async fn change_password( return Err(ExitCodes::InputError("Passwords don't match!".to_string())); } - wallet.remove_encryption().await?; + wallet + .remove_encryption() + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; - wallet.apply_encryption(passphrase).await?; + wallet + .apply_encryption(passphrase) + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; println!("Wallet password changed successfully."); diff --git a/applications/tari_console_wallet/src/main.rs b/applications/tari_console_wallet/src/main.rs index 4042eb27eb..8dbccac4a0 100644 --- a/applications/tari_console_wallet/src/main.rs +++ b/applications/tari_console_wallet/src/main.rs @@ -58,8 +58,7 @@ fn main() { } fn main_inner() -> Result<(), ExitCodes> { - let mut runtime = tokio::runtime::Builder::new() - .threaded_scheduler() + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Failed to build a runtime!"); @@ -156,11 +155,8 @@ fn main_inner() -> Result<(), ExitCodes> { }; print!("\nShutting down wallet... "); - if shutdown.trigger().is_ok() { - runtime.block_on(wallet.wait_until_shutdown()); - } else { - error!(target: LOG_TARGET, "No listeners for the shutdown signal!"); - } + shutdown.trigger(); + runtime.block_on(wallet.wait_until_shutdown()); println!("Done."); result diff --git a/applications/tari_console_wallet/src/recovery.rs b/applications/tari_console_wallet/src/recovery.rs index 887995f0a5..e9eca033bc 100644 --- a/applications/tari_console_wallet/src/recovery.rs +++ b/applications/tari_console_wallet/src/recovery.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use chrono::offset::Local; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use rustyline::Editor; use tari_app_utilities::utilities::ExitCodes; @@ -35,6 +35,7 @@ use tari_wallet::{ }; use crate::wallet_modes::PeerConfig; +use tokio::sync::broadcast; pub const LOG_TARGET: &str = "wallet::recovery"; @@ -97,13 +98,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi .with_retry_limit(10) .build_with_wallet(wallet, shutdown_signal); - let mut event_stream = recovery_task.get_event_receiver().fuse(); + let mut event_stream = recovery_task.get_event_receiver(); let recovery_join_handle = tokio::spawn(recovery_task.run()).fuse(); // Read recovery task events. The event stream will end once recovery has completed. - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { print!("Connecting to base node {}... ", peer); }, @@ -170,11 +171,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi info!(target: LOG_TARGET, "{}", stats); println!("{}", stats); }, - Err(e) => { - // Can occur if we read events too slowly (lagging/slow subscriber) + Err(e @ broadcast::error::RecvError::Lagged(_)) => { debug!(target: LOG_TARGET, "Error receiving Wallet recovery events: {}", e); continue; }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Ok(UtxoScannerEvent::ScanningFailed) => { error!(target: LOG_TARGET, "Wallet Recovery process failed and is exiting"); }, diff --git a/applications/tari_console_wallet/src/ui/components/base_node.rs b/applications/tari_console_wallet/src/ui/components/base_node.rs index c51421d89f..ade233b90c 100644 --- a/applications/tari_console_wallet/src/ui/components/base_node.rs +++ b/applications/tari_console_wallet/src/ui/components/base_node.rs @@ -57,7 +57,7 @@ impl Component for BaseNode { ]), OnlineStatus::Online => { let base_node_state = app_state.get_base_node_state(); - if let Some(metadata) = base_node_state.clone().chain_metadata { + if let Some(ref metadata) = base_node_state.chain_metadata { let tip = metadata.height_of_longest_chain(); let synced = base_node_state.is_synced.unwrap_or_default(); @@ -93,7 +93,7 @@ impl Component for BaseNode { Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), - Span::styled("Waiting for data...", Style::default().fg(Color::White)), + Span::styled("Waiting for data...", Style::default().fg(Color::DarkGray)), ]) } }, diff --git a/applications/tari_console_wallet/src/ui/state/app_state.rs b/applications/tari_console_wallet/src/ui/state/app_state.rs index 4736151b9d..be85727e07 100644 --- a/applications/tari_console_wallet/src/ui/state/app_state.rs +++ b/applications/tari_console_wallet/src/ui/state/app_state.rs @@ -34,7 +34,6 @@ use crate::{ wallet_modes::PeerConfig, }; use bitflags::bitflags; -use futures::{stream::Fuse, StreamExt}; use log::*; use qrcode::{render::unicode, QrCode}; use std::{ @@ -51,6 +50,7 @@ use tari_comms::{ NodeIdentity, }; use tari_core::transactions::{ + emoji::EmojiId, tari_amount::{uT, MicroTari}, types::PublicKey, }; @@ -66,7 +66,6 @@ use tari_wallet::{ storage::models::{CompletedTransaction, TransactionStatus}, }, types::ValidationRetryStrategy, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ @@ -648,24 +647,24 @@ impl AppStateInner { self.wallet.comms.shutdown_signal() } - pub fn get_transaction_service_event_stream(&self) -> Fuse { - self.wallet.transaction_service.get_event_stream_fused() + pub fn get_transaction_service_event_stream(&self) -> TransactionEventReceiver { + self.wallet.transaction_service.get_event_stream() } - pub fn get_output_manager_service_event_stream(&self) -> Fuse { - self.wallet.output_manager_service.get_event_stream_fused() + pub fn get_output_manager_service_event_stream(&self) -> OutputManagerEventReceiver { + self.wallet.output_manager_service.get_event_stream() } - pub fn get_connectivity_event_stream(&self) -> Fuse { - self.wallet.comms.connectivity().get_event_subscription().fuse() + pub fn get_connectivity_event_stream(&self) -> ConnectivityEventRx { + self.wallet.comms.connectivity().get_event_subscription() } pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { self.wallet.wallet_connectivity.clone() } - pub fn get_base_node_event_stream(&self) -> Fuse { - self.wallet.base_node_service.clone().get_event_stream_fused() + pub fn get_base_node_event_stream(&self) -> BaseNodeEventReceiver { + self.wallet.base_node_service.get_event_stream() } pub async fn set_base_node_peer(&mut self, peer: Peer) -> Result<(), UiError> { diff --git a/applications/tari_console_wallet/src/ui/state/tasks.rs b/applications/tari_console_wallet/src/ui/state/tasks.rs index caf8073f56..85243660e6 100644 --- a/applications/tari_console_wallet/src/ui/state/tasks.rs +++ b/applications/tari_console_wallet/src/ui/state/tasks.rs @@ -21,11 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::ui::{state::UiTransactionSendStatus, UiError}; -use futures::StreamExt; use tari_comms::types::CommsPublicKey; use tari_core::transactions::tari_amount::MicroTari; use tari_wallet::transaction_service::handle::{TransactionEvent, TransactionServiceHandle}; -use tokio::sync::watch; +use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "wallet::console_wallet::tasks "; @@ -37,8 +36,8 @@ pub async fn send_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); let mut send_direct_received_result = (false, false); let mut send_saf_received_result = (false, false); match transaction_service_handle @@ -46,15 +45,15 @@ pub async fn send_transaction_task( .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => match &*event { TransactionEvent::TransactionDiscoveryInProgress(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::DiscoveryInProgress); + let _ = result_tx.send(UiTransactionSendStatus::DiscoveryInProgress); } }, TransactionEvent::TransactionDirectSendResult(tx_id, result) => { @@ -75,25 +74,28 @@ pub async fn send_transaction_task( }, TransactionEvent::TransactionCompletedImmediately(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } }, _ => (), }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } if send_direct_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentDirect); + let _ = result_tx.send(UiTransactionSendStatus::SentDirect); } else if send_saf_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentViaSaf); + let _ = result_tx.send(UiTransactionSendStatus::SentViaSaf); } else { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "Transaction could not be sent".to_string(), )); } @@ -109,34 +111,37 @@ pub async fn send_one_sided_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); match transaction_service_handle .send_one_sided_transaction(public_key, amount, fee_per_gram, message) .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => { if let TransactionEvent::TransactionCompletedImmediately(tx_id) = &*event { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } } }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "One-sided transaction could not be sent".to_string(), )); }, diff --git a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs index 2e20999667..e7df30b653 100644 --- a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs +++ b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{notifier::Notifier, ui::state::AppStateInner}; -use futures::stream::StreamExt; use log::*; use std::sync::Arc; use tari_comms::{connectivity::ConnectivityEvent, peer_manager::Peer}; @@ -30,7 +29,7 @@ use tari_wallet::{ output_manager_service::{handle::OutputManagerEvent, TxId}, transaction_service::handle::TransactionEvent, }; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, RwLock}; const LOG_TARGET: &str = "wallet::console_wallet::wallet_event_monitor"; @@ -55,14 +54,14 @@ impl WalletEventMonitor { let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream(); let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity(); - let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse(); + let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch(); let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream(); info!(target: LOG_TARGET, "Wallet Event Monitor starting"); loop { - futures::select! { - result = transaction_service_events.select_next_some() => { + tokio::select! { + result = transaction_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet transaction service event {:?}", msg); @@ -104,18 +103,21 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Transaction events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - status = connectivity_status.select_next_some() => { - trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status); + Ok(_) = connectivity_status.changed() => { + trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status changed"); self.trigger_peer_state_refresh().await; }, - result = connectivity_events.select_next_some() => { + result = connectivity_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg); - match &*msg { + match msg { ConnectivityEvent::PeerDisconnected(_) | ConnectivityEvent::ManagedPeerDisconnected(_) | ConnectivityEvent::PeerConnected(_) => { @@ -125,10 +127,13 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Connectivity event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Connectivity events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = base_node_events.select_next_some() => { + result = base_node_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received base node event {:?}", msg); @@ -141,10 +146,13 @@ impl WalletEventMonitor { } } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on base node event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Base node Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = output_manager_service_events.select_next_some() => { + result = output_manager_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -152,14 +160,13 @@ impl WalletEventMonitor { self.trigger_balance_refresh().await; } }, - Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Output Manager Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } - }, - complete => { - info!(target: LOG_TARGET, "Wallet Event Monitor is exiting because all tasks have completed"); - break; }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Wallet Event Monitor shutting down because the shutdown signal was received"); break; }, diff --git a/applications/tari_console_wallet/src/ui/ui_contact.rs b/applications/tari_console_wallet/src/ui/ui_contact.rs index 2d8be4a182..d87fc0de57 100644 --- a/applications/tari_console_wallet/src/ui/ui_contact.rs +++ b/applications/tari_console_wallet/src/ui/ui_contact.rs @@ -1,4 +1,5 @@ -use tari_wallet::{contacts_service::storage::database::Contact, util::emoji::EmojiId}; +use tari_core::transactions::emoji::EmojiId; +use tari_wallet::contacts_service::storage::database::Contact; #[derive(Debug, Clone)] pub struct UiContact { diff --git a/applications/tari_console_wallet/src/wallet_modes.rs b/applications/tari_console_wallet/src/wallet_modes.rs index a205523a00..23b1ee6330 100644 --- a/applications/tari_console_wallet/src/wallet_modes.rs +++ b/applications/tari_console_wallet/src/wallet_modes.rs @@ -239,7 +239,10 @@ pub fn tui_mode(config: WalletModeConfig, mut wallet: WalletSqlite) -> Result<() info!(target: LOG_TARGET, "Starting app"); - handle.enter(|| ui::run(app))?; + { + let _enter = handle.enter(); + ui::run(app)?; + } info!( target: LOG_TARGET, diff --git a/applications/tari_merge_mining_proxy/Cargo.toml b/applications/tari_merge_mining_proxy/Cargo.toml index 626d0502b6..2cd31b8b3d 100644 --- a/applications/tari_merge_mining_proxy/Cargo.toml +++ b/applications/tari_merge_mining_proxy/Cargo.toml @@ -13,33 +13,32 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} -tari_app_utilities = { path = "../tari_app_utilities"} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } +tari_app_utilities = { path = "../tari_app_utilities" } tari_crypto = "0.11.1" tari_utilities = "^0.3" anyhow = "1.0.40" bincode = "1.3.1" -bytes = "0.5.6" +bytes = "1.1" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11.4", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" diff --git a/applications/tari_merge_mining_proxy/src/main.rs b/applications/tari_merge_mining_proxy/src/main.rs index 7e15777977..0df27490bb 100644 --- a/applications/tari_merge_mining_proxy/src/main.rs +++ b/applications/tari_merge_mining_proxy/src/main.rs @@ -46,7 +46,7 @@ use tari_app_utilities::initialization::init_configuration; use tari_common::configuration::bootstrap::ApplicationType; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), anyhow::Error> { let (_, config, _) = init_configuration(ApplicationType::MergeMiningProxy)?; diff --git a/applications/tari_mining_node/Cargo.toml b/applications/tari_mining_node/Cargo.toml index a04938c048..f893c0bc06 100644 --- a/applications/tari_mining_node/Cargo.toml +++ b/applications/tari_mining_node/Cargo.toml @@ -17,15 +17,15 @@ crossbeam = "0.8" futures = "0.3" log = { version = "0.4", features = ["std"] } num_cpus = "1.13" -prost-types = "0.6" +prost-types = "0.8" rand = "0.8" sha3 = "0.9" serde = { version = "1.0", default_features = false, features = ["derive"] } -tonic = { version = "0.2", features = ["transport"] } -tokio = { version = "0.2", default_features = false, features = ["rt-core"] } +tonic = { version = "0.5.2", features = ["transport"] } +tokio = { version = "1.10", default_features = false, features = ["rt-multi-thread"] } thiserror = "1.0" jsonrpc = "0.11.0" -reqwest = { version = "0.11", features = ["blocking", "json"] } +reqwest = { version = "0.11", features = [ "json"] } serde_json = "1.0.57" native-tls = "0.2" bufstream = "0.1" @@ -35,5 +35,5 @@ hex = "0.4.2" [dev-dependencies] tari_crypto = "0.11.1" -prost-types = "0.6.1" +prost-types = "0.8" chrono = "0.4" diff --git a/applications/tari_mining_node/src/main.rs b/applications/tari_mining_node/src/main.rs index b1f538c1e4..0a0bf31557 100644 --- a/applications/tari_mining_node/src/main.rs +++ b/applications/tari_mining_node/src/main.rs @@ -23,13 +23,6 @@ use config::MinerConfig; use futures::stream::StreamExt; use log::*; -use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; -use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; -use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; -use tari_core::blocks::BlockHeader; -use tokio::{runtime::Runtime, time::delay_for}; -use tonic::transport::Channel; -use utils::{coinbase_request, extract_outputs_and_kernels}; mod config; mod difficulty; @@ -53,10 +46,17 @@ use std::{ thread, time::Instant, }; +use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; +use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; +use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; +use tari_core::blocks::BlockHeader; +use tokio::{runtime::Runtime, time::sleep}; +use tonic::transport::Channel; +use utils::{coinbase_request, extract_outputs_and_kernels}; /// Application entry point fn main() { - let mut rt = Runtime::new().expect("Failed to start tokio runtime"); + let rt = Runtime::new().expect("Failed to start tokio runtime"); match rt.block_on(main_inner()) { Ok(_) => std::process::exit(0), Err(exit_code) => { @@ -144,7 +144,7 @@ async fn main_inner() -> Result<(), ExitCodes> { error!("Connection error: {:?}", err); loop { debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; match connect(&config, &global).await { Ok((nc, wc)) => { node_conn = nc; @@ -168,7 +168,7 @@ async fn main_inner() -> Result<(), ExitCodes> { Err(err) => { error!("Error: {:?}", err); debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; }, Ok(submitted) => { if submitted { diff --git a/applications/tari_stratum_transcoder/Cargo.toml b/applications/tari_stratum_transcoder/Cargo.toml index 29f95c82da..d29b91ead7 100644 --- a/applications/tari_stratum_transcoder/Cargo.toml +++ b/applications/tari_stratum_transcoder/Cargo.toml @@ -13,37 +13,37 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_utilities = "^0.3" + bincode = "1.3.1" -bytes = "0.5.6" +bytes = "0.5" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.7.2" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "^1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" url = "2.1.1" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" [dev-dependencies] futures-test = "0.3.5" diff --git a/applications/tari_stratum_transcoder/src/main.rs b/applications/tari_stratum_transcoder/src/main.rs index d55d551b5b..f742c92d6d 100644 --- a/applications/tari_stratum_transcoder/src/main.rs +++ b/applications/tari_stratum_transcoder/src/main.rs @@ -41,7 +41,7 @@ use tari_app_grpc::tari_rpc as grpc; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, GlobalConfig}; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), StratumTranscoderProxyError> { let config = initialize()?; diff --git a/applications/test_faucet/Cargo.toml b/applications/test_faucet/Cargo.toml index 3ef3c8a4c1..5fbdbaf11b 100644 --- a/applications/test_faucet/Cargo.toml +++ b/applications/test_faucet/Cargo.toml @@ -7,11 +7,12 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tari_crypto = "0.11.1" tari_utilities = "^0.3" + +rand = "0.8" serde = { version = "1.0.97", features = ["derive"] } serde_json = "1.0" -rand = "0.8" -tari_crypto = "0.11.1" [dependencies.tari_core] version = "^0.9" @@ -20,6 +21,6 @@ default-features = false features = ["transactions", "avx2"] [dependencies.tokio] -version = "^0.2.10" +version = "^1.10" default-features = false -features = ["fs", "blocking", "stream", "rt-threaded", "macros", "io-util", "sync"] +features = ["fs", "rt-multi-thread", "macros", "io-util", "sync"] diff --git a/applications/test_faucet/src/main.rs b/applications/test_faucet/src/main.rs index 0aadee4d5e..5c5590d0a8 100644 --- a/applications/test_faucet/src/main.rs +++ b/applications/test_faucet/src/main.rs @@ -32,7 +32,7 @@ struct Key { /// UTXO generation is pretty slow (esp range proofs), so we'll use async threads to speed things up. /// We'll use blocking thread tasks to do the CPU intensive utxo generation, and then push the results /// through a channel where a file-writer is waiting to persist the results to disk. -#[tokio::main(core_threads = 2, max_threads = 10)] +#[tokio::main(worker_threads = 2)] async fn main() -> Result<(), Box> { let num_keys: usize = std::env::args() .skip(1) @@ -52,7 +52,7 @@ async fn main() -> Result<(), Box> { // Use Rust's awesome Iterator trait to produce a sequence of values and output features. for (value, feature) in values.take(num_keys).zip(features.take(num_keys)) { let fc = factories.clone(); - let mut txc = tx.clone(); + let txc = tx.clone(); // Notice the `spawn(.. spawn_blocking)` nested call here. If we don't do this, we're basically queuing up // blocking tasks, `await`ing them to finish, and then queueing up the next one. In effect we're running things // synchronously. diff --git a/base_layer/common_types/Cargo.toml b/base_layer/common_types/Cargo.toml index 2b85d4e9a9..5042a1ddfc 100644 --- a/base_layer/common_types/Cargo.toml +++ b/base_layer/common_types/Cargo.toml @@ -11,4 +11,5 @@ futures = {version = "^0.3.1", features = ["async-await"] } rand = "0.8" tari_crypto = "0.11.1" serde = { version = "1.0.106", features = ["derive"] } -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +tokio = { version="^1.10", features = [ "time", "sync"] } +lazy_static = "1.4.0" diff --git a/base_layer/common_types/src/waiting_requests.rs b/base_layer/common_types/src/waiting_requests.rs index a26119a5cb..67e6eed6ef 100644 --- a/base_layer/common_types/src/waiting_requests.rs +++ b/base_layer/common_types/src/waiting_requests.rs @@ -20,10 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::oneshot::Sender as OneshotSender; use rand::RngCore; use std::{collections::HashMap, sync::Arc, time::Instant}; -use tokio::sync::RwLock; +use tokio::sync::{oneshot::Sender as OneshotSender, RwLock}; pub type RequestKey = u64; diff --git a/base_layer/core/Cargo.toml b/base_layer/core/Cargo.toml index 1212b2a40c..25e8ceb197 100644 --- a/base_layer/core/Cargo.toml +++ b/base_layer/core/Cargo.toml @@ -35,27 +35,28 @@ bincode = "1.1.4" bitflags = "1.0.4" blake2 = "^0.9.0" sha3 = "0.9" -bytes = "0.4.12" +bytes = "0.5" chrono = { version = "0.4.6", features = ["serde"]} croaring = { version = "=0.4.5", optional = true } digest = "0.9.0" -futures = {version = "^0.3.1", features = ["async-await"] } +futures = {version = "^0.3.16", features = ["async-await"] } fs2 = "0.3.0" hex = "0.4.2" +lazy_static = "1.4.0" lmdb-zero = "0.4.4" log = "0.4" monero = { version = "^0.13.0", features= ["serde_support"], optional = true } newtype-ops = "0.1.4" num = "0.3" -prost = "0.6.1" -prost-types = "0.6.1" +prost = "0.8.0" +prost-types = "0.8.0" rand = "0.8" randomx-rs = { version = "0.5.0", optional = true } serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0" strum_macros = "0.17.1" -thiserror = "1.0.20" -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +thiserror = "1.0.26" +tokio = { version="^1.10", features = [ "time", "sync", "macros"] } ttl_cache = "0.5.1" uint = { version = "0.9", default-features = false } num-format = "0.4.0" @@ -70,7 +71,6 @@ tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } config = { version = "0.9.3" } env_logger = "0.7.0" tempfile = "3.1.0" -tokio-macros = "0.2.4" [build-dependencies] tari_common = { version = "^0.9", path="../../common", features = ["build"]} diff --git a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs index 1310f22702..2700dc9d01 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs @@ -20,10 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{service::ChainMetadataService, LOG_TARGET}; +use super::service::ChainMetadataService; use crate::base_node::{chain_metadata_service::handle::ChainMetadataHandle, comms_interface::LocalNodeCommsInterface}; -use futures::{future, pin_mut}; -use log::*; use tari_comms::connectivity::ConnectivityRequester; use tari_p2p::services::liveness::LivenessHandle; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; @@ -40,15 +38,12 @@ impl ServiceInitializer for ChainMetadataServiceInitializer { let handle = ChainMetadataHandle::new(publisher.clone()); context.register_handle(handle); - context.spawn_when_ready(|handles| async move { + context.spawn_until_shutdown(|handles| { let liveness = handles.expect_handle::(); let base_node = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - let service_run = ChainMetadataService::new(liveness, base_node, connectivity, publisher).run(); - pin_mut!(service_run); - future::select(service_run, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "ChainMetadataService has shut down"); + ChainMetadataService::new(liveness, base_node, connectivity, publisher).run() }); Ok(()) diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index 625fedb042..61110bc147 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -29,7 +29,6 @@ use crate::{ chain_storage::BlockAddResult, proto::base_node as proto, }; -use futures::stream::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use prost::Message; @@ -75,9 +74,9 @@ impl ChainMetadataService { /// Run the service pub async fn run(mut self) { - let mut liveness_event_stream = self.liveness.get_event_stream().fuse(); - let mut block_event_stream = self.base_node.get_block_event_stream().fuse(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut liveness_event_stream = self.liveness.get_event_stream(); + let mut block_event_stream = self.base_node.get_block_event_stream(); + let mut connectivity_events = self.connectivity.get_event_subscription(); log_if_error!( target: LOG_TARGET, @@ -86,47 +85,36 @@ impl ChainMetadataService { ); loop { - futures::select! { - block_event = block_event_stream.select_next_some() => { - if let Ok(block_event) = block_event { - log_if_error!( - level: debug, - target: LOG_TARGET, - "Failed to handle block event because '{}'", - self.handle_block_event(&block_event).await - ); - } + tokio::select! { + Ok(block_event) = block_event_stream.recv() => { + log_if_error!( + level: debug, + target: LOG_TARGET, + "Failed to handle block event because '{}'", + self.handle_block_event(&block_event).await + ); }, - liveness_event = liveness_event_stream.select_next_some() => { - if let Ok(event) = liveness_event { - log_if_error!( - target: LOG_TARGET, - "Failed to handle liveness event because '{}'", - self.handle_liveness_event(&*event).await - ); - } + Ok(event) = liveness_event_stream.recv() => { + log_if_error!( + target: LOG_TARGET, + "Failed to handle liveness event because '{}'", + self.handle_liveness_event(&*event).await + ); }, - event = connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event); - } - } - - complete => { - info!(target: LOG_TARGET, "ChainStateSyncService is exiting because all tasks have completed"); - break; + Ok(event) = connectivity_events.recv() => { + self.handle_connectivity_event(event); } } } } - fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { use ConnectivityEvent::*; match event { PeerDisconnected(node_id) | ManagedPeerDisconnected(node_id) | PeerBanned(node_id) => { - if let Some(pos) = self.peer_chain_metadata.iter().position(|p| &p.node_id == node_id) { + if let Some(pos) = self.peer_chain_metadata.iter().position(|p| p.node_id == node_id) { debug!( target: LOG_TARGET, "Removing disconnected/banned peer `{}` from chain metadata list ", node_id @@ -292,6 +280,7 @@ impl ChainMetadataService { mod test { use super::*; use crate::base_node::comms_interface::{CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}; + use futures::StreamExt; use std::convert::TryInto; use tari_comms::test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, @@ -355,7 +344,7 @@ mod test { ) } - #[tokio_macros::test] + #[tokio::test] async fn update_liveness_chain_metadata() { let (mut service, liveness_mock_state, _, mut base_node_receiver) = setup(); @@ -364,11 +353,11 @@ mod test { let chain_metadata = proto_chain_metadata.clone().try_into().unwrap(); task::spawn(async move { - let base_node_req = base_node_receiver.select_next_some().await; - let (_req, reply_tx) = base_node_req.split(); - reply_tx - .send(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) - .unwrap(); + if let Some(base_node_req) = base_node_receiver.next().await { + base_node_req + .reply(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) + .unwrap(); + } }); service.update_liveness_chain_metadata().await.unwrap(); @@ -381,7 +370,7 @@ mod test { let chain_metadata = proto::ChainMetadata::decode(data.as_slice()).unwrap(); assert_eq!(chain_metadata.height_of_longest_chain, Some(123)); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_ok() { let (mut service, _, _, _) = setup(); @@ -410,7 +399,7 @@ mod test { ); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_banned_peer() { let (mut service, _, _, _) = setup(); @@ -436,7 +425,7 @@ mod test { .peer_chain_metadata .iter() .any(|p| &p.node_id == nodes[0].node_id())); - service.handle_connectivity_event(&ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); + service.handle_connectivity_event(ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); // Check that banned peer was removed assert!(service .peer_chain_metadata @@ -444,7 +433,7 @@ mod test { .all(|p| &p.node_id != nodes[0].node_id())); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_no_metadata() { let (mut service, _, _, _) = setup(); @@ -462,7 +451,7 @@ mod test { assert_eq!(service.peer_chain_metadata.len(), 0); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_bad_metadata() { let (mut service, _, _, _) = setup(); diff --git a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs index 753fe802d8..f3f08409b2 100644 --- a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs @@ -26,11 +26,11 @@ use crate::{ chain_storage::HistoricalBlock, transactions::{transaction::TransactionOutput, types::HashOutput}, }; -use futures::channel::mpsc::UnboundedSender; use log::*; use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::bn::comms_interface::outbound_interface"; @@ -234,10 +234,8 @@ impl OutboundNodeCommsInterface { new_block: NewBlock, exclude_peers: Vec, ) -> Result<(), CommsInterfaceError> { - self.block_sender - .unbounded_send((new_block, exclude_peers)) - .map_err(|err| { - CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) - }) + self.block_sender.send((new_block, exclude_peers)).map_err(|err| { + CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) + }) } } diff --git a/base_layer/core/src/base_node/rpc/service.rs b/base_layer/core/src/base_node/rpc/service.rs index c50600ea9c..5adefae0e6 100644 --- a/base_layer/core/src/base_node/rpc/service.rs +++ b/base_layer/core/src/base_node/rpc/service.rs @@ -230,7 +230,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpc // Determine if we are synced let status_watch = state_machine.get_status_info_watch(); - let is_synced = match (*status_watch.borrow()).state_info { + let is_synced = match status_watch.borrow().state_info { StateInfo::Listening(li) => li.is_synced(), _ => false, }; diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index ae6be0b519..11132db138 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -33,7 +33,7 @@ use crate::{ proto as shared_protos, proto::base_node as proto, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -50,7 +50,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::service::initializer"; const SUBSCRIPTION_LABEL: &str = "Base Node"; @@ -151,7 +151,7 @@ where T: BlockchainBackend + 'static let inbound_block_stream = self.inbound_block_stream(); // Connect InboundNodeCommsInterface and OutboundNodeCommsInterface to BaseNodeService let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); - let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded(); + let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded_channel(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (local_block_sender_service, local_block_stream) = reply_channel::unbounded(); let outbound_nci = diff --git a/base_layer/core/src/base_node/service/service.rs b/base_layer/core/src/base_node/service/service.rs index db96b6ed9e..1d66cbf1b1 100644 --- a/base_layer/core/src/base_node/service/service.rs +++ b/base_layer/core/src/base_node/service/service.rs @@ -38,16 +38,7 @@ use crate::{ proto as shared_protos, proto::{base_node as proto, base_node::base_node_service_request::Request}, }; -use futures::{ - channel::{ - mpsc::{channel, Receiver, Sender, UnboundedReceiver}, - oneshot::Sender as OneshotSender, - }, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -64,7 +55,14 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::reply_channel::RequestContext; -use tokio::task; +use tokio::{ + sync::{ + mpsc, + mpsc::{Receiver, Sender, UnboundedReceiver}, + oneshot::Sender as OneshotSender, + }, + task, +}; const LOG_TARGET: &str = "c::bn::base_node_service::service"; @@ -134,7 +132,7 @@ where B: BlockchainBackend + 'static config: BaseNodeServiceConfig, state_machine_handle: StateMachineHandle, ) -> Self { - let (timeout_sender, timeout_receiver) = channel(100); + let (timeout_sender, timeout_receiver) = mpsc::channel(100); Self { outbound_message_service, inbound_nch, @@ -162,7 +160,7 @@ where B: BlockchainBackend + 'static { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let outbound_block_stream = streams.outbound_block_stream.fuse(); + let outbound_block_stream = streams.outbound_block_stream; pin_mut!(outbound_block_stream); let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); @@ -177,53 +175,52 @@ where B: BlockchainBackend + 'static let timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Base Node Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Base Node Service initialized without timeout_receiver_stream"); pin_mut!(timeout_receiver_stream); loop { - futures::select! { + tokio::select! { // Outbound request messages from the OutboundNodeCommsInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound block messages from the OutboundNodeCommsInterface - (block, excluded_peers) = outbound_block_stream.select_next_some() => { + Some((block, excluded_peers)) = outbound_block_stream.recv() => { self.spawn_handle_outbound_block(block, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, // Incoming block messages from the Comms layer - block_msg = inbound_block_stream.select_next_some() => { + Some(block_msg) = inbound_block_stream.next() => { self.spawn_handle_incoming_block(block_msg).await; } // Incoming local request messages from the LocalNodeCommsInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Incoming local block messages from the LocalNodeCommsInterface e.g. miner and block sync - local_block_context = local_block_stream.select_next_some() => { + Some(local_block_context) = local_block_stream.next() => { self.spawn_handle_local_block(local_block_context); }, - complete => { - info!(target: LOG_TARGET, "Base Node service shutting down"); + else => { + info!(target: LOG_TARGET, "Base Node service shutting down because all streams ended"); break; } } @@ -646,9 +643,9 @@ async fn handle_request_timeout( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/base_node/state_machine_service/state_machine.rs b/base_layer/core/src/base_node/state_machine_service/state_machine.rs index 3c8e33ee06..383bc82d03 100644 --- a/base_layer/core/src/base_node/state_machine_service/state_machine.rs +++ b/base_layer/core/src/base_node/state_machine_service/state_machine.rs @@ -158,7 +158,7 @@ impl BaseNodeStateMachine { state_info: self.info.clone(), }; - if let Err(e) = self.status_event_sender.broadcast(status) { + if let Err(e) = self.status_event_sender.send(status) { debug!(target: LOG_TARGET, "Error broadcasting a StatusEvent update: {}", e); } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs index 689c9a8316..5c7710c371 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs @@ -67,7 +67,7 @@ impl BlockSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSyncStarting, }); @@ -80,7 +80,7 @@ impl BlockSync { false.into(), )); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSync(BlockSyncInfo { tip_height: remote_tip_height, diff --git a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs index 4e864624bc..0ef1f6e99b 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs @@ -160,7 +160,7 @@ impl Display for BaseNodeState { #[derive(Debug, Clone, PartialEq)] pub enum StateInfo { StartUp, - HeaderSync(BlockSyncInfo), + HeaderSync(Option), HorizonSync(HorizonSyncInfo), BlockSyncStarting, BlockSync(BlockSyncInfo), @@ -169,15 +169,12 @@ pub enum StateInfo { impl StateInfo { pub fn short_desc(&self) -> String { + use StateInfo::*; match self { - Self::StartUp => "Starting up".to_string(), - Self::HeaderSync(info) => format!( - "Syncing headers: {}/{} ({:.0}%)", - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 - ), - Self::HorizonSync(info) => match info.status { + StartUp => "Starting up".to_string(), + HeaderSync(None) => "Starting header sync".to_string(), + HeaderSync(Some(info)) => format!("Syncing headers: {}", info.sync_progress_string()), + HorizonSync(info) => match info.status { HorizonSyncStatus::Starting => "Starting horizon sync".to_string(), HorizonSyncStatus::Kernels(current, total) => format!( "Syncing kernels: {}/{} ({:.0}%)", @@ -193,18 +190,16 @@ impl StateInfo { ), HorizonSyncStatus::Finalizing => "Finalizing horizon sync".to_string(), }, - Self::BlockSync(info) => format!( - "Syncing blocks with {}: {}/{} ({:.0}%) ", + BlockSync(info) => format!( + "Syncing blocks: ({}) {}", info.sync_peers .first() - .map(|s| s.short_str()) + .map(|n| n.short_str()) .unwrap_or_else(|| "".to_string()), - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 + info.sync_progress_string() ), - Self::Listening(_) => "Listening".to_string(), - Self::BlockSyncStarting => "Starting block sync".to_string(), + Listening(_) => "Listening".to_string(), + BlockSyncStarting => "Starting block sync".to_string(), } } @@ -226,13 +221,15 @@ impl StateInfo { impl Display for StateInfo { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + use StateInfo::*; match self { - Self::StartUp => write!(f, "Node starting up"), - Self::HeaderSync(info) => write!(f, "Synchronizing block headers: {}", info), - Self::HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), - Self::BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), - Self::Listening(info) => write!(f, "Listening: {}", info), - Self::BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), + StartUp => write!(f, "Node starting up"), + HeaderSync(Some(info)) => write!(f, "Synchronizing block headers: {}", info), + HeaderSync(None) => write!(f, "Synchronizing block headers: Starting"), + HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), + BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), + Listening(info) => write!(f, "Listening: {}", info), + BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), } } } @@ -282,15 +279,24 @@ impl BlockSyncInfo { sync_peers, } } + + pub fn sync_progress_string(&self) -> String { + format!( + "{}/{} ({:.0}%)", + self.local_height, + self.tip_height, + (self.local_height as f64 / self.tip_height as f64 * 100.0) + ) + } } impl Display for BlockSyncInfo { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - fmt.write_str("Syncing from the following peers: \n")?; + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Syncing from the following peers:")?; for peer in &self.sync_peers { - fmt.write_str(&format!("{}\n", peer))?; + writeln!(f, "{}", peer)?; } - fmt.write_str(&format!("Syncing {}/{}\n", self.local_height, self.tip_height)) + writeln!(f, "Syncing {}", self.sync_progress_string()) } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs index 2acfd53206..68f663d71f 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs @@ -74,14 +74,15 @@ impl HeaderSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - synchronizer.on_progress(move |current_height, remote_tip_height, sync_peers| { - let _ = status_event_sender.broadcast(StatusInfo { + synchronizer.on_progress(move |details, sync_peers| { + let details = details.map(|(current_height, remote_tip_height)| BlockSyncInfo { + tip_height: remote_tip_height, + local_height: current_height, + sync_peers: sync_peers.to_vec(), + }); + let _ = status_event_sender.send(StatusInfo { bootstrapped, - state_info: StateInfo::HeaderSync(BlockSyncInfo { - tip_height: remote_tip_height, - local_height: current_height, - sync_peers: sync_peers.to_vec(), - }), + state_info: StateInfo::HeaderSync(details), }); }); diff --git a/base_layer/core/src/base_node/state_machine_service/states/listening.rs b/base_layer/core/src/base_node/state_machine_service/states/listening.rs index 0ea8157568..92349c269a 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/listening.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/listening.rs @@ -31,7 +31,6 @@ use crate::{ }, chain_storage::BlockchainBackend, }; -use futures::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use serde::{Deserialize, Serialize}; @@ -118,7 +117,8 @@ impl Listening { info!(target: LOG_TARGET, "Listening for chain metadata updates"); shared.set_state_info(StateInfo::Listening(ListeningInfo::new(self.is_synced))); - while let Some(metadata_event) = shared.metadata_event_stream.next().await { + loop { + let metadata_event = shared.metadata_event_stream.recv().await; match metadata_event.as_ref().map(|v| v.deref()) { Ok(ChainMetadataEvent::PeerChainMetadataReceived(peer_metadata_list)) => { let mut peer_metadata_list = peer_metadata_list.clone(); @@ -199,16 +199,16 @@ impl Listening { if !self.is_synced { self.is_synced = true; + shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); debug!(target: LOG_TARGET, "Initial sync achieved"); } - shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { debug!(target: LOG_TARGET, "Metadata event subscriber lagged by {} item(s)", n); }, - Err(broadcast::RecvError::Closed) => { - // This should never happen because the while loop exits when the stream ends + Err(broadcast::error::RecvError::Closed) => { debug!(target: LOG_TARGET, "Metadata event subscriber closed"); + break; }, } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs index 7ea2e7e2b0..aeaa5ab430 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs @@ -23,7 +23,7 @@ use crate::base_node::state_machine_service::states::{BlockSync, HeaderSync, HorizonStateSync, StateEvent}; use log::info; use std::time::Duration; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "c::bn::state_machine_service::states::waiting"; @@ -41,7 +41,7 @@ impl Waiting { "The base node has started a WAITING state for {} seconds", self.timeout.as_secs() ); - delay_for(self.timeout).await; + sleep(self.timeout).await; info!( target: LOG_TARGET, "The base node waiting state has completed. Resuming normal operations" diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index b7483ff70e..cfb29410d3 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -83,7 +83,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } pub fn on_progress(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.hooks.add_on_progress_header_hook(hook); } @@ -94,6 +94,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { pub async fn synchronize(&mut self) -> Result { debug!(target: LOG_TARGET, "Starting header sync.",); + self.hooks.call_on_progress_header_hooks(None, self.sync_peers); let sync_peers = self.select_sync_peers().await?; info!( target: LOG_TARGET, @@ -261,7 +262,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { let latency = client.get_last_request_latency().await?; debug!( target: LOG_TARGET, - "Initiating header sync with peer `{}` (latency = {}ms)", + "Initiating header sync with peer `{}` (sync latency = {}ms)", conn.peer_node_id(), latency.unwrap_or_default().as_millis() ); @@ -272,6 +273,10 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { // We're ahead of this peer, try another peer if possible SyncStatus::Ahead => Err(BlockHeaderSyncError::NotInSync), SyncStatus::Lagging(split_info) => { + self.hooks.call_on_progress_header_hooks( + Some((split_info.local_tip_header.height(), split_info.remote_tip_height)), + self.sync_peers, + ); self.synchronize_headers(&peer, &mut client, *split_info).await?; Ok(()) }, @@ -483,7 +488,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { const COMMIT_EVERY_N_HEADERS: usize = 1000; // Peer returned no more than the max headers. This indicates that there are no further headers to request. - if self.header_validator.valid_headers().len() <= NUM_INITIAL_HEADERS_TO_REQUEST as usize { + if self.header_validator.valid_headers().len() < NUM_INITIAL_HEADERS_TO_REQUEST as usize { debug!(target: LOG_TARGET, "No further headers to download"); if !self.pending_chain_has_higher_pow(&split_info.local_tip_header)? { return Err(BlockHeaderSyncError::WeakerChain); @@ -563,7 +568,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } self.hooks - .call_on_progress_header_hooks(current_height, split_info.remote_tip_height, self.sync_peers); + .call_on_progress_header_hooks(Some((current_height, split_info.remote_tip_height)), self.sync_peers); } if !has_switched_to_new_chain { diff --git a/base_layer/core/src/base_node/sync/header_sync/validator.rs b/base_layer/core/src/base_node/sync/header_sync/validator.rs index aff52b80fd..4a25f28aa7 100644 --- a/base_layer/core/src/base_node/sync/header_sync/validator.rs +++ b/base_layer/core/src/base_node/sync/header_sync/validator.rs @@ -283,7 +283,7 @@ mod test { mod initialize_state { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_initializes_state_to_given_header() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(&tip.header().hash()).await.unwrap(); @@ -295,7 +295,7 @@ mod test { assert_eq!(state.current_height, 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_if_hash_does_not_exist() { let (mut validator, _) = setup(); let start_hash = vec![0; 32]; @@ -308,7 +308,7 @@ mod test { mod validate { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_passes_if_headers_are_valid() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(tip.hash()).await.unwrap(); @@ -322,7 +322,7 @@ mod test { assert_eq!(validator.valid_headers().len(), 2); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_fails_if_height_is_not_serial() { let (mut validator, _, tip) = setup_with_headers(2).await; validator.initialize_state(tip.hash()).await.unwrap(); diff --git a/base_layer/core/src/base_node/sync/hooks.rs b/base_layer/core/src/base_node/sync/hooks.rs index 71f0802926..d1e2628822 100644 --- a/base_layer/core/src/base_node/sync/hooks.rs +++ b/base_layer/core/src/base_node/sync/hooks.rs @@ -28,7 +28,7 @@ use tari_comms::peer_manager::NodeId; #[derive(Default)] pub(super) struct Hooks { - on_progress_header: Vec>, + on_progress_header: Vec, &[NodeId]) + Send + Sync>>, on_progress_block: Vec, u64, &[NodeId]) + Send + Sync>>, on_complete: Vec) + Send + Sync>>, on_rewind: Vec>) + Send + Sync>>, @@ -36,14 +36,14 @@ pub(super) struct Hooks { impl Hooks { pub fn add_on_progress_header_hook(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.on_progress_header.push(Box::new(hook)); } - pub fn call_on_progress_header_hooks(&mut self, height: u64, remote_tip_height: u64, sync_peers: &[NodeId]) { + pub fn call_on_progress_header_hooks(&mut self, height_vs_remote: Option<(u64, u64)>, sync_peers: &[NodeId]) { self.on_progress_header .iter_mut() - .for_each(|f| (*f)(height, remote_tip_height, sync_peers)); + .for_each(|f| (*f)(height_vs_remote, sync_peers)); } pub fn add_on_progress_block_hook(&mut self, hook: H) diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index 776c2b7e01..e9df7073a2 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -35,12 +35,14 @@ use crate::{ SyncUtxosResponse, }, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::cmp; -use tari_comms::protocol::rpc::{Request, Response, RpcStatus, Streaming}; +use tari_comms::{ + protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, +}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::task; +use tokio::{sync::mpsc, task}; use tracing::{instrument, span, Instrument, Level}; const LOG_TARGET: &str = "c::base_node::sync_rpc"; @@ -116,7 +118,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ // Number of blocks to load and push to the stream before loading the next batch const BATCH_SIZE: usize = 4; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let span = span!(Level::TRACE, "sync_rpc::block_sync::inner_worker"); task::spawn( @@ -138,19 +140,16 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(blocks) => { - let mut blocks = stream::iter( - blocks - .into_iter() - .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) - .map(|block| match block { - Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), - Err(err) => Err(err), - }) - .map(Ok), - ); + let blocks = blocks + .into_iter() + .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) + .map(|block| match block { + Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), + Err(err) => Err(err), + }); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut blocks).await.is_err() { + if utils::mpsc::send_all(&tx, blocks).await.is_err() { break; } }, @@ -209,7 +208,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ chunk_size ); - let (mut tx, rx) = mpsc::channel(chunk_size); + let (tx, rx) = mpsc::channel(chunk_size); let span = span!(Level::TRACE, "sync_rpc::sync_headers::inner_worker"); task::spawn( async move { @@ -233,10 +232,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(headers) => { - let mut headers = - stream::iter(headers.into_iter().map(proto::core::BlockHeader::from).map(Ok).map(Ok)); + let headers = headers.into_iter().map(proto::core::BlockHeader::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut headers).await.is_err() { + if utils::mpsc::send_all(&tx, headers).await.is_err() { break; } }, @@ -354,7 +352,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ ) -> Result, RpcStatus> { let req = request.into_message(); const BATCH_SIZE: usize = 1000; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let db = self.db(); task::spawn(async move { @@ -394,15 +392,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(kernels) => { - let mut kernels = stream::iter( - kernels - .into_iter() - .map(proto::types::TransactionKernel::from) - .map(Ok) - .map(Ok), - ); + let kernels = kernels.into_iter().map(proto::types::TransactionKernel::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut kernels).await.is_err() { + if utils::mpsc::send_all(&tx, kernels).await.is_err() { break; } }, diff --git a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs index ef10b41c2f..8064aaf458 100644 --- a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs +++ b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs @@ -25,11 +25,11 @@ use crate::{ proto, proto::base_node::{SyncUtxo, SyncUtxosRequest, SyncUtxosResponse}, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc, time::Instant}; -use tari_comms::protocol::rpc::RpcStatus; +use tari_comms::{protocol::rpc::RpcStatus, utils}; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tokio::sync::mpsc; const LOG_TARGET: &str = "c::base_node::sync_rpc::sync_utxo_task"; @@ -147,8 +147,7 @@ where B: BlockchainBackend + 'static utxos.len(), deleted_diff.cardinality(), ); - let mut utxos = stream::iter( - utxos + let utxos = utxos .into_iter() .enumerate() // Only include pruned UTXOs if include_pruned_utxos is true @@ -161,12 +160,10 @@ where B: BlockchainBackend + 'static mmr_index: start + i as u64, } }) - .map(Ok) - .map(Ok), - ); + .map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut utxos).await.is_err() { + if utils::mpsc::send_all(&tx, utxos).await.is_err() { break; } diff --git a/base_layer/core/src/base_node/sync/rpc/tests.rs b/base_layer/core/src/base_node/sync/rpc/tests.rs index 35611a9dda..61adc1aa3c 100644 --- a/base_layer/core/src/base_node/sync/rpc/tests.rs +++ b/base_layer/core/src/base_node/sync/rpc/tests.rs @@ -89,7 +89,7 @@ mod sync_blocks { use tari_comms::protocol::rpc::RpcStatusCode; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_not_found_if_unknown_hash() { let mut backend = create_mock_backend(); backend.expect_fetch().times(1).returning(|_| Ok(None)); @@ -103,7 +103,7 @@ mod sync_blocks { unpack_enum!(RpcStatusCode::NotFound = err.status_code()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_sends_an_empty_response() { let mut backend = create_mock_backend(); @@ -136,7 +136,7 @@ mod sync_blocks { assert!(streaming.next().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_streams_blocks_until_end() { let mut backend = create_mock_backend(); diff --git a/base_layer/core/src/lib.rs b/base_layer/core/src/lib.rs index 2e4bdc2f49..5a93b73576 100644 --- a/base_layer/core/src/lib.rs +++ b/base_layer/core/src/lib.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work +// Needed to make tokio::select! work #![recursion_limit = "512"] #![feature(shrink_to)] // #![cfg_attr(not(debug_assertions), deny(unused_variables))] diff --git a/base_layer/core/src/mempool/rpc/test.rs b/base_layer/core/src/mempool/rpc/test.rs index 64ad84f122..a9cbb2ee49 100644 --- a/base_layer/core/src/mempool/rpc/test.rs +++ b/base_layer/core/src/mempool/rpc/test.rs @@ -43,7 +43,7 @@ mod get_stats { use super::*; use crate::mempool::{MempoolService, StatsResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_stats() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_stats = StatsResponse { @@ -66,7 +66,7 @@ mod get_state { use super::*; use crate::mempool::{MempoolService, StateResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_state() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_state = StateResponse { @@ -94,7 +94,7 @@ mod get_tx_state_by_excess_sig { use tari_crypto::ristretto::{RistrettoPublicKey, RistrettoSecretKey}; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_storage_status() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -116,7 +116,7 @@ mod get_tx_state_by_excess_sig { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_signature() { let (service, _, req_mock, _tmpdir) = setup(); let status = service @@ -139,7 +139,7 @@ mod submit_transaction { use tari_crypto::ristretto::RistrettoSecretKey; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_submits_transaction() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -166,7 +166,7 @@ mod submit_transaction { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_transaction() { let (service, _, req_mock, _tmpdir) = setup(); let status = service diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index cb57d58e94..a295daf96a 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -37,7 +37,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -54,7 +54,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::mempool_service::initializer"; const SUBSCRIPTION_LABEL: &str = "Mempool"; @@ -148,7 +148,7 @@ impl ServiceInitializer for MempoolServiceInitializer { let mempool_handle = MempoolHandle::new(request_sender); context.register_handle(mempool_handle); - let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded(); + let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded_channel(); let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (mempool_state_event_publisher, _) = broadcast::channel(100); @@ -167,7 +167,7 @@ impl ServiceInitializer for MempoolServiceInitializer { context.register_handle(outbound_mp_interface); context.register_handle(local_mp_interface); - context.spawn_when_ready(move |handles| async move { + context.spawn_until_shutdown(move |handles| { let outbound_message_service = handles.expect_handle::().outbound_requester(); let state_machine = handles.expect_handle::(); let base_node = handles.expect_handle::(); @@ -182,11 +182,7 @@ impl ServiceInitializer for MempoolServiceInitializer { block_event_stream: base_node.get_block_event_stream(), request_receiver, }; - let service = - MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams); - futures::pin_mut!(service); - future::select(service, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "Mempool Service shutdown"); + MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams) }); Ok(()) diff --git a/base_layer/core/src/mempool/service/local_service.rs b/base_layer/core/src/mempool/service/local_service.rs index c58af9912d..58e4bef4d3 100644 --- a/base_layer/core/src/mempool/service/local_service.rs +++ b/base_layer/core/src/mempool/service/local_service.rs @@ -146,7 +146,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); @@ -157,7 +157,7 @@ mod test { assert_eq!(stats, request_stats()); } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats_from_multiple() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); diff --git a/base_layer/core/src/mempool/service/outbound_interface.rs b/base_layer/core/src/mempool/service/outbound_interface.rs index 87cda226f3..f84a90a4a8 100644 --- a/base_layer/core/src/mempool/service/outbound_interface.rs +++ b/base_layer/core/src/mempool/service/outbound_interface.rs @@ -28,10 +28,10 @@ use crate::{ }, transactions::{transaction::Transaction, types::Signature}, }; -use futures::channel::mpsc::UnboundedSender; use log::*; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::mp::service::outbound_interface"; @@ -71,15 +71,13 @@ impl OutboundMempoolServiceInterface { transaction: Transaction, exclude_peers: Vec, ) -> Result<(), MempoolServiceError> { - self.tx_sender - .unbounded_send((transaction, exclude_peers)) - .or_else(|e| { - { - error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); - Err(e) - } - .map_err(|_| MempoolServiceError::BroadcastFailed) - }) + self.tx_sender.send((transaction, exclude_peers)).or_else(|e| { + { + error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); + Err(e) + } + .map_err(|_| MempoolServiceError::BroadcastFailed) + }) } /// Check if the specified transaction is stored in the mempool of a remote base node. diff --git a/base_layer/core/src/mempool/service/service.rs b/base_layer/core/src/mempool/service/service.rs index e7c2918fb3..137dd9d0a0 100644 --- a/base_layer/core/src/mempool/service/service.rs +++ b/base_layer/core/src/mempool/service/service.rs @@ -38,13 +38,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{ - channel::{mpsc, oneshot::Sender as OneshotSender}, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -58,7 +52,10 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot::Sender as OneshotSender}, + task, +}; const LOG_TARGET: &str = "c::mempool::service::service"; @@ -118,7 +115,7 @@ impl MempoolService { { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let mut outbound_tx_stream = streams.outbound_tx_stream.fuse(); + let mut outbound_tx_stream = streams.outbound_tx_stream; let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); let inbound_response_stream = streams.inbound_response_stream.fuse(); @@ -127,70 +124,70 @@ impl MempoolService { pin_mut!(inbound_transaction_stream); let local_request_stream = streams.local_request_stream.fuse(); pin_mut!(local_request_stream); - let mut block_event_stream = streams.block_event_stream.fuse(); + let mut block_event_stream = streams.block_event_stream; let mut timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Mempool Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Mempool Service initialized without timeout_receiver_stream"); let mut request_receiver = streams.request_receiver; loop { - futures::select! { + tokio::select! { // Requests sent from the handle - request = request_receiver.select_next_some() => { + Some(request) = request_receiver.next() => { let (request, reply) = request.split(); let _ = reply.send(self.handle_request(request).await); }, // Outbound request messages from the OutboundMempoolServiceInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound tx messages from the OutboundMempoolServiceInterface - (txn, excluded_peers) = outbound_tx_stream.select_next_some() => { + Some((txn, excluded_peers)) = outbound_tx_stream.recv() => { self.spawn_handle_outbound_tx(txn, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Incoming transaction messages from the Comms layer - transaction_msg = inbound_transaction_stream.select_next_some() => { + Some(transaction_msg) = inbound_transaction_stream.next() => { self.spawn_handle_incoming_tx(transaction_msg).await; } // Incoming local request messages from the LocalMempoolServiceInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Block events from local Base Node. - block_event = block_event_stream.select_next_some() => { + block_event = block_event_stream.recv() => { if let Ok(block_event) = block_event { self.spawn_handle_block_event(block_event); } }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, - complete => { + else => { info!(target: LOG_TARGET, "Mempool service shutting down"); break; } } } + Ok(()) } @@ -506,9 +503,9 @@ async fn handle_outbound_tx( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/mempool/sync_protocol/initializer.rs b/base_layer/core/src/mempool/sync_protocol/initializer.rs index df40de648a..535c32c77a 100644 --- a/base_layer/core/src/mempool/sync_protocol/initializer.rs +++ b/base_layer/core/src/mempool/sync_protocol/initializer.rs @@ -28,13 +28,17 @@ use crate::{ MempoolServiceConfig, }, }; -use futures::channel::mpsc; +use log::*; +use std::time::Duration; use tari_comms::{ connectivity::ConnectivityRequester, protocol::{ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolNotification}, Substream, }; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::{sync::mpsc, time::sleep}; + +const LOG_TARGET: &str = "c::mempool::sync_protocol"; pub struct MempoolSyncInitializer { config: MempoolServiceConfig, @@ -70,17 +74,29 @@ impl ServiceInitializer for MempoolSyncInitializer { let mempool = self.mempool.clone(); let notif_rx = self.notif_rx.take().unwrap(); - context.spawn_when_ready(move |handles| { + context.spawn_until_shutdown(move |handles| async move { let state_machine = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - MempoolSyncProtocol::new( - config, - notif_rx, - connectivity.get_event_subscription(), - mempool, - Some(state_machine), - ) - .run() + + let mut status_watch = state_machine.get_status_info_watch(); + if !status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Waiting for node to bootstrap..."); + while status_watch.changed().await.is_ok() { + if status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Node bootstrapped. Starting mempool sync protocol"); + break; + } + trace!( + target: LOG_TARGET, + "Mempool sync still on hold, waiting for bootstrap to finish", + ); + sleep(Duration::from_secs(1)).await; + } + } + + MempoolSyncProtocol::new(config, notif_rx, connectivity.get_event_subscription(), mempool) + .run() + .await; }); Ok(()) diff --git a/base_layer/core/src/mempool/sync_protocol/mod.rs b/base_layer/core/src/mempool/sync_protocol/mod.rs index 08219f2a8d..dc133f744d 100644 --- a/base_layer/core/src/mempool/sync_protocol/mod.rs +++ b/base_layer/core/src/mempool/sync_protocol/mod.rs @@ -73,12 +73,11 @@ mod initializer; pub use initializer::MempoolSyncInitializer; use crate::{ - base_node::StateMachineHandle, mempool::{async_mempool, proto, Mempool, MempoolServiceConfig}, proto as shared_proto, transactions::transaction::Transaction, }; -use futures::{stream, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{stream, SinkExt, Stream, StreamExt}; use log::*; use prost::Message; use std::{ @@ -88,7 +87,6 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx}, @@ -101,7 +99,11 @@ use tari_comms::{ PeerConnection, }; use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tokio::{sync::Semaphore, task}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Semaphore, + task, +}; const MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB const LOG_TARGET: &str = "c::mempool::sync_protocol"; @@ -111,11 +113,10 @@ pub static MEMPOOL_SYNC_PROTOCOL: Bytes = Bytes::from_static(b"t/mempool-sync/1" pub struct MempoolSyncProtocol { config: MempoolServiceConfig, protocol_notifier: ProtocolNotificationRx, - connectivity_events: Fuse, + connectivity_events: ConnectivityEventRx, mempool: Mempool, num_synched: Arc, permits: Arc, - state_machine: Option, } impl MempoolSyncProtocol @@ -126,54 +127,34 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static protocol_notifier: ProtocolNotificationRx, connectivity_events: ConnectivityEventRx, mempool: Mempool, - state_machine: Option, ) -> Self { Self { config, protocol_notifier, - connectivity_events: connectivity_events.fuse(), + connectivity_events, mempool, num_synched: Arc::new(AtomicUsize::new(0)), permits: Arc::new(Semaphore::new(1)), - state_machine, } } pub async fn run(mut self) { info!(target: LOG_TARGET, "Mempool protocol handler has started"); - while let Some(ref v) = self.state_machine { - let status_watch = v.get_status_info_watch(); - if (*status_watch.borrow()).bootstrapped { - break; - } - trace!( - target: LOG_TARGET, - "Mempool sync still on hold, waiting for bootstrap to finish", - ); - tokio::time::delay_for(Duration::from_secs(1)).await; - } + loop { - futures::select! { - event = self.connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event).await; - } + tokio::select! { + Ok(event) = self.connectivity_events.recv() => { + self.handle_connectivity_event(event).await; }, - notif = self.protocol_notifier.select_next_some() => { + Some(notif) = self.protocol_notifier.recv() => { self.handle_protocol_notification(notif); } - - // protocol_notifier and connectivity_events are closed - complete => { - info!(target: LOG_TARGET, "Mempool protocol handler is shutting down"); - break; - } } } } - async fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + async fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { match event { // If this node is connecting to a peer ConnectivityEvent::PeerConnected(conn) if conn.direction().is_outbound() => { diff --git a/base_layer/core/src/mempool/sync_protocol/test.rs b/base_layer/core/src/mempool/sync_protocol/test.rs index dd77fe3c70..30f4b096d2 100644 --- a/base_layer/core/src/mempool/sync_protocol/test.rs +++ b/base_layer/core/src/mempool/sync_protocol/test.rs @@ -30,7 +30,7 @@ use crate::{ transactions::{helpers::create_tx, tari_amount::uT, transaction::Transaction}, validation::mocks::MockValidator, }; -use futures::{channel::mpsc, Sink, SinkExt, Stream, StreamExt}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use std::{fmt, io, iter::repeat_with, sync::Arc}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventTx}, @@ -44,7 +44,10 @@ use tari_comms::{ BytesMut, }; use tari_crypto::tari_utilities::ByteArray; -use tokio::{sync::broadcast, task}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; pub fn create_transactions(n: usize) -> Vec { repeat_with(|| { @@ -82,7 +85,6 @@ fn setup( protocol_notif_rx, connectivity_events_rx, mempool.clone(), - None, ); task::spawn(protocol.run()); @@ -90,7 +92,7 @@ fn setup( (protocol_notif_tx, connectivity_events_tx, mempool, transactions) } -#[tokio_macros::test_basic] +#[tokio::test] async fn empty_set() { let (_, connectivity_events_tx, mempool1, _) = setup(0); @@ -101,7 +103,7 @@ async fn empty_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -120,7 +122,7 @@ async fn empty_set() { assert_eq!(transactions.len(), 0); } -#[tokio_macros::test_basic] +#[tokio::test] async fn synchronise() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(5); @@ -131,7 +133,7 @@ async fn synchronise() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -154,7 +156,7 @@ async fn synchronise() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn duplicate_set() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(2); @@ -165,7 +167,7 @@ async fn duplicate_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -189,9 +191,9 @@ async fn duplicate_set() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -225,9 +227,9 @@ async fn responder() { // this. } -#[tokio_macros::test_basic] +#[tokio::test] async fn initiator_messages() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -260,7 +262,7 @@ async fn initiator_messages() { assert_eq!(indexes.indexes, [0, 1]); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder_messages() { let (_, connectivity_events_tx, _, transactions1) = setup(1); @@ -271,7 +273,7 @@ async fn responder_messages() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); diff --git a/base_layer/wallet/src/util/emoji.rs b/base_layer/core/src/transactions/emoji/emoji_id.rs similarity index 97% rename from base_layer/wallet/src/util/emoji.rs rename to base_layer/core/src/transactions/emoji/emoji_id.rs index 18ecdc174c..8234e41fad 100644 --- a/base_layer/wallet/src/util/emoji.rs +++ b/base_layer/core/src/transactions/emoji/emoji_id.rs @@ -20,12 +20,13 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::util::luhn::{checksum, is_valid}; +use super::luhn::{checksum, is_valid}; +use crate::transactions::types::PublicKey; +use lazy_static::lazy_static; use std::{ collections::HashMap, fmt::{Display, Error, Formatter}, }; -use tari_core::transactions::types::PublicKey; use tari_crypto::tari_utilities::{ hex::{Hex, HexError}, ByteArray, @@ -70,7 +71,7 @@ lazy_static! { /// # Example /// /// ``` -/// use tari_wallet::util::emoji::EmojiId; +/// use tari_core::transactions::emoji::EmojiId; /// /// assert!(EmojiId::is_valid("🐎🍴🌷🌟💻🐖🐩🐾🌟🐬🎧🐌🏦🐳🐎🐝🐢🔋👕🎸👿🍒🐓🎉💔🌹🏆🐬💡🎳🚦🍹🎒")); /// let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); @@ -170,9 +171,7 @@ pub struct EmojiIdError; #[cfg(test)] mod test { - use crate::util::emoji::EmojiId; - use tari_core::transactions::types::PublicKey; - use tari_crypto::tari_utilities::hex::Hex; + use super::*; #[test] fn convert_key() { diff --git a/base_layer/wallet/src/util/luhn.rs b/base_layer/core/src/transactions/emoji/luhn.rs similarity index 99% rename from base_layer/wallet/src/util/luhn.rs rename to base_layer/core/src/transactions/emoji/luhn.rs index 9a9996ef72..d3cc83f508 100644 --- a/base_layer/wallet/src/util/luhn.rs +++ b/base_layer/core/src/transactions/emoji/luhn.rs @@ -45,7 +45,7 @@ pub fn is_valid(arr: &[usize], dict_len: usize) -> bool { #[cfg(test)] mod test { - use crate::util::luhn::*; + use super::*; #[test] fn luhn_6() { diff --git a/base_layer/core/src/transactions/emoji/mod.rs b/base_layer/core/src/transactions/emoji/mod.rs new file mode 100644 index 0000000000..a9b0b1add0 --- /dev/null +++ b/base_layer/core/src/transactions/emoji/mod.rs @@ -0,0 +1,26 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod emoji_id; +pub use emoji_id::{emoji_set, EmojiId, EmojiIdError}; + +mod luhn; diff --git a/base_layer/core/src/transactions/mod.rs b/base_layer/core/src/transactions/mod.rs index 9d98bd394b..a027c4c4cf 100644 --- a/base_layer/core/src/transactions/mod.rs +++ b/base_layer/core/src/transactions/mod.rs @@ -11,10 +11,7 @@ pub use transaction_protocol::{recipient::ReceiverTransactionProtocol, sender::S #[macro_use] pub mod helpers; -#[cfg(any(feature = "base_node", feature = "transactions"))] -mod coinbase_builder; +pub mod emoji; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuildError; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuilder; +mod coinbase_builder; +pub use crate::transactions::coinbase_builder::{CoinbaseBuildError, CoinbaseBuilder}; diff --git a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs index 0d5beb738d..46e8267c18 100644 --- a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs +++ b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs @@ -763,6 +763,7 @@ mod test { .with_output(output, p.sender_offset_private_key) .unwrap() .with_fee_per_gram(MicroTari(2)); + for _ in 0..MAX_TRANSACTION_INPUTS + 1 { let (utxo, input) = create_test_input(MicroTari(50), 0, &factories.commitment); builder.with_input(utxo, input); diff --git a/base_layer/core/tests/base_node_rpc.rs b/base_layer/core/tests/base_node_rpc.rs index 9b96512d47..453b9eef30 100644 --- a/base_layer/core/tests/base_node_rpc.rs +++ b/base_layer/core/tests/base_node_rpc.rs @@ -81,22 +81,19 @@ use tari_core::{ txn_schema, }; use tempfile::{tempdir, TempDir}; -use tokio::runtime::Runtime; -fn setup() -> ( +async fn setup() -> ( BaseNodeWalletRpcService, NodeInterfaces, RpcRequestMock, ConsensusManager, ChainBlock, UnblindedOutput, - Runtime, TempDir, ) { let network = NetworkConsensus::from(Network::LocalNet); let consensus_constants = network.create_consensus_constants(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxo0) = @@ -107,13 +104,14 @@ fn setup() -> ( let (mut base_node, _consensus_manager) = BaseNodeBuilder::new(network) .with_consensus_manager(consensus_manager.clone()) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; base_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), }); - let request_mock = runtime.enter(|| RpcRequestMock::new(base_node.comms.peer_manager())); + let request_mock = RpcRequestMock::new(base_node.comms.peer_manager()); let service = BaseNodeWalletRpcService::new( base_node.blockchain_db.clone().into(), base_node.mempool_handle.clone(), @@ -126,16 +124,15 @@ fn setup() -> ( consensus_manager, block0, utxo0, - runtime, temp_dir, ) } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_base_node_wallet_rpc() { +async fn test_base_node_wallet_rpc() { // Testing the submit_transaction() and transaction_query() rpc calls - let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, mut runtime, _temp_dir) = setup(); + let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, _temp_dir) = setup().await; let (txs1, utxos1) = schema_to_transaction(&[txn_schema!(from: vec![utxo0.clone()], to: vec![1 * T, 1 * T])]); let tx1 = (*txs1[0]).clone(); @@ -151,8 +148,8 @@ fn test_base_node_wallet_rpc() { // Query Tx1 let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = service.transaction_query(req).await.unwrap().into_message(); + let resp = TxQueryResponse::try_from(resp).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -162,13 +159,7 @@ fn test_base_node_wallet_rpc() { let msg = TransactionProto::from(tx2.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::Orphan); @@ -176,8 +167,7 @@ fn test_base_node_wallet_rpc() { // Query Tx2 to confirm it wasn't accepted let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -189,24 +179,22 @@ fn test_base_node_wallet_rpc() { .prepare_block_merkle_roots(chain_block(&block0.block(), vec![tx1.clone()], &consensus_manager)) .unwrap(); - assert!(runtime - .block_on(base_node.local_nci.submit_block(block1.clone(), Broadcast::from(true))) - .is_ok()); + base_node + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); // Check that subitting Tx2 will now be accepted let msg = TransactionProto::from(tx2); let req = request_mock.request_with_context(Default::default(), msg); - let resp = runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(); + let resp = service.submit_transaction(req).await.unwrap().into_message(); assert!(resp.accepted); // Query Tx2 which should now be in the mempool let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -215,13 +203,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined as Tx1's kernel is present let msg = TransactionProto::from(tx1); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::AlreadyMined); @@ -233,13 +215,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined let msg = TransactionProto::from(tx1b); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::DoubleSpend); @@ -253,15 +229,16 @@ fn test_base_node_wallet_rpc() { block2.header.output_mmr_size += 1; block2.header.kernel_mmr_size += 1; - runtime - .block_on(base_node.local_nci.submit_block(block2, Broadcast::from(true))) + base_node + .local_nci + .submit_block(block2, Broadcast::from(true)) + .await .unwrap(); // Query Tx1 which should be in block 1 with 1 confirmation let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 1); assert_eq!(resp.block_hash, Some(block1.hash())); @@ -271,10 +248,7 @@ fn test_base_node_wallet_rpc() { sigs: vec![SignatureProto::from(tx1_sig.clone()), SignatureProto::from(tx2_sig)], }; let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.transaction_batch_query(req)) - .unwrap() - .into_message(); + let response = service.transaction_batch_query(req).await.unwrap().into_message(); for r in response.responses { let response = TxQueryBatchResponse::try_from(r).unwrap(); @@ -299,10 +273,7 @@ fn test_base_node_wallet_rpc() { let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.fetch_matching_utxos(req)) - .unwrap() - .into_message(); + let response = service.fetch_matching_utxos(req).await.unwrap().into_message(); assert_eq!(response.outputs.len(), utxos1.len()); for output_proto in response.outputs.iter() { diff --git a/base_layer/core/tests/helpers/event_stream.rs b/base_layer/core/tests/helpers/event_stream.rs index b79b494900..5485467f4c 100644 --- a/base_layer/core/tests/helpers/event_stream.rs +++ b/base_layer/core/tests/helpers/event_stream.rs @@ -20,16 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{future, future::Either, FutureExt, Stream, StreamExt}; use std::time::Duration; +use tokio::{sync::broadcast, time}; #[allow(dead_code)] -pub async fn event_stream_next(stream: &mut TStream, timeout: Duration) -> Option -where TStream: Stream + Unpin { - let either = future::select(stream.next(), tokio::time::delay_for(timeout).fuse()).await; - - match either { - Either::Left((v, _)) => v, - Either::Right(_) => None, +pub async fn event_stream_next(stream: &mut broadcast::Receiver, timeout: Duration) -> Option { + tokio::select! { + item = stream.recv() => match item { + Ok(item) => Some(item), + Err(broadcast::error::RecvError::Closed) => None, + Err(broadcast::error::RecvError::Lagged(n)) => panic!("Lagged events channel {}", n), + }, + _ = time::sleep(timeout) => None } } diff --git a/base_layer/core/tests/helpers/mock_state_machine.rs b/base_layer/core/tests/helpers/mock_state_machine.rs index 7d49f93e85..0d4b6ce512 100644 --- a/base_layer/core/tests/helpers/mock_state_machine.rs +++ b/base_layer/core/tests/helpers/mock_state_machine.rs @@ -40,7 +40,7 @@ impl MockBaseNodeStateMachine { } pub fn publish_status(&mut self, status: StatusInfo) { - let _ = self.status_sender.broadcast(status); + let _ = self.status_sender.send(status); } pub fn get_initializer(&self) -> MockBaseNodeStateMachineInitializer { diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index ffe69c8034..06f2d5f8e0 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -21,9 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::helpers::mock_state_machine::MockBaseNodeStateMachine; -use futures::Sink; use rand::rngs::OsRng; -use std::{error::Error, path::Path, sync::Arc, time::Duration}; +use std::{path::Path, sync::Arc, time::Duration}; use tari_common::configuration::Network; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, @@ -60,13 +59,12 @@ use tari_core::{ }, }; use tari_p2p::{ - comms_connector::{pubsub_connector, InboundDomainConnector, PeerMessage}, + comms_connector::{pubsub_connector, InboundDomainConnector}, initialization::initialize_local_test_comms, services::liveness::{LivenessConfig, LivenessHandle, LivenessInitializer}, }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tokio::runtime::Runtime; /// The NodeInterfaces is used as a container for providing access to all the services and interfaces of a single node. pub struct NodeInterfaces { @@ -91,7 +89,7 @@ pub struct NodeInterfaces { #[allow(dead_code)] impl NodeInterfaces { pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -182,7 +180,7 @@ impl BaseNodeBuilder { /// Build the test base node and start its services. #[allow(clippy::redundant_closure)] - pub fn start(self, runtime: &mut Runtime, data_path: &str) -> (NodeInterfaces, ConsensusManager) { + pub async fn start(self, data_path: &str) -> (NodeInterfaces, ConsensusManager) { let validators = self.validators.unwrap_or_else(|| { Validators::new( MockValidator::new(true), @@ -199,7 +197,6 @@ impl BaseNodeBuilder { let mempool = Mempool::new(self.mempool_config.unwrap_or_default(), Arc::new(mempool_validator)); let node_identity = self.node_identity.unwrap_or_else(|| random_node_identity()); let node_interfaces = setup_base_node_services( - runtime, node_identity, self.peers.unwrap_or_default(), blockchain_db, @@ -209,17 +206,19 @@ impl BaseNodeBuilder { self.mempool_service_config.unwrap_or_default(), self.liveness_service_config.unwrap_or_default(), data_path, - ); + ) + .await; (node_interfaces, consensus_manager) } } -#[allow(dead_code)] -pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { +pub async fn wait_until_online(nodes: &[&NodeInterfaces]) { for node in nodes { - runtime - .block_on(node.comms.connectivity().wait_for_connectivity(Duration::from_secs(10))) + node.comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await .map_err(|err| format!("Node '{}' failed to go online {:?}", node.node_identity.node_id(), err)) .unwrap(); } @@ -227,10 +226,7 @@ pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes( - runtime: &mut Runtime, - data_path: &str, -) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { +pub async fn create_network_with_2_base_nodes(data_path: &str) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { let alice_node_identity = random_node_identity(); let bob_node_identity = random_node_identity(); @@ -238,22 +234,23 @@ pub fn create_network_with_2_base_nodes( let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) - .start(runtime, data_path); + .start(data_path) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) .with_consensus_manager(consensus_manager) - .start(runtime, data_path); + .start(data_path) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_2_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -269,7 +266,8 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) @@ -277,35 +275,34 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes( data_path: &str, ) -> (NodeInterfaces, NodeInterfaces, NodeInterfaces, ConsensusManager) { let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); create_network_with_3_base_nodes_with_config( - runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, data_path, ) + .await } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -329,7 +326,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("carol").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("carol").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![carol_node_identity.clone()]) @@ -337,7 +335,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) @@ -345,9 +344,10 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; (alice_node, bob_node, carol_node, consensus_manager) } @@ -365,16 +365,12 @@ pub fn random_node_identity() -> Arc { // Helper function for starting the comms stack. #[allow(dead_code)] -async fn setup_comms_services( +async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, -) -> (CommsNode, Dht, MessagingEventSender, Shutdown) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender, Shutdown) { let peers = peers.into_iter().map(|p| p.to_peer()).collect(); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = initialize_local_test_comms( @@ -393,8 +389,7 @@ where // Helper function for starting the services of the Base node. #[allow(clippy::too_many_arguments)] -fn setup_base_node_services( - runtime: &mut Runtime, +async fn setup_base_node_services( node_identity: Arc, peers: Vec>, blockchain_db: BlockchainDatabase, @@ -405,14 +400,14 @@ fn setup_base_node_services( liveness_service_config: LivenessConfig, data_path: &str, ) -> NodeInterfaces { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht, messaging_events, shutdown) = - runtime.block_on(setup_comms_services(node_identity.clone(), peers, publisher, data_path)); + setup_comms_services(node_identity.clone(), peers, publisher, data_path).await; let mock_state_machine = MockBaseNodeStateMachine::new(); - let fut = StackBuilder::new(shutdown.to_signal()) + let handles = StackBuilder::new(shutdown.to_signal()) .add_initializer(RegisterHandle::new(dht)) .add_initializer(RegisterHandle::new(comms.connectivity())) .add_initializer(LivenessInitializer::new( @@ -433,9 +428,9 @@ fn setup_base_node_services( )) .add_initializer(mock_state_machine.get_initializer()) .add_initializer(ChainMetadataServiceInitializer) - .build(); - - let handles = runtime.block_on(fut).expect("Service initialization failed"); + .build() + .await + .unwrap(); let outbound_nci = handles.expect_handle::(); let local_nci = handles.expect_handle::(); diff --git a/base_layer/core/tests/mempool.rs b/base_layer/core/tests/mempool.rs index 80e187a99a..973769f664 100644 --- a/base_layer/core/tests/mempool.rs +++ b/base_layer/core/tests/mempool.rs @@ -66,11 +66,10 @@ use tari_crypto::script; use tari_p2p::{services::liveness::LivenessConfig, tari_message::TariMessageType}; use tari_test_utils::async_assert_eventually; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_insert_and_process_published_block() { +async fn test_insert_and_process_published_block() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -201,9 +200,9 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.total_weight, 30); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_time_locked() { +async fn test_time_locked() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -245,9 +244,9 @@ fn test_time_locked() { assert_eq!(mempool.insert(tx2).unwrap(), TxStorageResponse::UnconfirmedPool); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_retrieve() { +async fn test_retrieve() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -331,9 +330,9 @@ fn test_retrieve() { assert!(retrieved_txs.contains(&tx2[1])); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_zero_conf() { +async fn test_zero_conf() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -631,9 +630,9 @@ fn test_zero_conf() { assert!(retrieved_txs.contains(&Arc::new(tx34))); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_reorg() { +async fn test_reorg() { let network = Network::LocalNet; let (mut db, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(db.clone()); @@ -712,13 +711,12 @@ fn test_reorg() { mempool.process_reorg(vec![], vec![reorg_block4.into()]).unwrap(); } -#[test] // TODO: This test returns 0 in the unconfirmed pool, so might not catch errors. It should be updated to return better // data #[allow(clippy::identity_op)] -fn request_response_get_stats() { +#[tokio::test] +async fn request_response_get_stats() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -731,13 +729,13 @@ fn request_response_get_stats() { .with_block(block0) .build(); let (mut alice, bob, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path(), - ); + ) + .await; // Create a tx spending the genesis output. Then create 2 orphan txs let (tx1, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![2 * T, 2 * T, 2 * T])); @@ -759,21 +757,18 @@ fn request_response_get_stats() { assert_eq!(stats.reorg_txs, 0); assert_eq!(stats.total_weight, 0); - runtime.block_on(async { - // Alice will request mempool stats from Bob, and thus should be identical - let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); - assert_eq!(received_stats.total_txs, 0); - assert_eq!(received_stats.unconfirmed_txs, 0); - assert_eq!(received_stats.reorg_txs, 0); - assert_eq!(received_stats.total_weight, 0); - }); + // Alice will request mempool stats from Bob, and thus should be identical + let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); + assert_eq!(received_stats.total_txs, 0); + assert_eq!(received_stats.unconfirmed_txs, 0); + assert_eq!(received_stats.reorg_txs, 0); + assert_eq!(received_stats.total_weight, 0); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn request_response_get_tx_state_by_excess_sig() { +async fn request_response_get_tx_state_by_excess_sig() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -786,13 +781,13 @@ fn request_response_get_tx_state_by_excess_sig() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let (tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo.clone()], to: vec![2 * T, 2 * T, 2 * T])); let (unpublished_tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![3 * T])); @@ -807,43 +802,40 @@ fn request_response_get_tx_state_by_excess_sig() { // Check that the transactions are in the expected pools. // Spending the coinbase utxo will be in the pending pool, because cb utxos have a maturity. // The orphan tx will be in the orphan pool, while the unadded tx won't be found - runtime.block_on(async { - let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); - let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); - let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(orphan_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - }); + let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); + let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); + let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(orphan_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); } static EMISSION: [u64; 2] = [10, 10]; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn receive_and_propagate_transaction() { +async fn receive_and_propagate_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -857,13 +849,13 @@ fn receive_and_propagate_transaction() { .build(); let (mut alice_node, mut bob_node, mut carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -884,63 +876,61 @@ fn receive_and_propagate_transaction() { assert!(alice_node.mempool.insert(Arc::new(tx.clone())).is_ok()); assert!(alice_node.mempool.insert(Arc::new(orphan.clone())).is_ok()); - runtime.block_on(async { - alice_node - .outbound_message_service - .send_direct( - bob_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), - ) - .await - .unwrap(); - alice_node - .outbound_message_service - .send_direct( - carol_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), - ) - .await - .unwrap(); + alice_node + .outbound_message_service + .send_direct( + bob_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), + ) + .await + .unwrap(); + alice_node + .outbound_message_service + .send_direct( + carol_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), + ) + .await + .unwrap(); - async_assert_eventually!( - bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(tx_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // Carol got sent the orphan tx directly, so it will be in her mempool - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated - // by the time we check it - async_assert_eventually!( - bob_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - ); - }); + async_assert_eventually!( + bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(tx_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // Carol got sent the orphan tx directly, so it will be in her mempool + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated + // by the time we check it + async_assert_eventually!( + bob_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + ); } -#[test] -fn consensus_validation_large_tx() { +#[tokio::test] +async fn consensus_validation_large_tx() { let network = Network::LocalNet; // We dont want to compute the 19500 limit of local net, so we create smaller blocks let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -1042,9 +1032,8 @@ fn consensus_validation_large_tx() { assert!(matches!(response, TxStorageResponse::NotStored)); } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let mempool_service_config = MempoolServiceConfig { @@ -1053,27 +1042,25 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), mempool_service_config, LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - bob_node.shutdown().await; + bob_node.shutdown().await; - match alice_node.outbound_mp_interface.get_stats().await { - Err(MempoolServiceError::RequestTimedOut) => {}, - _ => panic!(), - } - }); + match alice_node.outbound_mp_interface.get_stats().await { + Err(MempoolServiceError::RequestTimedOut) => {}, + _ => panic!(), + } } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn block_event_and_reorg_event_handling() { +async fn block_event_and_reorg_event_handling() { // This test creates 2 nodes Alice and Bob // Then creates 2 chains B1 -> B2A (diff 1) and B1 -> B2B (diff 10) // There are 5 transactions created @@ -1086,7 +1073,6 @@ fn block_event_and_reorg_event_handling() { let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxos0) = create_genesis_block_with_coinbase_value(&factories, 100_000_000.into(), &consensus_constants[0]); @@ -1095,13 +1081,13 @@ fn block_event_and_reorg_event_handling() { .with_block(block0.clone()) .build(); let (mut alice, mut bob, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -1135,88 +1121,86 @@ fn block_event_and_reorg_event_handling() { .prepare_block_merkle_roots(chain_block(block0.block(), vec![], &consensus_manager)) .unwrap(); - runtime.block_on(async { - // Add one empty block, so the coinbase UTXO is no longer time-locked. - assert!(bob - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - assert!(alice - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); - let mut block1 = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); - // Add Block1 - tx1 will be moved to the ReorgPool. - assert!(bob - .local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .is_ok()); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - - let mut block2a = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); - // Block2b also builds on Block1 but has a stronger PoW - let mut block2b = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); - - // Add Block2a - tx2b and tx3b will be discarded as double spends. - assert!(bob - .local_nci - .submit_block(block2a.clone(), Broadcast::from(true)) - .await - .is_ok()); - - async_assert_eventually!( - bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - }); + // Add one empty block, so the coinbase UTXO is no longer time-locked. + assert!(bob + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + assert!(alice + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); + let mut block1 = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); + // Add Block1 - tx1 will be moved to the ReorgPool. + assert!(bob + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .is_ok()); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + + let mut block2a = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); + // Block2b also builds on Block1 but has a stronger PoW + let mut block2b = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); + + // Add Block2a - tx2b and tx3b will be discarded as double spends. + assert!(bob + .local_nci + .submit_block(block2a.clone(), Broadcast::from(true)) + .await + .is_ok()); + + async_assert_eventually!( + bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); } diff --git a/base_layer/core/tests/node_comms_interface.rs b/base_layer/core/tests/node_comms_interface.rs index 532d102bbe..97851f0d17 100644 --- a/base_layer/core/tests/node_comms_interface.rs +++ b/base_layer/core/tests/node_comms_interface.rs @@ -22,7 +22,7 @@ #[allow(dead_code)] mod helpers; -use futures::{channel::mpsc, StreamExt}; +use futures::StreamExt; use helpers::block_builders::append_block; use std::sync::Arc; use tari_common::configuration::Network; @@ -55,7 +55,7 @@ use tari_crypto::{ tari_utilities::hash::Hashable, }; use tari_service_framework::{reply_channel, reply_channel::Receiver}; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; // use crate::helpers::database::create_test_db; async fn test_request_responder( @@ -71,10 +71,10 @@ fn new_mempool() -> Mempool { Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)) } -#[tokio_macros::test] +#[tokio::test] async fn outbound_get_metadata() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let metadata = ChainMetadata::new(5, vec![0u8], 3, 0, 5); @@ -86,7 +86,7 @@ async fn outbound_get_metadata() { assert_eq!(received_metadata.unwrap(), metadata); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_get_metadata() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -95,7 +95,7 @@ async fn inbound_get_metadata() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -117,7 +117,7 @@ async fn inbound_get_metadata() { } } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_kernel_by_excess_sig() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -126,7 +126,7 @@ async fn inbound_fetch_kernel_by_excess_sig() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -149,10 +149,10 @@ async fn inbound_fetch_kernel_by_excess_sig() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_headers() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let mut header = BlockHeader::new(0); @@ -167,7 +167,7 @@ async fn outbound_fetch_headers() { assert_eq!(received_headers[0], header); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_headers() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -175,7 +175,7 @@ async fn inbound_fetch_headers() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -197,11 +197,11 @@ async fn inbound_fetch_headers() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_utxos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (utxo, _, _) = create_utxo( @@ -221,7 +221,7 @@ async fn outbound_fetch_utxos() { assert_eq!(received_utxos[0], utxo); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_utxos() { let factories = CryptoFactories::default(); @@ -231,7 +231,7 @@ async fn inbound_fetch_utxos() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -264,11 +264,11 @@ async fn inbound_fetch_utxos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_txos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (txo1, _, _) = create_utxo( @@ -296,7 +296,7 @@ async fn outbound_fetch_txos() { assert_eq!(received_txos[1], txo2); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_txos() { let factories = CryptoFactories::default(); let store = create_test_blockchain_db(); @@ -305,7 +305,7 @@ async fn inbound_fetch_txos() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -366,10 +366,10 @@ async fn inbound_fetch_txos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_blocks() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -385,7 +385,7 @@ async fn outbound_fetch_blocks() { assert_eq!(received_blocks[0], block); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_blocks() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -393,7 +393,7 @@ async fn inbound_fetch_blocks() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -415,7 +415,7 @@ async fn inbound_fetch_blocks() { } } -#[tokio_macros::test] +#[tokio::test] // Test needs to be updated to new pruned structure. async fn inbound_fetch_blocks_before_horizon_height() { let factories = CryptoFactories::default(); @@ -437,7 +437,7 @@ async fn inbound_fetch_blocks_before_horizon_height() { let mempool = Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index 14f277f19b..d806b9b646 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -23,7 +23,6 @@ #[allow(dead_code)] mod helpers; use crate::helpers::block_builders::{construct_chained_blocks, create_coinbase}; -use futures::join; use helpers::{ block_builders::{ append_block, @@ -72,11 +71,9 @@ use tari_crypto::tari_utilities::hash::Hashable; use tari_p2p::services::liveness::LivenessConfig; use tari_test_utils::unpack_enum; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] -fn request_response_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_response_get_metadata() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -89,27 +86,24 @@ fn request_response_get_metadata() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); - assert_eq!(received_metadata.height_of_longest_chain(), 0); + let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); + assert_eq!(received_metadata.height_of_longest_chain(), 0); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -122,13 +116,13 @@ fn request_and_response_fetch_blocks() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -147,26 +141,23 @@ fn request_and_response_fetch_blocks() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks_with_hashes() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks_with_hashes() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -179,13 +170,13 @@ fn request_and_response_fetch_blocks_with_hashes() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -206,34 +197,31 @@ fn request_and_response_fetch_blocks_with_hashes() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(received_blocks[0], received_blocks[1]); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(received_blocks[0], received_blocks[1]); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_many_valid_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_many_valid_blocks() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate a number of block hashes to bob, bob will receive it, request the full block, verify and @@ -261,24 +249,28 @@ fn propagate_and_forward_many_valid_blocks() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity) .with_peers(vec![carol_node_identity, bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -302,56 +294,49 @@ fn propagate_and_forward_many_valid_blocks() { let blocks = construct_chained_blocks(&alice_node.blockchain_db, block0, &rules, 5); - runtime.block_on(async { - for block in &blocks { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block.block()), vec![]) - .await - .unwrap(); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - let block_hash = block.hash(); - - if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Bob's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*carol_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Carol's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*dan_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Dan's node did not receive and validate the expected block"); - } + for block in &blocks { + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block.block()), vec![]) + .await + .unwrap(); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + let block_hash = block.hash(); + + if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Bob's node did not receive and validate the expected block"); } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Carol's node did not receive and validate the expected block"); + } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*dan_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Dan's node did not receive and validate the expected block"); + } + } - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn propagate_and_forward_invalid_block_hash() { +#[tokio::test] +async fn propagate_and_forward_invalid_block_hash() { // Alice will propagate a "made up" block hash to Bob, Bob will request the block from Alice. Alice will not be able // to provide the block and so Bob will not propagate the hash further to Carol. // alice -> bob -> carol - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); @@ -370,19 +355,22 @@ fn propagate_and_forward_invalid_block_hash() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity) .with_peers(vec![bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -409,42 +397,37 @@ fn propagate_and_forward_invalid_block_hash() { let mut bob_message_events = bob_node.messaging_events.subscribe(); let mut carol_message_events = carol_node.messaging_events.subscribe(); - runtime.block_on(async { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .unwrap(); + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .unwrap(); - // Alice propagated to Bob - // Bob received the invalid hash - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); - // Sent the request for the block to Alice - // Bob received a response from Alice - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); - assert_eq!(&*node_id, alice_node.node_identity.node_id()); - // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be - // flaky. - let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; - assert!(msg_event.is_none()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + // Alice propagated to Bob + // Bob received the invalid hash + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + // Sent the request for the block to Alice + // Bob received a response from Alice + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + assert_eq!(&*node_id, alice_node.node_identity.node_id()); + // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be + // flaky. + let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; + assert!(msg_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_invalid_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_invalid_block() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate an invalid block to Carol and Bob, they will check the received block and not propagate the @@ -473,7 +456,8 @@ fn propagate_and_forward_invalid_block() { let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![dan_node_identity.clone()]) @@ -483,20 +467,23 @@ fn propagate_and_forward_invalid_block() { mock_validator.clone(), stateless_block_validator.clone(), ) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![dan_node_identity]) .with_consensus_manager(rules) .with_validators(mock_validator.clone(), mock_validator, stateless_block_validator) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, @@ -520,45 +507,42 @@ fn propagate_and_forward_invalid_block() { let block1 = append_block(&alice_node.blockchain_db, &block0, vec![], &rules, 1.into()).unwrap(); let block1_hash = block1.hash(); - runtime.block_on(async { - let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); - let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); - let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - - assert!(alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .is_ok()); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - - if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Bob's node should have detected an invalid block"); - } - if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Carol's node should have detected an invalid block"); - } - assert!(dan_block_event.is_none()); + let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); + let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); + let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + assert!(alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .is_ok()); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + + if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Bob's node should have detected an invalid block"); + } + if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Carol's node should have detected an invalid block"); + } + assert!(dan_block_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let base_node_service_config = BaseNodeServiceConfig { @@ -569,47 +553,42 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, base_node_service_config, MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - // Bob should not be reachable - bob_node.shutdown().await; - unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); - alice_node.shutdown().await; - }); + // Bob should not be reachable + bob_node.shutdown().await; + unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); + alice_node.shutdown().await; } -#[test] -fn local_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_get_metadata() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let block0 = db.fetch_block(0).unwrap().try_into_chain_block().unwrap(); let block1 = append_block(db, &block0, vec![], &consensus_manager, 1.into()).unwrap(); let block2 = append_block(db, &block1, vec![], &consensus_manager, 1.into()).unwrap(); - runtime.block_on(async { - let metadata = node.local_nci.get_metadata().await.unwrap(); - assert_eq!(metadata.height_of_longest_chain(), 2); - assert_eq!(metadata.best_block(), block2.hash()); + let metadata = node.local_nci.get_metadata().await.unwrap(); + assert_eq!(metadata.height_of_longest_chain(), 2); + assert_eq!(metadata.best_block(), block2.hash()); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_template_and_get_new_block() { +#[tokio::test] +async fn local_get_new_block_template_and_get_new_block() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -620,7 +599,8 @@ fn local_get_new_block_template_and_get_new_block() { .build(); let (mut node, _rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let schema = [ txn_schema!(from: vec![outputs[1].clone()], to: vec![10_000 * uT, 20_000 * uT]), @@ -630,29 +610,26 @@ fn local_get_new_block_template_and_get_new_block() { assert!(node.mempool.insert(txs[0].clone()).is_ok()); assert!(node.mempool.insert(txs[1].clone()).is_ok()); - runtime.block_on(async { - let block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 2); + let block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 2); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); - node.blockchain_db.add_block(block.clone().into()).unwrap(); + node.blockchain_db.add_block(block.clone().into()).unwrap(); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_with_zero_conf() { +#[tokio::test] +async fn local_get_new_block_with_zero_conf() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -668,7 +645,8 @@ fn local_get_new_block_with_zero_conf() { HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -700,38 +678,35 @@ fn local_get_new_block_with_zero_conf() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_get_new_block_with_combined_transaction() { +#[tokio::test] +async fn local_get_new_block_with_combined_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -747,7 +722,8 @@ fn local_get_new_block_with_combined_transaction() { HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -774,41 +750,39 @@ fn local_get_new_block_with_combined_transaction() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_submit_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_submit_block() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let mut event_stream = node.local_nci.get_block_event_stream(); @@ -818,20 +792,18 @@ fn local_submit_block() { .unwrap(); block1.header.kernel_mmr_size += 1; block1.header.output_mmr_size += 1; - runtime.block_on(async { - node.local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .unwrap(); + node.local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); - let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; - if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap().unwrap() { - assert_eq!(received_block.hash(), block1.hash()); - result.assert_added(); - } else { - panic!("Block validation failed"); - } + let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; + if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap() { + assert_eq!(received_block.hash(), block1.hash()); + result.assert_added(); + } else { + panic!("Block validation failed"); + } - node.shutdown().await; - }); + node.shutdown().await; } diff --git a/base_layer/core/tests/node_state_machine.rs b/base_layer/core/tests/node_state_machine.rs index bcbeeea436..bbe9baa5c5 100644 --- a/base_layer/core/tests/node_state_machine.rs +++ b/base_layer/core/tests/node_state_machine.rs @@ -23,13 +23,12 @@ #[allow(dead_code)] mod helpers; -use futures::StreamExt; use helpers::{ block_builders::{append_block, chain_block, create_genesis_block}, chain_metadata::{random_peer_metadata, MockChainMetadata}, nodes::{create_network_with_2_base_nodes_with_config, wait_until_online, BaseNodeBuilder}, }; -use std::{thread, time::Duration}; +use std::time::Duration; use tari_common::configuration::Network; use tari_core::{ base_node::{ @@ -54,15 +53,14 @@ use tari_p2p::services::liveness::LivenessConfig; use tari_shutdown::Shutdown; use tempfile::tempdir; use tokio::{ - runtime::Runtime, sync::{broadcast, watch}, + task, time, }; static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn test_listening_lagging() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_listening_lagging() { let factories = CryptoFactories::default(); let network = Network::LocalNet; let temp_dir = tempdir().unwrap(); @@ -75,7 +73,6 @@ fn test_listening_lagging() { .with_block(prev_block.clone()) .build(); let (alice_node, bob_node, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig { @@ -84,7 +81,8 @@ fn test_listening_lagging() { }, consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let shutdown = Shutdown::new(); let (state_change_event_publisher, _) = broadcast::channel(10); let (status_event_sender, _status_event_receiver) = watch::channel(StatusInfo::new()); @@ -103,46 +101,44 @@ fn test_listening_lagging() { consensus_manager.clone(), shutdown.to_signal(), ); - wait_until_online(&mut runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; - let await_event_task = runtime.spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); + let await_event_task = task::spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); - runtime.block_on(async move { - let bob_db = bob_node.blockchain_db; - let mut bob_local_nci = bob_node.local_nci; + let bob_db = bob_node.blockchain_db; + let mut bob_local_nci = bob_node.local_nci; - // Bob Block 1 - no block event - let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); - // Bob Block 2 - with block event and liveness service metadata update - let mut prev_block = bob_db - .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) - .unwrap(); - prev_block.header.output_mmr_size += 1; - prev_block.header.kernel_mmr_size += 1; - bob_local_nci - .submit_block(prev_block, Broadcast::from(true)) - .await - .unwrap(); - assert_eq!(bob_db.get_height().unwrap(), 2); + // Bob Block 1 - no block event + let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); + // Bob Block 2 - with block event and liveness service metadata update + let mut prev_block = bob_db + .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) + .unwrap(); + prev_block.header.output_mmr_size += 1; + prev_block.header.kernel_mmr_size += 1; + bob_local_nci + .submit_block(prev_block, Broadcast::from(true)) + .await + .unwrap(); + assert_eq!(bob_db.get_height().unwrap(), 2); - let next_event = time::timeout(Duration::from_secs(10), await_event_task) - .await - .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") - .unwrap(); + let next_event = time::timeout(Duration::from_secs(10), await_event_task) + .await + .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") + .unwrap(); - match next_event { - StateEvent::InitialSync => {}, - _ => panic!(), - } - }); + match next_event { + StateEvent::InitialSync => {}, + _ => panic!(), + } } -#[test] -fn test_event_channel() { +#[tokio::test] +async fn test_event_channel() { let temp_dir = tempdir().unwrap(); - let mut runtime = Runtime::new().unwrap(); - let (node, consensus_manager) = - BaseNodeBuilder::new(Network::Weatherwax.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (node, consensus_manager) = BaseNodeBuilder::new(Network::Weatherwax.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; // let shutdown = Shutdown::new(); let db = create_test_blockchain_db(); let shutdown = Shutdown::new(); @@ -165,24 +161,21 @@ fn test_event_channel() { shutdown.to_signal(), ); - runtime.spawn(state_machine.run()); + task::spawn(state_machine.run()); let PeerChainMetadata { node_id, chain_metadata, } = random_peer_metadata(10, 5_000); - runtime - .block_on(mock.publish_chain_metadata(&node_id, &chain_metadata)) + mock.publish_chain_metadata(&node_id, &chain_metadata) + .await .expect("Could not publish metadata"); - thread::sleep(Duration::from_millis(50)); - runtime.block_on(async { - let event = state_change_event_subscriber.next().await; - assert_eq!(*event.unwrap().unwrap(), StateEvent::Initialized); - let event = state_change_event_subscriber.next().await; - let event = event.unwrap().unwrap(); - match event.as_ref() { - StateEvent::InitialSync => (), - _ => panic!("Unexpected state was found:{:?}", event), - } - }); + let event = state_change_event_subscriber.recv().await; + assert_eq!(*event.unwrap(), StateEvent::Initialized); + let event = state_change_event_subscriber.recv().await; + let event = event.unwrap(); + match event.as_ref() { + StateEvent::InitialSync => (), + _ => panic!("Unexpected state was found:{:?}", event), + } } diff --git a/base_layer/key_manager/Cargo.toml b/base_layer/key_manager/Cargo.toml index e57d0bf8b4..a40415ce17 100644 --- a/base_layer/key_manager/Cargo.toml +++ b/base_layer/key_manager/Cargo.toml @@ -15,7 +15,7 @@ sha2 = "0.9.5" serde = "1.0.89" serde_derive = "1.0.89" serde_json = "1.0.39" -thiserror = "1.0.20" +thiserror = "1.0.26" [features] avx2 = ["tari_crypto/avx2"] diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index 38b723a5f2..0d725874fb 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -14,7 +14,7 @@ benches = ["criterion"] [dependencies] tari_utilities = "^0.3" -thiserror = "1.0.20" +thiserror = "1.0.26" digest = "0.9.0" log = "0.4" serde = { version = "1.0.97", features = ["derive"] } diff --git a/base_layer/p2p/Cargo.toml b/base_layer/p2p/Cargo.toml index e9d3a0291c..0d7aae7390 100644 --- a/base_layer/p2p/Cargo.toml +++ b/base_layer/p2p/Cargo.toml @@ -10,37 +10,38 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../../comms"} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht"} -tari_common = { version= "^0.9", path = "../../common" } +tari_comms = { version = "^0.9", path = "../../comms" } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } +tari_common = { version = "^0.9", path = "../../common" } tari_crypto = "0.11.1" -tari_service_framework = { version = "^0.9", path = "../service_framework"} -tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } tari_utilities = "^0.3" anyhow = "1.0.32" bytes = "0.5" -chrono = {version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } fs2 = "0.3.0" -futures = {version = "^0.3.1"} +futures = { version = "^0.3.1" } lmdb-zero = "0.4.4" log = "0.4.6" -pgp = {version = "0.7.1", optional = true} -prost = "0.6.1" +pgp = { version = "0.7.1", optional = true } +prost = "=0.8.0" rand = "0.8" -reqwest = {version = "0.10", optional = true, default-features = false} +reqwest = { version = "0.10", optional = true, default-features = false } semver = "1.0.1" serde = "1.0.90" serde_derive = "1.0.90" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["blocking"]} +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tokio-stream = { version = "0.1.7", default-features = false, features = ["time"] } tower = "0.3.0-alpha.2" -tower-service = { version="0.3.0-alpha.2" } -trust-dns-client = {version="0.19.5", features=["dns-over-rustls"]} +tower-service = { version = "0.3.0-alpha.2" } +trust-dns-client = { version = "0.21.0-alpha.1", features = ["dns-over-rustls"] } [dev-dependencies] -tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } clap = "2.33.0" env_logger = "0.6.2" @@ -48,7 +49,6 @@ futures-timer = "0.3.0" lazy_static = "1.3.0" stream-cancel = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.4" [dev-dependencies.log4rs] version = "^0.8" @@ -56,7 +56,7 @@ features = ["console_appender", "file_appender", "file", "yaml_format"] default-features = false [build-dependencies] -tari_common = { version = "^0.9", path="../../common", features = ["build"] } +tari_common = { version = "^0.9", path = "../../common", features = ["build"] } [features] test-mocks = [] diff --git a/base_layer/p2p/examples/gen_tor_identity.rs b/base_layer/p2p/examples/gen_tor_identity.rs index 52a2e3c785..c1a7693de3 100644 --- a/base_layer/p2p/examples/gen_tor_identity.rs +++ b/base_layer/p2p/examples/gen_tor_identity.rs @@ -39,7 +39,7 @@ fn to_abs_path(path: &str) -> String { } } -#[tokio_macros::main] +#[tokio::main] async fn main() { let matches = App::new("Tor identity file generator") .version("1.0") diff --git a/base_layer/p2p/src/auto_update/dns.rs b/base_layer/p2p/src/auto_update/dns.rs index 7042b19919..64c3be7f5a 100644 --- a/base_layer/p2p/src/auto_update/dns.rs +++ b/base_layer/p2p/src/auto_update/dns.rs @@ -189,19 +189,26 @@ impl Display for UpdateSpec { #[cfg(test)] mod test { use super::*; + use crate::dns::mock; use trust_dns_client::{ - proto::rr::{rdata, RData, RecordType}, + op::Query, + proto::{ + rr::{rdata, Name, RData, RecordType}, + xfer::DnsResponse, + }, rr::Record, }; - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let resp_query = Query::query(Name::from_str("test.local.").unwrap(), RecordType::A); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } mod update_spec { @@ -220,7 +227,6 @@ mod test { mod dns_software_update { use super::*; use crate::DEFAULT_DNS_NAME_SERVER; - use std::{collections::HashMap, iter::FromIterator}; impl AutoUpdateConfig { fn get_test_defaults() -> Self { @@ -238,15 +244,15 @@ mod test { } } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_ignores_non_conforming_txt_entries() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec![":::"]), - create_txt_record(vec!["base-node:::"]), - create_txt_record(vec!["base-node::1.0:"]), - create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec![":::"])), + Ok(create_txt_record(vec!["base-node:::"])), + Ok(create_txt_record(vec!["base-node::1.0:"])), + Ok(create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), @@ -258,12 +264,12 @@ mod test { assert!(spec.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_best_update() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), diff --git a/base_layer/p2p/src/auto_update/service.rs b/base_layer/p2p/src/auto_update/service.rs index a786d84ad3..eebd9ec555 100644 --- a/base_layer/p2p/src/auto_update/service.rs +++ b/base_layer/p2p/src/auto_update/service.rs @@ -24,17 +24,15 @@ use crate::{ auto_update, auto_update::{AutoUpdateConfig, SoftwareUpdate, Version}, }; -use futures::{ - channel::{mpsc, oneshot}, - future::Either, - stream, - SinkExt, - StreamExt, -}; +use futures::{future::Either, stream, StreamExt}; use std::{env::consts, time::Duration}; use tari_common::configuration::bootstrap::ApplicationType; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::{sync::watch, time}; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "app:auto-update"; @@ -94,19 +92,20 @@ impl SoftwareUpdaterService { new_update_notification: watch::Receiver>, ) { let mut interval_or_never = match self.check_interval { - Some(interval) => Either::Left(time::interval(interval)).fuse(), - None => Either::Right(stream::empty()).fuse(), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval(interval))), + None => Either::Right(stream::empty()), }; loop { let last_version = new_update_notification.borrow().clone(); - let maybe_update = futures::select! { - reply = request_rx.select_next_some() => { + let maybe_update = tokio::select! { + Some(reply) = request_rx.recv() => { let maybe_update = self.check_for_updates().await; let _ = reply.send(maybe_update.clone()); maybe_update }, + _ = interval_or_never.next() => { // Periodically, check for updates if configured to do so. // If an update is found the new update notifier will be triggered and any listeners notified @@ -121,7 +120,7 @@ impl SoftwareUpdaterService { .map(|up| up.version() < update.version()) .unwrap_or(true) { - let _ = notifier.broadcast(Some(update.clone())); + let _ = notifier.send(Some(update.clone())); } } } diff --git a/base_layer/p2p/src/comms_connector/inbound_connector.rs b/base_layer/p2p/src/comms_connector/inbound_connector.rs index ed16cd578d..6feed82be7 100644 --- a/base_layer/p2p/src/comms_connector/inbound_connector.rs +++ b/base_layer/p2p/src/comms_connector/inbound_connector.rs @@ -22,45 +22,42 @@ use super::peer_message::PeerMessage; use anyhow::anyhow; -use futures::{task::Context, Future, Sink, SinkExt}; +use futures::{task::Context, Future}; use log::*; use std::{pin::Pin, sync::Arc, task::Poll}; use tari_comms::pipeline::PipelineError; use tari_comms_dht::{domain_message::MessageHeader, inbound::DecryptedDhtMessage}; +use tokio::sync::mpsc; use tower::Service; const LOG_TARGET: &str = "comms::middleware::inbound_connector"; /// This service receives DecryptedDhtMessage, deserializes the MessageHeader and /// sends a `PeerMessage` on the given sink. #[derive(Clone)] -pub struct InboundDomainConnector { - sink: TSink, +pub struct InboundDomainConnector { + sink: mpsc::Sender>, } -impl InboundDomainConnector { - pub fn new(sink: TSink) -> Self { +impl InboundDomainConnector { + pub fn new(sink: mpsc::Sender>) -> Self { Self { sink } } } -impl Service for InboundDomainConnector -where - TSink: Sink> + Unpin + Clone + 'static, - TSink::Error: std::error::Error + Send + Sync + 'static, -{ +impl Service for InboundDomainConnector { type Error = PipelineError; - type Future = Pin>>>; + type Future = Pin> + Send>>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - let mut sink = self.sink.clone(); + let sink = self.sink.clone(); let future = async move { let peer_message = Self::construct_peer_message(msg)?; - // If this fails there is something wrong with the sink and the pubsub middleware should not + // If this fails the channel has closed and the pubsub middleware should not // continue sink.send(Arc::new(peer_message)).await?; @@ -70,7 +67,7 @@ where } } -impl InboundDomainConnector { +impl InboundDomainConnector { fn construct_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result { let envelope_body = inbound_message .success_mut() @@ -107,41 +104,17 @@ impl InboundDomainConnector { } } -impl Sink for InboundDomainConnector -where - TSink: Sink> + Unpin, - TSink::Error: Into + Send + Sync + 'static, -{ - type Error = PipelineError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) - } - - fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { - let item = Self::construct_peer_message(item)?; - Pin::new(&mut self.sink).start_send(Arc::new(item)).map_err(Into::into) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_flush(cx).map_err(Into::into) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_close(cx).map_err(Into::into) - } -} - #[cfg(test)] mod test { use super::*; use crate::test_utils::{make_dht_inbound_message, make_node_identity}; - use futures::{channel::mpsc, executor::block_on, StreamExt}; + use futures::executor::block_on; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_comms_dht::domain_message::MessageHeader; + use tokio::sync::mpsc; use tower::ServiceExt; - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -151,12 +124,12 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -165,14 +138,14 @@ mod test { let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); - InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); + InboundDomainConnector::new(tx).call(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); @@ -182,10 +155,11 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); - assert!(rx.try_next().unwrap().is_none()); + rx.close(); + assert!(rx.recv().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_send() { // Drop the receiver of the channel, this is the only reason this middleware should return an error // from it's call function diff --git a/base_layer/p2p/src/comms_connector/pubsub.rs b/base_layer/p2p/src/comms_connector/pubsub.rs index ae01e8ced8..198eee63f2 100644 --- a/base_layer/p2p/src/comms_connector/pubsub.rs +++ b/base_layer/p2p/src/comms_connector/pubsub.rs @@ -22,11 +22,15 @@ use super::peer_message::PeerMessage; use crate::{comms_connector::InboundDomainConnector, tari_message::TariMessageType}; -use futures::{channel::mpsc, future, stream::Fuse, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{cmp, fmt::Debug, sync::Arc, time::Duration}; use tari_comms::rate_limit::RateLimit; -use tokio::{runtime::Handle, sync::broadcast}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "comms::middleware::pubsub"; @@ -35,16 +39,11 @@ const RATE_LIMIT_MIN_CAPACITY: usize = 5; const RATE_LIMIT_RESTOCK_INTERVAL: Duration = Duration::from_millis(1000); /// Alias for a pubsub-type domain connector -pub type PubsubDomainConnector = InboundDomainConnector>>; +pub type PubsubDomainConnector = InboundDomainConnector; pub type SubscriptionFactory = TopicSubscriptionFactory>; /// Connects `InboundDomainConnector` to a `tari_pubsub::TopicPublisher` through a buffered broadcast channel -pub fn pubsub_connector( - // TODO: Remove this arg in favor of task::spawn - executor: Handle, - buf_size: usize, - rate_limit: usize, -) -> (PubsubDomainConnector, SubscriptionFactory) { +pub fn pubsub_connector(buf_size: usize, rate_limit: usize) -> (PubsubDomainConnector, SubscriptionFactory) { let (publisher, subscription_factory) = pubsub_channel(buf_size); let (sender, receiver) = mpsc::channel(buf_size); trace!( @@ -55,8 +54,8 @@ pub fn pubsub_connector( ); // Spawn a task which forwards messages from the pubsub service to the TopicPublisher - executor.spawn(async move { - let forwarder = receiver + task::spawn(async move { + wrappers::ReceiverStream::new(receiver) // Rate limit the receiver; the sender will adhere to the limit .rate_limit(cmp::max(rate_limit, RATE_LIMIT_MIN_CAPACITY), RATE_LIMIT_RESTOCK_INTERVAL) // Map DomainMessage into a TopicPayload @@ -89,8 +88,7 @@ pub fn pubsub_connector( ); } future::ready(()) - }); - forwarder.await; + }).await; }); (InboundDomainConnector::new(sender), subscription_factory) } @@ -98,8 +96,8 @@ pub fn pubsub_connector( /// Create a topic-based pub-sub channel fn pubsub_channel(size: usize) -> (TopicPublisher, TopicSubscriptionFactory) where - T: Clone + Debug + Send + Eq, - M: Send + Clone, + T: Clone + Debug + Send + Eq + 'static, + M: Send + Clone + 'static, { let (publisher, _) = broadcast::channel(size); (publisher.clone(), TopicSubscriptionFactory::new(publisher)) @@ -138,8 +136,8 @@ pub struct TopicSubscriptionFactory { impl TopicSubscriptionFactory where - T: Clone + Eq + Debug + Send, - M: Clone + Send, + T: Clone + Eq + Debug + Send + 'static, + M: Clone + Send + 'static, { pub fn new(sender: broadcast::Sender>) -> Self { TopicSubscriptionFactory { sender } @@ -148,38 +146,22 @@ where /// Create a subscription stream to a particular topic. The provided label is used to identify which consumer is /// lagging. pub fn get_subscription(&self, topic: T, label: &'static str) -> impl Stream { - self.sender - .subscribe() - .filter_map({ - let topic = topic.clone(); - move |result| { - let opt = match result { - Ok(payload) => Some(payload), - Err(broadcast::RecvError::Closed) => None, - Err(broadcast::RecvError::Lagged(n)) => { - warn!( - target: LOG_TARGET, - "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n - ); - None - }, - }; - future::ready(opt) - } - }) - .filter_map(move |item| { - let opt = if item.topic() == &topic { - Some(item.message) - } else { - None + wrappers::BroadcastStream::new(self.sender.subscribe()).filter_map({ + move |result| { + let opt = match result { + Ok(payload) if *payload.topic() == topic => Some(payload.message), + Ok(_) => None, + Err(wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => { + warn!( + target: LOG_TARGET, + "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n + ); + None + }, }; future::ready(opt) - }) - } - - /// Convenience function that returns a fused (`stream::Fuse`) version of the subscription stream. - pub fn get_subscription_fused(&self, topic: T, label: &'static str) -> Fuse> { - self.get_subscription(topic, label).fuse() + } + }) } } @@ -190,7 +172,7 @@ mod test { use std::time::Duration; use tari_test_utils::collect_stream; - #[tokio_macros::test_basic] + #[tokio::test] async fn topic_pub_sub() { let (publisher, subscriber_factory) = pubsub_channel(10); diff --git a/base_layer/p2p/src/dns/client.rs b/base_layer/p2p/src/dns/client.rs index 85e03f71e4..78093186ef 100644 --- a/base_layer/p2p/src/dns/client.rs +++ b/base_layer/p2p/src/dns/client.rs @@ -20,35 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#[cfg(test)] +use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use super::DnsClientError; -use futures::future; +use futures::{future, FutureExt}; use std::{net::SocketAddr, sync::Arc}; use tari_shutdown::Shutdown; use tokio::{net::UdpSocket, task}; use trust_dns_client::{ - client::{AsyncClient, AsyncDnssecClient}, - op::{DnsResponse, Query}, - proto::{ - rr::dnssec::TrustAnchor, - udp::{UdpClientStream, UdpResponse}, - xfer::DnsRequestOptions, - DnsHandle, - }, + client::{AsyncClient, AsyncDnssecClient, ClientHandle}, + op::Query, + proto::{error::ProtoError, rr::dnssec::TrustAnchor, udp::UdpClientStream, xfer::DnsResponse, DnsHandle}, rr::{DNSClass, IntoName, RecordType}, serialize::binary::BinEncoder, }; -#[cfg(test)] -use std::collections::HashMap; -#[cfg(test)] -use trust_dns_client::{proto::xfer::DnsMultiplexerSerialResponse, rr::Record}; - #[derive(Clone)] pub enum DnsClient { - Secure(Client>), - Normal(Client>), + Secure(Client), + Normal(Client), #[cfg(test)] - Mock(Client>), + Mock(Client>), } impl DnsClient { @@ -63,18 +56,18 @@ impl DnsClient { } #[cfg(test)] - pub async fn connect_mock(records: HashMap<&'static str, Vec>) -> Result { - let client = Client::connect_mock(records).await?; + pub async fn connect_mock(messages: Vec>) -> Result { + let client = Client::connect_mock(messages).await?; Ok(DnsClient::Mock(client)) } - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { + pub async fn lookup(&mut self, query: Query) -> Result { use DnsClient::*; match self { - Secure(ref mut client) => client.lookup(query, options).await, - Normal(ref mut client) => client.lookup(query, options).await, + Secure(ref mut client) => client.lookup(query).await, + Normal(ref mut client) => client.lookup(query).await, #[cfg(test)] - Mock(ref mut client) => client.lookup(query, options).await, + Mock(ref mut client) => client.lookup(query).await, } } @@ -85,11 +78,11 @@ impl DnsClient { .set_query_class(DNSClass::IN) .set_query_type(RecordType::TXT); - let response = self.lookup(query, Default::default()).await?; + let responses = self.lookup(query).await?; - let records = response - .messages() - .flat_map(|msg| msg.answers()) + let records = responses + .answers() + .iter() .map(|answer| { let data = answer.rdata(); let mut buf = Vec::new(); @@ -116,7 +109,7 @@ pub struct Client { shutdown: Arc, } -impl Client> { +impl Client { pub async fn connect_secure(name_server: SocketAddr, trust_anchor: TrustAnchor) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); @@ -124,7 +117,7 @@ impl Client> { .trust_anchor(trust_anchor) .build() .await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -133,12 +126,12 @@ impl Client> { } } -impl Client> { +impl Client { pub async fn connect(name_server: SocketAddr) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); let (client, background) = AsyncClient::connect(stream).await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -148,87 +141,31 @@ impl Client> { } impl Client -where C: DnsHandle +where C: DnsHandle { - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { - let resp = self.inner.lookup(query, options).await?; - Ok(resp) + pub async fn lookup(&mut self, query: Query) -> Result { + let client_resp = self + .inner + .query(query.name().clone(), query.query_class(), query.query_type()) + .await?; + Ok(client_resp) } } #[cfg(test)] mod mock { use super::*; - use futures::{channel::mpsc, future, Stream, StreamExt}; - use std::{ - fmt, - fmt::Display, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; + use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use std::sync::Arc; use tari_shutdown::Shutdown; - use tokio::task; - use trust_dns_client::{ - client::AsyncClient, - op::Message, - proto::{ - error::ProtoError, - xfer::{DnsClientStream, DnsMultiplexerSerialResponse, SerialMessage}, - StreamHandle, - }, - rr::Record, - }; - - pub struct MockStream { - receiver: mpsc::UnboundedReceiver>, - answers: HashMap<&'static str, Vec>, - } - - impl DnsClientStream for MockStream { - fn name_server_addr(&self) -> SocketAddr { - ([0u8, 0, 0, 0], 53).into() - } - } - - impl Display for MockStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MockStream") - } - } - - impl Stream for MockStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let req = match futures::ready!(self.receiver.poll_next_unpin(cx)) { - Some(r) => r, - None => return Poll::Ready(None), - }; - let req = Message::from_vec(&req).unwrap(); - let name = req.queries()[0].name().to_string(); - let mut msg = Message::new(); - let answers = self.answers.get(name.as_str()).into_iter().flatten().cloned(); - msg.set_id(req.id()).add_answers(answers); - Poll::Ready(Some(Ok(SerialMessage::new( - msg.to_vec().unwrap(), - self.name_server_addr(), - )))) - } - } - - impl Client> { - pub async fn connect_mock(answers: HashMap<&'static str, Vec>) -> Result { - let (tx, rx) = mpsc::unbounded(); - let stream = future::ready(Ok(MockStream { receiver: rx, answers })); - let (client, background) = AsyncClient::new(stream, Box::new(StreamHandle::new(tx)), None).await?; + use trust_dns_client::proto::error::ProtoError; - let shutdown = Shutdown::new(); - task::spawn(future::select(shutdown.to_signal(), background)); + impl Client> { + pub async fn connect_mock(messages: Vec>) -> Result { + let client = MockClientHandle::mock(messages); Ok(Self { inner: client, - shutdown: Arc::new(shutdown), + shutdown: Arc::new(Shutdown::new()), }) } } diff --git a/base_layer/p2p/src/dns/mock.rs b/base_layer/p2p/src/dns/mock.rs new file mode 100644 index 0000000000..9dec5155cb --- /dev/null +++ b/base_layer/p2p/src/dns/mock.rs @@ -0,0 +1,105 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{future, stream, Future}; +use std::{error::Error, pin::Pin, sync::Arc}; +use trust_dns_client::{ + op::{Message, Query}, + proto::{ + error::ProtoError, + xfer::{DnsHandle, DnsRequest, DnsResponse}, + }, + rr::Record, +}; + +#[derive(Clone)] +pub struct MockClientHandle { + messages: Arc>>, + on_send: O, +} + +impl MockClientHandle { + /// constructs a new MockClient which returns each Message one after the other + pub fn mock(messages: Vec>) -> Self { + println!("MockClientHandle::mock message count: {}", messages.len()); + + MockClientHandle { + messages: Arc::new(messages), + on_send: DefaultOnSend, + } + } +} + +impl DnsHandle for MockClientHandle +where E: From + Error + Clone + Send + Sync + Unpin + 'static +{ + type Error = E; + type Response = stream::Once>>; + + fn send>(&mut self, _: R) -> Self::Response { + let responses = (*self.messages) + .clone() + .into_iter() + .fold(Result::<_, E>::Ok(Message::new()), |msg, resp| { + msg.and_then(|mut msg| { + resp.map(move |resp| { + msg.add_answers(resp.answers().iter().cloned()); + msg + }) + }) + }) + .map(DnsResponse::from); + + // let stream = stream::unfold(messages, |mut msgs| async move { + // let msg = msgs.pop()?; + // Some((msg, msgs)) + // }); + + stream::once(future::ready(responses)) + } +} + +pub fn message(query: Query, answers: Vec, name_servers: Vec, additionals: Vec) -> Message { + let mut message = Message::new(); + message.add_query(query); + message.insert_answers(answers); + message.insert_name_servers(name_servers); + message.insert_additionals(additionals); + message +} + +pub trait OnSend: Clone + Send + Sync + 'static { + fn on_send( + &mut self, + response: Result, + ) -> Pin> + Send>> + where + E: From + Send + 'static, + { + Box::pin(future::ready(response)) + } +} + +#[derive(Clone)] +pub struct DefaultOnSend; + +impl OnSend for DefaultOnSend {} diff --git a/base_layer/p2p/src/dns/mod.rs b/base_layer/p2p/src/dns/mod.rs index 197b788236..8c49de1993 100644 --- a/base_layer/p2p/src/dns/mod.rs +++ b/base_layer/p2p/src/dns/mod.rs @@ -4,6 +4,9 @@ pub use client::DnsClient; mod error; pub use error::DnsClientError; +#[cfg(test)] +pub(crate) mod mock; + use trust_dns_client::proto::rr::dnssec::{public_key::Rsa, TrustAnchor}; #[inline] diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9264a5033c..0915479612 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -19,20 +19,20 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] use crate::{ - comms_connector::{InboundDomainConnector, PeerMessage, PubsubDomainConnector}, + comms_connector::{InboundDomainConnector, PubsubDomainConnector}, peer_seeds::{DnsSeedResolver, SeedPeer}, transport::{TorConfig, TransportType}, MAJOR_NETWORK_VERSION, MINOR_NETWORK_VERSION, }; use fs2::FileExt; -use futures::{channel::mpsc, future, Sink}; +use futures::future; use log::*; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::{ - error::Error, fs::File, iter, net::SocketAddr, @@ -47,7 +47,6 @@ use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerManagerError}, pipeline, - pipeline::SinkService, protocol::{ messaging::{MessagingEventSender, MessagingProtocolExtension}, rpc::RpcServer, @@ -71,7 +70,7 @@ use tari_storage::{ LMDBWrapper, }; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::ServiceBuilder; const LOG_TARGET: &str = "p2p::initialization"; @@ -158,18 +157,14 @@ pub struct CommsConfig { } /// Initialize Tari Comms configured for tests -pub async fn initialize_local_test_comms( +pub async fn initialize_local_test_comms( node_identity: Arc, - connector: InboundDomainConnector, + connector: InboundDomainConnector, data_path: &str, discovery_request_timeout: Duration, seed_peers: Vec, shutdown_signal: ShutdownSignal, -) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> { let peer_database_name = { let mut rng = thread_rng(); iter::repeat(()) @@ -230,7 +225,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); @@ -319,15 +314,11 @@ async fn initialize_hidden_service( builder.build().await } -async fn configure_comms_and_dht( +async fn configure_comms_and_dht( builder: CommsBuilder, config: &CommsConfig, - connector: InboundDomainConnector, -) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ + connector: InboundDomainConnector, +) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> { let file_lock = acquire_exclusive_file_lock(&config.datastore_path)?; let datastore = LMDBBuilder::new() @@ -391,7 +382,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); diff --git a/base_layer/p2p/src/lib.rs b/base_layer/p2p/src/lib.rs index c21dace083..93cb45d1db 100644 --- a/base_layer/p2p/src/lib.rs +++ b/base_layer/p2p/src/lib.rs @@ -20,8 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work -#![recursion_limit = "256"] #![cfg_attr(not(debug_assertions), deny(unused_variables))] #![cfg_attr(not(debug_assertions), deny(unused_imports))] #![cfg_attr(not(debug_assertions), deny(dead_code))] diff --git a/base_layer/p2p/src/peer_seeds.rs b/base_layer/p2p/src/peer_seeds.rs index 24ce3c4dda..4698770fa1 100644 --- a/base_layer/p2p/src/peer_seeds.rs +++ b/base_layer/p2p/src/peer_seeds.rs @@ -120,6 +120,8 @@ mod test { use super::*; use tari_utilities::hex::Hex; + const TEST_NAME: &str = "test.local."; + mod peer_seed { use super::*; @@ -182,76 +184,88 @@ mod test { mod peer_seed_resolver { use super::*; - use std::{collections::HashMap, iter::FromIterator}; - use trust_dns_client::rr::{rdata, RData, Record, RecordType}; + use crate::dns::mock; + use trust_dns_client::{ + proto::{ + op::Query, + rr::{DNSClass, Name}, + xfer::DnsResponse, + }, + rr::{rdata, RData, Record, RecordType}, + }; #[ignore = "This test requires network IO and is mostly useful during development"] - #[tokio_macros::test] - async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { + #[tokio::test] + async fn it_returns_seeds_from_real_address() { let mut resolver = DnsSeedResolver { client: DnsClient::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(), }; - let seeds = resolver.resolve("tari.com").await.unwrap(); - assert!(seeds.is_empty()); + let seeds = resolver.resolve("seeds.weatherwax.tari.com").await.unwrap(); + assert!(!seeds.is_empty()); } - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let mut resp_query = Query::query(Name::from_str(TEST_NAME).unwrap(), RecordType::TXT); + resp_query.set_query_class(DNSClass::IN); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } - #[tokio_macros::test] + #[tokio::test] async fn it_returns_peer_seeds() { - let records = HashMap::from_iter([("test.local.", vec![ + let records = vec![ // Multiple addresses(works) - create_txt_record(vec![ - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::/\ + Ok(create_txt_record(vec![ + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c::/ip4/127.0.0.1/tcp/8000::/\ onion3/bsmuof2cn4y2ysz253gzsvg3s72fcgh4f3qcm3hdlxdtcwe6al2dicyd:1234", - ]), + ])), // Misc - create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"]), + Ok(create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"])), // Single address (works) - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // Single address trailing delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::", - ]), + ])), // Invalid public key - create_txt_record(vec![ + Ok(create_txt_record(vec![ "07e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // No Address with delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::", - ]), + ])), // No Address no delim - create_txt_record(vec!["06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a"]), + Ok(create_txt_record(vec![ + "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a", + ])), // Invalid address - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/onion3/invalid:1234", - ]), - ])]); + ])), + ]; let mut resolver = DnsSeedResolver { client: DnsClient::connect_mock(records).await.unwrap(), }; - let seeds = resolver.resolve("test.local.").await.unwrap(); + let seeds = resolver.resolve(TEST_NAME).await.unwrap(); assert_eq!(seeds.len(), 2); assert_eq!( seeds[0].public_key.to_hex(), - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c" ); + assert_eq!(seeds[0].addresses.len(), 2); assert_eq!( seeds[1].public_key.to_hex(), "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" ); - assert_eq!(seeds[0].addresses.len(), 2); assert_eq!(seeds[1].addresses.len(), 1); } } diff --git a/base_layer/p2p/src/services/liveness/mock.rs b/base_layer/p2p/src/services/liveness/mock.rs index fa5f2f72c1..470cca84c2 100644 --- a/base_layer/p2p/src/services/liveness/mock.rs +++ b/base_layer/p2p/src/services/liveness/mock.rs @@ -36,9 +36,9 @@ use std::sync::{ RwLock, }; -use tari_crypto::tari_utilities::acquire_write_lock; +use tari_crypto::tari_utilities::{acquire_read_lock, acquire_write_lock}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::sync::{broadcast, broadcast::SendError}; +use tokio::sync::{broadcast, broadcast::error::SendError}; const LOG_TARGET: &str = "p2p::liveness_mock"; @@ -69,7 +69,8 @@ impl LivenessMockState { } pub async fn publish_event(&self, event: LivenessEvent) -> Result<(), SendError>> { - acquire_write_lock!(self.event_publisher).send(Arc::new(event))?; + let lock = acquire_read_lock!(self.event_publisher); + lock.send(Arc::new(event))?; Ok(()) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 65e85e0ea8..6f933f8e5c 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -49,6 +49,7 @@ use tari_comms_dht::{ use tari_service_framework::reply_channel::RequestContext; use tari_shutdown::ShutdownSignal; use tokio::time; +use tokio_stream::wrappers; /// Service responsible for testing Liveness of Peers. pub struct LivenessService { @@ -59,7 +60,7 @@ pub struct LivenessService { connectivity: ConnectivityRequester, outbound_messaging: OutboundMessageRequester, event_publisher: LivenessEventSender, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl LivenessService @@ -85,7 +86,7 @@ where connectivity, outbound_messaging, event_publisher, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, config, } } @@ -100,39 +101,36 @@ where pin_mut!(request_stream); let mut ping_tick = match self.config.auto_ping_interval { - Some(interval) => Either::Left(time::interval_at((Instant::now() + interval).into(), interval)), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval_at( + (Instant::now() + interval).into(), + interval, + ))), None => Either::Right(futures::stream::iter(iter::empty())), - } - .fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("Liveness service initialized without shutdown signal"); + }; loop { - futures::select! { + tokio::select! { // Requests from the handle - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let _ = reply_tx.send(self.handle_request(request).await); }, // Tick events - _ = ping_tick.select_next_some() => { + Some(_) = ping_tick.next() => { if let Err(err) = self.start_ping_round().await { warn!(target: LOG_TARGET, "Error when pinging peers: {}", err); } }, // Incoming messages from the Comms layer - msg = ping_stream.select_next_some() => { + Some(msg) = ping_stream.next() => { if let Err(err) = self.handle_incoming_message(msg).await { warn!(target: LOG_TARGET, "Failed to handle incoming PingPong message: {}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Liveness service shutting down because the shutdown signal was received"); break; } @@ -306,10 +304,7 @@ mod test { proto::liveness::MetadataKey, services::liveness::{handle::LivenessHandle, state::Metadata}, }; - use futures::{ - channel::{mpsc, oneshot}, - stream, - }; + use futures::stream; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ @@ -325,9 +320,12 @@ mod test { use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; use tari_shutdown::Shutdown; - use tokio::{sync::broadcast, task}; + use tokio::{ + sync::{broadcast, mpsc, oneshot}, + task, + }; - #[tokio_macros::test_basic] + #[tokio::test] async fn get_ping_pong_count() { let mut state = LivenessState::new(); state.inc_pings_received(); @@ -369,7 +367,7 @@ mod test { assert_eq!(res, 2); } - #[tokio_macros::test] + #[tokio::test] async fn send_ping() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); @@ -401,8 +399,9 @@ mod test { let node_id = NodeId::from_key(&pk); // Receive outbound request task::spawn(async move { - match outbound_rx.select_next_some().await { - DhtOutboundRequest::SendMessage(_, _, reply_tx) => { + #[allow(clippy::single_match)] + match outbound_rx.recv().await { + Some(DhtOutboundRequest::SendMessage(_, _, reply_tx)) => { let (_, rx) = oneshot::channel(); reply_tx .send(SendMessageResponse::Queued( @@ -410,6 +409,7 @@ mod test { )) .unwrap(); }, + None => {}, } }); @@ -445,7 +445,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn handle_message_ping() { let state = LivenessState::new(); @@ -478,10 +478,10 @@ mod test { task::spawn(service.run()); // Test oms got request to send message - unwrap_oms_send_msg!(outbound_rx.select_next_some().await); + unwrap_oms_send_msg!(outbound_rx.recv().await.unwrap()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_pong() { let mut state = LivenessState::new(); @@ -516,9 +516,9 @@ mod test { task::spawn(service.run()); // Listen for the pong event - let subscriber = publisher.subscribe(); + let mut subscriber = publisher.subscribe(); - let event = time::timeout(Duration::from_secs(10), subscriber.fuse().select_next_some()) + let event = time::timeout(Duration::from_secs(10), subscriber.recv()) .await .unwrap() .unwrap(); @@ -530,12 +530,12 @@ mod test { _ => panic!("Unexpected event"), } - shutdown.trigger().unwrap(); + shutdown.trigger(); // No further events (malicious_msg was ignored) - let mut subscriber = publisher.subscribe().fuse(); + let mut subscriber = publisher.subscribe(); drop(publisher); - let msg = subscriber.next().await; - assert!(msg.is_none()); + let msg = subscriber.recv().await; + assert!(msg.is_err()); } } diff --git a/base_layer/p2p/tests/services/liveness.rs b/base_layer/p2p/tests/services/liveness.rs index ab9a66cf59..2505c15543 100644 --- a/base_layer/p2p/tests/services/liveness.rs +++ b/base_layer/p2p/tests/services/liveness.rs @@ -35,16 +35,15 @@ use tari_p2p::{ }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tari_test_utils::collect_stream; +use tari_test_utils::collect_try_recv; use tempfile::tempdir; -use tokio::runtime; pub async fn setup_liveness_service( node_identity: Arc, peers: Vec>, data_path: &str, ) -> (LivenessHandle, CommsNode, Dht, Shutdown) { - let (publisher, subscription_factory) = pubsub_connector(runtime::Handle::current(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let shutdown = Shutdown::new(); let (comms, dht, _) = @@ -75,7 +74,7 @@ fn make_node_identity() -> Arc { )) } -#[tokio_macros::test_basic] +#[tokio::test] async fn end_to_end() { let node_1_identity = make_node_identity(); let node_2_identity = make_node_identity(); @@ -114,34 +113,34 @@ async fn end_to_end() { liveness1.send_ping(node_2_identity.node_id().clone()).await.unwrap(); } - let events = collect_stream!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20),); + let events = collect_try_recv!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 10); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 8); - let events = collect_stream!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10),); + let events = collect_try_recv!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 8); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 10); diff --git a/base_layer/p2p/tests/support/comms_and_services.rs b/base_layer/p2p/tests/support/comms_and_services.rs index 40cc85a710..33bc8fdef7 100644 --- a/base_layer/p2p/tests/support/comms_and_services.rs +++ b/base_layer/p2p/tests/support/comms_and_services.rs @@ -20,27 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeIdentity, protocol::messaging::MessagingEventSender, CommsNode}; use tari_comms_dht::Dht; -use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, - initialization::initialize_local_test_comms, -}; +use tari_p2p::{comms_connector::InboundDomainConnector, initialization::initialize_local_test_comms}; use tari_shutdown::ShutdownSignal; -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht, MessagingEventSender) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, messaging_events) = initialize_local_test_comms( node_identity, diff --git a/base_layer/service_framework/Cargo.toml b/base_layer/service_framework/Cargo.toml index 49e829c949..ead5e311f2 100644 --- a/base_layer/service_framework/Cargo.toml +++ b/base_layer/service_framework/Cargo.toml @@ -14,15 +14,15 @@ tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } anyhow = "1.0.32" async-trait = "0.1.50" -futures = { version = "^0.3.1", features=["async-await"]} +futures = { version = "^0.3.16", features=["async-await"]} log = "0.4.8" -thiserror = "1.0.20" -tokio = { version = "0.2.10" } +thiserror = "1.0.26" +tokio = {version="1.10", features=["rt"]} tower-service = { version="0.3.0" } [dev-dependencies] tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tokio = {version="1.10", features=["rt-multi-thread", "macros", "time"]} futures-test = { version = "0.3.3" } -tokio-macros = "0.2.5" tower = "0.3.1" diff --git a/base_layer/service_framework/examples/services/service_a.rs b/base_layer/service_framework/examples/services/service_a.rs index c898696415..dfbd9ace93 100644 --- a/base_layer/service_framework/examples/services/service_a.rs +++ b/base_layer/service_framework/examples/services/service_a.rs @@ -69,7 +69,7 @@ impl ServiceA { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service A API Request"); @@ -82,7 +82,7 @@ impl ServiceA { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service A shutting down because the shutdown signal was received"); break; } diff --git a/base_layer/service_framework/examples/services/service_b.rs b/base_layer/service_framework/examples/services/service_b.rs index decf53ab14..8e74408077 100644 --- a/base_layer/service_framework/examples/services/service_b.rs +++ b/base_layer/service_framework/examples/services/service_b.rs @@ -31,7 +31,7 @@ use tari_service_framework::{ ServiceInitializerContext, }; use tari_shutdown::ShutdownSignal; -use tokio::time::delay_for; +use tokio::time::sleep; use tower::Service; pub struct ServiceB { @@ -67,7 +67,7 @@ impl ServiceB { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service B API Request"); @@ -76,7 +76,7 @@ impl ServiceB { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service B shutting down because the shutdown signal was received"); break; } @@ -134,7 +134,7 @@ impl ServiceInitializer for ServiceBInitializer { println!("Service B has shutdown and initializer spawned task is now ending"); }); - delay_for(Duration::from_secs(10)).await; + sleep(Duration::from_secs(10)).await; Ok(()) } } diff --git a/base_layer/service_framework/examples/stack_builder_example.rs b/base_layer/service_framework/examples/stack_builder_example.rs index 35fd785ed7..6f150796e7 100644 --- a/base_layer/service_framework/examples/stack_builder_example.rs +++ b/base_layer/service_framework/examples/stack_builder_example.rs @@ -25,9 +25,9 @@ use crate::services::{ServiceAHandle, ServiceAInitializer, ServiceBHandle, Servi use std::time::Duration; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; -use tokio::time::delay_for; +use tokio::time::sleep; -#[tokio_macros::main] +#[tokio::main] async fn main() { let mut shutdown = Shutdown::new(); let fut = StackBuilder::new(shutdown.to_signal()) @@ -40,7 +40,7 @@ async fn main() { let mut service_a_handle = handles.expect_handle::(); let mut service_b_handle = handles.expect_handle::(); - delay_for(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; println!("----------------------------------------------------"); let response_b = service_b_handle.send_msg("Hello B".to_string()).await; println!("Response from Service B: {}", response_b); @@ -51,5 +51,5 @@ async fn main() { let _ = shutdown.trigger(); - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } diff --git a/base_layer/service_framework/src/reply_channel.rs b/base_layer/service_framework/src/reply_channel.rs index 54cef8d90f..4921ed9601 100644 --- a/base_layer/service_framework/src/reply_channel.rs +++ b/base_layer/service_framework/src/reply_channel.rs @@ -20,26 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{ - mpsc::{self, SendError}, - oneshot, - }, - ready, - stream::FusedStream, - task::Context, - Future, - FutureExt, - Stream, - StreamExt, -}; +use futures::{ready, stream::FusedStream, task::Context, Future, FutureExt, Stream}; use std::{pin::Pin, task::Poll}; use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; use tower_service::Service; /// Create a new Requester/Responder pair which wraps and calls the given service pub fn unbounded() -> (SenderService, Receiver) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (SenderService::new(tx), Receiver::new(rx)) } @@ -81,20 +70,15 @@ impl Service for SenderService { type Future = TransportResponseFuture; type Response = TRes; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready(cx).map_err(|err| { - if err.is_disconnected() { - return TransportChannelError::ChannelClosed; - } - - unreachable!("unbounded channels can never be full"); - }) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + // An unbounded sender is always ready (i.e. will never wait to send) + Poll::Ready(Ok(())) } fn call(&mut self, request: TReq) -> Self::Future { let (tx, rx) = oneshot::channel(); - if self.tx.unbounded_send((request, tx)).is_ok() { + if self.tx.send((request, tx)).is_ok() { TransportResponseFuture::new(rx) } else { // We're not able to send (rx closed) so return a future which resolves to @@ -106,8 +90,6 @@ impl Service for SenderService { #[derive(Debug, Error, Eq, PartialEq, Clone)] pub enum TransportChannelError { - #[error("Error occurred when sending: `{0}`")] - SendError(#[from] SendError), #[error("Request was canceled")] Canceled, #[error("The response channel has closed")] @@ -188,23 +170,21 @@ impl RequestContext { } /// Receiver side of the reply channel. -/// This is functionally equivalent to `rx.map(|(req, reply_tx)| RequestContext::new(req, reply_tx))` -/// but is ergonomically better to use with the `futures::select` macro (implements FusedStream) -/// and has a short type signature. pub struct Receiver { rx: Rx, + is_closed: bool, } impl FusedStream for Receiver { fn is_terminated(&self) -> bool { - self.rx.is_terminated() + self.is_closed } } impl Receiver { // Create a new Responder pub fn new(rx: Rx) -> Self { - Self { rx } + Self { rx, is_closed: false } } pub fn close(&mut self) { @@ -216,10 +196,17 @@ impl Stream for Receiver { type Item = RequestContext; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.rx.poll_next_unpin(cx)) { + if self.is_terminated() { + return Poll::Ready(None); + } + + match ready!(self.rx.poll_recv(cx)) { Some((req, tx)) => Poll::Ready(Some(RequestContext::new(req, tx))), // Stream has closed, so we're done - None => Poll::Ready(None), + None => { + self.is_closed = true; + Poll::Ready(None) + }, } } } @@ -227,7 +214,7 @@ impl Stream for Receiver { #[cfg(test)] mod test { use super::*; - use futures::{executor::block_on, future}; + use futures::{executor::block_on, future, StreamExt}; use std::fmt::Debug; use tari_test_utils::unpack_enum; use tower::ServiceExt; @@ -247,7 +234,7 @@ mod test { async fn reply(mut rx: Rx, msg: TResp) where TResp: Debug { - match rx.next().await { + match rx.recv().await { Some((_, tx)) => { tx.send(msg).unwrap(); }, @@ -257,7 +244,7 @@ mod test { #[test] fn requestor_call() { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); let requestor = SenderService::<_, _>::new(tx); let fut = future::join(requestor.oneshot("PING"), reply(rx, "PONG")); diff --git a/base_layer/service_framework/src/stack.rs b/base_layer/service_framework/src/stack.rs index 2489006d56..ad38a94be3 100644 --- a/base_layer/service_framework/src/stack.rs +++ b/base_layer/service_framework/src/stack.rs @@ -103,7 +103,7 @@ mod test { use tari_shutdown::Shutdown; use tower::service_fn; - #[tokio_macros::test] + #[tokio::test] async fn service_defn_simple() { // This is less of a test and more of a demo of using the short-hand implementation of ServiceInitializer let simple_initializer = |_: ServiceInitializerContext| Ok(()); @@ -155,7 +155,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn service_stack_new() { let shared_state = Arc::new(AtomicUsize::new(0)); diff --git a/base_layer/tari_stratum_ffi/Cargo.toml b/base_layer/tari_stratum_ffi/Cargo.toml index 6598df9f06..9ed773d89f 100644 --- a/base_layer/tari_stratum_ffi/Cargo.toml +++ b/base_layer/tari_stratum_ffi/Cargo.toml @@ -14,7 +14,7 @@ tari_app_grpc = { path = "../../applications/tari_app_grpc" } tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} tari_utilities = "^0.3" libc = "0.2.65" -thiserror = "1.0.20" +thiserror = "1.0.26" hex = "0.4.2" serde = { version="1.0.106", features = ["derive"] } serde_json = "1.0.57" diff --git a/base_layer/wallet/Cargo.toml b/base_layer/wallet/Cargo.toml index c2cf12ca22..b64cb34ab2 100644 --- a/base_layer/wallet/Cargo.toml +++ b/base_layer/wallet/Cargo.toml @@ -7,54 +7,53 @@ version = "0.9.5" edition = "2018" [dependencies] -tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} -tari_comms = { version = "^0.9", path = "../../comms"} +tari_common_types = { version = "^0.9", path = "../../base_layer/common_types" } +tari_comms = { version = "^0.9", path = "../../comms" } tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } -tari_service_framework = { version = "^0.9", path = "../service_framework"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } aes-gcm = "^0.8" blake2 = "0.9.0" -chrono = { version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } crossbeam-channel = "0.3.8" digest = "0.9.0" -diesel = { version="1.4.7", features = ["sqlite", "serde_json", "chrono"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } fs2 = "0.3.0" -futures = { version = "^0.3.1", features =["compat", "std"]} +futures = { version = "^0.3.1", features = ["compat", "std"] } lazy_static = "1.4.0" log = "0.4.6" -log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} +log4rs = { version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"] } lmdb-zero = "0.4.4" rand = "0.8" -serde = {version = "1.0.89", features = ["derive"] } +serde = { version = "1.0.89", features = ["derive"] } serde_json = "1.0.39" -tokio = { version = "0.2.10", features = ["blocking", "sync"]} +tokio = { version = "1.10", features = ["sync", "macros"] } tower = "0.3.0-alpha.2" tempfile = "3.1.0" -time = {version = "0.1.39"} -thiserror = "1.0.20" +time = { version = "0.1.39" } +thiserror = "1.0.26" bincode = "1.3.1" [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" default-features = false -features = ["transactions", "mempool_proto", "base_node_proto",] +features = ["transactions", "mempool_proto", "base_node_proto", ] [dev-dependencies] -tari_p2p = { version = "^0.9", path = "../p2p", features=["test-mocks"]} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features=["test-mocks"]} +tari_p2p = { version = "^0.9", path = "../p2p", features = ["test-mocks"] } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features = ["test-mocks"] } tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } lazy_static = "1.3.0" env_logger = "0.7.1" -prost = "0.6.1" -tokio-macros = "0.2.4" +prost = "0.8.0" [features] c_integration = [] diff --git a/base_layer/wallet/src/base_node_service/handle.rs b/base_layer/wallet/src/base_node_service/handle.rs index 4957823c72..f495479778 100644 --- a/base_layer/wallet/src/base_node_service/handle.rs +++ b/base_layer/wallet/src/base_node_service/handle.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::BaseNodeServiceError, service::BaseNodeState}; -use futures::{stream::Fuse, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::peer_manager::Peer; @@ -72,8 +71,8 @@ impl BaseNodeServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> BaseNodeEventReceiver { + self.event_stream_sender.subscribe() } pub async fn get_chain_metadata(&mut self) -> Result, BaseNodeServiceError> { diff --git a/base_layer/wallet/src/base_node_service/monitor.rs b/base_layer/wallet/src/base_node_service/monitor.rs index 0bd1180c3e..8e0298ca27 100644 --- a/base_layer/wallet/src/base_node_service/monitor.rs +++ b/base_layer/wallet/src/base_node_service/monitor.rs @@ -136,7 +136,7 @@ impl BaseNodeMonitor { }) .await; - time::delay_for(self.interval).await + time::sleep(self.interval).await } // loop only exits on shutdown/error diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index ac17805edb..5a35696e14 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -22,16 +22,12 @@ use super::service::OnlineStatus; use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use tari_comms::{ peer_manager::{NodeId, Peer}, protocol::rpc::RpcClientLease, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::sync::watch; +use tokio::sync::{mpsc, oneshot, watch}; pub enum WalletConnectivityRequest { ObtainBaseNodeWalletRpcClient(oneshot::Sender>), diff --git a/base_layer/wallet/src/connectivity_service/initializer.rs b/base_layer/wallet/src/connectivity_service/initializer.rs index d0c2b94126..1610a834e3 100644 --- a/base_layer/wallet/src/connectivity_service/initializer.rs +++ b/base_layer/wallet/src/connectivity_service/initializer.rs @@ -30,8 +30,8 @@ use super::{handle::WalletConnectivityHandle, service::WalletConnectivityService, watch::Watch}; use crate::{base_node_service::config::BaseNodeServiceConfig, connectivity_service::service::OnlineStatus}; -use futures::channel::mpsc; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::sync::mpsc; pub struct WalletConnectivityInitializer { config: BaseNodeServiceConfig, @@ -59,8 +59,13 @@ impl ServiceInitializer for WalletConnectivityInitializer { context.spawn_until_shutdown(move |handles| { let connectivity = handles.expect_handle(); - let service = - WalletConnectivityService::new(config, receiver, base_node_watch, online_status_watch, connectivity); + let service = WalletConnectivityService::new( + config, + receiver, + base_node_watch.get_receiver(), + online_status_watch, + connectivity, + ); service.start() }); diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index cf0901339f..950b9a9a72 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -24,16 +24,8 @@ use crate::{ base_node_service::config::BaseNodeServiceConfig, connectivity_service::{error::WalletConnectivityError, handle::WalletConnectivityRequest, watch::Watch}, }; -use core::mem; -use futures::{ - channel::{mpsc, oneshot}, - future, - future::Either, - stream::Fuse, - FutureExt, - StreamExt, -}; use log::*; +use std::{mem, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer}, @@ -41,7 +33,11 @@ use tari_comms::{ PeerConnection, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::{time, time::Duration}; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; const LOG_TARGET: &str = "wallet::connectivity"; @@ -55,9 +51,9 @@ pub enum OnlineStatus { pub struct WalletConnectivityService { config: BaseNodeServiceConfig, - request_stream: Fuse>, + request_stream: mpsc::Receiver, connectivity: ConnectivityRequester, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, pools: Option, online_status_watch: Watch, pending_requests: Vec, @@ -72,13 +68,13 @@ impl WalletConnectivityService { pub(super) fn new( config: BaseNodeServiceConfig, request_stream: mpsc::Receiver, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, online_status_watch: Watch, connectivity: ConnectivityRequester, ) -> Self { Self { config, - request_stream: request_stream.fuse(), + request_stream, connectivity, base_node_watch, pools: None, @@ -89,20 +85,26 @@ impl WalletConnectivityService { pub async fn start(mut self) { debug!(target: LOG_TARGET, "Wallet connectivity service has started."); - let mut base_node_watch_rx = self.base_node_watch.get_receiver().fuse(); + let mut check_connection = + time::interval_at(time::Instant::now() + Duration::from_secs(5), Duration::from_secs(5)); + check_connection.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - let mut check_connection = time::delay_for(Duration::from_secs(1)).fuse(); - futures::select! { - req = self.request_stream.select_next_some() => { - self.handle_request(req).await; - }, - maybe_peer = base_node_watch_rx.select_next_some() => { - if maybe_peer.is_some() { + tokio::select! { + // BIASED: select branches are in order of priority + biased; + + Ok(_) = self.base_node_watch.changed() => { + if self.base_node_watch.borrow().is_some() { // This will block the rest until the connection is established. This is what we want. self.setup_base_node_connection().await; } }, - _ = check_connection => { + + Some(req) = self.request_stream.recv() => { + self.handle_request(req).await; + }, + + _ = check_connection.tick() => { self.check_connection().await; } } @@ -234,7 +236,7 @@ impl WalletConnectivityService { self.set_online_status(OnlineStatus::Offline); } warn!(target: LOG_TARGET, "{}", e); - time::delay_for(self.config.base_node_monitor_refresh_interval).await; + time::sleep(self.config.base_node_monitor_refresh_interval).await; continue; }, } @@ -272,13 +274,15 @@ impl WalletConnectivityService { } async fn try_dial_peer(&mut self, peer: NodeId) -> Result, WalletConnectivityError> { - let recv_fut = self.base_node_watch.recv(); - futures::pin_mut!(recv_fut); - let dial_fut = self.connectivity.dial_peer(peer); - futures::pin_mut!(dial_fut); - match future::select(recv_fut, dial_fut).await { - Either::Left(_) => Ok(None), - Either::Right((conn, _)) => Ok(Some(conn?)), + tokio::select! { + biased; + + _ = self.base_node_watch.changed() => { + Ok(None) + } + result = self.connectivity.dial_peer(peer) => { + Ok(Some(result?)) + } } } @@ -304,8 +308,8 @@ impl ReplyOneshot { pub fn is_canceled(&self) -> bool { use ReplyOneshot::*; match self { - WalletRpc(tx) => tx.is_canceled(), - SyncRpc(tx) => tx.is_canceled(), + WalletRpc(tx) => tx.is_closed(), + SyncRpc(tx) => tx.is_closed(), } } } diff --git a/base_layer/wallet/src/connectivity_service/test.rs b/base_layer/wallet/src/connectivity_service/test.rs index 7c24ef5b46..9a8f5a2da9 100644 --- a/base_layer/wallet/src/connectivity_service/test.rs +++ b/base_layer/wallet/src/connectivity_service/test.rs @@ -23,7 +23,7 @@ use super::service::WalletConnectivityService; use crate::connectivity_service::{watch::Watch, OnlineStatus, WalletConnectivityHandle}; use core::convert; -use futures::{channel::mpsc, future}; +use futures::future; use std::{iter, sync::Arc}; use tari_comms::{ peer_manager::PeerFeatures, @@ -39,7 +39,10 @@ use tari_comms::{ }; use tari_shutdown::Shutdown; use tari_test_utils::runtime::spawn_until_shutdown; -use tokio::{sync::Barrier, task}; +use tokio::{ + sync::{mpsc, Barrier}, + task, +}; async fn setup() -> ( WalletConnectivityHandle, @@ -57,7 +60,7 @@ async fn setup() -> ( let service = WalletConnectivityService::new( Default::default(), rx, - base_node_watch, + base_node_watch.get_receiver(), online_status_watch, connectivity, ); @@ -70,7 +73,7 @@ async fn setup() -> ( (handle, mock_server, mock_state, shutdown) } -#[tokio_macros::test] +#[tokio::test] async fn it_dials_peer_when_base_node_is_set() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -92,7 +95,7 @@ async fn it_dials_peer_when_base_node_is_set() { assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_resolves_many_pending_rpc_session_requests() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -122,7 +125,7 @@ async fn it_resolves_many_pending_rpc_session_requests() { assert!(results.into_iter().map(Result::unwrap).all(convert::identity)); } -#[tokio_macros::test] +#[tokio::test] async fn it_changes_to_a_new_base_node() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -138,7 +141,7 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer1.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); + assert!(mock_state.count_calls_containing("AddManagedPeer").await >= 1); let _ = mock_state.take_calls().await; let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); @@ -149,13 +152,12 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer2.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_connect_fail_reconnect() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -198,7 +200,7 @@ async fn it_gracefully_handles_connect_fail_reconnect() { pending_request.await.unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_multiple_connection_failures() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/src/connectivity_service/watch.rs b/base_layer/wallet/src/connectivity_service/watch.rs index 1f1e868d47..4669b355f6 100644 --- a/base_layer/wallet/src/connectivity_service/watch.rs +++ b/base_layer/wallet/src/connectivity_service/watch.rs @@ -26,24 +26,19 @@ use tokio::sync::watch; #[derive(Clone)] pub struct Watch(Arc>, watch::Receiver); -impl Watch { +impl Watch { pub fn new(initial: T) -> Self { let (tx, rx) = watch::channel(initial); Self(Arc::new(tx), rx) } - #[allow(dead_code)] - pub async fn recv(&mut self) -> Option { - self.receiver_mut().recv().await - } - pub fn borrow(&self) -> watch::Ref<'_, T> { self.receiver().borrow() } pub fn broadcast(&self, item: T) { - // SAFETY: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime - if self.sender().broadcast(item).is_err() { + // PANIC: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime + if self.sender().send(item).is_err() { // Result::expect requires E: fmt::Debug and `watch::SendError` is not, this is equivalent panic!("watch internal receiver is dropped"); } @@ -53,10 +48,6 @@ impl Watch { &self.0 } - fn receiver_mut(&mut self) -> &mut watch::Receiver { - &mut self.1 - } - pub fn receiver(&self) -> &watch::Receiver { &self.1 } diff --git a/base_layer/wallet/src/contacts_service/service.rs b/base_layer/wallet/src/contacts_service/service.rs index f86f6b49cc..cfc8473202 100644 --- a/base_layer/wallet/src/contacts_service/service.rs +++ b/base_layer/wallet/src/contacts_service/service.rs @@ -76,8 +76,8 @@ where T: ContactsBackend + 'static info!(target: LOG_TARGET, "Contacts Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { + tokio::select! { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { error!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -88,14 +88,10 @@ where T: ContactsBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Contacts service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Contacts service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Contacts Service ended"); diff --git a/base_layer/wallet/src/lib.rs b/base_layer/wallet/src/lib.rs index bc0b4c1a04..22bce8bfdb 100644 --- a/base_layer/wallet/src/lib.rs +++ b/base_layer/wallet/src/lib.rs @@ -25,8 +25,6 @@ pub mod wallet; extern crate diesel; #[macro_use] extern crate diesel_migrations; -#[macro_use] -extern crate lazy_static; mod config; pub mod schema; diff --git a/base_layer/wallet/src/output_manager_service/handle.rs b/base_layer/wallet/src/output_manager_service/handle.rs index 659fab4a42..183f46ab4e 100644 --- a/base_layer/wallet/src/output_manager_service/handle.rs +++ b/base_layer/wallet/src/output_manager_service/handle.rs @@ -31,7 +31,6 @@ use crate::{ types::ValidationRetryStrategy, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc, time::Duration}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ @@ -191,8 +190,8 @@ impl OutputManagerHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> OutputManagerEventReceiver { + self.event_stream_sender.subscribe() } pub async fn add_output(&mut self, output: UnblindedOutput) -> Result<(), OutputManagerError> { diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index c8946b9ff9..77687c8845 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -166,9 +166,9 @@ where TBackend: OutputManagerBackend + 'static info!(target: LOG_TARGET, "Output Manager Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Service API Request"); + tokio::select! { + Some(request_context) = request_stream.next() => { + trace!(target: LOG_TARGET, "Handling Service API Request"); let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { warn!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -179,14 +179,10 @@ where TBackend: OutputManagerBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Output manager service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Output manager service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Output Manager Service ended"); diff --git a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs index e3e022e4fb..de0e0e3c17 100644 --- a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs +++ b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs @@ -30,7 +30,7 @@ use crate::{ transaction_service::storage::models::TransactionStatus, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, collections::HashMap, convert::TryFrom, fmt, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -40,7 +40,7 @@ use tari_core::{ transactions::{transaction::TransactionOutput, types::Signature}, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::output_manager_service::utxo_validation_task"; @@ -87,21 +87,17 @@ where TBackend: OutputManagerBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - OutputManagerProtocolError::new( - self.id, - OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), - ) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + OutputManagerProtocolError::new( + self.id, + OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), + ) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); let total_retries_str = match self.retry_strategy { - ValidationRetryStrategy::Limited(n) => format!("{}", n), + ValidationRetryStrategy::Limited(n) => n.to_string(), ValidationRetryStrategy::UntilSuccess => "∞".to_string(), }; @@ -180,14 +176,14 @@ where TBackend: OutputManagerBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key.clone()); let mut connection: Option = None; - let delay = delay_for(self.resources.config.peer_dial_retry_timeout); + let delay = sleep(self.resources.config.peer_dial_retry_timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -197,7 +193,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!( @@ -228,7 +224,7 @@ where TBackend: OutputManagerBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -236,7 +232,7 @@ where TBackend: OutputManagerBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -253,7 +249,7 @@ where TBackend: OutputManagerBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -294,9 +290,9 @@ where TBackend: OutputManagerBackend + 'static batch_num, batch_total ); - let delay = delay_for(self.retry_delay); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.retry_delay); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_bn) => { info!(target: LOG_TARGET, "TXO Validation protocol aborted due to Base Node Public key change" ); @@ -323,7 +319,7 @@ where TBackend: OutputManagerBackend + 'static } } }, - result = self.send_query_batch(batch.clone(), &mut client).fuse() => { + result = self.send_query_batch(batch.clone(), &mut client) => { match result { Ok(synced) => { self.base_node_synced = synced; @@ -374,7 +370,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, diff --git a/base_layer/wallet/src/storage/database.rs b/base_layer/wallet/src/storage/database.rs index 3d5cbf8838..d23d806298 100644 --- a/base_layer/wallet/src/storage/database.rs +++ b/base_layer/wallet/src/storage/database.rs @@ -373,7 +373,7 @@ mod test { #[test] fn test_database_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db_name = format!("{}.sqlite3", string(8).as_str()); let db_folder = tempdir().unwrap().path().to_str().unwrap().to_string(); diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index c197dd2024..05e9e4af2b 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -34,7 +34,7 @@ use tari_p2p::services::liveness::error::LivenessError; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; use time::OutOfRangeError; -use tokio::sync::broadcast::RecvError; +use tokio::sync::broadcast::error::RecvError; #[derive(Debug, Error)] pub enum TransactionServiceError { diff --git a/base_layer/wallet/src/transaction_service/handle.rs b/base_layer/wallet/src/transaction_service/handle.rs index f34a5f667f..7a6ab48649 100644 --- a/base_layer/wallet/src/transaction_service/handle.rs +++ b/base_layer/wallet/src/transaction_service/handle.rs @@ -28,7 +28,6 @@ use crate::{ }, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; @@ -187,8 +186,8 @@ impl TransactionServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> TransactionEventReceiver { + self.event_stream_sender.subscribe() } pub async fn send_transaction( diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs index 05191c8f8d..e2b0d7b871 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -44,7 +44,7 @@ use tari_core::{ transactions::{transaction::Transaction, types::Signature}, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::broadcast_protocol"; @@ -86,21 +86,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); // Main protocol loop @@ -108,14 +100,14 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -139,7 +131,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -158,7 +150,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -179,7 +171,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -187,11 +179,11 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -243,10 +235,10 @@ where TBackend: TransactionBackend + 'static }, }; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -315,7 +307,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -332,7 +324,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs index 65fb58f601..d072eb77a8 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs @@ -29,7 +29,7 @@ use crate::{ storage::{database::TransactionBackend, models::CompletedTransaction}, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -41,7 +41,7 @@ use tari_core::{ transactions::types::Signature, }; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::coinbase_monitoring"; @@ -86,21 +86,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; trace!( target: LOG_TARGET, @@ -173,7 +165,7 @@ where TBackend: TransactionBackend + 'static self.base_node_public_key, self.tx_id, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -203,7 +195,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -225,7 +217,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -248,7 +240,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring protocol (TxId: {}) shutting down because it received the shutdown \ @@ -259,14 +251,14 @@ where TBackend: TransactionBackend + 'static }, } - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the \ @@ -314,10 +306,10 @@ where TBackend: TransactionBackend + 'static TransactionServiceError::InvalidCompletedTransaction, )); } - let delay = delay_for(self.timeout).fuse(); + let delay = sleep(self.timeout).fuse(); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -392,7 +384,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -411,7 +403,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the shutdown \ diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs index 744bb6f1fc..058d2c4f74 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs @@ -34,21 +34,18 @@ use crate::{ }, }; use chrono::Utc; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; +use futures::future::FutureExt; use log::*; use std::sync::Arc; use tari_comms::types::CommsPublicKey; +use tokio::sync::{mpsc, oneshot}; use tari_core::transactions::{ transaction::Transaction, transaction_protocol::{recipient::RecipientState, sender::TransactionSenderMessage}, }; use tari_crypto::tari_utilities::Hashable; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "wallet::transaction_service::protocols::receive_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::receive_protocol"; @@ -263,7 +260,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match inbound_tx.last_send_timestamp { @@ -310,9 +308,9 @@ where TBackend: TransactionBackend + 'static let mut incoming_finalized_transaction = None; loop { loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, tx_id, tx) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, tx_id, tx)) = receiver.recv() => { incoming_finalized_transaction = Some(tx); if inbound_tx.source_public_key != spk { warn!( @@ -325,16 +323,14 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { - if result.is_ok() { - info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); - return Err(TransactionServiceProtocolError::new( - self.id, - TransactionServiceError::TransactionCancelled, - )); - } + Ok(_) = &mut cancellation_receiver => { + info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + )); }, - () = resend_timeout => { + _ = resend_timeout => { match send_transaction_reply( inbound_tx.clone(), self.resources.outbound_message_service.clone(), @@ -353,10 +349,10 @@ where TBackend: TransactionBackend + 'static ), } }, - () = timeout_delay => { + _ = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Receive Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index cf30a7f928..e88c9013e6 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -20,12 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; - -use chrono::Utc; -use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; -use log::*; - use crate::transaction_service::{ config::TransactionRoutingMechanism, error::{TransactionServiceError, TransactionServiceProtocolError}, @@ -41,7 +35,10 @@ use crate::transaction_service::{ wait_on_dial::wait_on_dial, }, }; -use futures::channel::oneshot; +use chrono::Utc; +use futures::FutureExt; +use log::*; +use std::sync::Arc; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_comms_dht::{ domain_message::OutboundDomainMessage, @@ -55,7 +52,10 @@ use tari_core::transactions::{ }; use tari_crypto::script; use tari_p2p::tari_message::TariMessageType; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc::Receiver, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::send_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::send_protocol"; @@ -344,7 +344,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match outbound_tx.last_send_timestamp { @@ -390,9 +391,9 @@ where TBackend: TransactionBackend + 'static #[allow(unused_assignments)] let mut reply = None; loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, rr) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, rr)) = receiver.recv() => { let rr_tx_id = rr.tx_id; reply = Some(rr); @@ -407,7 +408,7 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { + result = &mut cancellation_receiver => { if result.is_ok() { info!(target: LOG_TARGET, "Cancelling Transaction Send Protocol (TxId: {})", self.id); let _ = send_transaction_cancelled_message(self.id,self.dest_pubkey.clone(), self.resources.outbound_message_service.clone(), ).await.map_err(|e| { @@ -441,10 +442,10 @@ where TBackend: TransactionBackend + 'static .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; } }, - () = timeout_delay => { + () = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Send Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs index d0b2f7f6ac..dcf072c272 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -43,7 +43,7 @@ use tari_core::{ }, proto::{base_node::Signatures as SignaturesProto, types::Signature as SignatureProto}, }; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::validation_protocol"; @@ -94,14 +94,12 @@ where TBackend: TransactionBackend + 'static let mut timeout_update_receiver = self .timeout_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut base_node_update_receiver = self .base_node_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut shutdown = self.resources.shutdown_signal.clone(); @@ -158,13 +156,13 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -175,7 +173,7 @@ where TBackend: TransactionBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { @@ -204,7 +202,7 @@ where TBackend: TransactionBackend + 'static } } } - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -223,7 +221,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -231,7 +229,7 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -248,7 +246,7 @@ where TBackend: TransactionBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -282,9 +280,9 @@ where TBackend: TransactionBackend + 'static } else { break 'main; }; - let delay = delay_for(self.timeout); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.timeout); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!(target: LOG_TARGET, "Aborting Transaction Validation Protocol as new Base node is set"); @@ -372,7 +370,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -391,7 +389,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index c30fb7f412..c768a085ed 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -47,14 +47,7 @@ use crate::{ }; use chrono::{NaiveDateTime, Utc}; use digest::Digest; -use futures::{ - channel::{mpsc, mpsc::Sender, oneshot}, - pin_mut, - stream::FuturesUnordered, - SinkExt, - Stream, - StreamExt, -}; +use futures::{pin_mut, stream::FuturesUnordered, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -85,7 +78,10 @@ use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteA use tari_p2p::domain_message::DomainMessage; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle}; +use tokio::{ + sync::{broadcast, mpsc, mpsc::Sender, oneshot}, + task::JoinHandle, +}; const LOG_TARGET: &str = "wallet::transaction_service::service"; @@ -276,9 +272,9 @@ where info!(target: LOG_TARGET, "Transaction Service started"); loop { - futures::select! { + tokio::select! { //Incoming request - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (request, reply_tx) = request_context.split(); @@ -303,7 +299,7 @@ where ); }, // Incoming Transaction messages from the Comms layer - msg = transaction_stream.select_next_some() => { + Some(msg) = transaction_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -333,7 +329,7 @@ where ); }, // Incoming Transaction Reply messages from the Comms layer - msg = transaction_reply_stream.select_next_some() => { + Some(msg) = transaction_reply_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -364,7 +360,7 @@ where ); }, // Incoming Finalized Transaction messages from the Comms layer - msg = transaction_finalized_stream.select_next_some() => { + Some(msg) = transaction_finalized_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -402,7 +398,7 @@ where ); }, // Incoming messages from the Comms layer - msg = base_node_response_stream.select_next_some() => { + Some(msg) = base_node_response_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -421,7 +417,7 @@ where ); } // Incoming messages from the Comms layer - msg = transaction_cancelled_stream.select_next_some() => { + Some(msg) = transaction_cancelled_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -436,7 +432,7 @@ where finish.duration_since(start).as_millis(), ); } - join_result = send_transaction_protocol_handles.select_next_some() => { + Some(join_result) = send_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Send Protocol for Transaction has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_send_transaction_protocol( @@ -446,7 +442,7 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = receive_transaction_protocol_handles.select_next_some() => { + Some(join_result) = receive_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Receive Transaction Protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_receive_transaction_protocol( @@ -456,14 +452,14 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = transaction_broadcast_protocol_handles.select_next_some() => { + Some(join_result) = transaction_broadcast_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Broadcast protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_broadcast_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Broadcast Protocol: {:?}", e), }; } - join_result = coinbase_transaction_monitoring_protocol_handles.select_next_some() => { + Some(join_result) = coinbase_transaction_monitoring_protocol_handles.next() => { trace!(target: LOG_TARGET, "Coinbase transaction monitoring protocol has ended with result {:?}", join_result); match join_result { @@ -471,21 +467,17 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Coinbase Monitoring protocol: {:?}", e), }; } - join_result = transaction_validation_protocol_handles.select_next_some() => { + Some(join_result) = transaction_validation_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Validation protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_validation_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Transaction Validation protocol: {:?}", e), }; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Transaction service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Transaction service shut down"); diff --git a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs index 61b3ee7d75..522bacdbb9 100644 --- a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs +++ b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs @@ -27,8 +27,8 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::StreamExt; use log::*; +use tokio::sync::broadcast; const LOG_TARGET: &str = "wallet::transaction_service::tasks::start_tx_validation_and_broadcast"; @@ -36,16 +36,16 @@ pub async fn start_transaction_validation_and_broadcast_protocols( mut handle: TransactionServiceHandle, retry_strategy: ValidationRetryStrategy, ) -> Result<(), TransactionServiceError> { - let mut event_stream = handle.get_event_stream_fused(); + let mut event_stream = handle.get_event_stream(); let our_id = handle.validate_transactions(retry_strategy).await?; // Now that its started we will spawn an task to monitor the event bus and when its successful we will start the // Broadcast protocols tokio::spawn(async move { - while let Some(event_item) = event_stream.next().await { - if let Ok(event) = event_item { - match (*event).clone() { + loop { + match event_stream.recv().await { + Ok(event) => match &*event { TransactionEvent::TransactionValidationSuccess(_id) => { info!( target: LOG_TARGET, @@ -59,19 +59,28 @@ pub async fn start_transaction_validation_and_broadcast_protocols( } }, TransactionEvent::TransactionValidationFailure(id) => { - if our_id == id { + if our_id == *id { error!(target: LOG_TARGET, "Transaction Validation failed!"); break; } }, _ => (), - } - } else { - warn!( - target: LOG_TARGET, - "Error reading from Transaction Service Event Stream" - ); - break; + }, + Err(e @ broadcast::error::RecvError::Lagged(_)) => { + warn!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols: {}", e + ); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { + debug!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols is exiting because the event stream \ + closed", + ); + break; + }, } } }); diff --git a/base_layer/wallet/src/util/mod.rs b/base_layer/wallet/src/util/mod.rs index 9664a0e376..7217ac5056 100644 --- a/base_layer/wallet/src/util/mod.rs +++ b/base_layer/wallet/src/util/mod.rs @@ -20,6 +20,4 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod emoji; pub mod encryption; -pub mod luhn; diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs index 6aaa38363f..341d853ad5 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs @@ -69,7 +69,7 @@ use tari_core::{ }; use tari_service_framework::{reply_channel, reply_channel::SenderService}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task, time}; +use tokio::{sync::broadcast, task, time, time::MissedTickBehavior}; pub const LOG_TARGET: &str = "wallet::utxo_scanning"; @@ -715,35 +715,23 @@ where TBackend: WalletBackend + 'static let mut shutdown = self.shutdown_signal.clone(); let start_at = Instant::now() + Duration::from_secs(1); - let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval).fuse(); - let mut previous = Instant::now(); + let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval); + work_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - _ = work_interval.select_next_some() => { - // This bit of code prevents bottled up tokio interval events to be fired successively for the edge - // case where a computer wakes up from sleep. - if start_at.elapsed() > self.scan_for_utxo_interval && - previous.elapsed() < self.scan_for_utxo_interval.mul_f32(0.9) - { - debug!( - target: LOG_TARGET, - "UTXO scanning work interval event fired too quickly, not running the task" - ); - } else { - let running_flag = self.is_running.clone(); - if !running_flag.load(Ordering::SeqCst) { - let task = self.create_task(); - debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); - task::spawn(async move { - if let Err(err) = task.run().await { - error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); - } - //we make sure the flag is set to false here - running_flag.store(false, Ordering::Relaxed); - }); - } + tokio::select! { + _ = work_interval.tick() => { + let running_flag = self.is_running.clone(); + if !running_flag.load(Ordering::SeqCst) { + let task = self.create_task(); + debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); + task::spawn(async move { + if let Err(err) = task.run().await { + error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); + } + //we make sure the flag is set to false here + running_flag.store(false, Ordering::Relaxed); + }); } - previous = Instant::now(); }, request_context = request_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Service API Request"); @@ -757,7 +745,7 @@ where TBackend: WalletBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { // this will stop the task if its running, and let that thread exit gracefully self.is_running.store(false, Ordering::Relaxed); info!(target: LOG_TARGET, "UTXO scanning service shutting down because it received the shutdown signal"); diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index 1f91f3d625..a02ea1129d 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -77,7 +77,6 @@ use tari_key_manager::key_manager::KeyManager; use tari_p2p::{comms_connector::pubsub_connector, initialization, initialization::P2pInitializer}; use tari_service_framework::StackBuilder; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "wallet"; @@ -139,8 +138,7 @@ where let bn_service_db = wallet_database.clone(); let factories = config.clone().factories; - let (publisher, subscription_factory) = - pubsub_connector(runtime::Handle::current(), config.buffer_size, config.rate_limit); + let (publisher, subscription_factory) = pubsub_connector(config.buffer_size, config.rate_limit); let peer_message_subscription_factory = Arc::new(subscription_factory); let transport_type = config.comms_config.transport_type.clone(); diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index c6c23da53e..66cf5e734f 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -25,12 +25,12 @@ use crate::support::{ rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, utils::{make_input, make_input_with_features, TestParams}, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use rand::{rngs::OsRng, RngCore}; -use std::{sync::Arc, thread, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, - protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, + protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcClientConfig, RpcStatus}, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -74,7 +74,7 @@ use tari_wallet::{ service::OutputManagerService, storage::{ database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, - models::DbUnblindedOutput, + models::{DbUnblindedOutput, OutputStatus}, sqlite_db::OutputManagerSqliteDatabase, }, TxId, @@ -83,18 +83,14 @@ use tari_wallet::{ transaction_service::handle::TransactionServiceHandle, types::ValidationRetryStrategy, }; - -use tari_comms::protocol::rpc::RpcClientConfig; -use tari_wallet::output_manager_service::storage::models::OutputStatus; use tokio::{ - runtime::Runtime, sync::{broadcast, broadcast::channel}, - time::delay_for, + task, + time, }; #[allow(clippy::type_complexity)] -pub fn setup_output_manager_service( - runtime: &mut Runtime, +async fn setup_output_manager_service( backend: T, with_connection: bool, ) -> ( @@ -124,11 +120,11 @@ pub fn setup_output_manager_service( let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); + task::spawn(connectivity_mock.run()); let service = BaseNodeWalletRpcMockService::new(); let rpc_service_state = service.get_state(); @@ -137,43 +133,39 @@ pub fn setup_output_manager_service( let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = MockRpcServer::new(server, server_node_identity.clone()); - runtime.handle().enter(|| mock_server.serve()); + mock_server.serve(); if with_connection { - let connection = runtime.block_on(async { - mock_server - .create_connection(server_node_identity.to_peer(), protocol_name.into()) - .await - }); - runtime.block_on(connectivity_mock_state.add_active_connection(connection)); + let connection = mock_server + .create_connection(server_node_identity.to_peer(), protocol_name.into()) + .await; + connectivity_mock_state.add_active_connection(connection).await; } - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - basenode_service_handle, - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + basenode_service_handle, + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -218,8 +210,7 @@ async fn complete_transaction(mut stp: SenderTransactionProtocol, mut oms: Outpu stp.get_transaction().unwrap().clone() } -pub fn setup_oms_with_bn_state( - runtime: &mut Runtime, +pub async fn setup_oms_with_bn_state( backend: T, height: Option, ) -> ( @@ -246,35 +237,35 @@ pub fn setup_oms_with_bn_state( let base_node_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_base_node_state(height); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let _connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); - - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - base_node_service_handle.clone(), - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + task::spawn(connectivity_mock.run()); + + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + base_node_service_handle.clone(), + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -321,63 +312,65 @@ fn generate_sender_transaction_message(amount: MicroTari) -> (TxId, TransactionS ) } -#[test] -fn fee_estimate() { +#[tokio::test] +async fn fee_estimate() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let (_, uo) = make_input(&mut OsRng.clone(), MicroTari::from(3000), &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // minimum fee let fee_per_gram = MicroTari::from(1); - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1) + .await .unwrap(); assert_eq!(fee, MicroTari::from(100)); let fee_per_gram = MicroTari::from(25); for outputs in 1..5 { - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs) + .await .unwrap(); assert_eq!(fee, Fee::calculate(fee_per_gram, 1, 1, outputs as usize)); } // not enough funds - let err = runtime - .block_on(oms.fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1)) + let err = oms + .fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_no_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_no_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // no chain metadata let (mut oms, _shutdown, _, _) = - setup_oms_with_bn_state(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), None); + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), None).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -389,24 +382,25 @@ fn test_utxo_selection_no_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } // but we have no chain state so the lowest maturity should be used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 8); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -415,34 +409,31 @@ fn test_utxo_selection_no_chain_metadata() { } // test that we can get a fee estimate with no chain metadata - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(300)); // test if a fee estimate would be possible with pending funds included // at this point 52000 uT is still spendable, with pending change incoming of 1690 uT // so instead of returning "not enough funds", return "funds pending" let spendable_amount = (3..=10).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::FundsPending)); // test not enough funds let broke_amount = spendable_amount + MicroTari::from(2000); - let err = runtime - .block_on(oms.fee_estimate(broke_amount, fee_per_gram, 1, 2)) - .unwrap_err(); + let err = oms.fee_estimate(broke_amount, fee_per_gram, 1, 2).await.unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // coin split uses the "Largest" selection strategy - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(fee, MicroTari::from(820)); assert_eq!(utxos_total_value, MicroTari::from(10_000)); // test that largest utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -452,31 +443,28 @@ fn test_utxo_selection_no_chain_metadata() { } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_with_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_with_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // setup with chain metadata at a height of 6 - let (mut oms, _shutdown, _, _) = setup_oms_with_bn_state( - &mut runtime, - OutputManagerSqliteDatabase::new(connection, None), - Some(6), - ); + let (mut oms, _shutdown, _, _) = + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), Some(6)).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -488,52 +476,52 @@ fn test_utxo_selection_with_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 10); // test fee estimates - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(310)); // test fee estimates are maturity aware // even though we have utxos for the fee, they can't be spent because they are not mature yet let spendable_amount = (1..=6).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // test coin split is maturity aware - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(utxos_total_value, MicroTari::from(6_000)); assert_eq!(fee, MicroTari::from(820)); // test that largest spendable utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 9); let found = utxos.iter().any(|u| u.value == 6 * amount); assert!(!found, "An unspendable utxo was selected"); // test transactions - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 1); @@ -543,20 +531,21 @@ fn test_utxo_selection_with_chain_metadata() { } // when the amount is greater than the largest utxo, then "Largest" selection strategy is used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), 6 * amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the highest spendable 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 5); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 4); @@ -566,22 +555,21 @@ fn test_utxo_selection_with_chain_metadata() { } } -#[test] -fn sending_transaction_and_confirmation() { +#[tokio::test] +async fn sending_transaction_and_confirmation() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let (_ti, uo) = make_input( &mut OsRng.clone(), MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); - match runtime.block_on(oms.add_output(uo)) { + oms.add_output(uo.clone()).await.unwrap(); + match oms.add_output(uo).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::DuplicateOutput)) => {}, _ => panic!("Incorrect error message"), }; @@ -592,25 +580,26 @@ fn sending_transaction_and_confirmation() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); // 1 of the 2 outputs should be rewindable, there should be 2 outputs due to change but if we get unlucky enough // that there is no change we will skip this aspect of the test @@ -643,23 +632,23 @@ fn sending_transaction_and_confirmation() { assert_eq!(num_rewound, 1, "Should only be 1 rewindable output"); } - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); assert_eq!( - runtime.block_on(oms.get_pending_transactions()).unwrap().len(), + oms.get_pending_transactions().await.unwrap().len(), 0, "Should have no pending tx" ); assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), + oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len(), "# Outputs should equal number of sent inputs" ); assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - num_outputs + 1 - runtime.block_on(oms.get_spent_outputs()).unwrap().len() + tx.body.outputs().len() - 1, + oms.get_unspent_outputs().await.unwrap().len(), + num_outputs + 1 - oms.get_spent_outputs().await.unwrap().len() + tx.body.outputs().len() - 1, "Unspent outputs" ); @@ -675,16 +664,14 @@ fn sending_transaction_and_confirmation() { } } -#[test] -fn send_not_enough_funds() { +#[tokio::test] +async fn send_not_enough_funds() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { let (_ti, uo) = make_input( @@ -692,68 +679,70 @@ fn send_not_enough_funds() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(num_outputs * 2000), - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(num_outputs * 2000), + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn send_no_change() { +#[tokio::test] +async fn send_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - let mut stp = runtime - .block_on(oms.prepare_transaction_to_send( + let mut stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(value1 + value2) - fee_without_change, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); assert_eq!(stp.get_amount_to_self().unwrap(), MicroTari::from(0)); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let msg = stp.build_single_round_message().unwrap(); @@ -776,99 +765,91 @@ fn send_no_change() { let tx = stp.get_transaction().unwrap(); - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), - tx.body.inputs().len() - ); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len()); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); } -#[test] -fn send_not_enough_for_change() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn send_not_enough_for_change() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(value1 + value2 + 1) - fee_without_change, - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(value1 + value2 + 1) - fee_without_change, + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn receiving_and_confirmation() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn receiving_and_confirmation() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(tx_id, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); } -#[test] -fn cancel_transaction() { +#[tokio::test] +async fn cancel_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -877,46 +858,43 @@ fn cancel_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - match runtime.block_on(oms.cancel_transaction(1)) { + match oms.cancel_transaction(1).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::ValueNotFound)) => {}, _ => panic!("Value should not exist"), } - runtime - .block_on(oms.cancel_transaction(stp.get_tx_id().unwrap())) - .unwrap(); + oms.cancel_transaction(stp.get_tx_id().unwrap()).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn cancel_transaction_and_reinstate_inbound_tx() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn cancel_transaction_and_reinstate_inbound_tx() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); - let pending_txs = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_txs = oms.get_pending_transactions().await.unwrap(); assert_eq!(pending_txs.len(), 1); @@ -928,7 +906,7 @@ fn cancel_transaction_and_reinstate_inbound_tx() { .unwrap() .clone(); - runtime.block_on(oms.cancel_transaction(tx_id)).unwrap(); + oms.cancel_transaction(tx_id).await.unwrap(); let cancelled_output = backend .fetch(&DbKey::OutputsByTxIdAndStatus(tx_id, OutputStatus::CancelledInbound)) @@ -942,28 +920,25 @@ fn cancel_transaction_and_reinstate_inbound_tx() { panic!("Should have found cancelled output"); } - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); - runtime - .block_on(oms.reinstate_cancelled_inbound_transaction(tx_id)) - .unwrap(); + oms.reinstate_cancelled_inbound_transaction(tx_id).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_incoming_balance, value); } -#[test] -fn timeout_transaction() { +#[tokio::test] +async fn timeout_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -972,50 +947,43 @@ fn timeout_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let remaining_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap().len(); + let remaining_outputs = oms.get_unspent_outputs().await.unwrap().len(); - thread::sleep(Duration::from_millis(2)); + time::sleep(Duration::from_millis(2)).await; - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1000))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1000)).await.unwrap(); - assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - remaining_outputs - ); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), remaining_outputs); - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1)).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn test_get_balance() { +#[tokio::test] +async fn test_get_balance() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(MicroTari::from(0), balance.available_balance); @@ -1023,63 +991,62 @@ fn test_get_balance() { let output_val = MicroTari::from(2000); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let send_value = MicroTari::from(1000); - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), send_value, MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let change_val = stp.get_change_amount().unwrap(); let recv_value = MicroTari::from(1500); let (_tx_id, sender_message) = generate_sender_transaction_message(recv_value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(output_val, balance.available_balance); assert_eq!(recv_value + change_val, balance.pending_incoming_balance); assert_eq!(output_val, balance.pending_outgoing_balance); } -#[test] -fn test_confirming_received_output() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn test_confirming_received_output() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output.clone()])) + oms.confirm_transaction(tx_id, vec![], vec![output.clone()]) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value); let factories = CryptoFactories::default(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1090,99 +1057,100 @@ fn test_confirming_received_output() { assert_eq!(rewind_result.committed_value, value); } -#[test] -fn sending_transaction_with_short_term_clear() { +#[tokio::test] +async fn sending_transaction_with_short_term_clear() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let available_balance = 10_000 * uT; let (_ti, uo) = make_input(&mut OsRng.clone(), available_balance, &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // Check that funds are encumbered and then unencumbered if the pending tx is not confirmed before restart - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); let expected_change = balance.pending_incoming_balance; assert_eq!(balance.pending_outgoing_balance, available_balance); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that a unconfirm Pending Transaction can be cancelled - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - runtime.block_on(oms.cancel_transaction(sender_tx_id)).unwrap(); + oms.cancel_transaction(sender_tx_id).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that is the pending tx is confirmed that the encumberance persists after restart - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - runtime.block_on(oms.confirm_pending_transaction(sender_tx_id)).unwrap(); + oms.confirm_pending_transaction(sender_tx_id).await.unwrap(); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, expected_change); } -#[test] -fn coin_split_with_change() { +#[tokio::test] +async fn coin_split_with_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let val1 = 6_000 * uT; let val2 = 7_000 * uT; @@ -1190,14 +1158,15 @@ fn coin_split_with_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); let fee_per_gram = MicroTari::from(25); let split_count = 8; - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 2); assert_eq!(coin_split_tx.body.outputs().len(), split_count + 1); @@ -1205,13 +1174,12 @@ fn coin_split_with_change() { assert_eq!(amount, val2 + val3); } -#[test] -fn coin_split_no_change() { +#[tokio::test] +async fn coin_split_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(25); let split_count = 15; @@ -1222,12 +1190,13 @@ fn coin_split_no_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 3); assert_eq!(coin_split_tx.body.outputs().len(), split_count); @@ -1235,13 +1204,12 @@ fn coin_split_no_change() { assert_eq!(amount, val1 + val2 + val3); } -#[test] -fn handle_coinbase() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn handle_coinbase() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let reward1 = MicroTari::from(1000); let fees1 = MicroTari::from(500); @@ -1253,37 +1221,25 @@ fn handle_coinbase() { let fees3 = MicroTari::from(500); let value3 = reward3 + fees3; - let _ = runtime - .block_on(oms.get_coinbase_transaction(1, reward1, fees1, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value1 - ); - let _tx2 = runtime - .block_on(oms.get_coinbase_transaction(2, reward2, fees2, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - let tx3 = runtime - .block_on(oms.get_coinbase_transaction(3, reward3, fees3, 2)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 2); + let _ = oms.get_coinbase_transaction(1, reward1, fees1, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value1); + let _tx2 = oms.get_coinbase_transaction(2, reward2, fees2, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); + let tx3 = oms.get_coinbase_transaction(3, reward3, fees3, 2).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, + oms.get_balance().await.unwrap().pending_incoming_balance, value2 + value3 ); let output = tx3.body.outputs()[0].clone(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1293,28 +1249,22 @@ fn handle_coinbase() { .unwrap(); assert_eq!(rewind_result.committed_value, value3); - runtime - .block_on(oms.confirm_transaction(3, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(3, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value3); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value3); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_outgoing_balance, + oms.get_balance().await.unwrap().pending_outgoing_balance, MicroTari::from(0) ); } -#[test] -fn test_utxo_stxo_invalid_txo_validation() { +#[tokio::test] +async fn test_utxo_stxo_invalid_txo_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1374,8 +1324,8 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1386,7 +1336,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1396,7 +1346,7 @@ fn test_utxo_stxo_invalid_txo_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1407,7 +1357,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3.clone())).unwrap(); + oms.add_output(unspent_output3.clone()).await.unwrap(); let unspent_value4 = 901; let unspent_output4 = create_unblinded_output( @@ -1418,44 +1368,42 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output4 = unspent_output4.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output4.clone())).unwrap(); + oms.add_output(unspent_output4.clone()).await.unwrap(); rpc_service_state.set_utxos(vec![invalid_output.as_transaction_output(&factories).unwrap()]); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); @@ -1466,36 +1414,34 @@ fn test_utxo_stxo_invalid_txo_validation() { unspent_tx_output3, ]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(3, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(3, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = (*msg).clone() { - success = true; - break; - }; - }; - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 4); assert!(outputs.iter().any(|o| o == &unspent_output1)); @@ -1505,46 +1451,45 @@ fn test_utxo_stxo_invalid_txo_validation() { rpc_service_state.set_utxos(vec![spent_tx_output1]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { + success = true; + break; + }; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); assert!(outputs.iter().any(|o| o == &spent_output1)); } -#[test] -fn test_base_node_switch_during_validation() { +#[tokio::test] +async fn test_base_node_switch_during_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1556,8 +1501,8 @@ fn test_base_node_switch_during_validation() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1568,7 +1513,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1578,7 +1523,7 @@ fn test_base_node_switch_during_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1589,7 +1534,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3)).unwrap(); + oms.add_output(unspent_output3).await.unwrap(); // First RPC server state rpc_service_state.set_utxos(vec![unspent_tx_output1, unspent_tx_output3]); @@ -1598,53 +1543,52 @@ fn test_base_node_switch_during_validation() { // New base node we will switch to let new_server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime - .block_on(oms.set_base_node_public_key(new_server_node_identity.public_key().clone())) + oms.set_base_node_public_key(new_server_node_identity.public_key().clone()) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut abort = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { - abort = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut abort = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { + abort = true; + break; + } + } + }, + () = &mut delay => { + break; + }, } - assert!(abort, "Did not receive validation abort"); - }); + } + assert!(abort, "Did not receive validation abort"); } -#[test] -fn test_txo_validation_connection_timeout_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_connection_timeout_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, _rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, false); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, false).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1654,7 +1598,7 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1664,57 +1608,54 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut timeout = 0; - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - match (*msg).clone() { - OutputManagerEvent::TxoValidationTimedOut(_,_) => { - timeout+=1; - }, - OutputManagerEvent::TxoValidationFailure(_,_) => { - failed+=1; - }, - _ => (), - } - }; - if timeout+failed >= 3 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)); + tokio::pin!(delay); + let mut timeout = 0; + let mut failed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + match &*event { + OutputManagerEvent::TxoValidationTimedOut(_,_) => { + timeout+=1; + }, + OutputManagerEvent::TxoValidationFailure(_,_) => { + failed+=1; + }, + _ => (), + } + + if timeout+failed >= 3 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - assert_eq!(timeout, 2); - }); + } + assert_eq!(failed, 1); + assert_eq!(timeout, 2); } -#[test] -fn test_txo_validation_rpc_error_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_error_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_rpc_status_error(Some(RpcStatus::bad_request("blah".to_string()))); let unspent_value1 = 500; @@ -1725,7 +1666,7 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1735,44 +1676,42 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } - } - - if failed >= 1 { - break; + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { + failed+=1; } - }, - () = delay => { + } + + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_rpc_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_timeout() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1784,8 +1723,8 @@ fn test_txo_validation_rpc_timeout() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(120))); let unspent_value1 = 500; @@ -1796,7 +1735,7 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1806,57 +1745,51 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for( - RpcClientConfig::default().deadline.unwrap() + - RpcClientConfig::default().deadline_grace_period + - Duration::from_secs(30), - ) - .fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } + let delay = + time::sleep(RpcClientConfig::default().timeout_with_grace_period().unwrap() + Duration::from_secs(30)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = &*msg { + failed+=1; } + } - if failed >= 1 { - break; - } - }, - () = delay => { + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_base_node_not_synced() { +#[tokio::test] +async fn test_txo_validation_base_node_not_synced() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_is_synced(false); let unspent_value1 = 500; @@ -1868,7 +1801,7 @@ fn test_txo_validation_base_node_not_synced() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1878,74 +1811,67 @@ fn test_txo_validation_base_node_not_synced() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut delayed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationDelayed(_,_) = (*msg).clone() { - delayed+=1; - } - } - if delayed >= 2 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut delayed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationDelayed(_,_) = &*event { + delayed += 1; + } + if delayed >= 2 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(delayed, 2); - }); + } + assert_eq!(delayed, 2); rpc_service_state.set_is_synced(true); rpc_service_state.set_utxos(vec![unspent_tx_output1]); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,_) = (*msg).clone() { - success = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,_) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 1); assert!(outputs.iter().any(|o| o == &unspent_output1)); } -#[test] -fn test_oms_key_manager_discrepancy() { +#[tokio::test] +async fn test_oms_key_manager_discrepancy() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (_oms_request_sender, oms_request_receiver) = reply_channel::unbounded(); let (oms_event_publisher, _) = broadcast::channel(200); @@ -1959,7 +1885,7 @@ fn test_oms_key_manager_discrepancy() { let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, _connectivity_mock) = create_connectivity_mock(); @@ -1968,45 +1894,45 @@ fn test_oms_key_manager_discrepancy() { let master_key1 = CommsSecretKey::random(&mut OsRng); - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1.clone(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1.clone(), + ) + .await + .unwrap(); drop(output_manager_service); let (_oms_request_sender2, oms_request_receiver2) = reply_channel::unbounded(); - let output_manager_service2 = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver2, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1, - )) - .expect("Should be able to make a new OMS with same master key"); + let output_manager_service2 = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver2, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1, + ) + .await + .expect("Should be able to make a new OMS with same master key"); drop(output_manager_service2); let (_oms_request_sender3, oms_request_receiver3) = reply_channel::unbounded(); let master_key2 = CommsSecretKey::random(&mut OsRng); - let output_manager_service3 = runtime.block_on(OutputManagerService::new( + let output_manager_service3 = OutputManagerService::new( OutputManagerServiceConfig::default(), ts_handle, oms_request_receiver3, @@ -2018,7 +1944,8 @@ fn test_oms_key_manager_discrepancy() { basenode_service_handle, connectivity_manager, master_key2, - )); + ) + .await; assert!(matches!( output_manager_service3, @@ -2026,26 +1953,25 @@ fn test_oms_key_manager_discrepancy() { )); } -#[test] -fn get_coinbase_tx_for_same_height() { +#[tokio::test] +async fn get_coinbase_tx_for_same_height() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); - let mut runtime = Runtime::new().unwrap(); let (mut oms, _shutdown, _, _, _, _, _) = - setup_output_manager_service(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), true); + setup_output_manager_service(OutputManagerSqliteDatabase::new(connection, None), true).await; - runtime - .block_on(oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(pending_transactions.values().any(|p| p.tx_id == 1)); - runtime - .block_on(oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(!pending_transactions.values().any(|p| p.tx_id == 1)); assert!(pending_transactions.values().any(|p| p.tx_id == 2)); } diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index c0609da64c..c9e0dd2938 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -50,7 +50,7 @@ use tokio::runtime::Runtime; #[allow(clippy::same_item_push)] pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db = OutputManagerDatabase::new(backend); let factories = CryptoFactories::default(); @@ -392,7 +392,7 @@ pub fn test_output_manager_sqlite_db_encrypted() { #[test] pub fn test_key_manager_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let db = OutputManagerDatabase::new(backend); @@ -429,7 +429,7 @@ pub fn test_key_manager_crud() { assert_eq!(read_state3.primary_key_index, 2); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_short_term_encumberance() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); @@ -510,7 +510,7 @@ pub async fn test_short_term_encumberance() { ); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_no_duplicate_outputs() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index f9d8010ac7..1b1243d72b 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -20,8 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{ message::MessageTag, multiaddr::Multiaddr, @@ -32,7 +31,7 @@ use tari_comms::{ }; use tari_comms_dht::{envelope::DhtMessageHeader, Dht}; use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, + comms_connector::InboundDomainConnector, domain_message::DomainMessage, initialization::initialize_local_test_comms, }; @@ -43,18 +42,14 @@ pub fn get_next_memory_address() -> Multiaddr { format!("/memory/{}", port).parse().unwrap() } -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, database_path: String, discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, _) = initialize_local_test_comms( node_identity, diff --git a/base_layer/wallet/tests/support/rpc.rs b/base_layer/wallet/tests/support/rpc.rs index 0f7009bdd0..fd1aa0cdbd 100644 --- a/base_layer/wallet/tests/support/rpc.rs +++ b/base_layer/wallet/tests/support/rpc.rs @@ -57,7 +57,7 @@ use tari_core::{ types::Signature, }, }; -use tokio::time::delay_for; +use tokio::time::sleep; /// This macro unlocks a Mutex or RwLock. If the lock is /// poisoned (i.e. panic while unlocked) the last value @@ -212,7 +212,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -234,7 +234,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -256,7 +256,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -276,7 +276,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err("Did not receive enough calls within the timeout period".to_string()) } @@ -318,7 +318,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -345,7 +345,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -371,7 +371,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -415,7 +415,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -448,7 +448,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { async fn get_tip_info(&self, _request: Request<()>) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } log::info!("Get tip info call received"); @@ -493,7 +493,7 @@ mod test { }; use tokio::time::Duration; - #[tokio_macros::test] + #[tokio::test] async fn test_wallet_rpc_mock() { let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let client_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index dc508f7321..f36557c8fa 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -33,7 +33,6 @@ use futures::{ channel::{mpsc, mpsc::Sender}, FutureExt, SinkExt, - StreamExt, }; use prost::Message; use rand::{rngs::OsRng, RngCore}; @@ -142,14 +141,13 @@ use tokio::{ runtime, runtime::{Builder, Runtime}, sync::{broadcast, broadcast::channel}, - time::delay_for, + time::sleep, }; fn create_runtime() -> Runtime { - Builder::new() - .threaded_scheduler() + Builder::new_multi_thread() .enable_all() - .core_threads(8) + .worker_threads(8) .build() .unwrap() } @@ -172,7 +170,8 @@ pub fn setup_transaction_service< discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, ) -> (TransactionServiceHandle, OutputManagerHandle, CommsNode) { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let _enter = runtime.enter(); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht) = runtime.block_on(setup_comms_services( node_identity, @@ -303,11 +302,15 @@ pub fn setup_transaction_service_no_comms_and_oms_backend< let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(server, server_node_identity.clone()) + }; - runtime.handle().enter(|| mock_server.serve()); + { + let _enter = runtime.handle().enter(); + mock_server.serve(); + } let connection = runtime.block_on(async { mock_server @@ -504,9 +507,9 @@ fn manage_single_transaction() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( &mut runtime, @@ -524,7 +527,7 @@ fn manage_single_transaction() { .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); let _ = runtime.block_on( bob_comms @@ -556,18 +559,19 @@ fn manage_single_transaction() { .expect("Alice sending tx"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut count = 0; loop { - futures::select! { - _event = alice_event_stream.select_next_some() => { + tokio::select! { + _event = alice_event_stream.recv() => { println!("alice: {:?}", &*_event.as_ref().unwrap()); count+=1; if count>=2 { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -576,18 +580,19 @@ fn manage_single_transaction() { let mut tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { println!("bob: {:?}", &*event.as_ref().unwrap()); if let TransactionEvent::ReceivedFinalizedTransaction(id) = &*event.unwrap() { tx_id = *id; finalized+=1; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -747,7 +752,7 @@ fn send_one_sided_transaction_to_other() { shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -792,11 +797,12 @@ fn send_one_sided_transaction_to_other() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut found = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCompletedImmediately(id) = &*event.unwrap() { if id == &tx_id { found = true; @@ -804,7 +810,7 @@ fn send_one_sided_transaction_to_other() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1071,9 +1077,9 @@ fn manage_multiple_transactions() { Duration::from_secs(60), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); // Spin up Bob and Carol let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( @@ -1088,8 +1094,8 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + let mut bob_event_stream = bob_ts.get_event_stream(); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( &mut runtime, @@ -1103,18 +1109,18 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); // Establish some connections beforehand, to reduce the amount of work done concurrently in tests // Connect Bob and Alice - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); let _ = runtime.block_on( bob_comms .connectivity() .dial_peer(alice_node_identity.node_id().clone()), ); - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); // Connect alice to carol let _ = runtime.block_on( @@ -1182,12 +1188,13 @@ fn manage_multiple_transactions() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1198,7 +1205,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1210,12 +1217,14 @@ fn manage_multiple_transactions() { log::trace!("Alice received all Tx messages"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1225,7 +1234,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1235,14 +1244,17 @@ fn manage_multiple_transactions() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); + + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event.unwrap() { finalized+=1 } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1264,7 +1276,7 @@ fn manage_multiple_transactions() { assert_eq!(carol_pending_inbound.len(), 0); assert_eq!(carol_completed_tx.len(), 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; bob_comms.wait_until_shutdown().await; @@ -1303,7 +1315,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); @@ -1355,11 +1367,14 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); +tokio::pin!(delay); + + tokio::pin!(delay); let mut errors = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { log::error!("ERROR: {:?}", event); if let TransactionEvent::Error(s) = &*event.unwrap() { if s == &"TransactionProtocolError(TransactionBuildError(InvalidSignatureError(\"Verifying kernel signature\")))".to_string() @@ -1371,7 +1386,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1415,7 +1430,7 @@ fn finalize_tx_with_incorrect_pubkey() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1488,15 +1503,18 @@ fn finalize_tx_with_incorrect_pubkey() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event!"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1542,7 +1560,7 @@ fn finalize_tx_with_missing_output() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1623,15 +1641,18 @@ fn finalize_tx_with_missing_output() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1648,8 +1669,7 @@ fn discovery_async_return_test() { let db_tempdir = tempdir().unwrap(); let db_folder = db_tempdir.path(); - let mut runtime = runtime::Builder::new() - .basic_scheduler() + let mut runtime = runtime::Builder::new_current_thread() .enable_time() .thread_name("discovery_async_return_test") .build() @@ -1714,7 +1734,7 @@ fn discovery_async_return_test() { Duration::from_secs(20), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo1a) = make_input(&mut OsRng, MicroTari(5500), &factories.commitment); runtime.block_on(alice_oms.add_output(uo1a)).unwrap(); @@ -1741,17 +1761,20 @@ fn discovery_async_return_test() { let mut txid = 0; let mut is_success = true; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, result) = (*event.unwrap()).clone() { txid = tx_id; is_success = result; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to fail sending"); }, } @@ -1772,18 +1795,21 @@ fn discovery_async_return_test() { let mut success_result = false; let mut success_tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, success) = &*event.unwrap() { success_result = *success; success_tx_id = *tx_id; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to successfully be sent"); }, } @@ -1794,24 +1820,26 @@ fn discovery_async_return_test() { assert!(success_result); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap() { if tx_id == &tx_id2 { break; } } }, - () = delay => { + () = &mut delay => { panic!("Timeout while Alice was waiting for a transaction reply"); }, } } }); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; carol_comms.wait_until_shutdown().await; @@ -2012,7 +2040,7 @@ fn test_transaction_cancellation() { ..Default::default() }), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let alice_total_available = 250000 * uT; let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); @@ -2030,15 +2058,17 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionStoreForwardSendResult(_,_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2054,7 +2084,7 @@ fn test_transaction_cancellation() { None => (), Some(_) => break, } - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); if i >= 12 { panic!("Pending outbound transaction should have been added by now"); } @@ -2066,17 +2096,18 @@ fn test_transaction_cancellation() { // Wait for cancellation event, in an effort to nail down where the issue is for the flakey CI test runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2143,15 +2174,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2213,15 +2245,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2243,7 +2276,7 @@ fn test_transaction_cancellation() { ))) .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime .block_on(alice_ts.get_pending_inbound_transactions()) @@ -2257,17 +2290,18 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)).fuse(); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2386,7 +2420,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob_outbound_service.call_count(), 0, "Should be no more calls"); let (_wallet_backend, backend, oms_backend, _, _temp_dir) = make_wallet_databases(None); @@ -2428,7 +2462,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob2_outbound_service.call_count(), 0, "Should be no more calls"); // Test finalize is sent Direct Only. @@ -2449,7 +2483,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { let _ = alice_outbound_service.pop_call().unwrap(); let _ = alice_outbound_service.pop_call().unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls"); // Now to repeat sending so we can test the SAF send of the finalize message @@ -2520,7 +2554,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { assert_eq!(alice_outbound_service.call_count(), 1); let _ = alice_outbound_service.pop_call(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls2"); } @@ -2548,7 +2582,7 @@ fn test_tx_direct_send_behaviour() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, 1000000 * uT, &factories.commitment); runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); @@ -2576,12 +2610,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if !result { saf_count+=1}, _ => (), @@ -2591,7 +2626,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2619,12 +2654,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 @@ -2635,7 +2671,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2663,11 +2699,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if *result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, _) => panic!("Should be no SAF messages"), @@ -2678,7 +2715,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2705,11 +2742,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 }, TransactionEvent::TransactionDirectSendResult(_, result) => if *result { panic!( @@ -2720,7 +2758,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2852,7 +2890,7 @@ fn test_restarting_transaction_protocols() { // Test that Bob's node restarts the send protocol let (mut bob_ts, _bob_oms, _bob_outbound_service, _, _, mut bob_tx_reply, _, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, bob_oms_backend, None); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); runtime .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2864,18 +2902,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_reply = false; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_reply = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2886,7 +2925,7 @@ fn test_restarting_transaction_protocols() { // Test Alice's node restarts the receive protocol let (mut alice_ts, _alice_oms, _alice_outbound_service, _, _, _, mut alice_tx_finalized, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories, alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2906,18 +2945,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_finalized = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_finalized = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3046,7 +3086,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3131,11 +3171,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3145,7 +3186,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3170,11 +3211,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3184,7 +3226,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3215,7 +3257,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3301,11 +3343,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMinedUnconfirmed(tx_id, _) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3316,7 +3359,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3331,10 +3374,14 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); - runtime.handle().enter(|| new_mock_server.serve()); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); runtime.block_on(connectivity_mock_state.add_active_connection(connection)); @@ -3368,11 +3415,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3382,7 +3430,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3413,7 +3461,7 @@ fn test_coinbase_monitoring_mined_not_synced() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3499,11 +3547,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3513,7 +3562,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3538,11 +3587,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3552,7 +3602,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3760,7 +3810,7 @@ fn test_transaction_resending() { assert_eq!(bob_reply_message.tx_id, tx_id); } - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); // See if sending a second message too soon is ignored runtime .block_on(bob_tx_sender.send(create_dummy_message( @@ -3772,7 +3822,7 @@ fn test_transaction_resending() { assert!(bob_outbound_service.wait_call_count(1, Duration::from_secs(2)).is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(bob_tx_sender.send(create_dummy_message( alice_sender_message.into(), @@ -3819,7 +3869,7 @@ fn test_transaction_resending() { .is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(alice_tx_reply_sender.send(create_dummy_message( @@ -4143,7 +4193,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(data.tx_id, tx_id); } // Need a moment for Alice's wallet to finish writing to its database before cancelling - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime.block_on(alice_ts.cancel_transaction(tx_id)).unwrap(); @@ -4193,7 +4243,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(bob_reply_message.tx_id, tx_id); // Wait for cooldown to expire - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let _ = alice_outbound_service.take_calls(); @@ -4406,7 +4456,7 @@ fn test_transaction_timeout_cancellation() { ..Default::default() }), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); runtime .block_on(carol_tx_sender.send(create_dummy_message( @@ -4431,11 +4481,12 @@ fn test_transaction_timeout_cancellation() { assert_eq!(carol_reply_message.tx_id, tx_id); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut transaction_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(t) = &*event.unwrap() { if t == &tx_id { transaction_cancelled = true; @@ -4443,7 +4494,7 @@ fn test_transaction_timeout_cancellation() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4481,7 +4532,7 @@ fn transaction_service_tx_broadcast() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4609,11 +4660,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -4621,7 +4673,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4655,11 +4707,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_mined = true; @@ -4667,7 +4720,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4683,11 +4736,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_received = true; @@ -4695,7 +4749,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4733,11 +4787,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_cancelled = true; @@ -4745,7 +4800,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4851,15 +4906,16 @@ fn broadcast_all_completed_transactions_on_startup() { assert!(runtime.block_on(alice_ts.restart_broadcast_protocols()).is_ok()); - let mut event_stream = alice_ts.get_event_stream_fused(); + let mut event_stream = alice_ts.get_event_stream(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut found1 = false; let mut found2 = false; let mut found3 = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBroadcast(tx_id) = (*event.unwrap()).clone() { if tx_id == 1u64 { found1 = true @@ -4876,7 +4932,7 @@ fn broadcast_all_completed_transactions_on_startup() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4916,7 +4972,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4995,11 +5051,12 @@ fn transaction_service_tx_broadcast_with_base_node_change() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -5007,7 +5064,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5038,11 +5095,15 @@ fn transaction_service_tx_broadcast_with_base_node_change() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); @@ -5075,17 +5136,18 @@ fn transaction_service_tx_broadcast_with_base_node_change() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(_) = &*event.unwrap(){ tx_mined = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5350,11 +5412,15 @@ fn start_validation_protocol_then_broadcast_protocol_change_base_node() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); diff --git a/base_layer/wallet/tests/transaction_service/storage.rs b/base_layer/wallet/tests/transaction_service/storage.rs index 6ea420e6a4..5a2e251403 100644 --- a/base_layer/wallet/tests/transaction_service/storage.rs +++ b/base_layer/wallet/tests/transaction_service/storage.rs @@ -60,7 +60,7 @@ use tempfile::tempdir; use tokio::runtime::Runtime; pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let mut db = TransactionDatabase::new(backend); let factories = CryptoFactories::default(); let input = create_unblinded_output( diff --git a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs index 66079613ca..7379c682ab 100644 --- a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs +++ b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs @@ -25,9 +25,9 @@ use crate::support::{ utils::make_input, }; use chrono::Utc; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use rand::rngs::OsRng; -use std::{sync::Arc, thread::sleep, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{ peer_manager::PeerFeatures, protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, @@ -80,7 +80,7 @@ use tari_wallet::{ types::ValidationRetryStrategy, }; use tempfile::{tempdir, TempDir}; -use tokio::{sync::broadcast, task, time::delay_for}; +use tokio::{sync::broadcast, task, time::sleep}; // Just in case other options become apparent in later testing #[derive(PartialEq)] @@ -230,7 +230,7 @@ pub async fn oms_reply_channel_task( } /// A happy path test by submitting a transaction into the mempool, have it mined but unconfirmed and then confirmed. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_i() { let ( @@ -245,7 +245,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _temp_dir, mut transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); let protocol = TransactionBroadcastProtocol::new( @@ -353,7 +353,8 @@ async fn tx_broadcast_protocol_submit_success_i() { .unwrap(); // lets wait for the transaction service event to notify us of a confirmed tx // We need to do this to ensure that the wallet db has been updated to "Mined" - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMined(_) => { @@ -392,13 +393,14 @@ async fn tx_broadcast_protocol_submit_success_i() { ); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(5)).fuse(); + let delay = sleep(Duration::from_secs(5)); + tokio::pin!(delay); let mut broadcast = false; let mut unconfirmed = false; let mut confirmed = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionMinedUnconfirmed(_, confirmations) => if *confirmations == 1 { unconfirmed = true; @@ -412,7 +414,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _ => (), } }, - () = delay => { + () = &mut delay => { break; }, } @@ -426,7 +428,7 @@ async fn tx_broadcast_protocol_submit_success_i() { } /// Test submitting a transaction that is immediately rejected -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_rejection() { let ( @@ -441,7 +443,7 @@ async fn tx_broadcast_protocol_submit_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -478,16 +480,17 @@ async fn tx_broadcast_protocol_submit_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -498,7 +501,7 @@ async fn tx_broadcast_protocol_submit_rejection() { /// Test restarting a protocol which means the first step is a query not a submission, detecting the Tx is not in the /// mempool, resubmit the tx and then have it mined -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_restart_protocol_as_query() { let ( @@ -585,7 +588,7 @@ async fn tx_broadcast_protocol_restart_protocol_as_query() { /// This test will submit a Tx which will be accepted and then dropped from the mempool, resulting in a resubmit which /// will be rejected and result in a cancelled transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { let ( @@ -600,7 +603,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -666,16 +669,17 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -686,7 +690,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { /// This test will submit a tx which is accepted and mined but unconfirmed, then the next query it will not exist /// resulting in a resubmission which we will let run to being mined with success -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { let ( @@ -751,14 +755,14 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { // Wait for the "TransactionMinedUnconfirmed" tx event to ensure that the wallet db state is "MinedUnconfirmed" let mut count = 0u16; - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMinedUnconfirmed(_, _) => { break; }, _ => { - sleep(Duration::from_millis(1000)); count += 1; if count >= 10 { break; @@ -806,7 +810,7 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { } /// Test being unable to connect and then connection becoming available. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_connection_problem() { let ( @@ -823,7 +827,7 @@ async fn tx_broadcast_protocol_connection_problem() { ) = setup(TxProtocolTestConfig::WithoutConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -839,11 +843,12 @@ async fn tx_broadcast_protocol_connection_problem() { let join_handle = task::spawn(protocol.execute()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut connection_issues = 0; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBaseNodeConnectionProblem(_) = &*event.unwrap() { connection_issues+=1; } @@ -851,7 +856,7 @@ async fn tx_broadcast_protocol_connection_problem() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -878,7 +883,7 @@ async fn tx_broadcast_protocol_connection_problem() { } /// Submit a transaction that is Already Mined for the submission, the subsequent query should confirm the transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_already_mined() { let ( @@ -948,7 +953,7 @@ async fn tx_broadcast_protocol_submit_already_mined() { } /// A test to see that the broadcast protocol can handle a change to the base node address while it runs. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { let ( @@ -1050,7 +1055,7 @@ async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_valid() { let ( @@ -1148,7 +1153,7 @@ async fn tx_validation_protocol_tx_becomes_valid() { } /// Validate completed transaction, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_invalid() { let ( @@ -1213,7 +1218,7 @@ async fn tx_validation_protocol_tx_becomes_invalid() { } /// Validate completed transactions, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_unconfirmed() { let ( @@ -1285,7 +1290,7 @@ async fn tx_validation_protocol_tx_becomes_unconfirmed() { /// Test the validation protocol reacts correctly to a change in base node and redoes the full validation based on the /// new base node -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_ends_on_base_node_end() { let ( @@ -1302,7 +1307,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1398,16 +1403,17 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { let result = join_handle.await.unwrap(); assert!(result.is_ok()); - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut aborted = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionValidationAborted(_) = &*event.unwrap() { aborted = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1416,7 +1422,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { } /// Test the validation protocol reacts correctly when the RPC client returns an error between calls. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_between_calls() { let ( @@ -1540,7 +1546,7 @@ async fn tx_validation_protocol_rpc_client_broken_between_calls() { /// Test the validation protocol reacts correctly when the RPC client returns an error between calls and only retry /// finite amount of times -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_finite_retries() { let ( @@ -1557,7 +1563,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, 1 * T, @@ -1610,12 +1616,13 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { assert!(result.is_err()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut timeouts = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { log::error!("EVENT: {:?}", event); match &*event.unwrap() { TransactionEvent::TransactionValidationTimedOut(_) => { @@ -1630,7 +1637,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1641,7 +1648,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_base_node_not_synced() { let ( @@ -1658,7 +1665,7 @@ async fn tx_validation_protocol_base_node_not_synced() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1711,12 +1718,13 @@ async fn tx_validation_protocol_base_node_not_synced() { let result = join_handle.await.unwrap(); assert!(result.is_err()); - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut delayed = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionValidationDelayed(_) => { delayed +=1 ; @@ -1728,7 +1736,7 @@ async fn tx_validation_protocol_base_node_not_synced() { } }, - () = delay => { + () = &mut delay => { break; }, } diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index 6aa231b09c..1959ee349f 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -28,7 +28,6 @@ use aes_gcm::{ Aes256Gcm, }; use digest::Digest; -use futures::{FutureExt, StreamExt}; use rand::rngs::OsRng; use std::{panic, path::Path, sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; @@ -71,7 +70,7 @@ use tari_wallet::{ WalletSqlite, }; use tempfile::tempdir; -use tokio::{runtime::Runtime, time::delay_for}; +use tokio::{runtime::Runtime, time::sleep}; fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { Peer::new( @@ -163,7 +162,7 @@ async fn create_wallet( .await } -#[tokio_macros::test] +#[tokio::test] async fn test_wallet() { let mut shutdown_a = Shutdown::new(); let mut shutdown_b = Shutdown::new(); @@ -227,7 +226,7 @@ async fn test_wallet() { .await .unwrap(); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream(); let value = MicroTari::from(1000); let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); @@ -245,15 +244,16 @@ async fn test_wallet() { .await .unwrap(); - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut reply_count = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { - reply_count = true; - break; - }, - () = delay => { + tokio::select! { + event = alice_event_stream.recv() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { + reply_count = true; + break; + }, + () = &mut delay => { break; }, } @@ -298,7 +298,7 @@ async fn test_wallet() { } drop(alice_event_stream); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -343,7 +343,7 @@ async fn test_wallet() { alice_wallet.remove_encryption().await.unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -379,7 +379,7 @@ async fn test_wallet() { .await .unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; partial_wallet_backup(current_wallet_path.clone(), backup_wallet_path.clone()) @@ -400,12 +400,12 @@ async fn test_wallet() { let master_secret_key = backup_wallet_db.get_master_secret_key().await.unwrap(); assert!(master_secret_key.is_none()); - shutdown_b.trigger().unwrap(); + shutdown_b.trigger(); bob_wallet.wait_until_shutdown().await; } -#[tokio_macros::test] +#[tokio::test] async fn test_do_not_overwrite_master_key() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -423,7 +423,7 @@ async fn test_do_not_overwrite_master_key() { ) .await .unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); wallet.wait_until_shutdown().await; // try to use a new master key to create a wallet using the existing wallet database @@ -457,7 +457,7 @@ async fn test_do_not_overwrite_master_key() { .unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn test_sign_message() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -516,9 +516,9 @@ fn test_store_and_forward_send_tx() { let bob_db_tempdir = tempdir().unwrap(); let carol_db_tempdir = tempdir().unwrap(); - let mut alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); let mut alice_wallet = alice_runtime .block_on(create_wallet( @@ -554,7 +554,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); let carol_identity = (*carol_wallet.comms.node_identity()).clone(); - shutdown_c.trigger().unwrap(); + shutdown_c.trigger(); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); alice_runtime @@ -591,13 +591,13 @@ fn test_store_and_forward_send_tx() { .unwrap(); // Waiting here for a while to make sure the discovery retry is over - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); alice_runtime .block_on(alice_wallet.transaction_service.cancel_transaction(tx_id)) .unwrap(); - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); let carol_wallet = carol_runtime .block_on(create_wallet( @@ -610,7 +610,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); - let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream_fused(); + let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream(); carol_runtime .block_on(carol_wallet.comms.peer_manager().add_peer(create_peer( @@ -623,13 +623,14 @@ fn test_store_and_forward_send_tx() { .unwrap(); carol_runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_recv = false; let mut tx_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransaction(_) => tx_recv = true, TransactionEvent::TransactionCancelled(_) => tx_cancelled = true, @@ -639,7 +640,7 @@ fn test_store_and_forward_send_tx() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -647,15 +648,15 @@ fn test_store_and_forward_send_tx() { assert!(tx_recv, "Must have received a tx from alice"); assert!(tx_cancelled, "Must have received a cancel tx from alice"); }); - shutdown_a.trigger().unwrap(); - shutdown_b.trigger().unwrap(); - shutdown_c2.trigger().unwrap(); + shutdown_a.trigger(); + shutdown_b.trigger(); + shutdown_c2.trigger(); alice_runtime.block_on(alice_wallet.wait_until_shutdown()); bob_runtime.block_on(bob_wallet.wait_until_shutdown()); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); } -#[tokio_macros::test] +#[tokio::test] async fn test_import_utxo() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); diff --git a/base_layer/wallet_ffi/Cargo.toml b/base_layer/wallet_ffi/Cargo.toml index d7381c0d8e..868d301e97 100644 --- a/base_layer/wallet_ffi/Cargo.toml +++ b/base_layer/wallet_ffi/Cargo.toml @@ -17,11 +17,11 @@ tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } tari_utilities = "^0.3" futures = { version = "^0.3.1", features =["compat", "std"]} -tokio = "0.2.10" +tokio = "1.10.1" libc = "0.2.65" rand = "0.8" chrono = { version = "0.4.6", features = ["serde"]} -thiserror = "1.0.20" +thiserror = "1.0.26" log = "0.4.6" log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} @@ -41,4 +41,3 @@ env_logger = "0.7.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} -tokio = { version="0.2.10" } diff --git a/base_layer/wallet_ffi/src/callback_handler.rs b/base_layer/wallet_ffi/src/callback_handler.rs index 6d00799f23..bb649bdfac 100644 --- a/base_layer/wallet_ffi/src/callback_handler.rs +++ b/base_layer/wallet_ffi/src/callback_handler.rs @@ -48,7 +48,6 @@ //! request_key is used to identify which request this callback references and a result of true means it was successful //! and false that the process timed out and new one will be started -use futures::{stream::Fuse, StreamExt}; use log::*; use tari_comms::types::CommsPublicKey; use tari_comms_dht::event::{DhtEvent, DhtEventReceiver}; @@ -96,9 +95,9 @@ where TBackend: TransactionBackend + 'static callback_transaction_validation_complete: unsafe extern "C" fn(u64, u8), callback_saf_messages_received: unsafe extern "C" fn(), db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: Option, comms_public_key: CommsPublicKey, } @@ -109,9 +108,9 @@ where TBackend: TransactionBackend + 'static { pub fn new( db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: ShutdownSignal, comms_public_key: CommsPublicKey, callback_received_transaction: unsafe extern "C" fn(*mut InboundTransaction), @@ -219,8 +218,8 @@ where TBackend: TransactionBackend + 'static info!(target: LOG_TARGET, "Transaction Service Callback Handler starting"); loop { - futures::select! { - result = self.transaction_service_event_stream.select_next_some() => { + tokio::select! { + result = self.transaction_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Transaction Service Callback Handler event {:?}", msg); @@ -271,7 +270,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Transaction Service event broadcast channel"), } }, - result = self.output_manager_service_event_stream.select_next_some() => { + result = self.output_manager_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -295,7 +294,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), } }, - result = self.dht_event_stream.select_next_some() => { + result = self.dht_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "DHT Callback Handler event {:?}", msg); @@ -306,11 +305,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from DHT event broadcast channel"), } } - complete => { - info!(target: LOG_TARGET, "Callback Handler is exiting because all tasks have completed"); - break; - }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Transaction Callback Handler shutting down because the shutdown signal was received"); break; }, @@ -585,7 +580,6 @@ where TBackend: TransactionBackend + 'static mod test { use crate::callback_handler::CallbackHandler; use chrono::Utc; - use futures::StreamExt; use rand::rngs::OsRng; use std::{ sync::{Arc, Mutex}, @@ -774,7 +768,7 @@ mod test { #[test] fn test_callback_handler() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (_wallet_backend, backend, _oms_backend, _, _tempdir) = make_wallet_databases(None); let db = TransactionDatabase::new(backend); @@ -854,9 +848,9 @@ mod test { let shutdown_signal = Shutdown::new(); let callback_handler = CallbackHandler::new( db, - tx_receiver.fuse(), - oms_receiver.fuse(), - dht_receiver.fuse(), + tx_receiver, + oms_receiver, + dht_receiver, shutdown_signal.to_signal(), PublicKey::from_secret_key(&PrivateKey::random(&mut OsRng)), received_tx_callback, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 2aae750b3e..ba163970a5 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -120,7 +120,6 @@ use crate::{ }; use core::ptr; use error::LibWalletError; -use futures::StreamExt; use libc::{c_char, c_int, c_longlong, c_uchar, c_uint, c_ulonglong, c_ushort}; use log::{LevelFilter, *}; use log4rs::{ @@ -155,6 +154,7 @@ use tari_comms::{ }; use tari_comms_dht::{DbConnectionUrl, DhtConfig}; use tari_core::transactions::{ + emoji::{emoji_set, EmojiId, EmojiIdError}, tari_amount::MicroTari, transaction::OutputFeatures, types::{ComSignature, CryptoFactories, PublicKey}, @@ -195,7 +195,6 @@ use tari_wallet::{ }, }, types::ValidationRetryStrategy, - util::emoji::{emoji_set, EmojiId, EmojiIdError}, utxo_scanner_service::utxo_scanning::{UtxoScannerService, RECOVERY_KEY}, Wallet, WalletConfig, @@ -2866,7 +2865,7 @@ pub unsafe extern "C" fn wallet_create( } }; - let mut runtime = match Runtime::new() { + let runtime = match Runtime::new() { Ok(r) => r, Err(e) => { error = LibWalletError::from(InterfaceError::TokioError(e.to_string())).code; @@ -2953,9 +2952,9 @@ pub unsafe extern "C" fn wallet_create( // Start Callback Handler let callback_handler = CallbackHandler::new( TransactionDatabase::new(transaction_backend), - w.transaction_service.get_event_stream_fused(), - w.output_manager_service.get_event_stream_fused(), - w.dht_service.subscribe_dht_events().fuse(), + w.transaction_service.get_event_stream(), + w.output_manager_service.get_event_stream(), + w.dht_service.subscribe_dht_events(), w.comms.shutdown_signal(), w.comms.node_identity().public_key().clone(), callback_received_transaction, @@ -5154,7 +5153,7 @@ pub unsafe extern "C" fn file_partial_backup( let runtime = Runtime::new(); match runtime { - Ok(mut runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { + Ok(runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { Ok(_) => (), Err(e) => { error = LibWalletError::from(WalletError::WalletStorageError(e)).code; @@ -5281,10 +5280,8 @@ pub unsafe extern "C" fn emoji_set_destroy(emoji_set: *mut EmojiSet) { pub unsafe extern "C" fn wallet_destroy(wallet: *mut TariWallet) { if !wallet.is_null() { let mut w = Box::from_raw(wallet); - match w.shutdown.trigger() { - Err(_) => error!(target: LOG_TARGET, "No listeners for the shutdown signal!"), - Ok(()) => w.runtime.block_on(w.wallet.wait_until_shutdown()), - } + w.shutdown.trigger(); + w.runtime.block_on(w.wallet.wait_until_shutdown()); } } @@ -5314,11 +5311,11 @@ mod test { str::{from_utf8, FromStr}, sync::Mutex, }; + use tari_core::transactions::emoji; use tari_test_utils::random; use tari_wallet::{ storage::sqlite_utilities::run_migration_and_create_sqlite_connection, transaction_service::storage::models::TransactionStatus, - util::emoji, }; use tempfile::tempdir; @@ -5781,7 +5778,7 @@ mod test { error_ptr, ); - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let connection = run_migration_and_create_sqlite_connection(&sql_database_path).expect("Could not open Sqlite db"); diff --git a/base_layer/wallet_ffi/src/tasks.rs b/base_layer/wallet_ffi/src/tasks.rs index 9c44c94106..9e67eaa091 100644 --- a/base_layer/wallet_ffi/src/tasks.rs +++ b/base_layer/wallet_ffi/src/tasks.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::StreamExt; use log::*; use tari_crypto::tari_utilities::hex::Hex; use tari_wallet::{error::WalletError, utxo_scanner_service::handle::UtxoScannerEvent}; @@ -44,8 +43,8 @@ pub async fn recovery_event_monitoring( recovery_join_handle: JoinHandle>, recovery_progress_callback: unsafe extern "C" fn(u8, u64, u64), ) { - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { unsafe { (recovery_progress_callback)(RecoveryEvent::ConnectingToBaseNode as u8, 0u64, 0u64); @@ -139,6 +138,9 @@ pub async fn recovery_event_monitoring( } warn!(target: LOG_TARGET, "UTXO Scanner failed and exited",); }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Err(e) => { // Event lagging warn!(target: LOG_TARGET, "{}", e); diff --git a/common/Cargo.toml b/common/Cargo.toml index aa9d646654..5998d0bdfd 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,7 +22,7 @@ dirs-next = "1.0.2" get_if_addrs = "0.5.3" log = "0.4.8" log4rs = { version = "1.0.0", default_features= false, features = ["config_parsing", "threshold_filter"]} -multiaddr={package="parity-multiaddr", version = "0.11.0"} +multiaddr={version = "0.13.0"} sha2 = "0.9.5" path-clean = "0.1.0" tari_storage = { version = "^0.9", path = "../infrastructure/storage"} @@ -36,7 +36,7 @@ opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} anyhow = { version = "1.0", optional = true } git2 = { version = "0.8", optional = true } -prost-build = { version = "0.6.1", optional = true } +prost-build = { version = "0.8.0", optional = true } toml = { version = "0.5", optional = true } [dev-dependencies] diff --git a/common/src/configuration/global.rs b/common/src/configuration/global.rs index 880e111cd1..941a7ba7fa 100644 --- a/common/src/configuration/global.rs +++ b/common/src/configuration/global.rs @@ -71,7 +71,6 @@ pub struct GlobalConfig { pub pruning_horizon: u64, pub pruned_mode_cleanup_interval: u64, pub core_threads: Option, - pub max_threads: Option, pub base_node_identity_file: PathBuf, pub public_address: Multiaddr, pub grpc_enabled: bool, @@ -270,10 +269,6 @@ fn convert_node_config( let core_threads = optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - let key = config_string("base_node", &net_str, "max_threads"); - let max_threads = - optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - // Max RandomX VMs let key = config_string("base_node", &net_str, "max_randomx_vms"); let max_randomx_vms = optional(cfg.get_int(&key).map(|n| n as usize)) @@ -712,7 +707,6 @@ fn convert_node_config( pruning_horizon, pruned_mode_cleanup_interval, core_threads, - max_threads, base_node_identity_file, public_address, grpc_enabled, diff --git a/common/src/dns/tests.rs b/common/src/dns/tests.rs index 955f22cf97..b7dc087517 100644 --- a/common/src/dns/tests.rs +++ b/common/src/dns/tests.rs @@ -48,7 +48,7 @@ use trust_dns_client::rr::{rdata, RData, Record, RecordType}; // Ignore as this test requires network IO #[ignore] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { let mut resolver = PeerSeedResolver::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(); let seeds = resolver.resolve("tari.com").await.unwrap(); @@ -64,7 +64,7 @@ fn create_txt_record(contents: Vec) -> Record { } #[allow(clippy::vec_init_then_push)] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_peer_seeds() { let mut records = Vec::new(); // Multiple addresses(works) diff --git a/common/src/lib.rs b/common/src/lib.rs index 6f5c98a2e4..cb2c99d0c1 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -72,7 +72,7 @@ //! let config = args.load_configuration().unwrap(); //! let global = GlobalConfig::convert_from(ApplicationType::BaseNode, config).unwrap(); //! assert_eq!(global.network, Network::Weatherwax); -//! assert!(global.max_threads.is_none()); +//! assert!(global.core_threads.is_none()); //! # std::fs::remove_dir_all(temp_dir).unwrap(); //! ``` diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 71b8eec12f..b781fa12bf 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -12,34 +12,36 @@ edition = "2018" [dependencies] tari_crypto = "0.11.1" tari_storage = { version = "^0.9", path = "../infrastructure/storage" } -tari_shutdown = { version="^0.9", path = "../infrastructure/shutdown" } +tari_shutdown = { version = "^0.9", path = "../infrastructure/shutdown" } +anyhow = "1.0.32" async-trait = "0.1.36" bitflags = "1.0.4" blake2 = "0.9.0" -bytes = { version = "0.5.x", features=["serde"] } +bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4.6", features = ["serde"] } cidr = "0.1.0" clear_on_drop = "=0.2.4" data-encoding = "2.2.0" digest = "0.9.0" -futures = { version = "^0.3", features = ["async-await"]} +futures = { version = "^0.3", features = ["async-await"] } lazy_static = "1.3.0" lmdb-zero = "0.4.4" log = { version = "0.4.0", features = ["std"] } -multiaddr = {version = "=0.11.0", package = "parity-multiaddr"} -nom = {version = "5.1.0", features=["std"], default-features=false} +multiaddr = { version = "0.13.0" } +nom = { version = "5.1.0", features = ["std"], default-features = false } openssl = { version = "0.10", features = ["vendored"] } -pin-project = "0.4.17" -prost = "=0.6.1" +pin-project = "1.0.8" +prost = "=0.8.0" rand = "0.8" serde = "1.0.119" serde_derive = "1.0.119" -snow = {version="=0.8.0", features=["default-resolver"]} -thiserror = "1.0.20" -tokio = {version="~0.2.19", features=["blocking", "time", "tcp", "dns", "sync", "stream", "signal"]} -tokio-util = {version="0.3.1", features=["codec"]} -tower= "0.3.1" +snow = { version = "=0.8.0", features = ["default-resolver"] } +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } +tokio-stream = { version = "0.1.7", features = ["sync"] } +tokio-util = { version = "0.6.7", features = ["codec", "compat"] } +tower = "0.3.1" tracing = "0.1.26" tracing-futures = "0.2.5" yamux = "=0.9.0" @@ -49,20 +51,19 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} # RPC dependencies -tower-make = {version="0.3.0", optional=true} -anyhow = "1.0.32" +tower-make = { version = "0.3.0", optional = true } [dev-dependencies] -tari_test_utils = {version="^0.9", path="../infrastructure/test_utils"} -tari_comms_rpc_macros = {version="*", path="./rpc_macros"} +tari_test_utils = { version = "^0.9", path = "../infrastructure/test_utils" } +tari_comms_rpc_macros = { version = "*", path = "./rpc_macros" } env_logger = "0.7.0" serde_json = "1.0.39" -tokio-macros = "0.2.3" +#tokio = {version="1.8", features=["macros"]} tempfile = "3.1.0" [build-dependencies] -tari_common = { version = "^0.9", path="../common", features = ["build"]} +tari_common = { version = "^0.9", path = "../common", features = ["build"] } [features] avx2 = ["tari_crypto/avx2"] diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index c75e423543..77cf0d9e9d 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -10,58 +10,57 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} -tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros"} +tari_comms = { version = "^0.9", path = "../", features = ["rpc"] } +tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros" } tari_crypto = "0.11.1" -tari_utilities = { version = "^0.3" } -tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown"} -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_utilities = { version = "^0.3" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } anyhow = "1.0.32" bitflags = "1.2.0" -bytes = "0.4.12" +bytes = "0.5" chacha20 = "0.7.1" chrono = "0.4.9" -diesel = {version="1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } digest = "0.9.0" -futures= {version= "^0.3.1"} +futures = { version = "^0.3.1" } log = "0.4.8" -prost = "=0.6.1" -prost-types = "=0.6.1" +prost = "=0.8.0" +prost-types = "=0.8.0" rand = "0.8" serde = "1.0.90" serde_derive = "1.0.90" serde_repr = "0.1.5" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} -tower= "0.3.1" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt", "macros"] } +tower = "0.3.1" ttl_cache = "0.5.1" # tower-filter dependencies pin-project = "0.4" [dev-dependencies] -tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } env_logger = "0.7.0" -futures-test = { version = "0.3.0-alpha.19", package = "futures-test-preview" } +futures-test = { version = "0.3.5" } lmdb-zero = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.3" +tokio-stream = { version = "0.1.7", features = ["sync"] } petgraph = "0.5.1" clap = "2.33.0" # tower-filter dependencies tower-test = { version = "^0.3" } -tokio-test = "^0.2" -tokio = "^0.2" +tokio-test = "^0.4.2" futures-util = "^0.3.1" lazy_static = "1.4.0" [build-dependencies] -tari_common = { version = "^0.9", path="../../common"} +tari_common = { version = "^0.9", path = "../../common" } [features] test-mocks = [] diff --git a/comms/dht/examples/graphing_utilities/utilities.rs b/comms/dht/examples/graphing_utilities/utilities.rs index bfb081eb21..3a4fea7ae9 100644 --- a/comms/dht/examples/graphing_utilities/utilities.rs +++ b/comms/dht/examples/graphing_utilities/utilities.rs @@ -32,6 +32,7 @@ use petgraph::{ }; use std::{collections::HashMap, convert::TryFrom, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; use tari_comms::{connectivity::ConnectivitySelection, peer_manager::NodeId}; +use tari_test_utils::streams::convert_unbounded_mpsc_to_stream; const TEMP_GRAPH_OUTPUT_DIR: &str = "/tmp/memorynet_temp"; @@ -277,7 +278,9 @@ pub enum PythonRenderType { /// This function will drain the message event queue and then build a message propagation tree assuming the first sender /// is the starting node pub async fn track_join_message_drain_messaging_events(messaging_rx: &mut NodeEventRx) -> StableGraph { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); let messages = drain_fut.await; let num_messages = messages.len(); diff --git a/comms/dht/examples/memory_net/drain_burst.rs b/comms/dht/examples/memory_net/drain_burst.rs index d2f5bce2be..59ed8723f4 100644 --- a/comms/dht/examples/memory_net/drain_burst.rs +++ b/comms/dht/examples/memory_net/drain_burst.rs @@ -42,7 +42,7 @@ where St: ?Sized + Stream + Unpin let (lower_bound, upper_bound) = stream.size_hint(); Self { inner: stream, - collection: Vec::with_capacity(upper_bound.or(Some(lower_bound)).unwrap()), + collection: Vec::with_capacity(upper_bound.unwrap_or(lower_bound)), } } } @@ -70,15 +70,16 @@ where St: ?Sized + Stream + Unpin mod test { use super::*; use futures::stream; + use tari_comms::runtime; - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_terminating_stream() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; assert_eq!(burst, (1..10u8).into_iter().collect::>()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_stream_with_pending() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bf1bf03ae4..0d4a675e8e 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -22,7 +22,7 @@ #![allow(clippy::mutex_atomic)] use crate::memory_net::DrainBurst; -use futures::{channel::mpsc, future, StreamExt}; +use futures::future; use lazy_static::lazy_static; use rand::{rngs::OsRng, Rng}; use std::{ @@ -62,8 +62,13 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tari_test_utils::{paths::create_temporary_data_path, random}; -use tokio::{runtime, sync::broadcast, task, time}; +use tari_test_utils::{paths::create_temporary_data_path, random, streams::convert_unbounded_mpsc_to_stream}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, + task, + time, +}; use tower::ServiceBuilder; pub type NodeEventRx = mpsc::UnboundedReceiver<(NodeId, NodeId)>; @@ -154,7 +159,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent start.elapsed() ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, Err(err) => { @@ -166,7 +171,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent err ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, } @@ -298,7 +303,7 @@ pub async fn do_network_wide_propagation(nodes: &mut [TestNode], origin_node_ind let node_name = node.name.clone(); task::spawn(async move { - let result = time::timeout(Duration::from_secs(30), ims_rx.next()).await; + let result = time::timeout(Duration::from_secs(30), ims_rx.recv()).await; let mut is_success = false; match result { Ok(Some(msg)) => { @@ -450,21 +455,23 @@ pub async fn do_store_and_forward_message_propagation( for (idx, mut s) in neighbour_subs.into_iter().enumerate() { let neighbour = neighbours[idx].name.clone(); task::spawn(async move { - let msg = time::timeout(Duration::from_secs(2), s.next()).await; + let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { - Ok(Some(Ok(evt))) => { + Ok(Ok(evt)) => { if let MessagingEvent::MessageReceived(_, tag) = &*evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, - Ok(_) => {}, + Ok(Err(err)) => { + println!("{}", err); + }, Err(_) => println!("{} did not receive the SAF message", neighbour), } }); } banner!("⏰ Waiting a few seconds for messages to propagate around the network..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; let mut total_messages = drain_messaging_events(messaging_rx, false).await; @@ -515,7 +522,7 @@ pub async fn do_store_and_forward_message_propagation( let mut num_msgs = 0; let mut succeeded = 0; loop { - let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().next()).await; + let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().recv()).await; num_msgs += 1; match result { Ok(msg) => { @@ -554,7 +561,9 @@ pub async fn do_store_and_forward_message_propagation( } pub async fn drain_messaging_events(messaging_rx: &mut NodeEventRx, show_logs: bool) -> usize { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); if show_logs { let messages = drain_fut.await; let num_messages = messages.len(); @@ -694,42 +703,46 @@ impl TestNode { fn spawn_event_monitor( comms: &CommsNode, - messaging_events: MessagingEventReceiver, + mut messaging_events: MessagingEventReceiver, events_tx: mpsc::Sender>, messaging_events_tx: NodeEventTx, quiet_mode: bool, ) { - let conn_man_event_sub = comms.subscribe_connection_manager_events(); + let mut conn_man_event_sub = comms.subscribe_connection_manager_events(); let executor = runtime::Handle::current(); - executor.spawn( - conn_man_event_sub - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .map(connection_manager_logger( - comms.node_identity().node_id().clone(), - quiet_mode, - )) - .map(Ok) - .forward(events_tx), - ); - let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + let mut logger = connection_manager_logger(node_id, quiet_mode); + loop { + match conn_man_event_sub.recv().await { + Ok(event) => { + events_tx.send(logger(event)).await.unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => break, + Err(err) => log::error!("{}", err), + } + } + }); - executor.spawn( - messaging_events - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .filter_map(move |event| { - use MessagingEvent::*; - future::ready(match &*event { - MessageReceived(peer_node_id, _) => Some((Clone::clone(&*peer_node_id), node_id.clone())), - _ => None, - }) - }) - .map(Ok) - .forward(messaging_events_tx), - ); + let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + loop { + let event = messaging_events.recv().await; + use MessagingEvent::*; + match event.as_deref() { + Ok(MessageReceived(peer_node_id, _)) => { + messaging_events_tx + .send((Clone::clone(&*peer_node_id), node_id.clone())) + .unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, + _ => {}, + } + } + }); } #[inline] @@ -749,7 +762,7 @@ impl TestNode { } use ConnectionManagerEvent::*; loop { - let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.next()) + let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.recv()) .await .ok()??; @@ -763,7 +776,7 @@ impl TestNode { } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -946,5 +959,5 @@ async fn setup_comms_dht( pub async fn take_a_break(num_nodes: usize) { banner!("Taking a break for a few seconds to let things settle..."); - time::delay_for(Duration::from_millis(num_nodes as u64 * 100)).await; + time::sleep(Duration::from_millis(num_nodes as u64 * 100)).await; } diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 9cc28551bc..3ed35c05b7 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -49,15 +49,16 @@ use crate::memory_net::utilities::{ shutdown_all, take_a_break, }; -use futures::{channel::mpsc, future}; +use futures::future; use rand::{rngs::OsRng, Rng}; use std::{iter::repeat_with, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; // Size of network -const NUM_NODES: usize = 6; +const NUM_NODES: usize = 40; // Must be at least 2 -const NUM_WALLETS: usize = 50; +const NUM_WALLETS: usize = 5; const QUIET_MODE: bool = true; /// Number of neighbouring nodes each node should include in the connection pool const NUM_NEIGHBOURING_NODES: usize = 8; @@ -66,7 +67,7 @@ const NUM_RANDOM_NODES: usize = 4; /// The number of messages that should be propagated out const PROPAGATION_FACTOR: usize = 4; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -77,7 +78,7 @@ async fn main() { NUM_WALLETS ); - let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let seed_node = vec![ make_node( diff --git a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs index a0a1bddc1e..5d2759fb54 100644 --- a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs +++ b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs @@ -65,11 +65,11 @@ use crate::{ }, }; use clap::{App, Arg}; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -96,7 +96,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_join.rs b/comms/dht/examples/memorynet_graph_network_track_join.rs index 259358e1cb..0a0324b683 100644 --- a/comms/dht/examples/memorynet_graph_network_track_join.rs +++ b/comms/dht/examples/memorynet_graph_network_track_join.rs @@ -73,11 +73,11 @@ use crate::{ }; use clap::{App, Arg}; use env_logger::Env; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -106,7 +106,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_propagation.rs b/comms/dht/examples/memorynet_graph_network_track_propagation.rs index fcc5debff3..d560b9f537 100644 --- a/comms/dht/examples/memorynet_graph_network_track_propagation.rs +++ b/comms/dht/examples/memorynet_graph_network_track_propagation.rs @@ -73,10 +73,10 @@ use crate::{ }, }; use env_logger::Env; -use futures::channel::mpsc; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -105,7 +105,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 326d465458..2e453291ac 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -37,13 +37,7 @@ use crate::{ DhtConfig, }; use chrono::{DateTime, Utc}; -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot}, - future::BoxFuture, - stream::{Fuse, FuturesUnordered}, - SinkExt, - StreamExt, -}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use log::*; use std::{cmp, fmt, fmt::Display, sync::Arc}; use tari_comms::{ @@ -55,7 +49,11 @@ use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::ShutdownSignal; use tari_utilities::message_format::{MessageFormat, MessageFormatError}; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::actor"; @@ -63,8 +61,6 @@ const LOG_TARGET: &str = "comms::dht::actor"; pub enum DhtActorError { #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("Reply sender canceled the request")] ReplyCanceled, #[error("PeerManagerError: {0}")] @@ -85,15 +81,9 @@ pub enum DhtActorError { ConnectivityEventStreamClosed, } -impl From for DhtActorError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtActorError::ChannelDisconnected - } else if err.is_full() { - DhtActorError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtActorError { + fn from(_: mpsc::error::SendError) -> Self { + DhtActorError::ChannelDisconnected } } @@ -215,8 +205,8 @@ pub struct DhtActor { outbound_requester: OutboundMessageRequester, connectivity: ConnectivityRequester, config: DhtConfig, - shutdown_signal: Option, - request_rx: Fuse>, + shutdown_signal: ShutdownSignal, + request_rx: mpsc::Receiver, msg_hash_dedup_cache: DedupCacheDatabase, } @@ -246,8 +236,8 @@ impl DhtActor { peer_manager, connectivity, node_identity, - shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + shutdown_signal, + request_rx, } } @@ -276,33 +266,28 @@ impl DhtActor { let mut pending_jobs = FuturesUnordered::new(); - let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval).fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtActor initialized without shutdown_signal"); + let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { trace!(target: LOG_TARGET, "DhtActor received request: {}", request); pending_jobs.push(self.request_handler(request)); }, - result = pending_jobs.select_next_some() => { + Some(result) = pending_jobs.next() => { if let Err(err) = result { debug!(target: LOG_TARGET, "Error when handling DHT request message. {}", err); } }, - _ = dedup_cache_trim_ticker.select_next_some() => { + _ = dedup_cache_trim_ticker.tick() => { if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); self.mark_shutdown_time().await; break Ok(()); @@ -731,7 +716,10 @@ mod test { test_utils::{build_peer_manager, make_client_identity, make_node_identity}, }; use chrono::{DateTime, Utc}; - use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; + use tari_comms::{ + runtime, + test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}, + }; use tari_shutdown::Shutdown; use tari_test_utils::random; @@ -741,7 +729,7 @@ mod test { conn } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_join_request() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -766,11 +754,11 @@ mod test { actor.spawn(); requester.send_join().await.unwrap(); - let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap()); + let (params, _) = unwrap_oms_send_msg!(out_rx.recv().await.unwrap()); assert_eq!(params.dht_message_type, DhtMessageType::Join); } - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_message_signature() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -812,7 +800,7 @@ mod test { assert_eq!(num_hits, 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn dedup_cache_cleanup() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -897,7 +885,7 @@ mod test { } } - #[tokio_macros::test_basic] + #[runtime::test] async fn select_peers() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -1008,7 +996,7 @@ mod test { assert_eq!(peers.len(), 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn get_and_set_metadata() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -1064,6 +1052,6 @@ mod test { .unwrap(); assert_eq!(got_ts, ts); - shutdown.trigger().unwrap(); + shutdown.trigger(); } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index b2fe763098..bd7fb2521c 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; -use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeIdentity, PeerManager}, }; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; pub struct DhtBuilder { node_identity: Arc, diff --git a/comms/dht/src/connectivity/metrics.rs b/comms/dht/src/connectivity/metrics.rs index b7ee546d4e..ba456d9aa9 100644 --- a/comms/dht/src/connectivity/metrics.rs +++ b/comms/dht/src/connectivity/metrics.rs @@ -20,20 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot, oneshot::Canceled}, - future, - SinkExt, - StreamExt, -}; use log::*; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, - future::Future, time::{Duration, Instant}, }; use tari_comms::peer_manager::NodeId; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot}, + task, +}; const LOG_TARGET: &str = "comms::dht::metrics"; @@ -124,7 +120,7 @@ impl MetricsState { } pub struct MetricsCollector { - stream: Option>, + stream: mpsc::Receiver, state: MetricsState, } @@ -133,18 +129,17 @@ impl MetricsCollector { let (metrics_tx, metrics_rx) = mpsc::channel(500); let metrics_collector = MetricsCollectorHandle::new(metrics_tx); let collector = Self { - stream: Some(metrics_rx), + stream: metrics_rx, state: Default::default(), }; task::spawn(collector.run()); metrics_collector } - fn run(mut self) -> impl Future { - self.stream.take().unwrap().for_each(move |op| { + async fn run(mut self) { + while let Some(op) = self.stream.recv().await { self.handle(op); - future::ready(()) - }) + } } fn handle(&mut self, op: MetricOp) { @@ -286,7 +281,7 @@ impl MetricsCollectorHandle { match self.inner.try_send(MetricOp::Write(write)) { Ok(_) => true, Err(err) => { - warn!(target: LOG_TARGET, "Failed to write metric: {}", err.into_send_error()); + warn!(target: LOG_TARGET, "Failed to write metric: {:?}", err); false }, } @@ -338,14 +333,14 @@ pub enum MetricsError { ReplyCancelled, } -impl From for MetricsError { - fn from(_: SendError) -> Self { +impl From> for MetricsError { + fn from(_: mpsc::error::SendError) -> Self { MetricsError::ChannelClosedUnexpectedly } } -impl From for MetricsError { - fn from(_: Canceled) -> Self { +impl From for MetricsError { + fn from(_: oneshot::error::RecvError) -> Self { MetricsError::ReplyCancelled } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index c9c855b6c1..a3c4471451 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -27,7 +27,6 @@ mod metrics; pub use metrics::{MetricsCollector, MetricsCollectorHandle}; use crate::{connectivity::metrics::MetricsError, event::DhtEvent, DhtActorError, DhtConfig, DhtRequester}; -use futures::{stream::Fuse, StreamExt}; use log::*; use std::{sync::Arc, time::Instant}; use tari_comms::{ @@ -78,11 +77,11 @@ pub struct DhtConnectivity { /// Used to track when the random peer pool was last refreshed random_pool_last_refresh: Option, stats: Stats, - dht_events: Fuse>>, + dht_events: broadcast::Receiver>, metrics_collector: MetricsCollectorHandle, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl DhtConnectivity { @@ -108,8 +107,8 @@ impl DhtConnectivity { metrics_collector, random_pool_last_refresh: None, stats: Stats::new(), - dht_events: dht_events.fuse(), - shutdown_signal: Some(shutdown_signal), + dht_events, + shutdown_signal, } } @@ -131,21 +130,15 @@ impl DhtConnectivity { }) } - pub async fn run(mut self, connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { - let mut connectivity_events = connectivity_events.fuse(); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtConnectivity initialized without a shutdown_signal"); - + pub async fn run(mut self, mut connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { debug!(target: LOG_TARGET, "DHT connectivity starting"); self.refresh_neighbour_pool().await?; - let mut ticker = time::interval(self.config.connectivity_update_interval).fuse(); + let mut ticker = time::interval(self.config.connectivity_update_interval); loop { - futures::select! { - event = connectivity_events.select_next_some() => { + tokio::select! { + event = connectivity_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { debug!(target: LOG_TARGET, "Error handling connectivity event: {:?}", err); @@ -153,15 +146,13 @@ impl DhtConnectivity { } }, - event = self.dht_events.select_next_some() => { - if let Ok(event) = event { - if let Err(err) = self.handle_dht_event(&event).await { - debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); - } - } + Ok(event) = self.dht_events.recv() => { + if let Err(err) = self.handle_dht_event(&event).await { + debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); + } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_random_pool_if_required().await { debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); } @@ -170,7 +161,7 @@ impl DhtConnectivity { } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtConnectivity shutting down because the shutdown signal was received"); break; } diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 65dba8c235..d0e83c0aa5 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -30,6 +30,7 @@ use std::{iter::repeat_with, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{Peer, PeerFeatures}, + runtime, test_utils::{ count_string_occurrences, mocks::{create_connectivity_mock, create_dummy_peer_connection, ConnectivityManagerMockState}, @@ -89,7 +90,7 @@ async fn setup( ) } -#[tokio_macros::test_basic] +#[runtime::test] async fn initialize() { let config = DhtConfig { num_neighbouring_nodes: 4, @@ -127,7 +128,7 @@ async fn initialize() { assert!(managed.iter().all(|n| !neighbours.contains(n))); } -#[tokio_macros::test_basic] +#[runtime::test] async fn added_neighbours() { let node_identity = make_node_identity(); let mut node_identities = @@ -173,7 +174,7 @@ async fn added_neighbours() { assert!(managed.contains(closer_peer.node_id())); } -#[tokio_macros::test_basic] +#[runtime::test] #[allow(clippy::redundant_closure)] async fn reinitialize_pools_when_offline() { let node_identity = make_node_identity(); @@ -215,7 +216,7 @@ async fn reinitialize_pools_when_offline() { assert_eq!(managed.len(), 5); } -#[tokio_macros::test_basic] +#[runtime::test] async fn insert_neighbour() { let node_identity = make_node_identity(); let node_identities = @@ -254,11 +255,13 @@ async fn insert_neighbour() { } mod metrics { + use super::*; mod collector { + use super::*; use crate::connectivity::MetricsCollector; use tari_comms::peer_manager::NodeId; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_adds_message_received() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); @@ -273,7 +276,7 @@ mod metrics { assert_eq!(ts.count(), 100); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_clears_the_metrics() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 7968a3e286..8bea19f39b 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -149,7 +149,7 @@ mod test { #[test] fn process_message() { - let mut rt = Runtime::new().unwrap(); + let rt = Runtime::new().unwrap(); let spy = service_spy(); let (dht_requester, mock) = create_dht_actor_mock(1); diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 2ae0ef561c..9d29a70d79 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -42,7 +42,7 @@ use crate::{ DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, Future}; +use futures::Future; use log::*; use std::sync::Arc; use tari_comms::{ @@ -53,7 +53,7 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::{layer::Layer, Service, ServiceBuilder}; const LOG_TARGET: &str = "comms::dht"; @@ -432,22 +432,23 @@ mod test { make_comms_inbound_message, make_dht_envelope, make_node_identity, + service_spy, }, DhtBuilder, }; - use futures::{channel::mpsc, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_comms::{ message::{MessageExt, MessageTag}, pipeline::SinkService, + runtime, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body, }; use tari_shutdown::Shutdown; - use tokio::{task, time}; + use tokio::{sync::mpsc, task, time}; use tower::{layer::Layer, Service}; - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_unencrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -487,7 +488,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -497,7 +498,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_encrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -537,7 +538,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -547,7 +548,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_forward() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -556,7 +557,6 @@ mod test { peer_manager.add_peer(node_identity.to_peer()).await.unwrap(); let (connectivity, _) = create_connectivity_mock(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); let (oms_requester, oms_mock) = create_outbound_service_mock(1); // Send all outbound requests to the mock @@ -573,7 +573,8 @@ mod test { let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); @@ -602,10 +603,10 @@ mod test { assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called - assert!(next_service_rx.try_next().is_err()); + assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_filter_saf_message() { let node_identity = make_client_identity(); let peer_manager = build_peer_manager(); @@ -628,9 +629,8 @@ mod test { .await .unwrap(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); - - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"secret".to_vec()); let mut dht_envelope = make_dht_envelope( @@ -647,10 +647,6 @@ mod test { let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); - // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely - assert_eq!( - format!("{}", next_service_rx.try_next().unwrap_err()), - "receiver channel is empty" - ); + assert_eq!(spy.call_count(), 0); } } diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index cf98e42d9d..c2a77f0c9a 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::outbound::{message::SendFailure, DhtOutboundError}; -use futures::channel::mpsc::SendError; use tari_comms::peer_manager::PeerManagerError; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtDiscoveryError { @@ -37,8 +37,6 @@ pub enum DhtDiscoveryError { InvalidNodeId, #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("The discovery request timed out")] DiscoveryTimeout, #[error("Failed to send discovery message: {0}")] @@ -56,14 +54,8 @@ impl DhtDiscoveryError { } } -impl From for DhtDiscoveryError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtDiscoveryError::ChannelDisconnected - } else if err.is_full() { - DhtDiscoveryError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtDiscoveryError { + fn from(_: SendError) -> Self { + DhtDiscoveryError::ChannelDisconnected } } diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index a286bcc6a5..a7317f79a2 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -21,16 +21,15 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; #[derive(Debug)] pub enum DhtDiscoveryRequest { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 258eb67cea..2cebf571b4 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -27,11 +27,6 @@ use crate::{ proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, }; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -47,7 +42,11 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use tari_utilities::{hex::Hex, ByteArray}; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -72,8 +71,8 @@ pub struct DhtDiscoveryService { node_identity: Arc, outbound_requester: OutboundMessageRequester, peer_manager: Arc, - request_rx: Option>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, inflight_discoveries: HashMap, } @@ -91,8 +90,8 @@ impl DhtDiscoveryService { outbound_requester, node_identity, peer_manager, - shutdown_signal: Some(shutdown_signal), - request_rx: Some(request_rx), + shutdown_signal, + request_rx, inflight_discoveries: HashMap::new(), } } @@ -106,29 +105,19 @@ impl DhtDiscoveryService { pub async fn run(mut self) { info!(target: LOG_TARGET, "Dht discovery service started"); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DiscoveryService initialized without shutdown_signal") - .fuse(); - - let mut request_rx = self - .request_rx - .take() - .expect("DiscoveryService initialized without request_rx") - .fuse(); - loop { - futures::select! { - request = request_rx.select_next_some() => { - trace!(target: LOG_TARGET, "Received request '{}'", request); - self.handle_request(request).await; - }, + tokio::select! { + biased; - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Discovery service is shutting down because the shutdown signal was received"); break; } + + Some(request) = self.request_rx.recv() => { + trace!(target: LOG_TARGET, "Received request '{}'", request); + self.handle_request(request).await; + }, } } } @@ -153,7 +142,7 @@ impl DhtDiscoveryService { let mut remaining_requests = HashMap::new(); for (nonce, request) in self.inflight_discoveries.drain() { // Exclude canceled requests - if request.reply_tx.is_canceled() { + if request.reply_tx.is_closed() { continue; } @@ -199,7 +188,7 @@ impl DhtDiscoveryService { ); for request in self.collect_all_discovery_requests(&public_key) { - if !reply_tx.is_canceled() { + if !reply_tx.is_closed() { let _ = request.reply_tx.send(Ok(peer.clone())); } } @@ -299,7 +288,7 @@ impl DhtDiscoveryService { self.inflight_discoveries = self .inflight_discoveries .drain() - .filter(|(_, state)| !state.reply_tx.is_canceled()) + .filter(|(_, state)| !state.reply_tx.is_closed()) .collect(); trace!( @@ -393,9 +382,10 @@ mod test { test_utils::{build_peer_manager, make_node_identity}, }; use std::time::Duration; + use tari_comms::runtime; use tari_shutdown::Shutdown; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_discovery() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index fb7e02050f..c8b18d9e9b 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -22,6 +22,8 @@ use bitflags::bitflags; use bytes::Bytes; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost_types::Timestamp; use serde::{Deserialize, Serialize}; use std::{ cmp, @@ -30,14 +32,11 @@ use std::{ fmt::Display, }; use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey, NodeIdentity}; -use tari_utilities::{ByteArray, ByteArrayError}; +use tari_utilities::{epoch_time::EpochTime, ByteArray, ByteArrayError}; use thiserror::Error; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType}; -use chrono::{DateTime, NaiveDateTime, Utc}; -use prost_types::Timestamp; -use tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 9e3a6bbd3e..46fc3daa47 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -397,7 +397,12 @@ mod test { }; use futures::{executor::block_on, future}; use std::sync::Mutex; - use tari_comms::{message::MessageExt, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body}; + use tari_comms::{ + message::MessageExt, + runtime, + test_utils::mocks::create_connectivity_mock, + wrap_in_envelope_body, + }; use tari_test_utils::{counter_context, unpack_enum}; use tower::service_fn; @@ -469,7 +474,7 @@ mod test { assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index b28a057cb5..a73c3b3cfb 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -137,9 +137,12 @@ mod test { service_spy, }, }; - use tari_comms::message::{MessageExt, MessageTag}; + use tari_comms::{ + message::{MessageExt, MessageTag}, + runtime, + }; - #[tokio_macros::test_basic] + #[runtime::test] async fn deserialize() { let spy = service_spy(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 7ef238d16e..710b354f7b 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -71,7 +71,7 @@ //! #use std::sync::Arc; //! #use tari_comms::CommsBuilder; //! #use tokio::runtime::Runtime; -//! #use futures::channel::mpsc; +//! #use tokio::sync::mpsc; //! //! let runtime = Runtime::new().unwrap(); //! // Channel from comms to inbound dht diff --git a/comms/dht/src/network_discovery/on_connect.rs b/comms/dht/src/network_discovery/on_connect.rs index b93657f061..cd162b3903 100644 --- a/comms/dht/src/network_discovery/on_connect.rs +++ b/comms/dht/src/network_discovery/on_connect.rs @@ -33,7 +33,7 @@ use crate::{ }; use futures::StreamExt; use log::*; -use std::{convert::TryInto, ops::Deref}; +use std::convert::TryInto; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{NodeId, Peer}, @@ -62,8 +62,9 @@ impl OnConnect { pub async fn next_event(&mut self) -> StateEvent { let mut connectivity_events = self.context.connectivity.get_event_subscription(); - while let Some(event) = connectivity_events.next().await { - match event.as_ref().map(|e| e.deref()) { + loop { + let event = connectivity_events.recv().await; + match event { Ok(ConnectivityEvent::PeerConnected(conn)) => { if conn.peer_features().is_client() { continue; @@ -96,10 +97,10 @@ impl OnConnect { self.prev_synced.push(conn.peer_node_id().clone()); }, Ok(_) => { /* Nothing to do */ }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagged behind on {} connectivity event(s)", n) }, - Err(broadcast::RecvError::Closed) => { + Err(broadcast::error::RecvError::Closed) => { break; }, } diff --git a/comms/dht/src/network_discovery/test.rs b/comms/dht/src/network_discovery/test.rs index 54f596ee26..2f854627f1 100644 --- a/comms/dht/src/network_discovery/test.rs +++ b/comms/dht/src/network_discovery/test.rs @@ -28,12 +28,12 @@ use crate::{ test_utils::{build_peer_manager, make_node_identity}, DhtConfig, }; -use futures::StreamExt; use std::{iter, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityStatus, peer_manager::{Peer, PeerFeatures}, protocol::rpc::{mock::MockRpcServer, NamedProtocolService}, + runtime, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -97,7 +97,7 @@ mod state_machine { ) } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn it_fetches_peers() { const NUM_PEERS: usize = 3; @@ -139,7 +139,7 @@ mod state_machine { mock.get_peers.set_response(Ok(peers)).await; discovery_actor.spawn(); - let event = event_rx.next().await.unwrap().unwrap(); + let event = event_rx.recv().await.unwrap(); unpack_enum!(DhtEvent::NetworkDiscoveryPeersAdded(info) = &*event); assert!(info.has_new_neighbours()); assert_eq!(info.num_new_neighbours, NUM_PEERS); @@ -149,11 +149,11 @@ mod state_machine { assert_eq!(info.sync_peers, vec![peer_node_identity.node_id().clone()]); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_shuts_down() { let (discovery, _, _, _, _, mut shutdown) = setup(Default::default(), make_node_identity(), vec![]).await; - shutdown.trigger().unwrap(); + shutdown.trigger(); tokio::time::timeout(Duration::from_secs(5), discovery.run()) .await .unwrap(); @@ -200,7 +200,7 @@ mod discovery_ready { (node_identity, peer_manager, connectivity_mock, ready, context) } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_begins_aggressive_discovery() { let (_, pm, _, mut ready, _) = setup(Default::default()); let peers = build_many_node_identities(1, PeerFeatures::COMMUNICATION_NODE); @@ -212,14 +212,14 @@ mod discovery_ready { assert!(params.num_peers_to_request.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_no_sync_peers() { let (_, _, _, mut ready, _) = setup(Default::default()); let state_event = ready.next_event().await; unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_num_rounds_reached() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, @@ -240,7 +240,7 @@ mod discovery_ready { unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_transitions_to_on_connect() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, diff --git a/comms/dht/src/network_discovery/waiting.rs b/comms/dht/src/network_discovery/waiting.rs index 73e8929ac5..f61dfc6b24 100644 --- a/comms/dht/src/network_discovery/waiting.rs +++ b/comms/dht/src/network_discovery/waiting.rs @@ -46,7 +46,7 @@ impl Waiting { target: LOG_TARGET, "Network discovery is IDLING for {:.0?}", self.duration ); - time::delay_for(self.duration).await; + time::sleep(self.duration).await; debug!(target: LOG_TARGET, "Network discovery resuming"); StateEvent::Ready } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 3107b41fc5..57c423df55 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -39,7 +39,6 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use digest::Digest; use futures::{ - channel::oneshot, future, future::BoxFuture, stream::{self, StreamExt}, @@ -60,6 +59,7 @@ use tari_crypto::{ tari_utilities::{message_format::MessageFormat, ByteArray}, }; use tari_utilities::hex::Hex; +use tokio::sync::oneshot; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::outbound::broadcast_middleware"; @@ -256,7 +256,7 @@ where S: Service match self.select_peers(broadcast_strategy.clone()).await { Ok(mut peers) => { - if reply_tx.is_canceled() { + if reply_tx.is_closed() { return Err(DhtOutboundError::ReplyChannelCanceled); } @@ -537,19 +537,19 @@ mod test { DhtDiscoveryMockState, }, }; - use futures::channel::oneshot; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, + runtime, types::CommsPublicKey, }; use tari_crypto::keys::PublicKey; use tari_test_utils::unpack_enum; - use tokio::task; + use tokio::{sync::oneshot, task}; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_flood() { let pk = CommsPublicKey::default(); let example_peer = Peer::new( @@ -613,7 +613,7 @@ mod test { assert!(requests.iter().any(|msg| msg.destination_node_id == other_peer.node_id)); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_not_found() { // Test for issue https://github.com/tari-project/tari/issues/959 @@ -657,7 +657,7 @@ mod test { assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_dht_discovery() { let node_identity = NodeIdentity::random( &mut OsRng, diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 6e1a8156c4..b919e45134 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -20,16 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::outbound::message::SendFailure; -use futures::channel::mpsc::SendError; +use crate::outbound::{message::SendFailure, DhtOutboundRequest}; use tari_comms::message::MessageError; use tari_crypto::{signatures::SchnorrSignatureError, tari_utilities::message_format::MessageFormatError}; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtOutboundError { - #[error("SendError: {0}")] - SendError(#[from] SendError), + #[error("`Failed to send: {0}")] + SendError(#[from] SendError), #[error("MessageSerializationError: {0}")] MessageSerializationError(#[from] MessageError), #[error("MessageFormatError: {0}")] diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 52356ec364..bb782dc2e5 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -25,7 +25,6 @@ use crate::{ outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, }; use bytes::Bytes; -use futures::channel::oneshot; use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ message::{MessageTag, MessagingReplyTx}, @@ -34,6 +33,7 @@ use tari_comms::{ }; use tari_utilities::hex::Hex; use thiserror::Error; +use tokio::sync::oneshot; /// Determines if an outbound message should be Encrypted and, if so, for which public key #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/comms/dht/src/outbound/message_send_state.rs b/comms/dht/src/outbound/message_send_state.rs index 1576e87c70..b3ca43fcb2 100644 --- a/comms/dht/src/outbound/message_send_state.rs +++ b/comms/dht/src/outbound/message_send_state.rs @@ -250,9 +250,9 @@ impl Index for MessageSendStates { #[cfg(test)] mod test { use super::*; - use futures::channel::oneshot; use std::iter::repeat_with; - use tari_comms::message::MessagingReplyTx; + use tari_comms::{message::MessagingReplyTx, runtime}; + use tokio::sync::oneshot; fn create_send_state() -> (MessageSendState, MessagingReplyTx) { let (reply_tx, reply_rx) = oneshot::channel(); @@ -269,7 +269,7 @@ mod test { assert!(!states.is_empty()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn wait_single() { let (state, mut reply_tx) = create_send_state(); let states = MessageSendStates::from(vec![state]); @@ -284,7 +284,7 @@ mod test { assert!(!states.wait_single().await); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_percentage_success() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -300,7 +300,7 @@ mod test { assert_eq!(failed.len(), 4); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_n_timeout() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -329,7 +329,7 @@ mod test { assert_eq!(failed.len(), 6); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_all() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index f5c3f30665..66e2b7258e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -31,11 +31,6 @@ use crate::{ }, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - StreamExt, -}; use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -45,7 +40,10 @@ use tari_comms::{ message::{MessageTag, MessagingReplyTx}, protocol::messaging::SendFailReason, }; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "mock::outbound_requester"; @@ -54,7 +52,7 @@ const LOG_TARGET: &str = "mock::outbound_requester"; /// Each time a request is expected, handle_next should be called. pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, OutboundServiceMock) { let (tx, rx) = mpsc::channel(size); - (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx.fuse())) + (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx)) } #[derive(Clone, Default)] @@ -149,12 +147,12 @@ impl OutboundServiceMockState { } pub struct OutboundServiceMock { - receiver: Fuse>, + receiver: mpsc::Receiver, mock_state: OutboundServiceMockState, } impl OutboundServiceMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, mock_state: OutboundServiceMockState::new(), @@ -166,7 +164,7 @@ impl OutboundServiceMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { let behaviour = self.mock_state.get_behaviour(); @@ -192,7 +190,7 @@ impl OutboundServiceMock { ResponseType::QueuedSuccessDelay(delay) => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); - delay_for(delay).await; + sleep(delay).await; inner_reply_tx.reply_success(); }, resp => { diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index c1957a98b0..f8536c9d9c 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -32,12 +32,9 @@ use crate::{ MessageSendStates, }, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use log::*; use tari_comms::{message::MessageExt, peer_manager::NodeId, types::CommsPublicKey, wrap_in_envelope_body}; +use tokio::sync::{mpsc, oneshot}; const LOG_TARGET: &str = "comms::dht::requests::outbound"; diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 97c7c0df58..195d2d3d39 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -137,9 +137,9 @@ mod test { use super::*; use crate::test_utils::{assert_send_static_service, create_outbound_message, service_spy}; use prost::Message; - use tari_comms::peer_manager::NodeId; + use tari_comms::{peer_manager::NodeId, runtime}; - #[tokio_macros::test_basic] + #[runtime::test] async fn serialize() { let spy = service_spy(); let mut serialize = SerializeLayer.layer(spy.to_service::()); diff --git a/comms/dht/src/rpc/service.rs b/comms/dht/src/rpc/service.rs index e762ed4d79..84aef6e5ff 100644 --- a/comms/dht/src/rpc/service.rs +++ b/comms/dht/src/rpc/service.rs @@ -24,16 +24,16 @@ use crate::{ proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}, rpc::DhtRpcService, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc}; use tari_comms::{ peer_manager::{NodeId, Peer, PeerFeatures, PeerQuery}, protocol::rpc::{Request, RpcError, RpcStatus, Streaming}, + utils, PeerManager, }; use tari_utilities::ByteArray; -use tokio::task; +use tokio::{sync::mpsc, task}; const LOG_TARGET: &str = "comms::dht::rpc"; @@ -56,17 +56,15 @@ impl DhtRpcServiceImpl { // A maximum buffer size of 10 is selected arbitrarily and is to allow the producer/consumer some room to // buffer. - let (mut tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); + let (tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); task::spawn(async move { let iter = peers .into_iter() .map(|peer| GetPeersResponse { peer: Some(peer.into()), }) - .map(Ok) .map(Ok); - let mut stream = stream::iter(iter); - let _ = tx.send_all(&mut stream).await; + let _ = utils::mpsc::send_all(&tx, iter).await; }); Streaming::new(rx) diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index 764d49ba7f..cd70d65a0f 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -26,13 +26,16 @@ use crate::{ test_utils::build_peer_manager, }; use futures::StreamExt; -use std::{convert::TryInto, sync::Arc}; +use std::{convert::TryInto, sync::Arc, time::Duration}; use tari_comms::{ - peer_manager::{node_id::NodeDistance, PeerFeatures}, + peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, protocol::rpc::{mock::RpcRequestMock, RpcStatusCode}, + runtime, test_utils::node_identity::{build_node_identity, ordered_node_identities_by_distance}, PeerManager, }; +use tari_test_utils::collect_recv; +use tari_utilities::ByteArray; fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { let peer_manager = build_peer_manager(); @@ -45,10 +48,8 @@ fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { // Unit tests for get_closer_peers request mod get_closer_peers { use super::*; - use tari_comms::peer_manager::{NodeId, Peer}; - use tari_utilities::ByteArray; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -66,7 +67,7 @@ mod get_closer_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_closest_peers() { let (service, mock, peer_manager) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -83,7 +84,7 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 10); let peers = results @@ -101,7 +102,7 @@ mod get_closer_peers { } } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); @@ -123,7 +124,7 @@ mod get_closer_peers { assert_eq!(results.len(), 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_skips_excluded_peers() { let (service, mock, peer_manager) = setup(); @@ -142,12 +143,12 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); let mut peers = results.into_iter().map(Result::unwrap).map(|r| r.peer.unwrap()); assert!(peers.all(|p| p.public_key != excluded_peer.public_key().as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_errors_if_maximum_n_exceeded() { let (service, mock, _) = setup(); let req = GetCloserPeersRequest { @@ -165,9 +166,10 @@ mod get_closer_peers { mod get_peers { use super::*; use crate::proto::rpc::GetPeersRequest; + use std::time::Duration; use tari_comms::{peer_manager::Peer, test_utils::node_identity::build_many_node_identities}; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -183,7 +185,7 @@ mod get_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_all_peers() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -200,7 +202,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 5); let peers = results @@ -214,7 +216,7 @@ mod get_peers { assert_eq!(peers.iter().filter(|p| p.features.is_node()).count(), 3); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_excludes_clients() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -231,7 +233,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 3); let peers = results @@ -244,7 +246,7 @@ mod get_peers { assert!(peers.iter().all(|p| p.features.is_node())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs index 856a94315a..ee99f8b560 100644 --- a/comms/dht/src/storage/connection.rs +++ b/comms/dht/src/storage/connection.rs @@ -123,16 +123,17 @@ impl DbConnection { mod test { use super::*; use diesel::{expression::sql_literal::sql, sql_types::Integer, RunQueryDsl}; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn connect_and_migrate() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); let output = conn.migrate().await.unwrap(); assert!(output.starts_with("Running migration")); } - #[tokio_macros::test_basic] + #[runtime::test] async fn memory_connections() { let id = random::string(8); let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 173d00e0ef..58ee06eb9c 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -255,9 +255,10 @@ impl StoreAndForwardDatabase { #[cfg(test)] mod test { use super::*; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -277,7 +278,7 @@ mod test { assert_eq!(messages[1].body_hash, msg2.body_hash); } - #[tokio_macros::test_basic] + #[runtime::test] async fn remove_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -304,7 +305,7 @@ mod test { assert_eq!(messages[0].id, msg2_id); } - #[tokio_macros::test_basic] + #[runtime::test] async fn truncate_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 013917e6e4..d8de4fe048 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -263,14 +263,13 @@ mod test { outbound::mock::create_outbound_service_mock, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; - use futures::{channel::mpsc, executor::block_on}; - use tari_comms::wrap_in_envelope_body; - use tokio::runtime::Runtime; + use tari_comms::{runtime, runtime::task, wrap_in_envelope_body}; + use tokio::sync::mpsc; - #[test] - fn decryption_succeeded() { + #[runtime::test] + async fn decryption_succeeded() { let spy = service_spy(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (oms_tx, _) = mpsc::channel(1); let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::()); @@ -281,18 +280,16 @@ mod test { Some(node_identity.public_key().clone()), inbound_msg, ); - block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); - assert!(oms_rx.try_next().is_err()); } - #[test] - fn decryption_failed() { - let mut rt = Runtime::new().unwrap(); + #[runtime::test] + async fn decryption_failed() { let spy = service_spy(); let (oms_requester, oms_mock) = create_outbound_service_mock(1); let oms_mock_state = oms_mock.get_state(); - rt.spawn(oms_mock.run()); + task::spawn(oms_mock.run()); let mut service = ForwardLayer::new(oms_requester, true).layer(spy.to_service::()); @@ -305,7 +302,7 @@ mod test { ); let header = inbound_msg.dht_header.clone(); let msg = DecryptedDhtMessage::failed(inbound_msg); - rt.block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); assert_eq!(oms_mock_state.call_count(), 1); diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 50b6ab7839..16e2760a1e 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -27,9 +27,9 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::channel::mpsc; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; +use tokio::sync::mpsc; use tower::layer::Layer; pub struct MessageHandlerLayer { diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 578fc1dcbc..641950e4f1 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -28,12 +28,13 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::{channel::mpsc, future::BoxFuture, task::Context}; +use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, pipeline::PipelineError, }; +use tokio::sync::mpsc; use tower::Service; #[derive(Clone)] diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index bcba88493c..c6224a7af7 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -41,7 +41,7 @@ use crate::{ }; use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; -use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; +use futures::{future, stream, StreamExt}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; @@ -53,6 +53,7 @@ use tari_comms::{ utils::signature, }; use tari_utilities::{convert::try_convert_all, ByteArray}; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::storeforward::handler"; @@ -582,14 +583,13 @@ mod test { }, }; use chrono::{Duration as OldDuration, Utc}; - use futures::channel::mpsc; use prost::Message; use std::time::Duration; - use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex; - use tari_test_utils::collect_stream; + use tari_test_utils::collect_recv; use tari_utilities::hex::Hex; - use tokio::{runtime::Handle, task, time::delay_for}; + use tokio::{runtime::Handle, sync::mpsc, task, time::sleep}; // TODO: unit tests for static functions (check_signature, etc) @@ -617,7 +617,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn request_stored_messages() { let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); @@ -677,7 +677,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); @@ -739,7 +739,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); let call = oms_mock_state.pop_call().unwrap(); @@ -765,7 +765,7 @@ mod test { assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); @@ -860,7 +860,7 @@ mod test { assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - let signals = collect_stream!( + let signals = collect_recv!( saf_response_signal_receiver, take = 1, timeout = Duration::from_secs(20) diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 5d06d85d56..7c8af0b731 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -36,12 +36,6 @@ use crate::{ DhtRequester, }; use chrono::{DateTime, NaiveDateTime, Utc}; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ @@ -51,7 +45,11 @@ use tari_comms::{ PeerManager, }; use tari_shutdown::ShutdownSignal; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::storeforward::actor"; /// The interval to initiate a database cleanup. @@ -167,13 +165,13 @@ pub struct StoreAndForwardService { dht_requester: DhtRequester, database: StoreAndForwardDatabase, peer_manager: Arc, - connection_events: Fuse, + connection_events: ConnectivityEventRx, outbound_requester: OutboundMessageRequester, - request_rx: Fuse>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, num_received_saf_responses: Option, num_online_peers: Option, - saf_response_signal_rx: Fuse>, + saf_response_signal_rx: mpsc::Receiver<()>, event_publisher: DhtEventSender, } @@ -196,13 +194,13 @@ impl StoreAndForwardService { database: StoreAndForwardDatabase::new(conn), peer_manager, dht_requester, - request_rx: request_rx.fuse(), - connection_events: connectivity.get_event_subscription().fuse(), + request_rx, + connection_events: connectivity.get_event_subscription(), outbound_requester, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, num_received_saf_responses: Some(0), num_online_peers: None, - saf_response_signal_rx: saf_response_signal_rx.fuse(), + saf_response_signal_rx, event_publisher, } } @@ -213,20 +211,15 @@ impl StoreAndForwardService { } async fn run(mut self) { - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("StoreAndForwardActor initialized without shutdown_signal"); - - let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - event = self.connection_events.select_next_some() => { + event = self.connection_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -234,20 +227,20 @@ impl StoreAndForwardService { } }, - _ = cleanup_ticker.select_next_some() => { + _ = cleanup_ticker.tick() => { if let Err(err) = self.cleanup().await { error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); } }, - _ = self.saf_response_signal_rx.select_next_some() => { + Some(_) = self.saf_response_signal_rx.recv() => { if let Some(n) = self.num_received_saf_responses { self.num_received_saf_responses = Some(n + 1); self.check_saf_response_threshold(); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); break; } diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index c23133f23d..e9b88a37ad 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -462,11 +462,11 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{runtime, wrap_in_envelope_body}; use tari_test_utils::async_assert_eventually; use tari_utilities::hex::Hex; - #[tokio_macros::test_basic] + #[runtime::test] async fn cleartext_message_no_origin() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -486,7 +486,7 @@ mod test { assert_eq!(messages.len(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_succeeded_no_store() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -514,7 +514,7 @@ mod test { assert_eq!(mock_state.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_should_store() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); @@ -553,7 +553,7 @@ mod test { assert!(duration.num_seconds() <= 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_banned_peer() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index a64292b9ed..4cfd99f209 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -25,7 +25,6 @@ use crate::{ actor::{DhtRequest, DhtRequester}, storage::DhtMetadataKey, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -35,11 +34,11 @@ use std::{ }, }; use tari_comms::peer_manager::Peer; -use tokio::task; +use tokio::{sync::mpsc, task}; pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { let (tx, rx) = mpsc::channel(buf_size); - (DhtRequester::new(tx), DhtActorMock::new(rx.fuse())) + (DhtRequester::new(tx), DhtActorMock::new(rx)) } #[derive(Default, Debug, Clone)] @@ -75,12 +74,12 @@ impl DhtMockState { } pub struct DhtActorMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtMockState, } impl DhtActorMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtMockState::default(), @@ -96,7 +95,7 @@ impl DhtActorMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index 70575e2ae0..fbdf8e8284 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -24,7 +24,6 @@ use crate::{ discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester}, test_utils::make_peer, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use std::{ sync::{ @@ -35,15 +34,13 @@ use std::{ time::Duration, }; use tari_comms::peer_manager::Peer; +use tokio::sync::mpsc; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_dht_discovery_mock(buf_size: usize, timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { let (tx, rx) = mpsc::channel(buf_size); - ( - DhtDiscoveryRequester::new(tx, timeout), - DhtDiscoveryMock::new(rx.fuse()), - ) + (DhtDiscoveryRequester::new(tx, timeout), DhtDiscoveryMock::new(rx)) } #[derive(Debug, Clone)] @@ -75,12 +72,12 @@ impl DhtDiscoveryMockState { } pub struct DhtDiscoveryMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtDiscoveryMockState, } impl DhtDiscoveryMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtDiscoveryMockState::new(), @@ -92,7 +89,7 @@ impl DhtDiscoveryMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 0dd464c43a..72c0861a6d 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -23,7 +23,6 @@ use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; use chrono::Utc; use digest::Digest; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::sync::{ @@ -32,14 +31,17 @@ use std::sync::{ }; use tari_comms::types::Challenge; use tari_utilities::hex; -use tokio::{runtime, sync::RwLock}; +use tokio::{ + runtime, + sync::{mpsc, RwLock}, +}; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { let (tx, rx) = mpsc::channel(10); - let mock = StoreAndForwardMock::new(rx.fuse()); + let mock = StoreAndForwardMock::new(rx); let state = mock.get_shared_state(); runtime::Handle::current().spawn(mock.run()); (StoreAndForwardRequester::new(tx), state) @@ -90,12 +92,12 @@ impl StoreAndForwardMockState { } pub struct StoreAndForwardMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: StoreAndForwardMockState, } impl StoreAndForwardMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: StoreAndForwardMockState::new(), @@ -107,7 +109,7 @@ impl StoreAndForwardMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 44e9f42718..761bb1badd 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -55,13 +54,16 @@ use tari_storage::{ }; use tari_test_utils::{ async_assert_eventually, - collect_stream, + collect_try_recv, paths::create_temporary_data_path, random, streams, unpack_enum, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc}, + time, +}; use tower::ServiceBuilder; struct TestNode { @@ -87,11 +89,11 @@ impl TestNode { } pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { - time::timeout(timeout, self.inbound_messages.next()).await.ok()? + time::timeout(timeout, self.inbound_messages.recv()).await.ok()? } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -230,7 +232,7 @@ fn dht_config() -> DhtConfig { config } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_join_propagation() { // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other @@ -299,7 +301,7 @@ async fn dht_join_propagation() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_discover_propagation() { // Create 4 nodes where A knows B, B knows A and C, C knows B and D, and D knows C @@ -373,7 +375,7 @@ async fn dht_discover_propagation() { assert!(node_D_peer_manager.exists(node_A.node_identity().public_key()).await); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_store_forward() { let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -431,7 +433,7 @@ async fn dht_store_forward() { .unwrap(); // Wait for node B to receive 2 propagation messages - collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); + collect_try_recv!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); let mut node_C = make_node_with_node_identity("node_C", node_C_node_identity, dht_config(), Some(node_B.to_peer())).await; @@ -451,8 +453,8 @@ async fn dht_store_forward() { .await .unwrap(); // Wait for node C to and receive a response from the SAF request - let event = collect_stream!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &*event.get(0).unwrap().as_ref()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -480,15 +482,15 @@ async fn dht_store_forward() { assert!(msgs.is_empty()); // Check that Node C emitted the StoreAndForwardMessagesReceived event when it went Online - let event = collect_stream!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &*event.get(0).unwrap().as_ref()); node_A.shutdown().await; node_B.shutdown().await; node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_dedup() { let mut config = dht_config(); @@ -599,28 +601,28 @@ async fn dht_propagate_dedup() { node_D.shutdown().await; // Check the message flow BEFORE deduping - let received = filter_received(collect_stream!(node_A_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_A_messaging, timeout = Duration::from_secs(20))); // Expected race condition: If A->(B|C)->(C|B) before A->(C|B) then (C|B)->A if !received.is_empty() { assert_eq!(count_messages_received(&received, &[&node_B_id, &node_C_id]), 1); } - let received = filter_received(collect_stream!(node_B_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_B_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); // Expected race condition: If A->B->C before A->C then C->B does not happen assert!((1..=2).contains(&recv_count)); - let received = filter_received(collect_stream!(node_C_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); assert_eq!(recv_count, 2); assert_eq!(count_messages_received(&received, &[&node_D_id]), 0); - let received = filter_received(collect_stream!(node_D_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_D_messaging, timeout = Duration::from_secs(20))); assert_eq!(received.len(), 1); assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_repropagate() { let mut config = dht_config(); @@ -723,7 +725,7 @@ async fn dht_repropagate() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() { let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; @@ -812,10 +814,10 @@ async fn dht_propagate_message_contents_not_malleable_ban() { let node_B_node_id = node_B.node_identity().node_id().clone(); // Node C should ban node B - let banned_node_id = streams::assert_in_stream( + let banned_node_id = streams::assert_in_broadcast( &mut connectivity_events, - |r| match &*r.unwrap() { - ConnectivityEvent::PeerBanned(node_id) => Some(node_id.clone()), + |r| match r { + ConnectivityEvent::PeerBanned(node_id) => Some(node_id), _ => None, }, Duration::from_secs(10), @@ -828,12 +830,9 @@ async fn dht_propagate_message_contents_not_malleable_ban() { node_C.shutdown().await; } -fn filter_received( - events: Vec, tokio::sync::broadcast::RecvError>>, -) -> Vec> { +fn filter_received(events: Vec>) -> Vec> { events .into_iter() - .map(Result::unwrap) .filter(|e| match &**e { MessagingEvent::MessageReceived(_, _) => true, _ => unreachable!(), diff --git a/comms/examples/stress/error.rs b/comms/examples/stress/error.rs index 5cb9be1cb3..e87aae514e 100644 --- a/comms/examples/stress/error.rs +++ b/comms/examples/stress/error.rs @@ -19,10 +19,10 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -use futures::channel::{mpsc::SendError, oneshot}; use std::io; use tari_comms::{ connectivity::ConnectivityError, + message::OutboundMessage, peer_manager::PeerManagerError, tor, CommsBuilderError, @@ -30,7 +30,11 @@ use tari_comms::{ }; use tari_crypto::tari_utilities::message_format::MessageFormatError; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc::error::SendError, oneshot}, + task, + time, +}; #[derive(Debug, Error)] pub enum Error { @@ -48,12 +52,12 @@ pub enum Error { ConnectivityError(#[from] ConnectivityError), #[error("Message format error: {0}")] MessageFormatError(#[from] MessageFormatError), - #[error("Failed to send message")] - SendError(#[from] SendError), + #[error("Failed to send message: {0}")] + SendError(#[from] SendError), #[error("JoinError: {0}")] JoinError(#[from] task::JoinError), #[error("Example did not exit cleanly: `{0}`")] - WaitTimeout(#[from] time::Elapsed), + WaitTimeout(#[from] time::error::Elapsed), #[error("IO error: {0}")] Io(#[from] io::Error), #[error("User quit")] @@ -63,5 +67,5 @@ pub enum Error { #[error("Unexpected EoF")] UnexpectedEof, #[error("Internal reply canceled")] - ReplyCanceled(#[from] oneshot::Canceled), + ReplyCanceled(#[from] oneshot::error::RecvError), } diff --git a/comms/examples/stress/node.rs b/comms/examples/stress/node.rs index 45ad0919ab..d060d18071 100644 --- a/comms/examples/stress/node.rs +++ b/comms/examples/stress/node.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::Error, STRESS_PROTOCOL_NAME, TOR_CONTROL_PORT_ADDR, TOR_SOCKS_ADDR}; -use futures::channel::mpsc; use rand::rngs::OsRng; use std::{convert, net::Ipv4Addr, path::Path, sync::Arc, time::Duration}; use tari_comms::{ @@ -43,7 +42,7 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; pub async fn create( node_identity: Option>, diff --git a/comms/examples/stress/service.rs b/comms/examples/stress/service.rs index 3e262cc38b..45e2bc0fd3 100644 --- a/comms/examples/stress/service.rs +++ b/comms/examples/stress/service.rs @@ -23,15 +23,7 @@ use super::error::Error; use crate::stress::{MAX_FRAME_SIZE, STRESS_PROTOCOL_NAME}; use bytes::{Buf, Bytes, BytesMut}; -use futures::{ - channel::{mpsc, oneshot}, - stream, - stream::Fuse, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::{stream, SinkExt, StreamExt}; use rand::{rngs::OsRng, RngCore}; use std::{ iter::repeat_with, @@ -43,12 +35,19 @@ use tari_comms::{ message::{InboundMessage, OutboundMessage}, peer_manager::{NodeId, Peer}, protocol::{ProtocolEvent, ProtocolNotification}, + utils, CommsNode, PeerConnection, Substream, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, task::JoinHandle, time}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot, RwLock}, + task, + task::JoinHandle, + time, +}; pub fn start_service( comms_node: CommsNode, @@ -70,6 +69,7 @@ pub fn start_service( (task::spawn(service.start()), request_tx) } +#[derive(Debug)] pub enum StressTestServiceRequest { BeginProtocol(Peer, StressProtocol, oneshot::Sender>), Shutdown, @@ -135,9 +135,9 @@ impl StressProtocol { } struct StressTestService { - request_rx: Fuse>, + request_rx: mpsc::Receiver, comms_node: CommsNode, - protocol_notif: Fuse>>, + protocol_notif: mpsc::Receiver>, shutdown: bool, inbound_rx: Arc>>, @@ -153,9 +153,9 @@ impl StressTestService { outbound_tx: mpsc::Sender, ) -> Self { Self { - request_rx: request_rx.fuse(), + request_rx, comms_node, - protocol_notif: protocol_notif.fuse(), + protocol_notif, shutdown: false, inbound_rx: Arc::new(RwLock::new(inbound_rx)), outbound_tx, @@ -163,23 +163,21 @@ impl StressTestService { } async fn start(mut self) -> Result<(), Error> { - let mut events = self.comms_node.subscribe_connectivity_events().fuse(); + let mut events = self.comms_node.subscribe_connectivity_events(); loop { - futures::select! { - event = events.select_next_some() => { - if let Ok(event) = event { - println!("{}", event); - } + tokio::select! { + Ok(event) = events.recv() => { + println!("{}", event); }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { if let Err(err) = self.handle_request(request).await { println!("Error: {}", err); } }, - notif = self.protocol_notif.select_next_some() => { + Some(notif) = self.protocol_notif.recv() => { self.handle_protocol_notification(notif).await; }, } @@ -431,7 +429,7 @@ async fn messaging_flood( peer: NodeId, protocol: StressProtocol, inbound_rx: Arc>>, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result<(), Error> { let start = Instant::now(); let mut counter = 1u32; @@ -441,18 +439,15 @@ async fn messaging_flood( protocol.num_messages * protocol.message_size / 1024 / 1024 ); let outbound_task = task::spawn(async move { - let mut iter = stream::iter( - repeat_with(|| { - counter += 1; - - println!("Send MSG {}", counter); - OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) - }) - .take(protocol.num_messages as usize) - .map(Ok), - ); - outbound_tx.send_all(&mut iter).await?; - time::delay_for(Duration::from_secs(5)).await; + let iter = repeat_with(|| { + counter += 1; + + println!("Send MSG {}", counter); + OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) + }) + .take(protocol.num_messages as usize); + utils::mpsc::send_all(&outbound_tx, iter).await?; + time::sleep(Duration::from_secs(5)).await; outbound_tx .send(OutboundMessage::new(peer.clone(), Bytes::from_static(&[0u8; 4]))) .await?; @@ -462,7 +457,7 @@ async fn messaging_flood( let inbound_task = task::spawn(async move { let mut inbound_rx = inbound_rx.write().await; let mut msgs = vec![]; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { let msg_id = decode_msg(msg.body); println!("GOT MSG {}", msg_id); if msgs.len() == protocol.num_messages as usize { @@ -497,6 +492,6 @@ fn generate_message(n: u32, size: usize) -> Bytes { fn decode_msg(msg: T) -> u32 { let mut buf = [0u8; 4]; - msg.bytes().copy_to_slice(&mut buf); + msg.chunk().copy_to_slice(&mut buf); u32::from_be_bytes(buf) } diff --git a/comms/examples/stress_test.rs b/comms/examples/stress_test.rs index 71c185aa3e..3a0c04f020 100644 --- a/comms/examples/stress_test.rs +++ b/comms/examples/stress_test.rs @@ -24,13 +24,13 @@ mod stress; use stress::{error::Error, prompt::user_prompt}; use crate::stress::{node, prompt::parse_from_short_str, service, service::StressTestServiceRequest}; -use futures::{channel::oneshot, future, future::Either, SinkExt}; +use futures::{future, future::Either}; use std::{env, net::Ipv4Addr, path::Path, process, sync::Arc, time::Duration}; use tari_crypto::tari_utilities::message_format::MessageFormat; use tempfile::Builder; -use tokio::time; +use tokio::{sync::oneshot, time}; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); match run().await { @@ -99,7 +99,7 @@ async fn run() -> Result<(), Error> { } println!("Stress test service started!"); - let (handle, mut requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); + let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); let mut last_peer = peer.as_ref().and_then(parse_from_short_str); diff --git a/comms/examples/tor.rs b/comms/examples/tor.rs index 734ef7718a..9186c69d43 100644 --- a/comms/examples/tor.rs +++ b/comms/examples/tor.rs @@ -1,7 +1,6 @@ use anyhow::anyhow; use bytes::Bytes; use chrono::Utc; -use futures::{channel::mpsc, SinkExt, StreamExt}; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{collections::HashMap, convert::identity, env, net::SocketAddr, path::Path, process, sync::Arc}; use tari_comms::{ @@ -21,7 +20,10 @@ use tari_storage::{ LMDBWrapper, }; use tempfile::Builder; -use tokio::{runtime, sync::broadcast}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, +}; // Tor example for tari_comms. // @@ -29,7 +31,7 @@ use tokio::{runtime, sync::broadcast}; type Error = anyhow::Error; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); if let Err(err) = run().await { @@ -56,7 +58,7 @@ async fn run() -> Result<(), Error> { println!("Starting comms nodes...",); let temp_dir1 = Builder::new().prefix("tor-example1").tempdir().unwrap(); - let (comms_node1, inbound_rx1, mut outbound_tx1) = setup_node_with_tor( + let (comms_node1, inbound_rx1, outbound_tx1) = setup_node_with_tor( control_port_addr.clone(), temp_dir1.as_ref(), (9098u16, "127.0.0.1:0".parse::().unwrap()), @@ -208,11 +210,11 @@ async fn setup_node_with_tor>( async fn start_ping_ponger( dest_node_id: NodeId, mut inbound_rx: mpsc::Receiver, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result { let mut inflight_pings = HashMap::new(); let mut counter = 0; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { counter += 1; let msg_str = String::from_utf8_lossy(&msg.body); diff --git a/comms/rpc_macros/Cargo.toml b/comms/rpc_macros/Cargo.toml index 3680ed81f9..5bc185bc90 100644 --- a/comms/rpc_macros/Cargo.toml +++ b/comms/rpc_macros/Cargo.toml @@ -13,16 +13,16 @@ edition = "2018" proc-macro = true [dependencies] +tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} + proc-macro2 = "1.0.24" quote = "1.0.7" syn = {version = "1.0.38", features = ["fold"]} -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} [dev-dependencies] tari_test_utils = {version="^0.9", path="../../infrastructure/test_utils"} futures = "0.3.5" -prost = "0.6.1" -tokio = "0.2.22" -tokio-macros = "0.2.5" +prost = "0.8.0" +tokio = {version = "1", features = ["macros"]} tower-service = "0.3.0" diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index f3e8cbffd1..5f44066f19 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -215,8 +215,8 @@ impl RpcCodeGenerator { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } }; diff --git a/comms/rpc_macros/tests/macro.rs b/comms/rpc_macros/tests/macro.rs index f3f05b4481..b41f3f9914 100644 --- a/comms/rpc_macros/tests/macro.rs +++ b/comms/rpc_macros/tests/macro.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::StreamExt; use prost::Message; use std::{collections::HashMap, ops::AddAssign, sync::Arc}; use tari_comms::{ @@ -34,7 +34,10 @@ use tari_comms::{ }; use tari_comms_rpc_macros::tari_rpc; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; use tower_service::Service; #[tari_rpc(protocol_name = b"/test/protocol/123", server_struct = TestServer, client_struct = TestClient)] @@ -80,7 +83,7 @@ impl Test for TestService { async fn server_streaming(&self, _: Request) -> Result, RpcStatus> { self.add_call("server_streaming").await; - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); tx.send(Ok(1)).await.unwrap(); Ok(Streaming::new(rx)) } @@ -101,7 +104,7 @@ fn it_sets_the_protocol_name() { assert_eq!(TestClient::PROTOCOL_NAME, b"/test/protocol/123"); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_the_correct_type() { let mut server = TestServer::new(TestService::default()); let resp = server @@ -112,7 +115,7 @@ async fn it_returns_the_correct_type() { assert_eq!(u32::decode(v).unwrap(), 12); } -#[tokio_macros::test] +#[tokio::test] async fn it_correctly_maps_the_method_nums() { let service = TestService::default(); let spy = service.state.clone(); @@ -135,7 +138,7 @@ async fn it_correctly_maps_the_method_nums() { assert_eq!(*spy.read().await.get("unit").unwrap(), 1); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_error_for_invalid_method_nums() { let service = TestService::default(); let mut server = TestServer::new(service); @@ -147,7 +150,7 @@ async fn it_returns_an_error_for_invalid_method_nums() { unpack_enum!(RpcStatusCode::UnsupportedMethod = err.status_code()); } -#[tokio_macros::test] +#[tokio::test] async fn it_generates_client_calls() { let (sock_client, sock_server) = MemorySocket::new_pair(); let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024))); diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index ee65e68476..3e36edb50b 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -145,7 +145,15 @@ impl BoundedExecutor { F::Output: Send + 'static, { let span = span!(Level::TRACE, "bounded_executor::waiting_time"); - let permit = self.semaphore.clone().acquire_owned().instrument(span).await; + // SAFETY: acquire_owned only fails if the semaphore is closed (i.e self.semaphore.close() is called) - this + // never happens in this implementation + let permit = self + .semaphore + .clone() + .acquire_owned() + .instrument(span) + .await + .expect("semaphore closed"); self.do_spawn(permit, future) } @@ -230,9 +238,9 @@ mod test { }, time::Duration, }; - use tokio::time::delay_for; + use tokio::time::sleep; - #[runtime::test_basic] + #[runtime::test] async fn spawn() { let flag = Arc::new(AtomicBool::new(false)); let flag_cloned = flag.clone(); @@ -241,7 +249,7 @@ mod test { // Spawn 1 let task1_fut = executor .spawn(async move { - delay_for(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)).await; flag_cloned.store(true, Ordering::SeqCst); }) .await; diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index 24c856f5ba..abd71e8952 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -46,11 +46,13 @@ use crate::{ CommsBuilder, Substream, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use log::*; use std::{iter, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::node"; diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index f5acc151a9..5da0d51793 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -51,10 +51,9 @@ use crate::{ tor, types::CommsDatabase, }; -use futures::channel::mpsc; use std::{fs::File, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; /// The `CommsBuilder` provides a simple builder API for getting Tari comms p2p messaging up and running. pub struct CommsBuilder { diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index d1ae9a0f9a..d4fe8cca97 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -42,19 +42,16 @@ use crate::{ CommsNode, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::stream::FuturesUnordered; use std::{collections::HashSet, convert::identity, hash::Hash, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::HashmapDatabase; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, task}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{broadcast, mpsc, oneshot}, + task, +}; async fn spawn_node( protocols: Protocols, @@ -109,7 +106,7 @@ async fn spawn_node( (comms_node, inbound_rx, outbound_tx, messaging_events_sender) } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_custom_protocols() { static TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test"); static ANOTHER_TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test-again"); @@ -161,9 +158,9 @@ async fn peer_to_peer_custom_protocols() { // Check that both nodes get the PeerConnected event. We subscribe after the nodes are initialized // so we miss those events. - let next_event = conn_man_events2.next().await.unwrap().unwrap(); + let next_event = conn_man_events2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn2) = &*next_event); - let next_event = conn_man_events1.next().await.unwrap().unwrap(); + let next_event = conn_man_events1.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn) = &*next_event); // Let's speak both our test protocols @@ -176,7 +173,7 @@ async fn peer_to_peer_custom_protocols() { negotiated_substream2.stream.write_all(ANOTHER_TEST_MSG).await.unwrap(); // Read TEST_PROTOCOL message to node 2 from node 1 - let negotiation = test_protocol_rx2.next().await.unwrap(); + let negotiation = test_protocol_rx2.recv().await.unwrap(); assert_eq!(negotiation.protocol, TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -185,7 +182,7 @@ async fn peer_to_peer_custom_protocols() { assert_eq!(buf, TEST_MSG); // Read ANOTHER_TEST_PROTOCOL message to node 1 from node 2 - let negotiation = another_test_protocol_rx1.next().await.unwrap(); + let negotiation = another_test_protocol_rx1.recv().await.unwrap(); assert_eq!(negotiation.protocol, ANOTHER_TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -193,18 +190,18 @@ async fn peer_to_peer_custom_protocols() { substream.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, ANOTHER_TEST_MSG); - shutdown.trigger().unwrap(); + shutdown.trigger(); comms_node1.wait_until_shutdown().await; comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging() { const NUM_MSGS: usize = 100; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, messaging_events2) = + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, messaging_events2) = spawn_node(Protocols::new(), shutdown.to_signal()).await; let mut messaging_events2 = messaging_events2.subscribe(); @@ -238,14 +235,14 @@ async fn peer_to_peer_messaging() { outbound_tx1.send(outbound_msg).await.unwrap(); } - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); let send_results = collect_stream!(replies, take = NUM_MSGS, timeout = Duration::from_secs(10)); send_results.into_iter().for_each(|r| { r.unwrap().unwrap(); }); - let events = collect_stream!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - events.into_iter().map(Result::unwrap).for_each(|m| { + let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + events.into_iter().for_each(|m| { unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); }); @@ -258,7 +255,7 @@ async fn peer_to_peer_messaging() { outbound_tx2.send(outbound_msg).await.unwrap(); } - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); // Check that we got all the messages let check_messages = |msgs: Vec| { @@ -279,13 +276,13 @@ async fn peer_to_peer_messaging() { comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging_simultaneous() { const NUM_MSGS: usize = 10; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; log::info!( "Peer1 = `{}`, Peer2 = `{}`", @@ -350,8 +347,8 @@ async fn peer_to_peer_messaging_simultaneous() { handle2.await.unwrap(); // Tasks are finished, let's see if all the messages made it though - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert!(has_unique_elements(messages1_to_2.into_iter().map(|m| m.body))); assert!(has_unique_elements(messages2_to_1.into_iter().map(|m| m.body))); diff --git a/comms/src/compat.rs b/comms/src/compat.rs index 254876f14d..67b53c7c91 100644 --- a/comms/src/compat.rs +++ b/comms/src/compat.rs @@ -27,8 +27,9 @@ use std::{ io, pin::Pin, - task::{self, Poll}, + task::{self, Context, Poll}, }; +use tokio::io::ReadBuf; /// `IoCompat` provides a compatibility shim between the `AsyncRead`/`AsyncWrite` traits provided by /// the `futures` library and those provided by the `tokio` library since they are different and @@ -47,16 +48,16 @@ impl IoCompat { impl tokio::io::AsyncRead for IoCompat where T: futures::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf.filled_mut()) } } impl futures::io::AsyncRead for IoCompat where T: tokio::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { + tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, &mut ReadBuf::new(buf)) } } diff --git a/comms/src/connection_manager/dial_state.rs b/comms/src/connection_manager/dial_state.rs index 0b07378747..07bbd1e631 100644 --- a/comms/src/connection_manager/dial_state.rs +++ b/comms/src/connection_manager/dial_state.rs @@ -24,8 +24,8 @@ use crate::{ connection_manager::{error::ConnectionManagerError, peer_connection::PeerConnection}, peer_manager::Peer, }; -use futures::channel::oneshot; use tari_shutdown::ShutdownSignal; +use tokio::sync::oneshot; /// The state of the dial request pub struct DialState { diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 538ffe89bd..df1acff4db 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -39,22 +39,22 @@ use crate::{ types::CommsPublicKey, }; use futures::{ - channel::{mpsc, oneshot}, future, future::{BoxFuture, Either, FusedFuture}, pin_mut, - stream::{Fuse, FuturesUnordered}, - AsyncRead, - AsyncWrite, - AsyncWriteExt, + stream::FuturesUnordered, FutureExt, - SinkExt, - StreamExt, }; use log::*; use std::{collections::HashMap, sync::Arc, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; -use tokio::{task::JoinHandle, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + sync::{mpsc, oneshot}, + task::JoinHandle, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::dialer"; @@ -79,7 +79,7 @@ pub struct Dialer { transport: TTransport, noise_config: NoiseConfig, backoff: Arc, - request_rx: Fuse>, + request_rx: mpsc::Receiver, cancel_signals: HashMap, conn_man_notifier: mpsc::Sender, shutdown: Option, @@ -112,7 +112,7 @@ where transport, noise_config, backoff: Arc::new(backoff), - request_rx: request_rx.fuse(), + request_rx, cancel_signals: Default::default(), conn_man_notifier, shutdown: Some(shutdown), @@ -139,16 +139,20 @@ where .expect("Establisher initialized without a shutdown"); debug!(target: LOG_TARGET, "Connection dialer started"); loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(&mut pending_dials, request), - (dial_state, dial_result) = pending_dials.select_next_some() => { - self.handle_dial_result(dial_state, dial_result).await; - } - _ = shutdown => { + tokio::select! { + // Biased ordering is used because we already have the futures polled here in a fair order, and so wish to + // forgo the minor cost of the random ordering + biased; + + _ = &mut shutdown => { info!(target: LOG_TARGET, "Connection dialer shutting down because the shutdown signal was received"); self.cancel_all_dials(); break; } + Some((dial_state, dial_result)) = pending_dials.next() => { + self.handle_dial_result(dial_state, dial_result).await; + } + Some(request) = self.request_rx.recv() => self.handle_request(&mut pending_dials, request), } } } @@ -179,12 +183,7 @@ where self.cancel_signals.len() ); self.cancel_signals.drain().for_each(|(_, mut signal)| { - log_if_error_fmt!( - level: warn, - target: LOG_TARGET, - signal.trigger(), - "Shutdown trigger failed", - ); + signal.trigger(); }) } @@ -352,7 +351,6 @@ where cancel_signal: ShutdownSignal, ) -> Result { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; - let mut muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) .await .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; @@ -448,9 +446,9 @@ where current_state.peer.node_id.short_str(), backoff_duration.as_secs() ); - let mut delay = time::delay_for(backoff_duration).fuse(); - let mut cancel_signal = current_state.get_cancel_signal(); - futures::select! { + let delay = time::sleep(backoff_duration).fuse(); + let cancel_signal = current_state.get_cancel_signal(); + tokio::select! { _ = delay => { debug!(target: LOG_TARGET, "[Attempt {}] Connecting to peer '{}'", current_state.num_attempts(), current_state.peer.node_id.short_str()); match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_byte).await { @@ -544,18 +542,13 @@ where // Try the next address continue; }, - Either::Right((cancel_result, _)) => { + // Canceled + Either::Right(_) => { debug!( target: LOG_TARGET, "Dial for peer '{}' cancelled", dial_state.peer.node_id.short_str() ); - log_if_error!( - level: warn, - target: LOG_TARGET, - cancel_result, - "Cancel channel error during dial: {}", - ); Err(ConnectionManagerError::DialCancelled) }, } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index f7bbbaa564..5645a57e62 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ + connection_manager::PeerConnectionRequest, noise, peer_manager::PeerManagerError, protocol::{IdentityProtocolError, ProtocolError}, }; -use futures::channel::mpsc; use thiserror::Error; -use tokio::{time, time::Elapsed}; +use tokio::{sync::mpsc, time::error::Elapsed}; #[derive(Debug, Error, Clone)] pub enum ConnectionManagerError { @@ -108,14 +108,14 @@ pub enum PeerConnectionError { #[error("Internal oneshot reply channel was unexpectedly cancelled")] InternalReplyCancelled, #[error("Failed to send internal request: {0}")] - InternalRequestSendFailed(#[from] mpsc::SendError), + InternalRequestSendFailed(#[from] mpsc::error::SendError), #[error("Protocol error: {0}")] ProtocolError(#[from] ProtocolError), #[error("Protocol negotiation timeout")] ProtocolNegotiationTimeout, } -impl From for PeerConnectionError { +impl From for PeerConnectionError { fn from(_: Elapsed) -> Self { PeerConnectionError::ProtocolNegotiationTimeout } diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 60ae3c2d12..21c3771610 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -32,7 +32,6 @@ use crate::{ bounded_executor::BoundedExecutor, connection_manager::{ liveness::LivenessSession, - types::OneshotTrigger, wire_mode::{WireMode, LIVENESS_WIRE_MODE}, }, multiaddr::Multiaddr, @@ -46,17 +45,7 @@ use crate::{ utils::multiaddr::multiaddr_to_socketaddr, PeerManager, }; -use futures::{ - channel::mpsc, - future, - AsyncRead, - AsyncReadExt, - AsyncWrite, - AsyncWriteExt, - FutureExt, - SinkExt, - StreamExt, -}; +use futures::{future, FutureExt}; use log::*; use std::{ convert::TryInto, @@ -69,8 +58,13 @@ use std::{ time::Duration, }; use tari_crypto::tari_utilities::hex::Hex; -use tari_shutdown::ShutdownSignal; -use tokio::time; +use tari_shutdown::{oneshot_trigger, oneshot_trigger::OneshotTrigger, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::listener"; @@ -118,7 +112,7 @@ where bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, - on_listening: OneshotTrigger::new(), + on_listening: oneshot_trigger::channel(), } } @@ -128,7 +122,7 @@ where // 'static lifetime as well as to flatten the oneshot result for ergonomics pub fn on_listening(&self) -> impl Future> + 'static { let signal = self.on_listening.to_signal(); - signal.map(|r| r.map_err(|_| ConnectionManagerError::ListenerOneshotCancelled)?) + signal.map(|r| r.ok_or(ConnectionManagerError::ListenerOneshotCancelled)?) } /// Set the supported protocols of this node to send to peers during the peer identity exchange @@ -147,31 +141,30 @@ where let mut shutdown_signal = self.shutdown_signal.clone(); match self.bind().await { - Ok((inbound, address)) => { + Ok((mut inbound, address)) => { info!(target: LOG_TARGET, "Listening for peer connections on '{}'", address); - self.on_listening.trigger(Ok(address)); - - let inbound = inbound.fuse(); - futures::pin_mut!(inbound); + self.on_listening.broadcast(Ok(address)); loop { - futures::select! { - inbound_result = inbound.select_next_some() => { + tokio::select! { + biased; + + _ = &mut shutdown_signal => { + info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); + break; + }, + Some(inbound_result) = inbound.next() => { if let Some((socket, peer_addr)) = log_if_error!(target: LOG_TARGET, inbound_result, "Inbound connection failed because '{error}'",) { self.spawn_listen_task(socket, peer_addr).await; } }, - _ = shutdown_signal => { - info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); - break; - }, } } }, Err(err) => { warn!(target: LOG_TARGET, "PeerListener was unable to start because '{}'", err); - self.on_listening.trigger(Err(err)); + self.on_listening.broadcast(Err(err)); }, } } @@ -238,7 +231,7 @@ where async fn spawn_listen_task(&self, mut socket: TTransport::Output, peer_addr: Multiaddr) { let node_identity = self.node_identity.clone(); let peer_manager = self.peer_manager.clone(); - let mut conn_man_notifier = self.conn_man_notifier.clone(); + let conn_man_notifier = self.conn_man_notifier.clone(); let noise_config = self.noise_config.clone(); let config = self.config.clone(); let our_supported_protocols = self.our_supported_protocols.clone(); @@ -318,7 +311,7 @@ where "No liveness sessions available or permitted for peer address '{}'", peer_addr ); - let _ = socket.close().await; + let _ = socket.shutdown().await; } }, Err(err) => { diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 3c06889307..75ee2db13f 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -20,15 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite, Future, StreamExt}; +use futures::StreamExt; +use std::future::Future; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; /// Max line length accepted by the liveness session. const MAX_LINE_LENGTH: usize = 50; pub struct LivenessSession { - framed: Framed, LinesCodec>, + framed: Framed, } impl LivenessSession @@ -36,7 +37,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin { pub fn new(socket: TSocket) -> Self { Self { - framed: Framed::new(IoCompat::new(socket), LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), + framed: Framed::new(socket, LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), } } @@ -52,13 +53,14 @@ mod test { use crate::{memsocket::MemorySocket, runtime}; use futures::SinkExt; use tokio::{time, time::Duration}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); let join_handle = runtime::current().spawn(liveness.run()); - let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); + let mut outbound = Framed::new(outbound, LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() } diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 928b53611f..0c6de18f59 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -36,20 +36,17 @@ use crate::{ transports::{TcpTransport, Transport}, PeerManager, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - AsyncRead, - AsyncWrite, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{fmt, sync::Arc}; use tari_shutdown::{Shutdown, ShutdownSignal}; use time::Duration; -use tokio::{sync::broadcast, task, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc, oneshot}, + task, + time, +}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::manager"; @@ -156,8 +153,8 @@ impl ListenerInfo { } pub struct ConnectionManager { - request_rx: Fuse>, - internal_event_rx: Fuse>, + request_rx: mpsc::Receiver, + internal_event_rx: mpsc::Receiver, dialer_tx: mpsc::Sender, dialer: Option>, listener: Option>, @@ -230,10 +227,10 @@ where Self { shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + request_rx, peer_manager, protocols: Protocols::new(), - internal_event_rx: internal_event_rx.fuse(), + internal_event_rx, dialer_tx, dialer: Some(dialer), listener: Some(listener), @@ -266,7 +263,7 @@ where .take() .expect("ConnectionManager initialized without a shutdown"); - // Runs the listeners, waiting for a + // Runs the listeners. Sockets are bound and ready once this resolves match self.run_listeners().await { Ok(info) => { self.listener_info = Some(info); @@ -293,16 +290,16 @@ where .join(", ") ); loop { - futures::select! { - event = self.internal_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_event_rx.recv() => { self.handle_event(event).await; }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - _ = shutdown => { + _ = &mut shutdown => { info!(target: LOG_TARGET, "ConnectionManager is shutting down because it received the shutdown signal"); break; } diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 6f0d90da5d..975c4969a4 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -44,20 +44,21 @@ use crate::{ protocol::{ProtocolId, ProtocolNegotiation}, runtime, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{ fmt, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level, Span}; const LOG_TARGET: &str = "comms::connection_manager::peer_connection"; @@ -130,7 +131,7 @@ pub struct PeerConnection { peer_node_id: NodeId, peer_features: PeerFeatures, request_tx: mpsc::Sender, - address: Multiaddr, + address: Arc, direction: ConnectionDirection, started_at: Instant, substream_counter: SubstreamCounter, @@ -151,7 +152,7 @@ impl PeerConnection { request_tx, peer_node_id, peer_features, - address, + address: Arc::new(address), direction, started_at: Instant::now(), substream_counter, @@ -301,9 +302,9 @@ impl PartialEq for PeerConnection { struct PeerConnectionActor { id: ConnectionId, peer_node_id: NodeId, - request_rx: Fuse>, + request_rx: mpsc::Receiver, direction: ConnectionDirection, - incoming_substreams: Fuse, + incoming_substreams: IncomingSubstreams, control: Control, event_notifier: mpsc::Sender, our_supported_protocols: Vec, @@ -327,8 +328,8 @@ impl PeerConnectionActor { peer_node_id, direction, control: connection.get_yamux_control(), - incoming_substreams: connection.incoming().fuse(), - request_rx: request_rx.fuse(), + incoming_substreams: connection.incoming(), + request_rx, event_notifier, our_supported_protocols, their_supported_protocols, @@ -337,8 +338,8 @@ impl PeerConnectionActor { pub async fn run(mut self) { loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(request).await, + tokio::select! { + Some(request) = self.request_rx.recv() => self.handle_request(request).await, maybe_substream = self.incoming_substreams.next() => { match maybe_substream { @@ -362,7 +363,7 @@ impl PeerConnectionActor { } } } - self.request_rx.get_mut().close(); + self.request_rx.close(); } async fn handle_request(&mut self, request: PeerConnectionRequest) { diff --git a/comms/src/connection_manager/requester.rs b/comms/src/connection_manager/requester.rs index 1f3f5cc887..c091bd9bb4 100644 --- a/comms/src/connection_manager/requester.rs +++ b/comms/src/connection_manager/requester.rs @@ -25,12 +25,8 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, oneshot}; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index 586c0d5ec4..25e715bf50 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -36,20 +36,17 @@ use crate::{ test_utils::{node_identity::build_node_identity, test_node::build_peer_manager}, transports::MemoryTransport, }; -use futures::{ - channel::{mpsc, oneshot}, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; use multiaddr::Protocol; use std::{error::Error, time::Duration}; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::time::timeout; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot}, + time::timeout, +}; -#[runtime::test_basic] +#[runtime::test] async fn listen() -> Result<(), Box> { let (event_tx, _) = mpsc::channel(1); let mut shutdown = Shutdown::new(); @@ -61,7 +58,7 @@ async fn listen() -> Result<(), Box> { "/memory/0".parse()?, MemoryTransport, noise_config.clone(), - event_tx.clone(), + event_tx, peer_manager, node_identity, shutdown.to_signal(), @@ -72,12 +69,12 @@ async fn listen() -> Result<(), Box> { unpack_enum!(Protocol::Memory(port) = bind_addr.pop().unwrap()); assert!(port > 0); - shutdown.trigger().unwrap(); + shutdown.trigger(); Ok(()) } -#[runtime::test_basic] +#[runtime::test] async fn smoke() { let rt_handle = runtime::current(); // This test sets up Dialer and Listener components, uses the Dialer to dial the Listener, @@ -108,7 +105,7 @@ async fn smoke() { let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -148,11 +145,11 @@ async fn smoke() { } // Read PeerConnected events - we don't know which connection is which - unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.next().await.unwrap()); - unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.recv().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.recv().await.unwrap()); // Next event should be a NewInboundSubstream has been received - let listen_event = event_rx.next().await.unwrap(); + let listen_event = event_rx.recv().await.unwrap(); { unpack_enum!(ConnectionManagerEvent::NewInboundSubstream(node_id, proto, in_stream) = listen_event); assert_eq!(&*node_id, node_identity2.node_id()); @@ -165,7 +162,7 @@ async fn smoke() { conn1.disconnect().await.unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); let peer2 = peer_manager1.find_by_node_id(node_identity2.node_id()).await.unwrap(); let peer1 = peer_manager2.find_by_node_id(node_identity1.node_id()).await.unwrap(); @@ -176,7 +173,7 @@ async fn smoke() { timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn banned() { let rt_handle = runtime::current(); let (event_tx, mut event_rx) = mpsc::channel(10); @@ -209,7 +206,7 @@ async fn banned() { peer_manager1.add_peer(peer).await.unwrap(); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -241,10 +238,10 @@ async fn banned() { let err = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::IdentityProtocolError(_err) = err); - unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.recv().await.unwrap()); unpack_enum!(ConnectionManagerError::PeerBanned = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index bca876eff4..910280b4cf 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -41,19 +41,17 @@ use crate::{ }, transports::{MemoryTransport, TcpTransport}, }; -use futures::{ - channel::{mpsc, oneshot}, - future, - AsyncReadExt, - AsyncWriteExt, - StreamExt, -}; +use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{runtime::Handle, sync::broadcast}; +use tari_test_utils::{collect_try_recv, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + runtime::Handle, + sync::{broadcast, mpsc, oneshot}, +}; -#[runtime::test_basic] +#[runtime::test] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); @@ -83,10 +81,10 @@ async fn connect_to_nonexistent_peer() { unpack_enum!(ConnectionManagerError::PeerManagerError(err) = err); unpack_enum!(PeerManagerError::PeerNotFoundError = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -159,7 +157,7 @@ async fn dial_success() { assert_eq!(peer2.supported_protocols, [&IDENTITY_PROTOCOL, &TEST_PROTO]); assert_eq!(peer2.user_agent, "node2"); - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn_in) = &*event); assert_eq!(conn_in.peer_node_id(), node_identity1.node_id()); @@ -179,7 +177,7 @@ async fn dial_success() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx2.next().await.unwrap(); + let protocol_in = proto_rx2.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -189,7 +187,7 @@ async fn dial_success() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success_aux_tcp_listener() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -271,7 +269,7 @@ async fn dial_success_aux_tcp_listener() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx1.next().await.unwrap(); + let protocol_in = proto_rx1.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -281,7 +279,7 @@ async fn dial_success_aux_tcp_listener() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn simultaneous_dial_events() { let mut shutdown = Shutdown::new(); @@ -360,29 +358,22 @@ async fn simultaneous_dial_events() { _ => panic!("unexpected simultaneous dial result"), } - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); assert!(count_string_occurrences(&[event], &["PeerConnected", "PeerInboundConnectFailed"]) >= 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); drop(conn_man2); - let _events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); - - let _events2 = collect_stream!(subscription2, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let _events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); + let _events2 = collect_try_recv!(subscription2, timeout = Duration::from_secs(5)); // TODO: Investigate why two PeerDisconnected events are sometimes received // assert!(count_string_occurrences(&events1, &["PeerDisconnected"]) >= 1); // assert!(count_string_occurrences(&events2, &["PeerDisconnected"]) >= 1); } -#[tokio_macros::test_basic] +#[runtime::test] async fn dial_cancelled() { let mut shutdown = Shutdown::new(); @@ -429,13 +420,10 @@ async fn dial_cancelled() { let err = dial_result.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::DialCancelled = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); - let events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); assert_eq!(events1.len(), 1); unpack_enum!(ConnectionManagerEvent::PeerConnectFailed(node_id, err) = &*events1[0]); diff --git a/comms/src/connection_manager/types.rs b/comms/src/connection_manager/types.rs index ddb8f6de8f..c92b2b717a 100644 --- a/comms/src/connection_manager/types.rs +++ b/comms/src/connection_manager/types.rs @@ -20,11 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::oneshot, - future::{Fuse, Shared}, - FutureExt, -}; use std::fmt; /// Direction of the connection relative to this node @@ -47,29 +42,3 @@ impl fmt::Display for ConnectionDirection { write!(f, "{:?}", self) } } - -pub type OneshotSignal = Shared>>; -pub struct OneshotTrigger(Option>, OneshotSignal); - -impl OneshotTrigger { - pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self(Some(tx), rx.fuse().shared()) - } - - pub fn to_signal(&self) -> OneshotSignal { - self.1.clone() - } - - pub fn trigger(&mut self, item: T) { - if let Some(tx) = self.0.take() { - let _ = tx.send(item); - } - } -} - -impl Default for OneshotTrigger { - fn default() -> Self { - Self::new() - } -} diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index 35f37627c4..821253e84b 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -34,6 +34,7 @@ use crate::{ ConnectionManagerEvent, ConnectionManagerRequester, }, + connectivity::ConnectivityEventTx, peer_manager::NodeId, runtime::task, utils::datetime::format_duration, @@ -41,7 +42,6 @@ use crate::{ PeerConnection, PeerManager, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use nom::lib::std::collections::hash_map::Entry; use std::{ @@ -52,7 +52,7 @@ use std::{ time::{Duration, Instant}, }; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle, time}; +use tokio::{sync::mpsc, task::JoinHandle, time}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connectivity::manager"; @@ -72,7 +72,7 @@ const LOG_TARGET: &str = "comms::connectivity::manager"; pub struct ConnectivityManager { pub config: ConnectivityConfig, pub request_rx: mpsc::Receiver, - pub event_tx: broadcast::Sender>, + pub event_tx: ConnectivityEventTx, pub connection_manager: ConnectionManagerRequester, pub peer_manager: Arc, pub node_identity: Arc, @@ -84,7 +84,7 @@ impl ConnectivityManager { ConnectivityManagerActor { config: self.config, status: ConnectivityStatus::Initializing, - request_rx: self.request_rx.fuse(), + request_rx: self.request_rx, connection_manager: self.connection_manager, peer_manager: self.peer_manager.clone(), event_tx: self.event_tx, @@ -140,12 +140,12 @@ impl fmt::Display for ConnectivityStatus { pub struct ConnectivityManagerActor { config: ConnectivityConfig, status: ConnectivityStatus, - request_rx: Fuse>, + request_rx: mpsc::Receiver, connection_manager: ConnectionManagerRequester, node_identity: Arc, shutdown_signal: Option, peer_manager: Arc, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, connection_stats: HashMap, managed_peers: Vec, @@ -165,7 +165,7 @@ impl ConnectivityManagerActor { .take() .expect("ConnectivityManager initialized without a shutdown_signal"); - let mut connection_manager_events = self.connection_manager.get_event_subscription().fuse(); + let mut connection_manager_events = self.connection_manager.get_event_subscription(); let interval = self.config.connection_pool_refresh_interval; let mut ticker = time::interval_at( @@ -174,18 +174,17 @@ impl ConnectivityManagerActor { .expect("connection_pool_refresh_interval cause overflow") .into(), interval, - ) - .fuse(); + ); self.publish_event(ConnectivityEvent::ConnectivityStateInitialized); loop { - futures::select! { - req = self.request_rx.select_next_some() => { + tokio::select! { + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; }, - event = connection_manager_events.select_next_some() => { + event = connection_manager_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connection_manager_event(&event).await { error!(target:LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -193,13 +192,13 @@ impl ConnectivityManagerActor { } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_connection_pool().await { error!(target: LOG_TARGET, "Error when refreshing connection pools: {:?}", err); } }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "ConnectivityManager is shutting down because it received the shutdown signal"); self.disconnect_all().await; break; @@ -823,7 +822,7 @@ impl ConnectivityManagerActor { fn publish_event(&mut self, event: ConnectivityEvent) { // A send operation can only fail if there are no subscribers, so it is safe to ignore the error - let _ = self.event_tx.send(Arc::new(event)); + let _ = self.event_tx.send(event); } async fn ban_peer( @@ -863,7 +862,7 @@ impl ConnectivityManagerActor { fn delayed_close(conn: PeerConnection, delay: Duration) { task::spawn(async move { - time::delay_for(delay).await; + time::sleep(delay).await; debug!( target: LOG_TARGET, "Closing connection from peer `{}` after delay", diff --git a/comms/src/connectivity/requester.rs b/comms/src/connectivity/requester.rs index 740c8a6c81..2c1e77268d 100644 --- a/comms/src/connectivity/requester.rs +++ b/comms/src/connectivity/requester.rs @@ -31,23 +31,21 @@ use crate::{ peer_manager::NodeId, PeerConnection, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, - StreamExt, -}; use log::*; use std::{ fmt, - sync::Arc, time::{Duration, Instant}, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, broadcast::error::RecvError, mpsc, oneshot}, + time, +}; + const LOG_TARGET: &str = "comms::connectivity::requester"; use tracing; -pub type ConnectivityEventRx = broadcast::Receiver>; -pub type ConnectivityEventTx = broadcast::Sender>; +pub type ConnectivityEventRx = broadcast::Receiver; +pub type ConnectivityEventTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ConnectivityEvent { @@ -264,24 +262,23 @@ impl ConnectivityRequester { let mut last_known_peer_count = status.num_connected_nodes(); loop { debug!(target: LOG_TARGET, "Waiting for connectivity event"); - let recv_result = time::timeout(remaining, connectivity_events.next()) + let recv_result = time::timeout(remaining, connectivity_events.recv()) .await - .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))? - .ok_or(ConnectivityError::ConnectivityEventStreamClosed)?; + .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; remaining = timeout .checked_sub(start.elapsed()) .ok_or(ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; match recv_result { - Ok(event) => match &*event { + Ok(event) => match event { ConnectivityEvent::ConnectivityStateOnline(_) => { info!(target: LOG_TARGET, "Connectivity is ONLINE."); break Ok(()); }, ConnectivityEvent::ConnectivityStateDegraded(n) => { warn!(target: LOG_TARGET, "Connectivity is DEGRADED ({} peer(s))", n); - last_known_peer_count = *n; + last_known_peer_count = n; }, ConnectivityEvent::ConnectivityStateOffline => { warn!( @@ -297,14 +294,14 @@ impl ConnectivityRequester { ); }, }, - Err(broadcast::RecvError::Closed) => { + Err(RecvError::Closed) => { error!( target: LOG_TARGET, "Connectivity event stream closed unexpectedly. System may be shutting down." ); break Err(ConnectivityError::ConnectivityEventStreamClosed); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagging behind on {} connectivity event(s)", n); // We lagged, so could have missed the state change. Check it explicitly. let status = self.get_connectivity_status().await?; diff --git a/comms/src/connectivity/selection.rs b/comms/src/connectivity/selection.rs index 931b78163d..c300d2d353 100644 --- a/comms/src/connectivity/selection.rs +++ b/comms/src/connectivity/selection.rs @@ -135,8 +135,8 @@ mod test { peer_manager::node_id::NodeDistance, test_utils::{mocks::create_dummy_peer_connection, node_id, node_identity::build_node_identity}, }; - use futures::channel::mpsc; use std::iter::repeat_with; + use tokio::sync::mpsc; fn create_pool_with_connections(n: usize) -> (ConnectionPool, Vec>) { let mut pool = ConnectionPool::new(); diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index a4fec1e896..948d083e94 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -28,6 +28,7 @@ use super::{ }; use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerEvent}, + connectivity::ConnectivityEventRx, peer_manager::{Peer, PeerFeatures}, runtime, runtime::task, @@ -39,18 +40,18 @@ use crate::{ NodeIdentity, PeerManager, }; -use futures::{channel::mpsc, future}; +use futures::future; use std::{sync::Arc, time::Duration}; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, streams, unpack_enum}; -use tokio::sync::broadcast; +use tari_test_utils::{collect_try_recv, streams, unpack_enum}; +use tokio::sync::{broadcast, mpsc}; #[allow(clippy::type_complexity)] fn setup_connectivity_manager( config: ConnectivityConfig, ) -> ( ConnectivityRequester, - broadcast::Receiver>, + ConnectivityEventRx, Arc, Arc, ConnectionManagerMockState, @@ -100,7 +101,7 @@ async fn add_test_peers(peer_manager: &PeerManager, n: usize) -> Vec { peers } -#[runtime::test_basic] +#[runtime::test] async fn connecting_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -117,15 +118,15 @@ async fn connecting_peers() { .map(|(_, _, conn, _)| conn) .collect::>(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // All connections succeeded for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } - let _events = collect_stream!(event_stream, take = 11, timeout = Duration::from_secs(10)); + let _events = collect_try_recv!(event_stream, take = 11, timeout = Duration::from_secs(10)); let connection_states = connectivity.get_all_connection_states().await.unwrap(); assert_eq!(connection_states.len(), 10); @@ -135,7 +136,7 @@ async fn connecting_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn add_many_managed_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -156,8 +157,8 @@ async fn add_many_managed_peers() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // First 5 succeeded for conn in &connections { @@ -172,10 +173,10 @@ async fn add_many_managed_peers() { )); } - let events = collect_stream!(event_stream, take = 9, timeout = Duration::from_secs(10)); + let events = collect_try_recv!(event_stream, take = 9, timeout = Duration::from_secs(10)); let n = events .iter() - .find_map(|event| match &**event.as_ref().unwrap() { + .find_map(|event| match event { ConnectivityEvent::ConnectivityStateOnline(n) => Some(n), ConnectivityEvent::ConnectivityStateDegraded(_) => None, ConnectivityEvent::PeerConnected(_) => None, @@ -205,7 +206,7 @@ async fn add_many_managed_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn online_then_offline() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -244,8 +245,8 @@ async fn online_then_offline() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); for conn in connections.iter().skip(1) { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); @@ -269,9 +270,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateDegraded(2) => Some(()), _ => None, }, @@ -289,9 +290,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateOffline => Some(()), _ => None, }, @@ -303,20 +304,20 @@ async fn online_then_offline() { assert!(is_offline); } -#[runtime::test_basic] +#[runtime::test] async fn ban_peer() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); let peer = add_test_peers(&peer_manager, 1).await.pop().unwrap(); let (conn, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); - let mut events = collect_stream!(event_stream, take = 2, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = &*events.remove(0).unwrap()); - unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 2, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = events.remove(0)); + unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = events.remove(0)); let conn = connectivity.get_connection(peer.node_id.clone()).await.unwrap(); assert!(conn.is_some()); @@ -329,13 +330,12 @@ async fn ban_peer() { // We can always expect a single PeerBanned because we do not publish a disconnected event from the connection // manager In a real system, peer disconnect and peer banned events may happen in any order and should always be // completely fine. - let event = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)) + let event = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)) .pop() - .unwrap() .unwrap(); - unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = &*event); - assert_eq!(node_id, &peer.node_id); + unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = event); + assert_eq!(node_id, peer.node_id); let peer = peer_manager.find_by_node_id(&peer.node_id).await.unwrap(); assert!(peer.is_banned()); @@ -344,7 +344,7 @@ async fn ban_peer() { assert!(conn.is_none()); } -#[runtime::test_basic] +#[runtime::test] async fn peer_selection() { let config = ConnectivityConfig { min_connectivity: 1.0, @@ -370,15 +370,15 @@ async fn peer_selection() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // 10 connections for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } // Wait for all peers to be connected (i.e. for the connection manager events to be received) - let mut _events = collect_stream!(event_stream, take = 12, timeout = Duration::from_secs(10)); + let mut _events = collect_try_recv!(event_stream, take = 12, timeout = Duration::from_secs(10)); let conns = connectivity .select_connections(ConnectivitySelection::random_nodes(10, vec![connections[0] diff --git a/comms/src/framing.rs b/comms/src/framing.rs index 1e6b67691e..06ccc00c30 100644 --- a/comms/src/framing.rs +++ b/comms/src/framing.rs @@ -20,17 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; /// Tari comms canonical framing -pub type CanonicalFraming = Framed, LengthDelimitedCodec>; +pub type CanonicalFraming = Framed; pub fn canonical(stream: T, max_frame_len: usize) -> CanonicalFraming where T: AsyncRead + AsyncWrite + Unpin { Framed::new( - IoCompat::new(stream), + stream, LengthDelimitedCodec::builder() .max_frame_length(max_frame_len) .new_codec(), diff --git a/comms/src/lib.rs b/comms/src/lib.rs index 4f3bb178d2..de48e316e5 100644 --- a/comms/src/lib.rs +++ b/comms/src/lib.rs @@ -32,21 +32,19 @@ pub use peer_manager::{NodeIdentity, PeerManager}; pub mod framing; -mod common; -pub use common::rate_limit; +pub mod rate_limit; mod multiplexing; pub use multiplexing::Substream; mod noise; mod proto; -mod runtime; pub mod backoff; pub mod bounded_executor; -pub mod compat; pub mod memsocket; pub mod protocol; +pub mod runtime; #[macro_use] pub mod message; pub mod net_address; diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index 3d102bf8dc..ed77fc6146 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -26,17 +26,21 @@ use bytes::{Buf, Bytes}; use futures::{ channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, - io::{AsyncRead, AsyncWrite, Error, ErrorKind, Result}, ready, stream::{FusedStream, Stream}, task::{Context, Poll}, }; use std::{ + cmp, collections::{hash_map::Entry, HashMap}, num::NonZeroU16, pin::Pin, sync::Mutex, }; +use tokio::{ + io, + io::{AsyncRead, AsyncWrite, ErrorKind, ReadBuf}, +}; lazy_static! { static ref SWITCHBOARD: Mutex = Mutex::new(SwitchBoard(HashMap::default(), 1)); @@ -114,6 +118,7 @@ pub fn release_memsocket_port(port: NonZeroU16) { /// use std::io::Result; /// /// use tari_comms::memsocket::{MemoryListener, MemorySocket}; +/// use tokio::io::*; /// use futures::prelude::*; /// /// async fn write_stormlight(mut stream: MemorySocket) -> Result<()> { @@ -170,7 +175,7 @@ impl MemoryListener { /// ``` /// /// [`local_addr`]: #method.local_addr - pub fn bind(port: u16) -> Result { + pub fn bind(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Get the port we should bind to. If 0 was given, use a random port @@ -262,11 +267,11 @@ impl MemoryListener { Incoming { inner: self } } - fn poll_accept(&mut self, context: &mut Context) -> Poll> { + fn poll_accept(&mut self, context: &mut Context) -> Poll> { match Pin::new(&mut self.incoming).poll_next(context) { Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)), Poll::Ready(None) => { - let err = Error::new(ErrorKind::Other, "MemoryListener unknown error"); + let err = io::Error::new(ErrorKind::Other, "MemoryListener unknown error"); Poll::Ready(Err(err)) }, Poll::Pending => Poll::Pending, @@ -283,7 +288,7 @@ pub struct Incoming<'a> { } impl<'a> Stream for Incoming<'a> { - type Item = Result; + type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { let socket = ready!(self.inner.poll_accept(context)?); @@ -302,6 +307,7 @@ impl<'a> Stream for Incoming<'a> { /// /// ```rust, no_run /// use futures::prelude::*; +/// use tokio::io::*; /// use tari_comms::memsocket::MemorySocket; /// /// # async fn run() -> ::std::io::Result<()> { @@ -371,7 +377,7 @@ impl MemorySocket { /// let socket = MemorySocket::connect(16)?; /// # Ok(())} /// ``` - pub fn connect(port: u16) -> Result { + pub fn connect(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Find port to connect to @@ -399,13 +405,13 @@ impl MemorySocket { impl AsyncRead for MemorySocket { /// Attempt to read from the `AsyncRead` into `buf`. - fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { if self.incoming.is_terminated() { if self.seen_eof { return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); } else { self.seen_eof = true; - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } } @@ -413,22 +419,23 @@ impl AsyncRead for MemorySocket { loop { // If we're already filled up the buffer then we can return - if bytes_read == buf.len() { - return Poll::Ready(Ok(bytes_read)); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } match self.current_buffer { // We have data to copy to buf Some(ref mut current_buffer) if !current_buffer.is_empty() => { - let bytes_to_read = ::std::cmp::min(buf.len() - bytes_read, current_buffer.len()); - debug_assert!(bytes_to_read > 0); - - buf[bytes_read..(bytes_read + bytes_to_read)] - .copy_from_slice(current_buffer.slice(0..bytes_to_read).as_ref()); + let bytes_to_read = cmp::min(buf.remaining(), current_buffer.len()); + if bytes_to_read > 0 { + buf.initialize_unfilled_to(bytes_to_read) + .copy_from_slice(¤t_buffer.slice(..bytes_to_read)); + buf.advance(bytes_to_read); - current_buffer.advance(bytes_to_read); + current_buffer.advance(bytes_to_read); - bytes_read += bytes_to_read; + bytes_read += bytes_to_read; + } }, // Either we've exhausted our current buffer or don't have one @@ -438,13 +445,13 @@ impl AsyncRead for MemorySocket { Poll::Pending => { // If we've read anything up to this point return the bytes read if bytes_read > 0 { - return Poll::Ready(Ok(bytes_read)); + return Poll::Ready(Ok(())); } else { return Poll::Pending; } }, Poll::Ready(Some(buf)) => Some(buf), - Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)), + Poll::Ready(None) => return Poll::Ready(Ok(())), } }; }, @@ -455,14 +462,14 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { let len = buf.len(); match self.outgoing.poll_ready(context) { Poll::Ready(Ok(())) => { if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -471,7 +478,7 @@ impl AsyncWrite for MemorySocket { }, Poll::Ready(Err(e)) => { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -484,12 +491,12 @@ impl AsyncWrite for MemorySocket { } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -499,15 +506,12 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use futures::{ - executor::block_on, - io::{AsyncReadExt, AsyncWriteExt}, - stream::StreamExt, - }; - use std::io::Result; + use crate::runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; #[test] - fn listener_bind() -> Result<()> { + fn listener_bind() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let listener = MemoryListener::bind(port)?; assert_eq!(listener.local_addr(), port); @@ -515,172 +519,187 @@ mod test { Ok(()) } - #[test] - fn simple_connect() -> Result<()> { + #[runtime::test] + async fn simple_connect() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let mut listener = MemoryListener::bind(port)?; let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); Ok(()) } - #[test] - fn listen_on_port_zero() -> Result<()> { + #[runtime::test] + async fn listen_on_port_zero() -> io::Result<()> { let mut listener = MemoryListener::bind(0)?; let listener_addr = listener.local_addr(); let mut dialer = MemorySocket::connect(listener_addr)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(listener_socket.write_all(b"bar"))?; - block_on(listener_socket.flush())?; + listener_socket.write_all(b"bar").await?; + listener_socket.flush().await?; let mut buf = [0; 3]; - block_on(dialer.read_exact(&mut buf))?; + dialer.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let mut listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + #[runtime::test] + async fn listener_correctly_frees_port_on_drop() { + async fn connect_on_port(port: u16) { + let mut listener = MemoryListener::bind(port).unwrap(); + let mut dialer = MemorySocket::connect(port).unwrap(); + let mut listener_socket = listener.incoming().next().await.unwrap().unwrap(); - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await.unwrap(); + dialer.flush().await.unwrap(); let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + let n = listener_socket.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, 3); assert_eq!(&buf, b"foo"); - - Ok(()) } let port = acquire_next_memsocket_port().into(); - connect_on_port(port)?; - connect_on_port(port)?; - - Ok(()) + connect_on_port(port).await; + connect_on_port(port).await; } - #[test] - fn simple_write_read() -> Result<()> { + #[runtime::test] + async fn simple_write_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"hello world"))?; - block_on(a.flush())?; + a.write_all(b"hello world").await?; + a.flush().await?; drop(a); let mut v = Vec::new(); - block_on(b.read_to_end(&mut v))?; + b.read_to_end(&mut v).await?; assert_eq!(v, b"hello world"); Ok(()) } - #[test] - fn partial_read() -> Result<()> { + #[runtime::test] + async fn partial_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; let mut buf = [0; 3]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn partial_read_write_both_sides() -> Result<()> { + #[runtime::test] + async fn partial_read_write_both_sides() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; - block_on(b.write_all(b"stormlight"))?; - block_on(b.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; + b.write_all(b"stormlight").await?; + b.flush().await?; let mut buf_a = [0; 5]; let mut buf_b = [0; 3]; - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"storm"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"foo"); - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"light"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"bar"); Ok(()) } - #[test] - fn many_small_writes() -> Result<()> { + #[runtime::test] + async fn many_small_writes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"words"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"of"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"radiance"))?; - block_on(a.flush())?; + a.write_all(b"words").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"of").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"radiance").await?; + a.flush().await?; drop(a); let mut buf = [0; 17]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"words of radiance"); Ok(()) } - #[test] - fn read_zero_bytes() -> Result<()> { + #[runtime::test] + async fn large_writes() -> io::Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + let large_data = vec![123u8; 1024]; + a.write_all(&large_data).await?; + a.flush().await?; + drop(a); + + let mut buf = Vec::new(); + b.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 1024); + + Ok(()) + } + + #[runtime::test] + async fn read_zero_bytes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 12]; - block_on(b.read_exact(&mut buf[0..0]))?; + b.read_exact(&mut buf[0..0]).await?; assert_eq!(buf, [0; 12]); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"way of kings"); Ok(()) } - #[test] - fn read_bytes_with_large_buffer() -> Result<()> { + #[runtime::test] + async fn read_bytes_with_large_buffer() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 20]; - let bytes_read = block_on(b.read(&mut buf))?; + let bytes_read = b.read(&mut buf).await?; assert_eq!(bytes_read, 12); assert_eq!(&buf[0..12], b"way of kings"); diff --git a/comms/src/message/outbound.rs b/comms/src/message/outbound.rs index 25f0899143..a08c9604ce 100644 --- a/comms/src/message/outbound.rs +++ b/comms/src/message/outbound.rs @@ -22,11 +22,11 @@ use crate::{message::MessageTag, peer_manager::NodeId, protocol::messaging::SendFailReason}; use bytes::Bytes; -use futures::channel::oneshot; use std::{ fmt, fmt::{Error, Formatter}, }; +use tokio::sync::oneshot; pub type MessagingReplyResult = Result<(), SendFailReason>; pub type MessagingReplyRx = oneshot::Receiver; diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 1ba104ce04..17558133f2 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -21,17 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{connection_manager::ConnectionDirection, runtime}; -use futures::{ - channel::mpsc, - io::{AsyncRead, AsyncWrite}, - stream::FusedStream, - task::Context, - SinkExt, - Stream, - StreamExt, -}; +use futures::{task::Context, Stream}; use std::{future::Future, io, pin::Pin, sync::Arc, task::Poll}; use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::{self, debug, error, event, Level}; use yamux::Mode; @@ -70,7 +67,7 @@ impl Yamux { config.set_receive_window(RECEIVE_WINDOW); let substream_counter = SubstreamCounter::new(); - let connection = yamux::Connection::new(socket, config, mode); + let connection = yamux::Connection::new(socket.compat(), config, mode); let control = Control::new(connection.control(), substream_counter.clone()); let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); @@ -88,12 +85,11 @@ impl Yamux { counter: SubstreamCounter, ) -> IncomingSubstreams where - TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, + TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static, { let shutdown = Shutdown::new(); let (incoming_tx, incoming_rx) = mpsc::channel(10); - let stream = yamux::into_stream(connection).boxed(); - let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); + let incoming = IncomingWorker::new(connection, incoming_tx, shutdown.to_signal()); runtime::task::spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } @@ -122,10 +118,6 @@ impl Yamux { pub(crate) fn substream_counter(&self) -> SubstreamCounter { self.substream_counter.clone() } - - pub fn is_terminated(&self) -> bool { - self.incoming.is_terminated() - } } #[derive(Clone)] @@ -146,7 +138,7 @@ impl Control { pub async fn open_stream(&mut self) -> Result { let stream = self.inner.open_stream().await?; Ok(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), }) } @@ -185,19 +177,13 @@ impl IncomingSubstreams { } } -impl FusedStream for IncomingSubstreams { - fn is_terminated(&self) -> bool { - self.inner.is_terminated() - } -} - impl Stream for IncomingSubstreams { type Item = Substream; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match futures::ready!(Pin::new(&mut self.inner).poll_next(cx)) { + match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(stream) => Poll::Ready(Some(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), })), None => Poll::Ready(None), @@ -213,17 +199,17 @@ impl Drop for IncomingSubstreams { #[derive(Debug)] pub struct Substream { - stream: yamux::Stream, + stream: Compat, counter_guard: CounterGuard, } -impl AsyncRead for Substream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { +impl tokio::io::AsyncRead for Substream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } } -impl AsyncWrite for Substream { +impl tokio::io::AsyncWrite for Substream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { Pin::new(&mut self.stream).poll_write(cx, buf) } @@ -232,23 +218,23 @@ impl AsyncWrite for Substream { Pin::new(&mut self.stream).poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) } } -struct IncomingWorker { - inner: S, +struct IncomingWorker { + connection: yamux::Connection, sender: mpsc::Sender, shutdown_signal: ShutdownSignal, } -impl IncomingWorker -where S: Stream> + Unpin +impl IncomingWorker +where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */ { - pub fn new(stream: S, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { + pub fn new(connection: yamux::Connection, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { Self { - inner: stream, + connection, sender, shutdown_signal, } @@ -256,37 +242,55 @@ where S: Stream> + Unpin #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))] pub async fn run(mut self) { - let mut mux_stream = self.inner.take_until(&mut self.shutdown_signal); - while let Some(result) = mux_stream.next().await { - match result { - Ok(stream) => { - event!(Level::TRACE, "yamux::stream received {}", stream); - if self.sender.send(stream).await.is_err() { - debug!( - target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" - ); - break; + loop { + tokio::select! { + biased; + + _ = &mut self.shutdown_signal => { + let mut control = self.connection.control(); + if let Err(err) = control.close().await { + error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); } - }, - Err(err) => { - event!( + break + } + + result = self.connection.next_stream() => { + match result { + Ok(Some(stream)) => { + event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() { + debug!( + target: LOG_TARGET, + "Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed" + ); + break; + } + }, + Ok(None) =>{ + debug!( + target: LOG_TARGET, + "Incoming peer substream completed. IncomingWorker exiting" + ); + break; + } + Err(err) => { + event!( Level::ERROR, "Incoming peer substream task received an error because '{}'", err ); error!( - target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err - ); - break; - }, + target: LOG_TARGET, + "Incoming peer substream task received an error because '{}'", err + ); + break; + }, + } + } } } debug!(target: LOG_TARGET, "Incoming peer substream task is shutting down"); - self.sender.close_channel(); } } @@ -321,15 +325,12 @@ mod test { runtime, runtime::task, }; - use futures::{ - future, - io::{AsyncReadExt, AsyncWriteExt}, - StreamExt, - }; use std::{io, time::Duration}; use tari_test_utils::collect_stream; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn open_substream() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; @@ -344,7 +345,7 @@ mod test { substream.write_all(msg).await.unwrap(); substream.flush().await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); }); let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) @@ -356,13 +357,16 @@ mod test { .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))?; let mut buf = Vec::new(); - let _ = future::select(substream.read_to_end(&mut buf), listener.next()).await; + tokio::select! { + _ = substream.read_to_end(&mut buf) => {}, + _ = listener.next() => {}, + }; assert_eq!(buf, msg); Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn substream_count() { const NUM_SUBSTREAMS: usize = 10; let (dialer, listener) = MemorySocket::new_pair(); @@ -396,7 +400,7 @@ mod test { assert_eq!(listener.substream_count(), 0); } - #[runtime::test_basic] + #[runtime::test] async fn close() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; @@ -425,7 +429,7 @@ mod test { assert_eq!(buf, msg); // Close the substream and then try to write to it - substream.close().await?; + substream.shutdown().await?; let result = substream.write_all(b"ignored message").await; match result { @@ -436,7 +440,7 @@ mod test { Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn send_big_message() -> io::Result<()> { #[allow(non_upper_case_globals)] static MiB: usize = 1 << 20; @@ -457,7 +461,7 @@ mod test { let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); assert_eq!(buf.len(), MSG_LEN); assert_eq!(buf, vec![0xAAu8; MSG_LEN]); @@ -476,7 +480,7 @@ mod test { let msg = vec![0xAAu8; MSG_LEN]; substream.write_all(msg.as_slice()).await?; - substream.close().await?; + substream.shutdown().await?; drop(substream); assert_eq!(incoming.substream_count(), 0); diff --git a/comms/src/noise/config.rs b/comms/src/noise/config.rs index 30cab07c48..8df40d7f05 100644 --- a/comms/src/noise/config.rs +++ b/comms/src/noise/config.rs @@ -31,11 +31,11 @@ use crate::{ }, peer_manager::NodeIdentity, }; -use futures::{AsyncRead, AsyncWrite}; use log::*; use snow::{self, params::NoiseParams}; use std::sync::Arc; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncWrite}; const LOG_TARGET: &str = "comms::noise"; pub(super) const NOISE_IX_PARAMETER: &str = "Noise_IX_25519_ChaChaPoly_BLAKE2b"; @@ -96,10 +96,15 @@ impl NoiseConfig { #[cfg(test)] mod test { use super::*; - use crate::{memsocket::MemorySocket, peer_manager::PeerFeatures, test_utils::node_identity::build_node_identity}; - use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt}; + use crate::{ + memsocket::MemorySocket, + peer_manager::PeerFeatures, + runtime, + test_utils::node_identity::build_node_identity, + }; + use futures::{future, FutureExt}; use snow::params::{BaseChoice, CipherChoice, DHChoice, HandshakePattern, HashChoice}; - use tokio::runtime::Runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; fn check_noise_params(config: &NoiseConfig) { assert_eq!(config.parameters.hash, HashChoice::Blake2b); @@ -118,39 +123,35 @@ mod test { assert_eq!(config.node_identity.public_key(), node_identity.public_key()); } - #[test] - fn upgrade_socket() { - let mut rt = Runtime::new().unwrap(); - + #[runtime::test] + async fn upgrade_socket() { let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config1 = NoiseConfig::new(node_identity1.clone()); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config2 = NoiseConfig::new(node_identity2.clone()); - rt.block_on(async move { - let (in_socket, out_socket) = MemorySocket::new_pair(); - let (mut socket_in, mut socket_out) = future::join( - config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), - config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), - ) - .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) - .await; - - let in_pubkey = socket_in.get_remote_public_key().unwrap(); - let out_pubkey = socket_out.get_remote_public_key().unwrap(); - - assert_eq!(&in_pubkey, node_identity2.public_key()); - assert_eq!(&out_pubkey, node_identity1.public_key()); - - let sample = b"Children of time"; - socket_in.write_all(sample).await.unwrap(); - socket_in.flush().await.unwrap(); - socket_in.close().await.unwrap(); - - let mut read_buf = Vec::with_capacity(16); - socket_out.read_to_end(&mut read_buf).await.unwrap(); - assert_eq!(read_buf, sample); - }); + let (in_socket, out_socket) = MemorySocket::new_pair(); + let (mut socket_in, mut socket_out) = future::join( + config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), + config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), + ) + .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) + .await; + + let in_pubkey = socket_in.get_remote_public_key().unwrap(); + let out_pubkey = socket_out.get_remote_public_key().unwrap(); + + assert_eq!(&in_pubkey, node_identity2.public_key()); + assert_eq!(&out_pubkey, node_identity1.public_key()); + + let sample = b"Children of time"; + socket_in.write_all(sample).await.unwrap(); + socket_in.flush().await.unwrap(); + socket_in.shutdown().await.unwrap(); + + let mut read_buf = Vec::with_capacity(16); + socket_out.read_to_end(&mut read_buf).await.unwrap(); + assert_eq!(read_buf, sample); } } diff --git a/comms/src/noise/socket.rs b/comms/src/noise/socket.rs index eaf02a60b0..33d35f89ca 100644 --- a/comms/src/noise/socket.rs +++ b/comms/src/noise/socket.rs @@ -26,27 +26,27 @@ //! Noise Socket +use crate::types::CommsPublicKey; use futures::ready; use log::*; use snow::{error::StateProblem, HandshakeState, TransportState}; use std::{ + cmp, convert::TryInto, io, pin::Pin, task::{Context, Poll}, }; -// use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::types::CommsPublicKey; -use futures::{io::Error, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; const LOG_TARGET: &str = "comms::noise::socket"; -const MAX_PAYLOAD_LENGTH: usize = u16::max_value() as usize; // 65535 +const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize; // 65535 // The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because // encrypted messages include a tag along with the payload. -const MAX_WRITE_BUFFER_LENGTH: usize = u16::max_value() as usize - 16; // 65519 +const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16; // 65519 /// Collection of buffers used for buffering data during the various read/write states of a /// NoiseSocket @@ -223,7 +223,12 @@ where TSocket: AsyncRead, { loop { - let n = ready!(socket.as_mut().poll_read(&mut context, &mut buf[*offset..]))?; + let mut read_buf = ReadBuf::new(&mut buf[*offset..]); + let prev_rem = read_buf.remaining(); + ready!(socket.as_mut().poll_read(&mut context, &mut read_buf))?; + let n = prev_rem + .checked_sub(read_buf.remaining()) + .expect("buffer underflow: prev_rem < read_buf.remaining()"); trace!( target: LOG_TARGET, "poll_read_exact: read {}/{} bytes", @@ -320,7 +325,7 @@ where TSocket: AsyncRead + Unpin decrypted_len, ref mut offset, } => { - let bytes_to_copy = ::std::cmp::min(decrypted_len as usize - *offset, buf.len()); + let bytes_to_copy = cmp::min(decrypted_len as usize - *offset, buf.len()); buf[..bytes_to_copy] .copy_from_slice(&self.buffers.read_decrypted[*offset..(*offset + bytes_to_copy)]); trace!( @@ -351,8 +356,11 @@ where TSocket: AsyncRead + Unpin impl AsyncRead for NoiseSocket where TSocket: AsyncRead + Unpin { - fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut [u8]) -> Poll> { - self.get_mut().poll_read(context, buf) + fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { + let slice = buf.initialize_unfilled(); + let n = futures::ready!(self.get_mut().poll_read(context, slice))?; + buf.advance(n); + Poll::Ready(Ok(())) } } @@ -501,8 +509,8 @@ where TSocket: AsyncWrite + Unpin self.get_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.socket).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.socket).poll_shutdown(cx) } } @@ -531,7 +539,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin target: LOG_TARGET, "Noise handshake failed because '{:?}'. Closing socket.", err ); - self.socket.close().await?; + self.socket.shutdown().await?; Err(err) }, } @@ -644,7 +652,6 @@ mod test { use futures::future::join; use snow::{params::NoiseParams, Builder, Error, Keypair}; use std::io; - use tokio::runtime::Runtime; async fn build_test_connection( ) -> Result<((Keypair, Handshake), (Keypair, Handshake)), Error> { @@ -707,7 +714,7 @@ mod test { dialer_socket.write_all(b" ").await?; dialer_socket.write_all(b"archive").await?; dialer_socket.flush().await?; - dialer_socket.close().await?; + dialer_socket.shutdown().await?; let mut buf = Vec::new(); listener_socket.read_to_end(&mut buf).await?; @@ -745,51 +752,60 @@ mod test { Ok(()) } - #[test] - fn u16_max_writes() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn u16_max_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH + 1]; + a.write_all(&buf_send).await?; + a.flush().await?; - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await?; - assert_eq!(&buf_receive[..], &buf_send[..]); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH + 1]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - Ok(()) - }) + Ok(()) } - #[test] - fn unexpected_eof() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn larger_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH * 2 + 1024]; + a.write_all(&buf_send).await?; + a.flush().await?; - a.socket.close().await.unwrap(); - drop(a); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH * 2 + 1024]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await.unwrap(); - assert_eq!(&buf_receive[..], &buf_send[..]); + Ok(()) + } + + #[runtime::test] + async fn unexpected_eof() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let err = b.read_exact(&mut buf_receive).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - Ok(()) - }) + let buf_send = [1; MAX_PAYLOAD_LENGTH]; + a.write_all(&buf_send).await?; + a.flush().await?; + + a.socket.shutdown().await.unwrap(); + drop(a); + + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; + b.read_exact(&mut buf_receive).await.unwrap(); + assert_eq!(&buf_receive[..], &buf_send[..]); + + let err = b.read_exact(&mut buf_receive).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) } } diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index 3828c82ab5..d7df51a2ea 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -337,7 +337,7 @@ mod test { peer } - #[runtime::test_basic] + #[runtime::test] async fn get_broadcast_identities() { // Create peer manager with random peers let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); @@ -446,7 +446,7 @@ mod test { assert_ne!(identities1, identities2); } - #[runtime::test_basic] + #[runtime::test] async fn calc_region_threshold() { let n = 5; // Create peer manager with random peers @@ -514,7 +514,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn closest_peers() { let n = 5; // Create peer manager with random peers @@ -548,7 +548,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn add_or_update_online_peer() { let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); let mut peer = create_test_peer(false, PeerFeatures::COMMUNICATION_NODE); diff --git a/comms/src/pipeline/builder.rs b/comms/src/pipeline/builder.rs index 40a38d10a3..9ae90fabec 100644 --- a/comms/src/pipeline/builder.rs +++ b/comms/src/pipeline/builder.rs @@ -24,8 +24,8 @@ use crate::{ message::{InboundMessage, OutboundMessage}, pipeline::SinkService, }; -use futures::channel::mpsc; use thiserror::Error; +use tokio::sync::mpsc; use tower::Service; const DEFAULT_MAX_CONCURRENT_TASKS: usize = 50; @@ -99,9 +99,7 @@ where TOutSvc: Service + Clone + Send + 'static, TInSvc: Service + Clone + Send + 'static, { - fn build_outbound( - &mut self, - ) -> Result, TOutSvc>, PipelineBuilderError> { + fn build_outbound(&mut self) -> Result, PipelineBuilderError> { let (out_sender, out_receiver) = mpsc::channel(self.outbound_buffer_size); let in_receiver = self @@ -137,9 +135,9 @@ where } } -pub struct OutboundPipelineConfig { +pub struct OutboundPipelineConfig { /// Messages read from this stream are passed to the pipeline - pub in_receiver: TInStream, + pub in_receiver: mpsc::Receiver, /// Receiver of `OutboundMessage`s coming from the pipeline pub out_receiver: mpsc::Receiver, /// The pipeline (`tower::Service`) to run for each in_stream message @@ -149,7 +147,7 @@ pub struct OutboundPipelineConfig { pub struct Config { pub max_concurrent_inbound_tasks: usize, pub inbound: TInSvc, - pub outbound: OutboundPipelineConfig, TOutSvc>, + pub outbound: OutboundPipelineConfig, } #[derive(Debug, Error)] diff --git a/comms/src/pipeline/inbound.rs b/comms/src/pipeline/inbound.rs index 0b2116bc37..1f135640a7 100644 --- a/comms/src/pipeline/inbound.rs +++ b/comms/src/pipeline/inbound.rs @@ -21,10 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::bounded_executor::BoundedExecutor; -use futures::{future::FusedFuture, stream::FusedStream, Stream, StreamExt}; +use futures::future::FusedFuture; use log::*; use std::fmt::Display; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::inbound"; @@ -33,22 +34,26 @@ const LOG_TARGET: &str = "comms::pipeline::inbound"; /// The difference between this can ServiceExt::call_all is /// that ServicePipeline doesn't keep the result of the service /// call and that it spawns a task for each incoming item. -pub struct Inbound { +pub struct Inbound { executor: BoundedExecutor, service: TSvc, - stream: TStream, + stream: mpsc::Receiver, shutdown_signal: ShutdownSignal, } -impl Inbound +impl Inbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TSvc: Service + Clone + Send + 'static, + TMsg: Send + 'static, + TSvc: Service + Clone + Send + 'static, TSvc::Error: Display + Send, TSvc::Future: Send, { - pub fn new(executor: BoundedExecutor, stream: TStream, service: TSvc, shutdown_signal: ShutdownSignal) -> Self { + pub fn new( + executor: BoundedExecutor, + stream: mpsc::Receiver, + service: TSvc, + shutdown_signal: ShutdownSignal, + ) -> Self { Self { executor, service, @@ -59,7 +64,7 @@ where } pub async fn run(mut self) { - while let Some(item) = self.stream.next().await { + while let Some(item) = self.stream.recv().await { // Check if the shutdown signal has been triggered. // If there are messages in the stream, drop them. Otherwise the stream is empty, // it will return None and the while loop will end. @@ -100,21 +105,25 @@ where mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, future, stream}; + use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; - use tari_test_utils::collect_stream; - use tokio::{runtime::Handle, time}; + use tari_test_utils::collect_recv; + use tokio::{sync::mpsc, time}; use tower::service_fn; - #[runtime::test_basic] + #[runtime::test] async fn run() { let items = vec![1, 2, 3, 4, 5, 6]; - let stream = stream::iter(items.clone()).fuse(); + let (tx, mut stream) = mpsc::channel(items.len()); + for i in items.clone() { + tx.send(i).await.unwrap(); + } + stream.close(); - let (mut out_tx, mut out_rx) = mpsc::channel(items.len()); + let (out_tx, mut out_rx) = mpsc::channel(items.len()); - let executor = Handle::current(); + let executor = runtime::current(); let shutdown = Shutdown::new(); let pipeline = Inbound::new( BoundedExecutor::new(executor.clone(), 1), @@ -125,9 +134,10 @@ mod test { }), shutdown.to_signal(), ); + let spawned_task = executor.spawn(pipeline.run()); - let received = collect_stream!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); + let received = collect_recv!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); assert!(received.iter().all(|i| items.contains(i))); // Check that this task ends because the stream has closed diff --git a/comms/src/pipeline/mod.rs b/comms/src/pipeline/mod.rs index 1039374c65..a2a057354b 100644 --- a/comms/src/pipeline/mod.rs +++ b/comms/src/pipeline/mod.rs @@ -44,7 +44,4 @@ pub(crate) use inbound::Inbound; mod outbound; pub(crate) use outbound::Outbound; -mod translate_sink; -pub use translate_sink::TranslateSink; - pub type PipelineError = anyhow::Error; diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index c860166ad0..54facdb92b 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -25,34 +25,33 @@ use crate::{ pipeline::builder::OutboundPipelineConfig, protocol::messaging::MessagingRequest, }; -use futures::{channel::mpsc, future, future::Either, stream::FusedStream, SinkExt, Stream, StreamExt}; +use futures::future::Either; use log::*; use std::fmt::Display; -use tokio::runtime; +use tokio::{runtime, sync::mpsc}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::outbound"; -pub struct Outbound { +pub struct Outbound { /// Executor used to spawn a pipeline for each received item on the stream executor: runtime::Handle, /// Outbound pipeline configuration containing the pipeline and it's in and out streams - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, /// Request sender for Messaging messaging_request_tx: mpsc::Sender, } -impl Outbound +impl Outbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TPipeline: Service + Clone + Send + 'static, + TItem: Send + 'static, + TPipeline: Service + Clone + Send + 'static, TPipeline::Error: Display + Send, TPipeline::Future: Send, { pub fn new( executor: runtime::Handle, - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, messaging_request_tx: mpsc::Sender, ) -> Self { Self { @@ -64,10 +63,13 @@ where pub async fn run(mut self) { loop { - let either = future::select(self.config.in_receiver.next(), self.config.out_receiver.next()).await; + let either = tokio::select! { + next = self.config.in_receiver.recv() => Either::Left(next), + next = self.config.out_receiver.recv() => Either::Right(next) + }; match either { // Pipeline IN received a message. Spawn a new task for the pipeline - Either::Left((Some(msg), _)) => { + Either::Left(Some(msg)) => { let pipeline = self.config.pipeline.clone(); self.executor.spawn(async move { if let Err(err) = pipeline.oneshot(msg).await { @@ -76,7 +78,7 @@ where }); }, // Pipeline IN channel closed - Either::Left((None, _)) => { + Either::Left(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the in channel closed" @@ -84,7 +86,7 @@ where break; }, // Pipeline OUT received a message - Either::Right((Some(out_msg), _)) => { + Either::Right(Some(out_msg)) => { if self.messaging_request_tx.is_closed() { // MessagingRequest channel closed break; @@ -92,7 +94,7 @@ where self.send_messaging_request(out_msg).await; }, // Pipeline OUT channel closed - Either::Right((None, _)) => { + Either::Right(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the out channel closed" @@ -117,19 +119,22 @@ where #[cfg(test)] mod test { use super::*; - use crate::{pipeline::SinkService, runtime}; + use crate::{pipeline::SinkService, runtime, utils}; use bytes::Bytes; - use futures::stream; use std::time::Duration; - use tari_test_utils::{collect_stream, unpack_enum}; + use tari_test_utils::{collect_recv, unpack_enum}; use tokio::{runtime::Handle, time}; - #[runtime::test_basic] + #[runtime::test] async fn run() { const NUM_ITEMS: usize = 10; - let items = - (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))); - let stream = stream::iter(items).fuse(); + let (tx, in_receiver) = mpsc::channel(NUM_ITEMS); + utils::mpsc::send_all( + &tx, + (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))), + ) + .await + .unwrap(); let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); let (msg_tx, mut msg_rx) = mpsc::channel(NUM_ITEMS); let executor = Handle::current(); @@ -137,7 +142,7 @@ mod test { let pipeline = Outbound::new( executor.clone(), OutboundPipelineConfig { - in_receiver: stream, + in_receiver, out_receiver: out_rx, pipeline: SinkService::new(out_tx), }, @@ -146,7 +151,8 @@ mod test { let spawned_task = executor.spawn(pipeline.run()); - let requests = collect_stream!(msg_rx, take = NUM_ITEMS, timeout = Duration::from_millis(5)); + msg_rx.close(); + let requests = collect_recv!(msg_rx, timeout = Duration::from_millis(5)); for req in requests { unpack_enum!(MessagingRequest::SendMessage(_o) = req); } diff --git a/comms/src/pipeline/sink.rs b/comms/src/pipeline/sink.rs index a455aaf320..1b524e92a5 100644 --- a/comms/src/pipeline/sink.rs +++ b/comms/src/pipeline/sink.rs @@ -21,8 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::PipelineError; -use futures::{future::BoxFuture, task::Context, FutureExt, Sink, SinkExt}; -use std::{pin::Pin, task::Poll}; +use futures::{future::BoxFuture, task::Context, FutureExt}; +use std::task::Poll; use tower::Service; /// A service which forwards and messages it gets to the given Sink @@ -35,22 +35,24 @@ impl SinkService { } } -impl Service for SinkService -where - T: Send + 'static, - TSink: Sink + Unpin + Clone + Send + 'static, - TSink::Error: Into + Send + 'static, +impl Service for SinkService> +where T: Send + 'static { type Error = PipelineError; type Future = BoxFuture<'static, Result>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, item: T) -> Self::Future { - let mut sink = self.0.clone(); - async move { sink.send(item).await.map_err(Into::into) }.boxed() + let sink = self.0.clone(); + async move { + sink.send(item) + .await + .map_err(|_| anyhow::anyhow!("sink closed in sink service")) + } + .boxed() } } diff --git a/comms/src/pipeline/translate_sink.rs b/comms/src/pipeline/translate_sink.rs index 6a2bcad56a..606c038299 100644 --- a/comms/src/pipeline/translate_sink.rs +++ b/comms/src/pipeline/translate_sink.rs @@ -93,9 +93,10 @@ where F: FnMut(I) -> Option mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, SinkExt, StreamExt}; + use futures::{SinkExt, StreamExt}; + use tokio::sync::mpsc; - #[runtime::test_basic] + #[runtime::test] async fn check_translates() { let (tx, mut rx) = mpsc::channel(1); diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 2c4eba1db5..f984117bb8 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -20,19 +20,21 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - compat::IoCompat, connection_manager::ConnectionDirection, message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, protocol::{NodeNetworkInfo, ProtocolError, ProtocolId, ProtocolNegotiation}, }; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use log::*; use prost::Message; use std::{io, time::Duration}; use thiserror::Error; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing; @@ -79,7 +81,7 @@ where debug_assert_eq!(proto, IDENTITY_PROTOCOL); // Create length-delimited frame codec - let framed = Framed::new(IoCompat::new(socket), LengthDelimitedCodec::new()); + let framed = Framed::new(socket, LengthDelimitedCodec::new()); let (mut sink, mut stream) = framed.split(); let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect(); @@ -136,8 +138,8 @@ pub enum IdentityProtocolError { ProtocolVersionMismatch, } -impl From for IdentityProtocolError { - fn from(_: time::Elapsed) -> Self { +impl From for IdentityProtocolError { + fn from(_: time::error::Elapsed) -> Self { IdentityProtocolError::Timeout } } @@ -172,7 +174,7 @@ mod test { }; use futures::{future, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn identity_exchange() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); @@ -221,7 +223,7 @@ mod test { assert_eq!(identity2.addresses, vec![node_identity2.public_address().to_vec()]); } - #[runtime::test_basic] + #[runtime::test] async fn fail_cases() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index 6f078cec23..91d0e786ba 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -20,10 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{connection_manager::PeerConnectionError, peer_manager::PeerManagerError, protocol::ProtocolError}; -use futures::channel::mpsc; +use crate::{ + connection_manager::PeerConnectionError, + message::OutboundMessage, + peer_manager::PeerManagerError, + protocol::ProtocolError, +}; use std::io; use thiserror::Error; +use tokio::sync::mpsc; #[derive(Debug, Error)] pub enum InboundMessagingError { @@ -46,7 +51,7 @@ pub enum MessagingProtocolError { #[error("IO Error: {0}")] Io(#[from] io::Error), #[error("Sender error: {0}")] - SenderError(#[from] mpsc::SendError), + SenderError(#[from] mpsc::error::SendError), #[error("Stream closed due to inactivity")] Inactivity, } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 241a152a5b..f216ddd04e 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -34,8 +34,8 @@ use crate::{ runtime, runtime::task, }; -use futures::channel::mpsc; use std::fmt; +use tokio::sync::mpsc; use tower::Service; /// Buffer size for inbound messages from _all_ peers. This should be large enough to buffer quite a few incoming diff --git a/comms/src/protocol/messaging/forward.rs b/comms/src/protocol/messaging/forward.rs new file mode 100644 index 0000000000..ce035e8fb8 --- /dev/null +++ b/comms/src/protocol/messaging/forward.rs @@ -0,0 +1,110 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Copied from futures rs + +use futures::{ + future::{FusedFuture, Future}, + ready, + stream::{Fuse, StreamExt}, + task::{Context, Poll}, + Sink, + Stream, + TryStream, +}; +use pin_project::pin_project; +use std::pin::Pin; + +/// Future for the [`forward`](super::StreamExt::forward) method. +#[pin_project(project = ForwardProj)] +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Forward { + #[pin] + sink: Option, + #[pin] + stream: Fuse, + buffered_item: Option, +} + +impl Forward +where St: TryStream +{ + pub(crate) fn new(stream: St, sink: Si) -> Self { + Self { + sink: Some(sink), + stream: stream.fuse(), + buffered_item: None, + } + } +} + +impl FusedFuture for Forward +where + Si: Sink, + St: Stream>, +{ + fn is_terminated(&self) -> bool { + self.sink.is_none() + } +} + +impl Future for Forward +where + Si: Sink, + St: Stream>, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let ForwardProj { + mut sink, + mut stream, + buffered_item, + } = self.project(); + let mut si = sink.as_mut().as_pin_mut().expect("polled `Forward` after completion"); + + loop { + // If we've got an item buffered already, we need to write it to the + // sink before we can do anything else + if buffered_item.is_some() { + ready!(si.as_mut().poll_ready(cx))?; + si.as_mut().start_send(buffered_item.take().unwrap())?; + } + + match stream.as_mut().poll_next(cx)? { + Poll::Ready(Some(item)) => { + *buffered_item = Some(item); + }, + Poll::Ready(None) => { + ready!(si.poll_close(cx))?; + sink.set(None); + return Poll::Ready(Ok(())); + }, + Poll::Pending => { + ready!(si.poll_flush(cx))?; + return Poll::Pending; + }, + } + } + } +} diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs index 643b07fc45..e65f81f8ec 100644 --- a/comms/src/protocol/messaging/inbound.rs +++ b/comms/src/protocol/messaging/inbound.rs @@ -21,15 +21,18 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - common::rate_limit::RateLimit, message::InboundMessage, peer_manager::NodeId, protocol::messaging::{MessagingEvent, MessagingProtocol}, + rate_limit::RateLimit, }; -use futures::{channel::mpsc, future::Either, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt}; use log::*; use std::{sync::Arc, time::Duration}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; @@ -61,7 +64,7 @@ impl InboundMessaging { } } - pub async fn run(mut self, socket: S) + pub async fn run(self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { let peer = &self.peer; debug!( @@ -70,48 +73,40 @@ impl InboundMessaging { peer.short_str() ); - let (mut sink, stream) = MessagingProtocol::framed(socket).split(); - - if let Err(err) = sink.close().await { - debug!( - target: LOG_TARGET, - "Error closing sink half for peer `{}`: {}", - peer.short_str(), - err - ); - } - let stream = stream.rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); + let stream = + MessagingProtocol::framed(socket).rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); - let mut stream = match self.inactivity_timeout { - Some(timeout) => Either::Left(tokio::stream::StreamExt::timeout(stream, timeout)), + let stream = match self.inactivity_timeout { + Some(timeout) => Either::Left(tokio_stream::StreamExt::timeout(stream, timeout)), None => Either::Right(stream.map(Ok)), }; + tokio::pin!(stream); while let Some(result) = stream.next().await { match result { Ok(Ok(raw_msg)) => { - let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.clone().freeze()); + let msg_len = raw_msg.len(); + let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); debug!( target: LOG_TARGET, "Received message {} from peer '{}' ({} bytes)", inbound_msg.tag, peer.short_str(), - raw_msg.len() + msg_len ); let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { + let tag = err.0.tag; warn!( target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", + "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", + tag, peer.short_str(), - err ); - if err.is_disconnected() { - break; - } + break; } let _ = self.messaging_events_tx.send(Arc::new(event)); diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 88fca6af05..1fa347b3a2 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -27,9 +27,9 @@ mod extension; pub use extension::MessagingProtocolExtension; mod error; +mod forward; mod inbound; mod outbound; - mod protocol; pub use protocol::{ MessagingEvent, diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 9d47895338..f48a9215b3 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -29,13 +29,13 @@ use crate::{ peer_manager::NodeId, protocol::messaging::protocol::MESSAGING_PROTOCOL, }; -use futures::{channel::mpsc, future::Either, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt, TryStreamExt}; use log::*; use std::{ io, time::{Duration, Instant}, }; -use tokio::stream as tokio_stream; +use tokio::sync::mpsc as tokiompsc; use tracing::{event, span, Instrument, Level}; const LOG_TARGET: &str = "comms::protocol::messaging::outbound"; @@ -46,8 +46,8 @@ const MAX_SEND_RETRIES: usize = 1; pub struct OutboundMessaging { connectivity: ConnectivityRequester, - request_rx: mpsc::UnboundedReceiver, - messaging_events_tx: mpsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, peer_node_id: NodeId, inactivity_timeout: Option, } @@ -55,8 +55,8 @@ pub struct OutboundMessaging { impl OutboundMessaging { pub fn new( connectivity: ConnectivityRequester, - messaging_events_tx: mpsc::Sender, - request_rx: mpsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, peer_node_id: NodeId, inactivity_timeout: Option, ) -> Self { @@ -82,7 +82,7 @@ impl OutboundMessaging { self.peer_node_id.short_str() ); let peer_node_id = self.peer_node_id.clone(); - let mut messaging_events_tx = self.messaging_events_tx.clone(); + let messaging_events_tx = self.messaging_events_tx.clone(); match self.run_inner().await { Ok(_) => { event!( @@ -265,7 +265,7 @@ impl OutboundMessaging { ); let substream = substream.stream; - let (sink, _) = MessagingProtocol::framed(substream).split(); + let framed = MessagingProtocol::framed(substream); let Self { request_rx, @@ -273,30 +273,30 @@ impl OutboundMessaging { .. } = self; + // Convert unbounded channel to a stream + let stream = futures::stream::unfold(request_rx, |mut rx| async move { + let v = rx.recv().await; + v.map(|v| (v, rx)) + }); + let stream = match inactivity_timeout { Some(timeout) => { - let s = tokio_stream::StreamExt::timeout(request_rx, timeout).map(|r| match r { - Ok(s) => Ok(s), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - MessagingProtocolError::Inactivity, - )), - }); + let s = tokio_stream::StreamExt::timeout(stream, timeout) + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, MessagingProtocolError::Inactivity)); Either::Left(s) }, - None => Either::Right(request_rx.map(Ok)), + None => Either::Right(stream.map(Ok)), }; - stream - .map(|msg| { - msg.map(|mut out_msg| { - event!(Level::DEBUG, "Message buffered for sending {}", out_msg); - out_msg.reply_success(); - out_msg.body - }) + let stream = stream.map(|msg| { + msg.map(|mut out_msg| { + event!(Level::DEBUG, "Message buffered for sending {}", out_msg); + out_msg.reply_success(); + out_msg.body }) - .forward(sink) - .await?; + }); + + super::forward::Forward::new(stream, framed).await?; debug!( target: LOG_TARGET, @@ -310,7 +310,7 @@ impl OutboundMessaging { // Close the request channel so that we can read all the remaining messages and flush them // to a failed event self.request_rx.close(); - while let Some(mut out_msg) = self.request_rx.next().await { + while let Some(mut out_msg) = self.request_rx.recv().await { out_msg.reply_fail(reason); let _ = self .messaging_events_tx diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index 1f6fe029ab..988b4ada21 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -22,7 +22,6 @@ use super::error::MessagingProtocolError; use crate::{ - compat::IoCompat, connectivity::{ConnectivityEvent, ConnectivityRequester}, framing, message::{InboundMessage, MessageTag, OutboundMessage}, @@ -36,7 +35,6 @@ use crate::{ runtime::task, }; use bytes::Bytes; -use futures::{channel::mpsc, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{ collections::{hash_map::Entry, HashMap}, @@ -46,7 +44,10 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; const LOG_TARGET: &str = "comms::protocol::messaging"; @@ -106,13 +107,13 @@ impl fmt::Display for MessagingEvent { pub struct MessagingProtocol { config: MessagingConfig, connectivity: ConnectivityRequester, - proto_notification: Fuse>>, + proto_notification: mpsc::Receiver>, active_queues: HashMap>, - request_rx: Fuse>, + request_rx: mpsc::Receiver, messaging_events_tx: MessagingEventSender, inbound_message_tx: mpsc::Sender, internal_messaging_event_tx: mpsc::Sender, - internal_messaging_event_rx: Fuse>, + internal_messaging_event_rx: mpsc::Receiver, shutdown_signal: ShutdownSignal, complete_trigger: Shutdown, } @@ -133,11 +134,11 @@ impl MessagingProtocol { Self { config, connectivity, - proto_notification: proto_notification.fuse(), - request_rx: request_rx.fuse(), + proto_notification, + request_rx, active_queues: Default::default(), messaging_events_tx, - internal_messaging_event_rx: internal_messaging_event_rx.fuse(), + internal_messaging_event_rx, internal_messaging_event_tx, inbound_message_tx, shutdown_signal, @@ -151,15 +152,15 @@ impl MessagingProtocol { pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut connectivity_events = self.connectivity.get_event_subscription(); loop { - futures::select! { - event = self.internal_messaging_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_messaging_event_rx.recv() => { self.handle_internal_messaging_event(event).await; }, - req = self.request_rx.select_next_some() => { + Some(req) = self.request_rx.recv() => { if let Err(err) = self.handle_request(req).await { error!( target: LOG_TARGET, @@ -169,17 +170,17 @@ impl MessagingProtocol { } }, - event = connectivity_events.select_next_some() => { + event = connectivity_events.recv() => { if let Ok(event) = event { self.handle_connectivity_event(&event); } } - notification = self.proto_notification.select_next_some() => { + Some(notification) = self.proto_notification.recv() => { self.handle_protocol_notification(notification).await; }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "MessagingProtocol is shutting down because the shutdown signal was triggered"); break; } @@ -188,7 +189,7 @@ impl MessagingProtocol { } #[inline] - pub fn framed(socket: TSubstream) -> Framed, LengthDelimitedCodec> + pub fn framed(socket: TSubstream) -> Framed where TSubstream: AsyncRead + AsyncWrite + Unpin { framing::canonical(socket, MAX_FRAME_LENGTH) } @@ -198,11 +199,9 @@ impl MessagingProtocol { #[allow(clippy::single_match)] match event { PeerConnectionWillClose(node_id, _) => { - // If the peer connection will close, cut off the pipe to send further messages. - // Any messages in the channel will be sent (hopefully) before the connection is disconnected. - if let Some(sender) = self.active_queues.remove(node_id) { - sender.close_channel(); - } + // If the peer connection will close, cut off the pipe to send further messages by dropping the sender. + // Any messages in the channel may be sent before the connection is disconnected. + let _ = self.active_queues.remove(node_id); }, _ => {}, } @@ -263,7 +262,7 @@ impl MessagingProtocol { let sender = Self::spawn_outbound_handler( self.connectivity.clone(), self.internal_messaging_event_tx.clone(), - peer_node_id.clone(), + peer_node_id, self.config.inactivity_timeout, ); break entry.insert(sender); @@ -273,7 +272,7 @@ impl MessagingProtocol { debug!(target: LOG_TARGET, "Sending message {}", out_msg); let tag = out_msg.tag; - match sender.send(out_msg).await { + match sender.send(out_msg) { Ok(_) => { debug!(target: LOG_TARGET, "Message ({}) dispatched to outbound handler", tag,); Ok(()) @@ -294,7 +293,7 @@ impl MessagingProtocol { peer_node_id: NodeId, inactivity_timeout: Option, ) -> mpsc::UnboundedSender { - let (msg_tx, msg_rx) = mpsc::unbounded(); + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); let outbound_messaging = OutboundMessaging::new(connectivity, events_tx, msg_rx, peer_node_id, inactivity_timeout); task::spawn(outbound_messaging.run()); diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index e954af31f2..5d066eb4c2 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -49,18 +49,16 @@ use crate::{ types::{CommsDatabase, CommsPublicKey}, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - SinkExt, - StreamExt, -}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::rngs::OsRng; use std::{io, sync::Arc, time::Duration}; use tari_crypto::keys::PublicKey; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, time}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1"); @@ -110,9 +108,9 @@ async fn spawn_messaging_protocol() -> ( ) } -#[runtime::test_basic] +#[runtime::test] async fn new_inbound_substream_handling() { - let (peer_manager, _, _, mut proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let expected_node_id = node_id::random(); @@ -148,7 +146,7 @@ async fn new_inbound_substream_handling() { framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); - let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.next()) + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() .unwrap(); @@ -156,19 +154,18 @@ async fn new_inbound_substream_handling() { assert_eq!(in_msg.body, TEST_MSG1); let expected_tag = in_msg.tag; - let event = time::timeout(Duration::from_secs(5), events_rx.next()) + let event = time::timeout(Duration::from_secs(5), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_request() { - let (_, node_identity, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, node_identity, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -192,9 +189,9 @@ async fn send_message_request() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_dial_failed() { - let (_, _, conn_manager_mock, _, mut request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); let out_msg = OutboundMessage::new(node_id, TEST_MSG1.clone()); @@ -202,7 +199,7 @@ async fn send_message_dial_failed() { // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); - let event = event_tx.next().await.unwrap().unwrap(); + let event = event_tx.recv().await.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); assert_eq!(out_msg.tag, expected_out_msg_tag); @@ -212,7 +209,7 @@ async fn send_message_dial_failed() { assert!(calls.iter().all(|evt| evt.starts_with("DialPeer"))); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_substream_bulk_failure() { const NUM_MSGS: usize = 10; let (_, node_identity, conn_manager_mock, _, mut request_tx, _, mut events_rx, _shutdown) = @@ -258,19 +255,18 @@ async fn send_message_substream_bulk_failure() { } // Check that the outbound handler closed - let event = time::timeout(Duration::from_secs(10), events_rx.next()) + let event = time::timeout(Duration::from_secs(10), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); assert_eq!(node_id, peer_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests() { const NUM_MSGS: usize = 100; - let (_, _, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -315,10 +311,10 @@ async fn many_concurrent_send_message_requests() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests_that_fail() { const NUM_MSGS: usize = 100; - let (_, _, _, _, mut request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, _, _, request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let node_id2 = node_id::random(); @@ -339,10 +335,9 @@ async fn many_concurrent_send_message_requests_that_fail() { } // Check that we got message success events - let events = collect_stream!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let events = collect_recv!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert_eq!(events.len(), NUM_MSGS); for event in events { - let event = event.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); // Assert that each tag is emitted only once @@ -357,7 +352,7 @@ async fn many_concurrent_send_message_requests_that_fail() { assert_eq!(msg_tags.len(), 0); } -#[runtime::test_basic] +#[runtime::test] async fn inactivity_timeout() { let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); let (inbound_msg_tx, mut inbound_msg_rx) = mpsc::channel(5); @@ -381,13 +376,13 @@ async fn inactivity_timeout() { let mut framed = MessagingProtocol::framed(socket_out); for _ in 0..5u8 { framed.send(Bytes::from_static(b"some message")).await.unwrap(); - time::delay_for(Duration::from_millis(1)).await; + time::sleep(Duration::from_millis(1)).await; } - time::delay_for(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(10)).await; let err = framed.send(Bytes::from_static(b"another message")).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - let _ = collect_stream!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); + let _ = collect_recv!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); } diff --git a/comms/src/protocol/negotiation.rs b/comms/src/protocol/negotiation.rs index 5178d790e6..326910dc99 100644 --- a/comms/src/protocol/negotiation.rs +++ b/comms/src/protocol/negotiation.rs @@ -23,9 +23,9 @@ use super::{ProtocolError, ProtocolId}; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use log::*; use std::convert::TryInto; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const LOG_TARGET: &str = "comms::connection_manager::protocol"; @@ -204,7 +204,7 @@ mod test { use futures::future; use tari_test_utils::unpack_enum; - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -229,7 +229,7 @@ mod test { assert_eq!(out_proto.unwrap(), ProtocolId::from_static(b"A")); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -254,7 +254,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolOutboundNegotiationFailed(_s) = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_max_rounds() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -279,7 +279,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -300,7 +300,7 @@ mod test { out_proto.unwrap(); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 936ef15f34..14d253196e 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -32,8 +32,8 @@ use crate::{ }, Substream, }; -use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +use tokio::sync::mpsc; pub type ProtocolNotificationTx = mpsc::Sender>; pub type ProtocolNotificationRx = mpsc::Receiver>; @@ -143,7 +143,6 @@ impl ProtocolExtension for Protocols { mod test { use super::*; use crate::runtime; - use futures::StreamExt; use tari_test_utils::unpack_enum; #[test] @@ -160,7 +159,7 @@ mod test { assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } - #[runtime::test_basic] + #[runtime::test] async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; @@ -172,12 +171,12 @@ mod test { .await .unwrap(); - let notification = rx.next().await.unwrap(); + let notification = rx.recv().await.unwrap(); unpack_enum!(ProtocolEvent::NewInboundSubstream(peer_id, _s) = notification.event); assert_eq!(peer_id, NodeId::new()); } - #[runtime::test_basic] + #[runtime::test] async fn notify_fail_not_registered() { let mut protocols = Protocols::<()>::new(); diff --git a/comms/src/protocol/rpc/body.rs b/comms/src/protocol/rpc/body.rs index 6079508729..e563d6483e 100644 --- a/comms/src/protocol/rpc/body.rs +++ b/comms/src/protocol/rpc/body.rs @@ -27,7 +27,6 @@ use crate::{ }; use bytes::BytesMut; use futures::{ - channel::mpsc, ready, stream::BoxStream, task::{Context, Poll}, @@ -37,6 +36,7 @@ use futures::{ use pin_project::pin_project; use prost::bytes::Buf; use std::{fmt, marker::PhantomData, pin::Pin}; +use tokio::sync::mpsc; pub trait IntoBody { fn into_body(self) -> Body; @@ -205,8 +205,8 @@ impl Buf for BodyBytes { self.0.as_ref().map(Buf::remaining).unwrap_or(0) } - fn bytes(&self) -> &[u8] { - self.0.as_ref().map(Buf::bytes).unwrap_or(&[]) + fn chunk(&self) -> &[u8] { + self.0.as_ref().map(Buf::chunk).unwrap_or(&[]) } fn advance(&mut self, cnt: usize) { @@ -227,7 +227,7 @@ impl Streaming { } pub fn empty() -> Self { - let (_, rx) = mpsc::channel(0); + let (_, rx) = mpsc::channel(1); Self { inner: rx } } @@ -240,7 +240,7 @@ impl Stream for Streaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(result) => { let result = result.map(|msg| msg.to_encoded_bytes().into()); Poll::Ready(Some(result)) @@ -275,7 +275,7 @@ impl Stream for ClientStreaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(Ok(resp)) => { // The streaming protocol dictates that an empty finish flag MUST be sent to indicate a terminated // stream. This empty response need not be emitted to downsteam consumers. @@ -298,7 +298,7 @@ mod test { use futures::{stream, StreamExt}; use prost::Message; - #[runtime::test_basic] + #[runtime::test] async fn single_body() { let mut body = Body::single(123u32.to_encoded_bytes()); let bytes = body.next().await.unwrap().unwrap(); @@ -306,7 +306,7 @@ mod test { assert_eq!(u32::decode(bytes).unwrap(), 123u32); } - #[runtime::test_basic] + #[runtime::test] async fn streaming_body() { let body = Body::streaming(stream::repeat(Bytes::new()).map(Ok).take(10)); let body = body.collect::>().await; diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 2806366f9f..cfc5c5a9c1 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -41,11 +41,8 @@ use crate::{ }; use bytes::Bytes; use futures::{ - channel::{mpsc, oneshot}, future::{BoxFuture, Either}, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, SinkExt, StreamExt, @@ -58,9 +55,15 @@ use std::{ fmt, future::Future, marker::PhantomData, + sync::Arc, time::{Duration, Instant}, }; -use tokio::time; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot, Mutex}, + time, +}; use tower::{Service, ServiceExt}; use tracing::{event, span, Instrument, Level}; @@ -82,14 +85,16 @@ impl RpcClient { TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (request_tx, request_rx) = mpsc::channel(1); - let connector = ClientConnector::new(request_tx); + let shutdown = Shutdown::new(); + let shutdown_signal = shutdown.to_signal(); + let connector = ClientConnector::new(request_tx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); let tracing_id = tracing::Span::current().id(); task::spawn({ let span = span!(Level::TRACE, "start_rpc_worker"); span.follows_from(tracing_id); - RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name) + RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name, shutdown_signal) .run() .instrument(span) }); @@ -110,7 +115,7 @@ impl RpcClient { let request = BaseRequest::new(method.into(), req_bytes.into()); let mut resp = self.call_inner(request).await?; - let resp = resp.next().await.ok_or(RpcError::ServerClosedRequest)??; + let resp = resp.recv().await.ok_or(RpcError::ServerClosedRequest)??; let resp = R::decode(resp.into_message())?; Ok(resp) @@ -132,8 +137,8 @@ impl RpcClient { } /// Close the RPC session. Any subsequent calls will error. - pub fn close(&mut self) { - self.connector.close() + pub async fn close(&mut self) { + self.connector.close().await; } pub fn is_connected(&self) -> bool { @@ -269,15 +274,20 @@ impl Default for RpcClientConfig { #[derive(Clone)] pub struct ClientConnector { inner: mpsc::Sender, + shutdown: Arc>, } impl ClientConnector { - pub(self) fn new(sender: mpsc::Sender) -> Self { - Self { inner: sender } + pub(self) fn new(sender: mpsc::Sender, shutdown: Shutdown) -> Self { + Self { + inner: sender, + shutdown: Arc::new(Mutex::new(shutdown)), + } } - pub fn close(&mut self) { - self.inner.close_channel(); + pub async fn close(&mut self) { + let mut lock = self.shutdown.lock().await; + lock.trigger(); } pub async fn get_last_request_latency(&mut self) -> Result, RpcError> { @@ -317,13 +327,13 @@ impl Service> for ClientConnector { type Future = BoxFuture<'static, Result>; type Response = mpsc::Receiver, RpcStatus>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready_unpin(cx).map_err(|_| RpcError::ClientClosed) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, request: BaseRequest) -> Self::Future { let (reply, reply_rx) = oneshot::channel(); - let mut inner = self.inner.clone(); + let inner = self.inner.clone(); async move { inner .send(ClientRequest::SendRequest { request, reply }) @@ -346,6 +356,7 @@ pub struct RpcClientWorker { ready_tx: Option>>, last_request_latency: Option, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, } impl RpcClientWorker @@ -357,6 +368,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send framed: CanonicalFraming, ready_tx: oneshot::Sender>, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, ) -> Self { Self { config, @@ -366,6 +378,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ready_tx: Some(ready_tx), last_request_latency: None, protocol_id, + shutdown_signal, } } @@ -405,26 +418,26 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send }, } - while let Some(req) = self.request_rx.next().await { - use ClientRequest::*; - match req { - SendRequest { request, reply } => { - if let Err(err) = self.do_request_response(request, reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; - } - }, - GetLastRequestLatency(reply) => { - let _ = reply.send(self.last_request_latency); - }, - SendPing(reply) => { - if let Err(err) = self.do_ping_pong(reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; + loop { + tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + break; + } + req = self.request_rx.recv() => { + match req { + Some(req) => { + if let Err(err) = self.handle_request(req).await { + error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + break; + } + } + None => break, } - }, + } } } + if let Err(err) = self.framed.close().await { debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); } @@ -436,6 +449,22 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ); } + async fn handle_request(&mut self, req: ClientRequest) -> Result<(), RpcError> { + use ClientRequest::*; + match req { + SendRequest { request, reply } => { + self.do_request_response(request, reply).await?; + }, + GetLastRequestLatency(reply) => { + let _ = reply.send(self.last_request_latency); + }, + SendPing(reply) => { + self.do_ping_pong(reply).await?; + }, + } + Ok(()) + } + async fn do_ping_pong(&mut self, reply: oneshot::Sender>) -> Result<(), RpcError> { let ack = proto::rpc::RpcRequest { flags: RpcMessageFlags::ACK.bits() as u32, @@ -501,14 +530,30 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send debug!(target: LOG_TARGET, "Sending request: {}", req); let start = Instant::now(); + if reply.is_closed() { + event!(Level::WARN, "Client request was cancelled before request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled before request was sent" + ); + } self.framed.send(req.to_encoded_bytes().into()).await?; - let (mut response_tx, response_rx) = mpsc::channel(10); - if reply.send(response_rx).is_err() { - event!(Level::WARN, "Client request was cancelled"); - warn!(target: LOG_TARGET, "Client request was cancelled."); - response_tx.close_channel(); - // TODO: Should this not exit here? + let (response_tx, response_rx) = mpsc::channel(10); + if let Err(mut rx) = reply.send(response_rx) { + event!(Level::WARN, "Client request was cancelled after request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled after request was sent" + ); + rx.close(); + // RPC is strictly request/response + // If the client drops the RpcClient request at this point after the , we have two options: + // 1. Obey the protocol: receive the response + // 2. Close the RPC session and return an error (seems brittle and unexpected) + // Option 1 has the disadvantage when receiving large/many streamed responses. + // TODO: Detect if all handles to the client handles have been dropped. If so, + // immediately close the RPC session } loop { @@ -537,8 +582,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send start.elapsed() ); event!(Level::ERROR, "Response timed out"); - let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; - response_tx.close_channel(); + if !response_tx.is_closed() { + let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; + } + break; + }, + Err(RpcError::ClientClosed) => { + debug!( + target: LOG_TARGET, + "Request {} (method={}) was closed after {:.0?} (read_reply)", + request_id, + method, + start.elapsed() + ); + self.request_rx.close(); break; }, Err(err) => { @@ -564,7 +621,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let _ = response_tx.send(Ok(resp)).await; } if is_finished { - response_tx.close_channel(); break; } }, @@ -573,7 +629,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if !response_tx.is_closed() { let _ = response_tx.send(Err(err)).await; } - response_tx.close_channel(); break; }, Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) | @@ -598,7 +653,15 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send None => Either::Right(self.framed.next().map(Ok)), }; - match next_msg_fut.await { + let result = tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + return Err(RpcError::ClientClosed); + } + result = next_msg_fut => result, + }; + + match result { Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(RpcError::ServerClosedRequest), diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index 3abd62cef6..b39c15e6d7 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -22,10 +22,13 @@ use crate::{framing::CanonicalFraming, message::MessageExt, proto, protocol::rpc::error::HandshakeRejectReason}; use bytes::BytesMut; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use prost::{DecodeError, Message}; use std::{io, time::Duration}; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tracing::{debug, error, event, span, warn, Instrument, Level}; const LOG_TARGET: &str = "comms::rpc::handshake"; @@ -168,7 +171,7 @@ where T: AsyncRead + AsyncWrite + Unpin } #[tracing::instrument(name = "rpc::receive_handshake_reply", skip(self), err)] - async fn recv_next_frame(&mut self) -> Result>, time::Elapsed> { + async fn recv_next_frame(&mut self) -> Result>, time::error::Elapsed> { match self.timeout { Some(timeout) => time::timeout(timeout, self.framed.next()).await, None => Ok(self.framed.next().await), diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index d4e91fa8e4..2244979adf 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -80,6 +80,7 @@ pub mod __macro_reexports { }, Bytes, }; - pub use futures::{future, future::BoxFuture, AsyncRead, AsyncWrite}; + pub use futures::{future, future::BoxFuture}; + pub use tokio::io::{AsyncRead, AsyncWrite}; pub use tower::Service; } diff --git a/comms/src/protocol/rpc/server/error.rs b/comms/src/protocol/rpc/server/error.rs index 5078c6c588..6972cec60b 100644 --- a/comms/src/protocol/rpc/server/error.rs +++ b/comms/src/protocol/rpc/server/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::protocol::rpc::handshake::RpcHandshakeError; -use futures::channel::oneshot; use prost::DecodeError; use std::io; +use tokio::sync::oneshot; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -41,8 +41,8 @@ pub enum RpcServerError { ProtocolServiceNotFound(String), } -impl From for RpcServerError { - fn from(_: oneshot::Canceled) -> Self { +impl From for RpcServerError { + fn from(_: oneshot::error::RecvError) -> Self { RpcServerError::RequestCanceled } } diff --git a/comms/src/protocol/rpc/server/handle.rs b/comms/src/protocol/rpc/server/handle.rs index 89bf8dd3b9..972d91429d 100644 --- a/comms/src/protocol/rpc/server/handle.rs +++ b/comms/src/protocol/rpc/server/handle.rs @@ -21,10 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::RpcServerError; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug)] pub enum RpcServerRequest { diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 19741a0a1a..69659ba03b 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -42,6 +42,7 @@ use crate::{ ProtocolNotificationTx, }, test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, + utils, NodeIdentity, PeerConnection, PeerManager, @@ -49,7 +50,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{channel::mpsc, future::BoxFuture, stream, SinkExt}; +use futures::future::BoxFuture; use std::{ collections::HashMap, future, @@ -57,7 +58,7 @@ use std::{ task::{Context, Poll}, }; use tokio::{ - sync::{Mutex, RwLock}, + sync::{mpsc, Mutex, RwLock}, task, }; use tower::Service; @@ -139,9 +140,13 @@ pub trait RpcMock { { method_state.requests.write().await.push(request.into_message()); let resp = method_state.response.read().await.clone()?; - let (mut tx, rx) = mpsc::channel(resp.len()); - let mut resp = stream::iter(resp.into_iter().map(Ok).map(Ok)); - tx.send_all(&mut resp).await.unwrap(); + let (tx, rx) = mpsc::channel(resp.len()); + match utils::mpsc::send_all(&tx, resp.into_iter().map(Ok)).await { + Ok(_) => {}, + // This is done because tokio mpsc channels give the item back to you in the error, and our item doesn't + // impl Debug, so we can't use unwrap, expect etc + Err(_) => panic!("send error"), + } Ok(Streaming::new(rx)) } } @@ -234,7 +239,7 @@ where let peer_node_id = peer.node_id.clone(); let (_, our_conn_mock, peer_conn, _) = create_peer_connection_mock_pair(peer, self.our_node.to_peer()).await; - let mut protocol_tx = self.protocol_tx.clone(); + let protocol_tx = self.protocol_tx.clone(); task::spawn(async move { while let Some(substream) = our_conn_mock.next_incoming_substream().await { let proto_notif = ProtocolNotification::new( diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index b2b7dbf76f..88fdb7ee61 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -53,14 +53,19 @@ use crate::{ protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::SinkExt; use prost::Message; use std::{ borrow::Cow, future::Future, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; @@ -198,7 +203,7 @@ pub(super) struct PeerRpcServer { service: TSvc, protocol_notifications: Option>, comms_provider: TCommsProvider, - request_rx: Option>, + request_rx: mpsc::Receiver, } impl PeerRpcServer @@ -233,7 +238,7 @@ where service, protocol_notifications: Some(protocol_notifications), comms_provider, - request_rx: Some(request_rx), + request_rx, } } @@ -243,24 +248,19 @@ where .take() .expect("PeerRpcServer initialized without protocol_notifications"); - let mut requests = self - .request_rx - .take() - .expect("PeerRpcServer initialized without request_rx"); - loop { - futures::select! { - maybe_notif = protocol_notifs.next() => { - match maybe_notif { - Some(notif) => self.handle_protocol_notification(notif).await?, - // No more protocol notifications to come, so we're done - None => break, - } - } - - req = requests.select_next_some() => { + tokio::select! { + maybe_notif = protocol_notifs.recv() => { + match maybe_notif { + Some(notif) => self.handle_protocol_notification(notif).await?, + // No more protocol notifications to come, so we're done + None => break, + } + } + + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; - }, + }, } } diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 9d03c6535d..1d40988075 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -44,14 +44,15 @@ use crate::{ Bytes, }; use futures::{ - channel::mpsc, future::BoxFuture, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, }; use std::sync::Arc; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; use tower::Service; use tower_make::MakeService; @@ -329,7 +330,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn find_route() { let server = RpcServer::new(); let mut router = Router::new(server, HelloService).add_service(GoodbyeService); diff --git a/comms/src/protocol/rpc/test/client_pool.rs b/comms/src/protocol/rpc/test/client_pool.rs index e1eb957d5f..d95e1d22b6 100644 --- a/comms/src/protocol/rpc/test/client_pool.rs +++ b/comms/src/protocol/rpc/test/client_pool.rs @@ -39,13 +39,13 @@ use crate::{ runtime::task, test_utils::mocks::{new_peer_connection_mock_pair, PeerConnectionMockState}, }; -use futures::{channel::mpsc, SinkExt}; use tari_shutdown::Shutdown; use tari_test_utils::{async_assert_eventually, unpack_enum}; +use tokio::sync::mpsc; async fn setup(num_concurrent_sessions: usize) -> (PeerConnection, PeerConnectionMockState, Shutdown) { let (conn1, conn1_state, conn2, conn2_state) = new_peer_connection_mock_pair().await; - let (mut notif_tx, notif_rx) = mpsc::channel(1); + let (notif_tx, notif_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); @@ -148,15 +148,15 @@ mod lazy_pool { async fn it_prunes_disconnected_sessions() { let (conn, mock_state, _shutdown) = setup(2).await; let mut pool = LazyPool::::new(conn, 2, Default::default()); - let mut conn1 = pool.get_least_used_or_connect().await.unwrap(); + let mut client1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); - let _conn2 = pool.get_least_used_or_connect().await.unwrap(); + let _client2 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 2); - conn1.close(); - drop(conn1); + client1.close().await; + drop(client1); async_assert_eventually!(mock_state.num_open_substreams(), expect = 1); assert_eq!(pool.refresh_num_active_connections(), 1); - let _conn3 = pool.get_least_used_or_connect().await.unwrap(); + let _client3 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(pool.refresh_num_active_connections(), 2); assert_eq!(mock_state.num_open_substreams(), 2); } diff --git a/comms/src/protocol/rpc/test/comms_integration.rs b/comms/src/protocol/rpc/test/comms_integration.rs index 9d23088f07..f43f921081 100644 --- a/comms/src/protocol/rpc/test/comms_integration.rs +++ b/comms/src/protocol/rpc/test/comms_integration.rs @@ -37,7 +37,7 @@ use crate::{ use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn run_service() { let node_identity1 = build_node_identity(Default::default()); let rpc_service = MockRpcService::new(); diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index 0e190473dd..445099e7c3 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -26,12 +26,16 @@ use crate::{ rpc::{NamedProtocolService, Request, Response, RpcError, RpcServerError, RpcStatus, Streaming}, ProtocolId, }, + utils, }; use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, time}; +use tokio::{ + sync::{mpsc, RwLock}, + task, + time, +}; #[async_trait] // #[tari_rpc(protocol_name = "/tari/greeting/1.0", server_struct = GreetingServer, client_struct = GreetingClient)] @@ -91,20 +95,11 @@ impl GreetingRpc for GreetingService { } async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); let num = *request.message(); let greetings = self.greetings[..num as usize].to_vec(); task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - match tx.send_all(&mut stream).await { - Ok(_) => {}, - Err(_err) => { - // Log error - }, - } + let _ = utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await; }); Ok(Streaming::new(rx)) @@ -118,7 +113,7 @@ impl GreetingRpc for GreetingService { } async fn streaming_error2(&self, _: Request<()>) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(2); + let (tx, rx) = mpsc::channel(2); tx.send(Ok("This is ok".to_string())).await.unwrap(); tx.send(Err(RpcStatus::bad_request("This is a problem"))).await.unwrap(); @@ -151,7 +146,7 @@ impl SlowGreetingService { impl GreetingRpc for SlowGreetingService { async fn say_hello(&self, _: Request) -> Result, RpcStatus> { let delay = *self.delay.read().await; - time::delay_for(delay).await; + time::sleep(delay).await; Ok(Response::new(SayHelloResponse { greeting: "took a while to load".to_string(), })) @@ -376,8 +371,8 @@ impl GreetingClient { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } } diff --git a/comms/src/protocol/rpc/test/handshake.rs b/comms/src/protocol/rpc/test/handshake.rs index cdd79746f2..9a21628012 100644 --- a/comms/src/protocol/rpc/test/handshake.rs +++ b/comms/src/protocol/rpc/test/handshake.rs @@ -33,7 +33,7 @@ use crate::{ }; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn it_performs_the_handshake() { let (client, server) = MemorySocket::new_pair(); @@ -51,7 +51,7 @@ async fn it_performs_the_handshake() { assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); } -#[runtime::test_basic] +#[runtime::test] async fn it_rejects_the_handshake() { let (client, server) = MemorySocket::new_pair(); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index a762ac4c9c..3149e794e3 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -52,12 +52,15 @@ use crate::{ test_utils::node_identity::build_node_identity, NodeIdentity, }; -use futures::{channel::mpsc, future, future::Either, SinkExt, StreamExt}; +use futures::{future, future::Either, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; pub(super) async fn setup_service( service_impl: T, @@ -85,7 +88,7 @@ pub(super) async fn setup_service( futures::pin_mut!(fut); match future::select(shutdown_signal, fut).await { - Either::Left((r, _)) => r.unwrap(), + Either::Left(_) => {}, Either::Right((r, _)) => r.unwrap(), } } @@ -97,7 +100,7 @@ pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, ) -> (MemorySocket, task::JoinHandle<()>, Arc, Shutdown) { - let (mut notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; + let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (inbound, outbound) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -114,7 +117,7 @@ pub(super) async fn setup( (outbound, server_hnd, node_identity, shutdown) } -#[runtime::test_basic] +#[runtime::test] async fn request_response_errors_and_streaming() { let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; @@ -171,7 +174,7 @@ async fn request_response_errors_and_streaming() { let pk_hex = client.get_public_key_hex().await.unwrap(); assert_eq!(pk_hex, node_identity.public_key().to_hex()); - client.close(); + client.close().await; let err = client .say_hello(SayHelloRequest { @@ -181,13 +184,20 @@ async fn request_response_errors_and_streaming() { .await .unwrap_err(); - unpack_enum!(RpcError::ClientClosed = err); + match err { + // Because of the race between closing the request stream and sending on that stream in the above call + // We can either get "this client was closed" or "the request you made was cancelled". + // If we delay some small time, we'll always get the former (but arbitrary delays cause flakiness and should be + // avoided) + RpcError::ClientClosed | RpcError::RequestCancelled => {}, + err => panic!("Unexpected error {:?}", err), + } - shutdown.trigger().unwrap(); + shutdown.trigger(); server_hnd.await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn concurrent_requests() { let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; @@ -227,7 +237,7 @@ async fn concurrent_requests() { assert_eq!(spawned2.await.unwrap(), GreetingService::DEFAULT_GREETINGS[..5]); } -#[runtime::test_basic] +#[runtime::test] async fn response_too_big() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -248,7 +258,7 @@ async fn response_too_big() { let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn ping_latency() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -261,11 +271,11 @@ async fn ping_latency() { assert!(latency.as_secs() < 5); } -#[runtime::test_basic] +#[runtime::test] async fn server_shutdown_before_connect() { let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; let framed = framing::canonical(socket, 1024); - shutdown.trigger().unwrap(); + shutdown.trigger(); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( @@ -274,7 +284,7 @@ async fn server_shutdown_before_connect() { )); } -#[runtime::test_basic] +#[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; @@ -298,9 +308,9 @@ async fn timeout() { assert_eq!(resp.greeting, "took a while to load"); } -#[runtime::test_basic] +#[runtime::test] async fn unknown_protocol() { - let (mut notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; + let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; let (inbound, socket) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -324,7 +334,7 @@ async fn unknown_protocol() { )); } -#[runtime::test_basic] +#[runtime::test] async fn rejected_no_sessions_available() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; let framed = framing::canonical(socket, 1024); diff --git a/comms/src/common/rate_limit.rs b/comms/src/rate_limit.rs similarity index 79% rename from comms/src/common/rate_limit.rs rename to comms/src/rate_limit.rs index 1705397d8e..3a06c70040 100644 --- a/comms/src/common/rate_limit.rs +++ b/comms/src/rate_limit.rs @@ -26,7 +26,7 @@ // This is slightly changed from the libra rate limiter implementation -use futures::{stream::Fuse, FutureExt, Stream, StreamExt}; +use futures::FutureExt; use pin_project::pin_project; use std::{ future::Future, @@ -36,10 +36,11 @@ use std::{ time::Duration, }; use tokio::{ - sync::{OwnedSemaphorePermit, Semaphore}, + sync::{AcquireError, OwnedSemaphorePermit, Semaphore}, time, time::Interval, }; +use tokio_stream::Stream; pub trait RateLimit: Stream { /// Consumes the stream and returns a rate-limited stream that only polls the underlying stream @@ -60,12 +61,13 @@ pub struct RateLimiter { stream: T, /// An interval stream that "restocks" the permits #[pin] - interval: Fuse, + interval: Interval, /// The maximum permits to issue capacity: usize, /// A semaphore that holds the permits permits: Arc, - permit_future: Option + Send>>>, + #[allow(clippy::type_complexity)] + permit_future: Option> + Send>>>, permit_acquired: bool, } @@ -75,7 +77,7 @@ impl RateLimiter { stream, capacity, - interval: time::interval(restock_interval).fuse(), + interval: time::interval(restock_interval), // `interval` starts immediately, so we can start with zero permits permits: Arc::new(Semaphore::new(0)), permit_future: None, @@ -89,7 +91,7 @@ impl Stream for RateLimiter { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // "Restock" permits once interval is ready - if let Poll::Ready(Some(_)) = self.as_mut().project().interval.poll_next(cx) { + if self.as_mut().project().interval.poll_tick(cx).is_ready() { self.permits .add_permits(self.capacity - self.permits.available_permits()); } @@ -103,6 +105,8 @@ impl Stream for RateLimiter { } // Wait until a permit is acquired + // `unwrap()` is safe because acquire_owned only panics if the semaphore has closed, but we never close it + // for the lifetime of this instance let permit = futures::ready!(self .as_mut() .project() @@ -110,7 +114,8 @@ impl Stream for RateLimiter { .as_mut() .unwrap() .as_mut() - .poll(cx)); + .poll(cx)) + .unwrap(); // Don't release the permit on drop, `interval` will restock permits permit.forget(); let this = self.as_mut().project(); @@ -130,45 +135,54 @@ impl Stream for RateLimiter { mod test { use super::*; use crate::runtime; - use futures::{future::Either, stream}; + use futures::{stream, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn rate_limit() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } assert_eq!(count, 10); } - #[runtime::test_basic] + #[runtime::test] async fn rate_limit_restock() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } // Test that at least 1 restock happens. diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index a6f7e615f7..48752c05d0 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -25,10 +25,7 @@ use tokio::runtime; // Re-export pub use tokio::{runtime::Handle, task}; -#[cfg(test)] -pub use tokio_macros::test; -#[cfg(test)] -pub use tokio_macros::test_basic; +pub use tokio::test; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] diff --git a/comms/src/socks/client.rs b/comms/src/socks/client.rs index f155f2f420..253a0c0aec 100644 --- a/comms/src/socks/client.rs +++ b/comms/src/socks/client.rs @@ -23,7 +23,6 @@ // Acknowledgement to @sticnarf for tokio-socks on which this code is based use super::error::SocksError; use data_encoding::BASE32; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use multiaddr::{Multiaddr, Protocol}; use std::{ borrow::Cow, @@ -31,6 +30,7 @@ use std::{ fmt::Formatter, net::{Ipv4Addr, Ipv6Addr}, }; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub type Result = std::result::Result; diff --git a/comms/src/test_utils/mocks/connection_manager.rs b/comms/src/test_utils/mocks/connection_manager.rs index cc489af60e..ece7224f44 100644 --- a/comms/src/test_utils/mocks/connection_manager.rs +++ b/comms/src/test_utils/mocks/connection_manager.rs @@ -31,7 +31,6 @@ use crate::{ peer_manager::NodeId, runtime::task, }; -use futures::{channel::mpsc, lock::Mutex, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -39,14 +38,14 @@ use std::{ Arc, }, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, Mutex}; pub fn create_connection_manager_mock() -> (ConnectionManagerRequester, ConnectionManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectionManagerRequester::new(tx, event_tx.clone()), - ConnectionManagerMock::new(rx.fuse(), event_tx), + ConnectionManagerMock::new(rx, event_tx), ) } @@ -97,13 +96,13 @@ impl ConnectionManagerMockState { } pub struct ConnectionManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectionManagerMockState, } impl ConnectionManagerMock { pub fn new( - receiver: Fuse>, + receiver: mpsc::Receiver, event_tx: broadcast::Sender>, ) -> Self { Self { @@ -121,7 +120,7 @@ impl ConnectionManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/connectivity_manager.rs b/comms/src/test_utils/mocks/connectivity_manager.rs index 122a60127b..78394f428b 100644 --- a/comms/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/src/test_utils/mocks/connectivity_manager.rs @@ -22,32 +22,36 @@ use crate::{ connection_manager::{ConnectionManagerError, PeerConnection}, - connectivity::{ConnectivityEvent, ConnectivityRequest, ConnectivityRequester, ConnectivityStatus}, + connectivity::{ + ConnectivityEvent, + ConnectivityEventTx, + ConnectivityRequest, + ConnectivityRequester, + ConnectivityStatus, + }, peer_manager::NodeId, runtime::task, }; -use futures::{ - channel::{mpsc, oneshot}, - lock::Mutex, - stream::Fuse, - StreamExt, -}; +use futures::lock::Mutex; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; pub fn create_connectivity_mock() -> (ConnectivityRequester, ConnectivityManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectivityRequester::new(tx, event_tx.clone()), - ConnectivityManagerMock::new(rx.fuse(), event_tx), + ConnectivityManagerMock::new(rx, event_tx), ) } #[derive(Debug, Clone)] pub struct ConnectivityManagerMockState { inner: Arc>, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, } #[derive(Debug, Default)] @@ -61,7 +65,7 @@ struct State { } impl ConnectivityManagerMockState { - pub fn new(event_tx: broadcast::Sender>) -> Self { + pub fn new(event_tx: ConnectivityEventTx) -> Self { Self { event_tx, inner: Default::default(), @@ -132,7 +136,7 @@ impl ConnectivityManagerMockState { count, self.call_count().await ); - time::delay_for(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(100)).await; } } @@ -156,9 +160,8 @@ impl ConnectivityManagerMockState { .await } - #[allow(dead_code)] pub fn publish_event(&self, event: ConnectivityEvent) { - self.event_tx.send(Arc::new(event)).unwrap(); + self.event_tx.send(event).unwrap(); } pub(self) async fn with_state(&self, f: F) -> R @@ -169,15 +172,12 @@ impl ConnectivityManagerMockState { } pub struct ConnectivityManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectivityManagerMockState, } impl ConnectivityManagerMock { - pub fn new( - receiver: Fuse>, - event_tx: broadcast::Sender>, - ) -> Self { + pub fn new(receiver: mpsc::Receiver, event_tx: ConnectivityEventTx) -> Self { Self { receiver, state: ConnectivityManagerMockState::new(event_tx), @@ -195,7 +195,7 @@ impl ConnectivityManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index 13b0c77bbf..bccdec9575 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -34,15 +34,18 @@ use crate::{ peer_manager::{NodeId, Peer, PeerFeatures}, test_utils::{node_identity::build_node_identity, transport}, }; -use futures::{channel::mpsc, lock::Mutex, StreamExt}; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use tokio::runtime::Handle; +use tokio::{ + runtime::Handle, + sync::{mpsc, Mutex}, +}; +use tokio_stream::StreamExt; pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::Receiver) { - let (tx, rx) = mpsc::channel(0); + let (tx, rx) = mpsc::channel(1); ( PeerConnection::new( 1, @@ -114,7 +117,7 @@ pub async fn new_peer_connection_mock_pair() -> ( create_peer_connection_mock_pair(peer1, peer2).await } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct PeerConnectionMockState { call_count: Arc, mux_control: Arc>, @@ -181,7 +184,7 @@ impl PeerConnectionMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index a150a765f8..3e5d7229a0 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -29,12 +29,14 @@ use crate::{ protocol::Protocols, transports::Transport, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_shutdown::ShutdownSignal; use tari_storage::HashmapDatabase; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; #[derive(Clone, Debug)] pub struct TestNodeConfig { diff --git a/comms/src/tor/control_client/client.rs b/comms/src/tor/control_client/client.rs index 5d14c770b7..573bce60c0 100644 --- a/comms/src/tor/control_client/client.rs +++ b/comms/src/tor/control_client/client.rs @@ -34,10 +34,13 @@ use crate::{ tor::control_client::{event::TorControlEvent, monitor::spawn_monitor}, transports::{TcpTransport, Transport}, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{borrow::Cow, fmt, fmt::Display, num::NonZeroU16}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; +use tokio_stream::wrappers::BroadcastStream; /// Client for the Tor control port. /// @@ -80,8 +83,8 @@ impl TorControlPortClient { &self.event_tx } - pub fn get_event_stream(&self) -> broadcast::Receiver { - self.event_tx.subscribe() + pub fn get_event_stream(&self) -> BroadcastStream { + BroadcastStream::new(self.event_tx.subscribe()) } /// Authenticate with the tor control port @@ -232,8 +235,7 @@ impl TorControlPortClient { } async fn receive_line(&mut self) -> Result { - let line = self.output_stream.next().await.ok_or(TorClientError::UnexpectedEof)?; - + let line = self.output_stream.recv().await.ok_or(TorClientError::UnexpectedEof)?; Ok(line) } } @@ -273,9 +275,11 @@ mod test { runtime, tor::control_client::{test_server, test_server::canned_responses, types::PrivateKey}, }; - use futures::{future, AsyncWriteExt}; + use futures::future; use std::net::SocketAddr; use tari_test_utils::unpack_enum; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; async fn setup_test() -> (TorControlPortClient, test_server::State) { let (_, mock_state, socket) = test_server::spawn().await; @@ -298,7 +302,7 @@ mod test { let _out_sock = result_out.unwrap(); let (mut in_sock, _) = result_in.unwrap().unwrap(); in_sock.write(b"test123").await.unwrap(); - in_sock.close().await.unwrap(); + in_sock.shutdown().await.unwrap(); } #[runtime::test] diff --git a/comms/src/tor/control_client/monitor.rs b/comms/src/tor/control_client/monitor.rs index a5191b466d..72eb6b88ef 100644 --- a/comms/src/tor/control_client/monitor.rs +++ b/comms/src/tor/control_client/monitor.rs @@ -21,11 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{event::TorControlEvent, parsers, response::ResponseLine, LOG_TARGET}; -use crate::{compat::IoCompat, runtime::task}; -use futures::{channel::mpsc, future, future::Either, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use crate::runtime::task; +use futures::{future::Either, SinkExt, Stream, StreamExt}; use log::*; use std::fmt; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LinesCodec}; pub fn spawn_monitor( @@ -36,16 +39,19 @@ pub fn spawn_monitor( where TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut responses_tx, responses_rx) = mpsc::channel(100); + let (responses_tx, responses_rx) = mpsc::channel(100); task::spawn(async move { - let framed = Framed::new(IoCompat::new(socket), LinesCodec::new()); + let framed = Framed::new(socket, LinesCodec::new()); let (mut sink, mut stream) = framed.split(); loop { - let either = future::select(cmd_rx.next(), stream.next()).await; + let either = tokio::select! { + next = cmd_rx.recv() => Either::Left(next), + next = stream.next() => Either::Right(next), + }; match either { // Received a command to send to the control server - Either::Left((Some(line), _)) => { + Either::Left(Some(line)) => { trace!(target: LOG_TARGET, "Writing command of length '{}'", line.len()); if let Err(err) = sink.send(line).await { error!( @@ -56,7 +62,7 @@ where } }, // Command stream ended - Either::Left((None, _)) => { + Either::Left(None) => { debug!( target: LOG_TARGET, "Tor control server command receiver closed. Monitor is exiting." @@ -65,7 +71,7 @@ where }, // Received a line from the control server - Either::Right((Some(Ok(line)), _)) => { + Either::Right(Some(Ok(line))) => { trace!(target: LOG_TARGET, "Read line of length '{}'", line.len()); match parsers::response_line(&line) { Ok(mut line) => { @@ -95,7 +101,7 @@ where }, // Error receiving a line from the control server - Either::Right((Some(Err(err)), _)) => { + Either::Right(Some(Err(err))) => { error!( target: LOG_TARGET, "Line framing error when reading from tor control server: '{:?}'. Monitor is exiting.", err @@ -103,7 +109,7 @@ where break; }, // The control server disconnected - Either::Right((None, _)) => { + Either::Right(None) => { cmd_rx.close(); debug!( target: LOG_TARGET, diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 5a5e1b3b7c..1741cfc721 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -20,13 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - compat::IoCompat, - memsocket::MemorySocket, - multiaddr::Multiaddr, - runtime, - test_utils::transport::build_connected_sockets, -}; +use crate::{memsocket::MemorySocket, multiaddr::Multiaddr, runtime, test_utils::transport::build_connected_sockets}; use futures::{lock::Mutex, stream, SinkExt, StreamExt}; use std::sync::Arc; use tokio_util::codec::{Framed, LinesCodec}; @@ -82,7 +76,7 @@ impl TorControlPortTestServer { } pub async fn run(self) { - let mut framed = Framed::new(IoCompat::new(self.socket), LinesCodec::new()); + let mut framed = Framed::new(self.socket, LinesCodec::new()); let state = self.state; while let Some(msg) = framed.next().await { state.request_lines.lock().await.push(msg.unwrap()); diff --git a/comms/src/tor/hidden_service/controller.rs b/comms/src/tor/hidden_service/controller.rs index e19818df58..74b89808fb 100644 --- a/comms/src/tor/hidden_service/controller.rs +++ b/comms/src/tor/hidden_service/controller.rs @@ -214,7 +214,7 @@ impl HiddenServiceController { "Failed to reestablish connection with tor control server because '{:?}'", err ); warn!(target: LOG_TARGET, "Will attempt again in 5 seconds..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; }, Either::Right(_) => { diff --git a/comms/src/transports/memory.rs b/comms/src/transports/memory.rs index 074e728274..8e3a1d7a91 100644 --- a/comms/src/transports/memory.rs +++ b/comms/src/transports/memory.rs @@ -129,7 +129,8 @@ impl Stream for Listener { mod test { use super::*; use crate::runtime; - use futures::{future::join, stream::StreamExt, AsyncReadExt, AsyncWriteExt}; + use futures::{future::join, stream::StreamExt}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[runtime::test] async fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { diff --git a/comms/src/transports/mod.rs b/comms/src/transports/mod.rs index 0c04e50a72..75b16db975 100644 --- a/comms/src/transports/mod.rs +++ b/comms/src/transports/mod.rs @@ -24,8 +24,8 @@ // Copyright (c) The Libra Core Contributors // SPDX-License-Identifier: Apache-2.0 -use futures::Stream; use multiaddr::Multiaddr; +use tokio_stream::Stream; mod dns; mod helpers; @@ -37,7 +37,7 @@ mod socks; pub use socks::{SocksConfig, SocksTransport}; mod tcp; -pub use tcp::{TcpSocket, TcpTransport}; +pub use tcp::TcpTransport; mod tcp_with_tor; pub use tcp_with_tor::TcpWithTorTransport; diff --git a/comms/src/transports/socks.rs b/comms/src/transports/socks.rs index 2027f34561..7dc87ef0e4 100644 --- a/comms/src/transports/socks.rs +++ b/comms/src/transports/socks.rs @@ -24,12 +24,13 @@ use crate::{ multiaddr::Multiaddr, socks, socks::Socks5Client, - transports::{dns::SystemDnsResolver, tcp::TcpTransport, TcpSocket, Transport}, + transports::{dns::SystemDnsResolver, tcp::TcpTransport, Transport}, }; -use std::{io, time::Duration}; +use std::io; +use tokio::net::TcpStream; -/// SO_KEEPALIVE setting for the SOCKS TCP connection -const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); +// /// SO_KEEPALIVE setting for the SOCKS TCP connection +// const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); #[derive(Clone, Debug)] pub struct SocksConfig { @@ -57,7 +58,7 @@ impl SocksTransport { pub fn create_socks_tcp_transport() -> TcpTransport { let mut tcp_transport = TcpTransport::new(); tcp_transport.set_nodelay(true); - tcp_transport.set_keepalive(Some(SOCKS_SO_KEEPALIVE)); + // .set_keepalive(Some(SOCKS_SO_KEEPALIVE)) tcp_transport.set_dns_resolver(SystemDnsResolver); tcp_transport } @@ -66,7 +67,7 @@ impl SocksTransport { tcp: TcpTransport, socks_config: SocksConfig, dest_addr: Multiaddr, - ) -> io::Result { + ) -> io::Result { // Create a new connection to the SOCKS proxy let socks_conn = tcp.dial(socks_config.proxy_address).await?; let mut client = Socks5Client::new(socks_conn); diff --git a/comms/src/transports/tcp.rs b/comms/src/transports/tcp.rs index 8112cca203..6b47e7c357 100644 --- a/comms/src/transports/tcp.rs +++ b/comms/src/transports/tcp.rs @@ -25,44 +25,42 @@ use crate::{ transports::dns::{DnsResolverRef, SystemDnsResolver}, utils::multiaddr::socketaddr_to_multiaddr, }; -use futures::{io::Error, ready, AsyncRead, AsyncWrite, Future, FutureExt, Stream}; +use futures::{ready, FutureExt}; use multiaddr::Multiaddr; use std::{ + future::Future, io, pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}, - net::{TcpListener, TcpStream}, }; +use tokio::net::{TcpListener, TcpStream}; +use tokio_stream::Stream; /// Transport implementation for TCP #[derive(Clone)] pub struct TcpTransport { - recv_buffer_size: Option, - send_buffer_size: Option, + // recv_buffer_size: Option, + // send_buffer_size: Option, ttl: Option, - #[allow(clippy::option_option)] - keepalive: Option>, + // #[allow(clippy::option_option)] + // keepalive: Option>, nodelay: Option, dns_resolver: DnsResolverRef, } impl TcpTransport { - #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] - setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); - - #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] - setter_mut!(set_send_buffer_size, send_buffer_size, Option); + // #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] + // setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); + // + // #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] + // setter_mut!(set_send_buffer_size, send_buffer_size, Option); #[doc("Sets `IP_TTL` i.e. the TTL of packets sent from this socket.")] setter_mut!(set_ttl, ttl, Option); - #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] - setter_mut!(set_keepalive, keepalive, Option>); + // #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] + // setter_mut!(set_keepalive, keepalive, Option>); #[doc("Sets `TCP_NODELAY` i.e disable Nagle's algorithm if set to true.")] setter_mut!(set_nodelay, nodelay, Option); @@ -81,9 +79,10 @@ impl TcpTransport { /// Apply socket options to `TcpStream`. fn configure(&self, socket: &TcpStream) -> io::Result<()> { - if let Some(keepalive) = self.keepalive { - socket.set_keepalive(keepalive)?; - } + // https://github.com/rust-lang/rust/issues/69774 + // if let Some(keepalive) = self.keepalive { + // socket.set_keepalive(keepalive)?; + // } if let Some(ttl) = self.ttl { socket.set_ttl(ttl)?; @@ -93,13 +92,13 @@ impl TcpTransport { socket.set_nodelay(nodelay)?; } - if let Some(recv_buffer_size) = self.recv_buffer_size { - socket.set_recv_buffer_size(recv_buffer_size)?; - } - - if let Some(send_buffer_size) = self.send_buffer_size { - socket.set_send_buffer_size(send_buffer_size)?; - } + // if let Some(recv_buffer_size) = self.recv_buffer_size { + // socket.set_recv_buffer_size(recv_buffer_size)?; + // } + // + // if let Some(send_buffer_size) = self.send_buffer_size { + // socket.set_send_buffer_size(send_buffer_size)?; + // } Ok(()) } @@ -108,10 +107,10 @@ impl TcpTransport { impl Default for TcpTransport { fn default() -> Self { Self { - recv_buffer_size: None, - send_buffer_size: None, + // recv_buffer_size: None, + // send_buffer_size: None, ttl: None, - keepalive: None, + // keepalive: None, nodelay: None, dns_resolver: Arc::new(SystemDnsResolver), } @@ -122,7 +121,7 @@ impl Default for TcpTransport { impl Transport for TcpTransport { type Error = io::Error; type Listener = TcpInbound; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { let socket_addr = self @@ -161,12 +160,12 @@ impl TcpOutbound { impl Future for TcpOutbound where F: Future> + Unpin { - type Output = io::Result; + type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let socket = ready!(Pin::new(&mut self.future).poll(cx))?; - self.config.configure(&socket)?; - Poll::Ready(Ok(TcpSocket::new(socket))) + let stream = ready!(Pin::new(&mut self.future).poll(cx))?; + self.config.configure(&stream)?; + Poll::Ready(Ok(stream)) } } @@ -184,52 +183,14 @@ impl TcpInbound { } impl Stream for TcpInbound { - type Item = io::Result<(TcpSocket, Multiaddr)>; + type Item = io::Result<(TcpStream, Multiaddr)>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let (socket, addr) = ready!(self.listener.poll_accept(cx))?; // Configure each socket self.config.configure(&socket)?; let peer_addr = socketaddr_to_multiaddr(&addr); - Poll::Ready(Some(Ok((TcpSocket::new(socket), peer_addr)))) - } -} - -/// TcpSocket is a wrapper struct for tokio `TcpStream` and implements -/// `futures-rs` AsyncRead/Write -pub struct TcpSocket { - inner: TcpStream, -} - -impl TcpSocket { - pub fn new(stream: TcpStream) -> Self { - Self { inner: stream } - } -} - -impl AsyncWrite for TcpSocket { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -impl AsyncRead for TcpSocket { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl From for TcpSocket { - fn from(stream: TcpStream) -> Self { - Self { inner: stream } + Poll::Ready(Some(Ok((socket, peer_addr)))) } } @@ -240,16 +201,15 @@ mod test { #[test] fn configure() { let mut tcp = TcpTransport::new(); - tcp.set_send_buffer_size(123) - .set_recv_buffer_size(456) - .set_nodelay(true) - .set_ttl(789) - .set_keepalive(Some(Duration::from_millis(100))); - - assert_eq!(tcp.send_buffer_size, Some(123)); - assert_eq!(tcp.recv_buffer_size, Some(456)); + // tcp.set_send_buffer_size(123) + // .set_recv_buffer_size(456) + tcp.set_nodelay(true).set_ttl(789); + // .set_keepalive(Some(Duration::from_millis(100))); + + // assert_eq!(tcp.send_buffer_size, Some(123)); + // assert_eq!(tcp.recv_buffer_size, Some(456)); assert_eq!(tcp.nodelay, Some(true)); assert_eq!(tcp.ttl, Some(789)); - assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); + // assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); } } diff --git a/comms/src/transports/tcp_with_tor.rs b/comms/src/transports/tcp_with_tor.rs index de54e17bb5..be800d9bfb 100644 --- a/comms/src/transports/tcp_with_tor.rs +++ b/comms/src/transports/tcp_with_tor.rs @@ -21,16 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::Transport; -use crate::transports::{ - dns::TorDnsResolver, - helpers::is_onion_address, - SocksConfig, - SocksTransport, - TcpSocket, - TcpTransport, -}; +use crate::transports::{dns::TorDnsResolver, helpers::is_onion_address, SocksConfig, SocksTransport, TcpTransport}; use multiaddr::Multiaddr; use std::io; +use tokio::net::TcpStream; /// Transport implementation for TCP with Tor support #[derive(Clone, Default)] @@ -69,7 +63,7 @@ impl TcpWithTorTransport { impl Transport for TcpWithTorTransport { type Error = io::Error; type Listener = ::Listener; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { self.tcp_transport.listen(addr).await diff --git a/comms/src/utils/mod.rs b/comms/src/utils/mod.rs index 697a758c1f..e543e644bb 100644 --- a/comms/src/utils/mod.rs +++ b/comms/src/utils/mod.rs @@ -22,5 +22,6 @@ pub mod cidr; pub mod datetime; +pub mod mpsc; pub mod multiaddr; pub mod signature; diff --git a/comms/src/common/mod.rs b/comms/src/utils/mpsc.rs similarity index 84% rename from comms/src/common/mod.rs rename to comms/src/utils/mpsc.rs index 2ac0cd0a6f..8ded39967f 100644 --- a/comms/src/common/mod.rs +++ b/comms/src/utils/mpsc.rs @@ -1,4 +1,4 @@ -// Copyright 2020, The Tari Project +// Copyright 2021, The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -20,4 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod rate_limit; +use tokio::sync::mpsc; + +pub async fn send_all>( + sender: &mpsc::Sender, + iter: I, +) -> Result<(), mpsc::error::SendError> { + for item in iter { + sender.send(item).await?; + } + Ok(()) +} diff --git a/comms/tests/greeting_service.rs b/comms/tests/greeting_service.rs index c13ae842e4..e70b6e16ff 100644 --- a/comms/tests/greeting_service.rs +++ b/comms/tests/greeting_service.rs @@ -23,14 +23,14 @@ #![cfg(feature = "rpc")] use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{cmp, time::Duration}; use tari_comms::{ async_trait, protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, }; use tari_comms_rpc_macros::tari_rpc; -use tokio::{task, time}; +use tokio::{sync::mpsc, task, time}; #[tari_rpc(protocol_name = b"t/greeting/1", server_struct = GreetingServer, client_struct = GreetingClient)] pub trait GreetingRpc: Send + Sync + 'static { @@ -85,15 +85,9 @@ impl GreetingRpc for GreetingService { async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { let num = *request.message(); - let (mut tx, rx) = mpsc::channel(num as usize); + let (tx, rx) = mpsc::channel(num as usize); let greetings = self.greetings[..cmp::min(num as usize + 1, self.greetings.len())].to_vec(); - task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - tx.send_all(&mut stream).await.unwrap(); - }); + task::spawn(async move { utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await }); Ok(Streaming::new(rx)) } @@ -113,7 +107,7 @@ impl GreetingRpc for GreetingService { item_size, num_items, } = request.into_message(); - let (mut tx, rx) = mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let t = std::time::Instant::now(); task::spawn(async move { let item = iter::repeat(0u8).take(item_size as usize).collect::>(); @@ -136,7 +130,7 @@ impl GreetingRpc for GreetingService { } async fn slow_response(&self, request: Request) -> Result, RpcStatus> { - time::delay_for(Duration::from_secs(request.into_message())).await; + time::sleep(Duration::from_secs(request.into_message())).await; Ok(Response::new(())) } } diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs index 0376a7dddc..933e158398 100644 --- a/comms/tests/rpc_stress.rs +++ b/comms/tests/rpc_stress.rs @@ -155,6 +155,7 @@ async fn run_stress_test(test_params: Params) { } future::join_all(tasks).await.into_iter().for_each(Result::unwrap); + log::info!("Stress test took {:.2?}", time.elapsed()); } @@ -259,7 +260,7 @@ async fn high_contention_high_concurrency() { .await; } -#[tokio_macros::test] +#[tokio::test] async fn run() { // let _ = env_logger::try_init(); log_timing("quick", quick()).await; diff --git a/comms/tests/substream_stress.rs b/comms/tests/substream_stress.rs index cbae6e8e52..a72eff03f3 100644 --- a/comms/tests/substream_stress.rs +++ b/comms/tests/substream_stress.rs @@ -23,7 +23,7 @@ mod helpers; use helpers::create_comms; -use futures::{channel::mpsc, future, SinkExt, StreamExt}; +use futures::{future, SinkExt, StreamExt}; use std::time::Duration; use tari_comms::{ framing, @@ -35,7 +35,7 @@ use tari_comms::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::unpack_enum; -use tokio::{task, time::Instant}; +use tokio::{sync::mpsc, task, time::Instant}; const PROTOCOL_NAME: &[u8] = b"test/dummy/protocol"; @@ -79,7 +79,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s task::spawn({ let sample = sample.clone(); async move { - while let Some(event) = notif_rx.next().await { + while let Some(event) = notif_rx.recv().await { unpack_enum!(ProtocolEvent::NewInboundSubstream(_n, remote_substream) = event.event); let mut remote_substream = framing::canonical(remote_substream, frame_size); @@ -150,7 +150,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s println!("avg t = {}ms", avg); } -#[tokio_macros::test] +#[tokio::test] async fn many_at_frame_limit() { const NUM_SUBSTREAMS: usize = 20; const NUM_ITERATIONS_PER_STREAM: usize = 100; diff --git a/infrastructure/shutdown/Cargo.toml b/infrastructure/shutdown/Cargo.toml index 1102ec037a..176bace6e2 100644 --- a/infrastructure/shutdown/Cargo.toml +++ b/infrastructure/shutdown/Cargo.toml @@ -12,7 +12,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "^0.3.1" +futures = "^0.3" [dev-dependencies] -tokio = {version="^0.2", features=["rt-core"]} +tokio = {version="1", default-features = false, features = ["rt", "macros"]} diff --git a/infrastructure/shutdown/src/lib.rs b/infrastructure/shutdown/src/lib.rs index f0054bd843..5afd56a239 100644 --- a/infrastructure/shutdown/src/lib.rs +++ b/infrastructure/shutdown/src/lib.rs @@ -26,14 +26,16 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use futures::{ - channel::{oneshot, oneshot::Canceled}, - future::{Fuse, FusedFuture, Shared}, + +pub mod oneshot_trigger; + +use crate::oneshot_trigger::OneshotSignal; +use futures::future::FusedFuture; +use std::{ + future::Future, + pin::Pin, task::{Context, Poll}, - Future, - FutureExt, }; -use std::pin::Pin; /// Trigger for shutdowns. /// @@ -42,71 +44,69 @@ use std::pin::Pin; /// /// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as /// long as required by the application. -pub struct Shutdown { - trigger: Option>, - signal: ShutdownSignal, - on_triggered: Option>, -} - +pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>); impl Shutdown { - /// Create a new Shutdown pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self { - trigger: Some(tx), - signal: rx.fuse().shared(), - on_triggered: None, - } + Self(oneshot_trigger::OneshotTrigger::new()) } - /// Set the on_triggered callback - pub fn on_triggered(&mut self, on_trigger: F) -> &mut Self - where F: FnOnce() + Send + Sync + 'static { - self.on_triggered = Some(Box::new(on_trigger)); - self + pub fn trigger(&mut self) { + self.0.broadcast(()); + } + + pub fn is_triggered(&self) -> bool { + self.0.is_used() } - /// Convert this into a ShutdownSignal without consuming the - /// struct. pub fn to_signal(&self) -> ShutdownSignal { - self.signal.clone() + self.0.to_signal().into() } +} - /// Trigger any listening signals - pub fn trigger(&mut self) -> Result<(), ShutdownError> { - match self.trigger.take() { - Some(trigger) => { - trigger.send(()).map_err(|_| ShutdownError)?; +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} - if let Some(on_triggered) = self.on_triggered.take() { - on_triggered(); - } +/// Receiver end of a shutdown signal. Once received the consumer should shut down. +#[derive(Debug, Clone)] +pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>); - Ok(()) - }, - None => Ok(()), - } +impl ShutdownSignal { + pub fn is_triggered(&self) -> bool { + self.0.is_terminated() } - pub fn is_triggered(&self) -> bool { - self.trigger.is_none() + /// Wait for the shutdown signal to trigger. + pub fn wait(&mut self) -> &mut Self { + self } } -impl Drop for Shutdown { - fn drop(&mut self) { - let _ = self.trigger(); +impl Future for ShutdownSignal { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.0).poll(cx) { + // Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } } } -impl Default for Shutdown { - fn default() -> Self { - Self::new() +impl FusedFuture for ShutdownSignal { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } -/// Receiver end of a shutdown signal. Once received the consumer should shut down. -pub type ShutdownSignal = Shared>>; +impl From> for ShutdownSignal { + fn from(inner: OneshotSignal<()>) -> Self { + Self(inner) + } +} #[derive(Debug, Clone, Default)] pub struct OptionalShutdownSignal(Option); @@ -137,11 +137,11 @@ impl OptionalShutdownSignal { } impl Future for OptionalShutdownSignal { - type Output = Result<(), Canceled>; + type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.as_mut() { - Some(inner) => inner.poll_unpin(cx), + Some(inner) => Pin::new(inner).poll(cx), None => Poll::Pending, } } @@ -165,73 +165,50 @@ impl FusedFuture for OptionalShutdownSignal { } } -#[derive(Debug)] -pub struct ShutdownError; - #[cfg(test)] mod test { use super::*; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; - use tokio::runtime::Runtime; - - #[test] - fn trigger() { - let rt = Runtime::new().unwrap(); + use tokio::task; + + #[tokio::test] + async fn trigger() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); assert!(!shutdown.is_triggered()); - rt.spawn(async move { - signal.await.unwrap(); + let fut = task::spawn(async move { + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + assert!(shutdown.is_triggered()); // Shutdown::trigger is idempotent - shutdown.trigger().unwrap(); + shutdown.trigger(); assert!(shutdown.is_triggered()); + fut.await.unwrap(); } - #[test] - fn signal_clone() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn signal_clone() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + fut.await.unwrap(); } - #[test] - fn drop_trigger() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn drop_trigger() { let shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); drop(shutdown); - } - - #[test] - fn on_trigger() { - let rt = Runtime::new().unwrap(); - let spy = Arc::new(AtomicBool::new(false)); - let spy_clone = Arc::clone(&spy); - let mut shutdown = Shutdown::new(); - shutdown.on_triggered(move || { - spy_clone.store(true, Ordering::SeqCst); - }); - let signal = shutdown.to_signal(); - rt.spawn(async move { - signal.await.unwrap(); - }); - shutdown.trigger().unwrap(); - assert!(spy.load(Ordering::SeqCst)); + fut.await.unwrap(); } } diff --git a/infrastructure/shutdown/src/oneshot_trigger.rs b/infrastructure/shutdown/src/oneshot_trigger.rs new file mode 100644 index 0000000000..7d2ee5b46a --- /dev/null +++ b/infrastructure/shutdown/src/oneshot_trigger.rs @@ -0,0 +1,106 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{ + channel::{oneshot, oneshot::Receiver}, + future::{FusedFuture, Shared}, + FutureExt, +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +pub fn channel() -> OneshotTrigger { + OneshotTrigger::new() +} + +pub struct OneshotTrigger { + sender: Option>, + signal: OneshotSignal, +} + +impl OneshotTrigger { + pub fn new() -> Self { + let (tx, rx) = oneshot::channel(); + Self { + sender: Some(tx), + signal: rx.shared().into(), + } + } + + pub fn to_signal(&self) -> OneshotSignal { + self.signal.clone() + } + + pub fn broadcast(&mut self, item: T) { + if let Some(tx) = self.sender.take() { + let _ = tx.send(item); + } + } + + pub fn is_used(&self) -> bool { + self.sender.is_none() + } +} + +impl Default for OneshotTrigger { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OneshotSignal { + inner: Shared>, +} + +impl From>> for OneshotSignal { + fn from(inner: Shared>) -> Self { + Self { inner } + } +} + +impl Future for OneshotSignal { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.is_terminated() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(Some(v)), + // Channel canceled + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl FusedFuture for OneshotSignal { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} diff --git a/infrastructure/storage/Cargo.toml b/infrastructure/storage/Cargo.toml index 5db74c5097..3915bde787 100644 --- a/infrastructure/storage/Cargo.toml +++ b/infrastructure/storage/Cargo.toml @@ -13,13 +13,13 @@ edition = "2018" bincode = "1.1" log = "0.4.0" lmdb-zero = "0.4.4" -thiserror = "1.0.20" +thiserror = "1.0.26" rmp = "0.8.7" rmp-serde = "0.13.7" serde = "1.0.80" serde_derive = "1.0.80" tari_utilities = "^0.3" -bytes = "0.4.12" +bytes = "0.5" [dev-dependencies] rand = "0.8" diff --git a/infrastructure/test_utils/Cargo.toml b/infrastructure/test_utils/Cargo.toml index bee18392c9..c2ea538ffb 100644 --- a/infrastructure/test_utils/Cargo.toml +++ b/infrastructure/test_utils/Cargo.toml @@ -10,9 +10,10 @@ license = "BSD-3-Clause" [dependencies] tari_shutdown = {version="*", path="../shutdown"} + futures-test = { version = "^0.3.1" } futures = {version= "^0.3.1"} rand = "0.8" -tokio = {version= "0.2.10", features=["rt-threaded", "time", "io-driver"]} +tokio = {version= "1.10", features=["rt-multi-thread", "time"]} lazy_static = "1.3.0" tempfile = "3.1.0" diff --git a/infrastructure/test_utils/src/futures/async_assert_eventually.rs b/infrastructure/test_utils/src/futures/async_assert_eventually.rs index cddfb29b37..f9a1f3ee9d 100644 --- a/infrastructure/test_utils/src/futures/async_assert_eventually.rs +++ b/infrastructure/test_utils/src/futures/async_assert_eventually.rs @@ -46,7 +46,7 @@ macro_rules! async_assert_eventually { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; value = $check_expr; } }}; @@ -82,7 +82,7 @@ macro_rules! async_assert { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; } }}; ($check_expr:expr$(,)?) => {{ diff --git a/infrastructure/test_utils/src/runtime.rs b/infrastructure/test_utils/src/runtime.rs index 7b8c5faa57..22ee5962a6 100644 --- a/infrastructure/test_utils/src/runtime.rs +++ b/infrastructure/test_utils/src/runtime.rs @@ -26,12 +26,8 @@ use tari_shutdown::Shutdown; use tokio::{runtime, runtime::Runtime, task, task::JoinError}; pub fn create_runtime() -> Runtime { - tokio::runtime::Builder::new() - .threaded_scheduler() - .enable_io() - .enable_time() - .max_threads(8) - .core_threads(4) + tokio::runtime::Builder::new_multi_thread() + .enable_all() .build() .expect("Could not create runtime") } @@ -49,7 +45,7 @@ where F: Future + Send + 'static { /// Create a runtime and report if it panics. If there are tasks still running after the panic, this /// will carry on running forever. -// #[deprecated(note = "use tokio_macros::test instead")] +// #[deprecated(note = "use tokio::test instead")] pub fn test_async(f: F) where F: FnOnce(&mut TestRuntime) { let mut rt = TestRuntime::from(create_runtime()); diff --git a/infrastructure/test_utils/src/streams/mod.rs b/infrastructure/test_utils/src/streams/mod.rs index a70e588f7e..177bea14b1 100644 --- a/infrastructure/test_utils/src/streams/mod.rs +++ b/infrastructure/test_utils/src/streams/mod.rs @@ -20,8 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{Stream, StreamExt}; -use std::{collections::HashMap, hash::Hash, time::Duration}; +use futures::{stream, Stream, StreamExt}; +use std::{borrow::BorrowMut, collections::HashMap, hash::Hash, time::Duration}; +use tokio::sync::{broadcast, mpsc}; #[allow(dead_code)] #[allow(clippy::mutable_key_type)] // Note: Clippy Breaks with Interior Mutability Error @@ -54,7 +55,6 @@ where #[macro_export] macro_rules! collect_stream { ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::{Stream, StreamExt}; use tokio::time; // Evaluate $stream once, NOT in the loop 🐛🚨 @@ -62,14 +62,17 @@ macro_rules! collect_stream { let mut items = Vec::new(); loop { - if let Some(item) = time::timeout($timeout, stream.next()).await.expect( - format!( - "Timeout before stream could collect {} item(s). Got {} item(s).", - $take, - items.len() + if let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next(stream)) + .await + .expect( + format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + ) + .as_str(), ) - .as_str(), - ) { + { items.push(item); if items.len() == $take { break items; @@ -80,11 +83,86 @@ macro_rules! collect_stream { } }}; ($stream:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::StreamExt; use tokio::time; + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next($stream)) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + + let mut items = Vec::new(); + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv ended early", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, stream.recv()) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_try_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + let mut items = Vec::new(); - while let Some(item) = time::timeout($timeout, $stream.next()) + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv returned unexpected result", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Ok(item) = time::timeout($timeout, stream.recv()) .await .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) { @@ -102,9 +180,9 @@ macro_rules! collect_stream { /// # use std::time::Duration; /// # use tari_test_utils::collect_stream_count; /// -/// let mut rt = Runtime::new().unwrap(); +/// let rt = Runtime::new().unwrap(); /// let mut stream = stream::iter(vec![1,2,2,3,2]); -/// assert_eq!(rt.block_on(async { collect_stream_count!(stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); +/// assert_eq!(rt.block_on(async { collect_stream_count!(&mut stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); /// ``` #[macro_export] macro_rules! collect_stream_count { @@ -139,3 +217,56 @@ where } } } + +pub async fn assert_in_mpsc(rx: &mut mpsc::Receiver, mut predicate: P, timeout: Duration) -> R +where P: FnMut(T) -> Option { + loop { + if let Some(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the mpsc stream ended"); + } + } +} + +pub async fn assert_in_broadcast(rx: &mut broadcast::Receiver, mut predicate: P, timeout: Duration) -> R +where + P: FnMut(T) -> Option, + T: Clone, +{ + loop { + if let Ok(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the broadcast channel ended"); + } + } +} + +pub fn convert_mpsc_to_stream(rx: &mut mpsc::Receiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_unbounded_mpsc_to_stream(rx: &mut mpsc::UnboundedReceiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_broadcast_to_stream<'a, T, S>(rx: S) -> impl Stream> + 'a +where + T: Clone + Send + 'static, + S: BorrowMut> + 'a, +{ + stream::unfold(rx, |mut rx| async move { + Some(rx.borrow_mut().recv().await).map(|t| (t, rx)) + }) +} diff --git a/integration_tests/features/Mempool.feature b/integration_tests/features/Mempool.feature index 6894c5d781..a77d837a08 100644 --- a/integration_tests/features/Mempool.feature +++ b/integration_tests/features/Mempool.feature @@ -199,4 +199,4 @@ Feature: Mempool When I submit transaction TX1 to BN1 Then I wait until base node BN1 has 1 unconfirmed transactions in its mempool When I mine 1 blocks on BN1 - Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool \ No newline at end of file + Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool diff --git a/integration_tests/features/support/steps.js b/integration_tests/features/support/steps.js index 664d89715e..ef93992b5e 100644 --- a/integration_tests/features/support/steps.js +++ b/integration_tests/features/support/steps.js @@ -1107,20 +1107,24 @@ When(/I spend outputs (.*) via (.*)/, async function (inputs, node) { expect(this.lastResult.result).to.equal("ACCEPTED"); }); -Then(/(.*) has (.*) in (.*) state/, async function (node, txn, pool) { - const client = this.getClient(node); - const sig = this.transactions[txn].body.kernels[0].excess_sig; - await waitFor( - async () => await client.transactionStateResult(sig), - pool, - 1200 * 1000 - ); - this.lastResult = await this.getClient(node).transactionState( - this.transactions[txn].body.kernels[0].excess_sig - ); - console.log(`Node ${node} response is: ${this.lastResult.result}`); - expect(this.lastResult.result).to.equal(pool); -}); +Then( + /(.*) has (.*) in (.*) state/, + { timeout: 21 * 60 * 1000 }, + async function (node, txn, pool) { + const client = this.getClient(node); + const sig = this.transactions[txn].body.kernels[0].excess_sig; + await waitForPredicate( + async () => (await client.transactionStateResult(sig)) === pool, + 20 * 60 * 1000, + 1000 + ); + this.lastResult = await this.getClient(node).transactionState( + this.transactions[txn].body.kernels[0].excess_sig + ); + console.log(`Node ${node} response is: ${this.lastResult.result}`); + expect(this.lastResult.result).to.equal(pool); + } +); // The number is rounded down. E.g. if 1% can fail out of 17, that is 16.83 have to succeed. // It's means at least 16 have to succeed.