diff --git a/Cargo.lock b/Cargo.lock index 93f24f8ead..58595f0837 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a2e47a1fbe209ee101dd6d61285226744c6c8d3c21c8dc878ba6cb9f467f3a" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -28,9 +19,9 @@ dependencies = [ [[package]] name = "aead" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "922b33332f54fc0ad13fa3e514601e8d30fb54e1f3eadc36643f6526db645621" +checksum = "6e3e798aa0c8239776f54415bc06f3d74b1850f3f830b45c35cfc80556973f70" dependencies = [ "generic-array", ] @@ -85,15 +76,15 @@ dependencies = [ [[package]] name = "aes-gcm" -version = "0.9.2" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc3be92e19a7ef47457b8e6f90707e12b6ac5d20c6f3866584fa3be0787d839f" +checksum = "b2a930fd487faaa92a30afa92cc9dd1526a5cff67124abbbb1c617ce070f4dcf" dependencies = [ - "aead 0.4.1", + "aead 0.4.2", "aes 0.7.4", "cipher 0.3.0", - "ctr 0.7.0", - "ghash 0.4.2", + "ctr 0.8.0", + "ghash 0.4.3", "subtle", ] @@ -173,9 +164,9 @@ checksum = "34fde25430d87a9388dadbe6e34d7f72a462c8b43ac8d309b42b0a8505d7e2a5" [[package]] name = "anyhow" -version = "1.0.41" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15af2628f6890fe2609a3b91bef4c83450512802e59489f9c1cb1fa5df064a61" +checksum = "28ae2b3dec75a406790005a200b1bd89785afc02517a00ca99ecfe093ee9e6cf" [[package]] name = "arc-swap" @@ -206,9 +197,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" dependencies = [ "async-stream-impl", "futures-core", @@ -216,24 +207,24 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "async-trait" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b98e84bbb4cbcdd97da190ba0c58a1bb0de2c1fdf67d159e192ed766aeca722" +checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -259,21 +250,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" -[[package]] -name = "backtrace" -version = "0.3.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7815ea54e4d821e791162e078acbebfd6d8c8939cd559c9335dceb1c8ca7282" -dependencies = [ - "addr2line", - "cc", - "cfg-if 1.0.0", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base58-monero" version = "0.3.0" @@ -302,12 +278,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "base64" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" - [[package]] name = "base64" version = "0.12.3" @@ -337,7 +307,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -346,7 +316,7 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da379dbebc0b76ef63ca68d8fc6e71c0f13e59432e0987e508c1820e6ab5239" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "cexpr", "clang-sys", "clap", @@ -355,12 +325,12 @@ dependencies = [ "lazycell", "log 0.4.14", "peeking_take_while", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", "regex", "rustc-hash", "shlex", - "which", + "which 3.1.1", ] [[package]] @@ -377,9 +347,9 @@ checksum = "4efd02e230a02e18f92fc2735f44597385ed02ad8f831e7c1c1156ee5e1ab3a5" [[package]] name = "bitflags" -version = "1.2.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitstring" @@ -452,14 +422,14 @@ checksum = "771fe0050b883fcc3ea2359b1a96bcfbc090b7116eae7c3c512c7a083fdf23d3" [[package]] name = "bstr" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a40b47ad93e1a5404e6c18dec46b628214fee441c70f4ab5d6942142cc268a3d" +checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" dependencies = [ "lazy_static 1.4.0", "memchr", "regex-automata", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -486,9 +456,9 @@ checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631" [[package]] name = "bytemuck" -version = "1.7.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966d2ab714d0f785dbac0a0396251a35280aeb42413281617d0209ab4898435" +checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b" [[package]] name = "byteorder" @@ -496,30 +466,20 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" -[[package]] -name = "bytes" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" -dependencies = [ - "byteorder", - "iovec", -] - [[package]] name = "bytes" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" -dependencies = [ - "serde 1.0.126", -] [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +dependencies = [ + "serde 1.0.129", +] [[package]] name = "c_linked_list" @@ -541,11 +501,11 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] name = "cast" -version = "0.2.3" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b9434b9a5aa1450faa3f9cb14ea0e8c53bb5d2b3c1bfd1ab4fc03e9f33fbfb0" +checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a" dependencies = [ - "rustc_version 0.2.3", + "rustc_version 0.4.0", ] [[package]] @@ -569,20 +529,20 @@ dependencies = [ "heck", "indexmap", "log 0.4.14", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "serde 1.0.126", + "serde 1.0.129", "serde_json", - "syn 1.0.73", + "syn 1.0.75", "tempfile", "toml 0.5.8", ] [[package]] name = "cc" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a72c244c1ff497a746a7e1fb3d14bd08420ecda70c8f25c7112f2781652d787" +checksum = "e70cc2f62c6ce1868963827bd677764c62d07c3d9a3e1fb1177ee1a9ab199eb2" [[package]] name = "cexpr" @@ -616,9 +576,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chacha20" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fee7ad89dc1128635074c268ee661f90c3f7e83d9fd12910608c36b47d6c3412" +checksum = "ea8756167ea0aca10e066cdbe7813bd71d2f24e69b0bc7b50509590cef2ce0b9" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -628,11 +588,11 @@ dependencies = [ [[package]] name = "chacha20poly1305" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1580317203210c517b6d44794abfbe600698276db18127e37ad3e69bf5e848e5" +checksum = "175a11316f33592cf2b71416ee65283730b5b7849813c4891d02a12906ed9acc" dependencies = [ - "aead 0.4.1", + "aead 0.4.2", "chacha20", "cipher 0.3.0", "poly1305", @@ -654,7 +614,7 @@ dependencies = [ "libc", "num-integer", "num-traits 0.2.14", - "serde 1.0.126", + "serde 1.0.129", "time", "winapi 0.3.9", ] @@ -677,7 +637,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6316c62053228eddd526a5e6deb6344c80bf2bc1e9786e7f90b3083e73197c1" dependencies = [ "bitstring", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -723,7 +683,7 @@ checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002" dependencies = [ "ansi_term 0.11.0", "atty", - "bitflags 1.2.1", + "bitflags 1.3.2", "strsim 0.8.0", "textwrap", "unicode-width", @@ -745,7 +705,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", ] [[package]] @@ -763,7 +723,7 @@ dependencies = [ "lazy_static 1.4.0", "nom 4.2.3", "rust-ini", - "serde 1.0.126", + "serde 1.0.129", "serde-hjson", "serde_json", "toml 0.4.10", @@ -827,7 +787,7 @@ dependencies = [ "clap", "criterion-plot", "csv", - "itertools", + "itertools 0.8.2", "lazy_static 1.4.0", "libc", "num-traits 0.2.14", @@ -836,7 +796,7 @@ dependencies = [ "rand_xoshiro", "rayon", "rayon-core", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "serde_json", "tinytemplate", @@ -851,7 +811,7 @@ checksum = "76f9212ddf2f4a9eb2d401635190600656a1f88a932ef53d06e7fa4c7e02fb8e" dependencies = [ "byteorder", "cast", - "itertools", + "itertools 0.8.2", ] [[package]] @@ -911,9 +871,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" dependencies = [ "cfg-if 1.0.0", "crossbeam-epoch", @@ -969,7 +929,7 @@ version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f4919d60f26ae233e14233cc39746c8c8bb8cd7b05840ace83604917b51b6c7" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "crossterm_winapi", "lazy_static 1.4.0", "libc", @@ -1014,7 +974,7 @@ dependencies = [ "csv-core", "itoa", "ryu", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -1037,18 +997,18 @@ dependencies = [ [[package]] name = "ctr" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a232f92a03f37dd7d7dd2adc67166c77e9cd88de5b019b9a9eecfaeaf7bfd481" +checksum = "049bb91fb4aaf0e3c7efa6cd5ef877dbbbd15b39dad06d9948de4ec8a75761ea" dependencies = [ "cipher 0.3.0", ] [[package]] name = "curl-sys" -version = "0.4.44+curl-7.77.0" +version = "0.4.45+curl-7.78.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6d85e9322b193f117c966e79c2d6929ec08c02f339f950044aba12e20bbaf1" +checksum = "de9e5a72b1c744eb5dd20b2be4d7eb84625070bb5c4ab9b347b70464ab1e62eb" dependencies = [ "cc", "libc", @@ -1061,14 +1021,14 @@ dependencies = [ [[package]] name = "curve25519-dalek" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "639891fde0dbea823fc3d798a0fdf9d2f9440a42d64a78ab3488b0ca025117b3" +checksum = "0b9fdf9972b2bd6af2d913799d9ebc165ea4d2e65878e329d9c6b372c4491b61" dependencies = [ "byteorder", "digest", "rand_core 0.5.1", - "serde 1.0.126", + "serde 1.0.129", "subtle", "zeroize", ] @@ -1083,7 +1043,7 @@ dependencies = [ "digest", "packed_simd_2", "rand_core 0.6.3", - "serde 1.0.126", + "serde 1.0.129", "subtle-ng", "zeroize", ] @@ -1106,10 +1066,10 @@ checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b" dependencies = [ "fnv", "ident_case", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", "strsim 0.9.3", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1120,7 +1080,7 @@ checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72" dependencies = [ "darling_core", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1135,9 +1095,9 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1159,9 +1119,9 @@ checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0" dependencies = [ "darling", "derive_builder_core", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1171,9 +1131,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef" dependencies = [ "darling", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1210,9 +1170,9 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f5098f628d02a7a0f68ddba586fb61e80edec3bdc1be3b921f4ceec60858d3" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1263,9 +1223,9 @@ checksum = "56899898ce76aaf4a0f24d914c97ea6ed976d42fec6ad33fcbb0a1103e07b2b0" [[package]] name = "ed25519" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d0860415b12243916284c67a9be413e044ee6668247b99ba26d94b2bc06c8f6" +checksum = "4620d40f6d2601794401d6dd95a5cf69b6c157852539470eeda433a99b3c0efc" dependencies = [ "signature", ] @@ -1279,7 +1239,7 @@ dependencies = [ "curve25519-dalek", "ed25519", "rand 0.7.3", - "serde 1.0.126", + "serde 1.0.129", "sha2", "zeroize", ] @@ -1312,9 +1272,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c5f0096a91d210159eceb2ff5e1c4da18388a170e1e3ce948aac9c8fdbbf595" dependencies = [ "heck", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1440,7 +1400,7 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "fuchsia-zircon-sys", ] @@ -1458,9 +1418,9 @@ checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" [[package]] name = "futures" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27" +checksum = "1adc00f486adfc9ce99f77d717836f0c5aa84965eb0b4f051f4e83f7cab53f8b" dependencies = [ "futures-channel", "futures-core", @@ -1473,9 +1433,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" +checksum = "74ed2411805f6e4e3d9bc904c95d5d423b89b3b25dc0250aa74729de20629ff9" dependencies = [ "futures-core", "futures-sink", @@ -1493,9 +1453,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +checksum = "af51b1b4a7fdff033703db39de8802c673eb91855f2e0d47dcf3bf2c0ef01f99" [[package]] name = "futures-core-preview" @@ -1505,9 +1465,9 @@ checksum = "b35b6263fb1ef523c3056565fa67b1d16f0a8604ff12b11b08c25f28a734c60a" [[package]] name = "futures-executor" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79" +checksum = "4d0d535a57b87e1ae31437b892713aee90cd2d7b0ee48727cd11fc72ef54761c" dependencies = [ "futures-core", "futures-task", @@ -1527,9 +1487,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" +checksum = "0b0e06c393068f3a6ef246c75cdca793d6a46347e75286933e5e75fd2fd11582" [[package]] name = "futures-io-preview" @@ -1539,15 +1499,15 @@ checksum = "f4914ae450db1921a56c91bde97a27846287d062087d4a652efc09bb3a01ebda" [[package]] name = "futures-macro" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" +checksum = "c54913bae956fb8df7f4dc6fc90362aa72e69148e3f39041fbe8742d21e0ac57" dependencies = [ "autocfg 1.0.1", "proc-macro-hack", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -1566,9 +1526,9 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" +checksum = "c0f30aaa67363d119812743aa5f33c201a7a66329f97d1a887022971feea4b53" [[package]] name = "futures-sink-preview" @@ -1578,15 +1538,15 @@ checksum = "86f148ef6b69f75bb610d4f9a2336d4fc88c4b5b67129d1a340dd0fd362efeec" [[package]] name = "futures-task" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" +checksum = "bbe54a98670017f3be909561f6ad13e810d9a51f3f061b902062ca3da80799f2" [[package]] name = "futures-test" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e771858b95154d86bc76b412e4cea3bc104803a7838179e5a1315d9c8a4c2b6" +checksum = "3a5ac667be097531d74ff9fff9c9da7820dd63afd2312bb9c6f589211ae32080" dependencies = [ "futures-core", "futures-executor", @@ -1595,7 +1555,7 @@ dependencies = [ "futures-sink", "futures-task", "futures-util", - "pin-project 1.0.7", + "pin-project 1.0.8", "pin-utils", ] @@ -1624,9 +1584,9 @@ dependencies = [ [[package]] name = "futures-util" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" +checksum = "67eb846bfd58e44a8481a00049e82c43e0ccb5d61f8dc071057cb19249dd4d78" dependencies = [ "autocfg 1.0.1", "futures 0.1.31", @@ -1731,27 +1691,21 @@ dependencies = [ [[package]] name = "ghash" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bbd60caa311237d508927dbba7594b483db3ef05faa55172fcf89b1bcda7853" +checksum = "b442c439366184de619215247d24e908912b175e824a530253845ac4c251a5c1" dependencies = [ "opaque-debug", - "polyval 0.5.1", + "polyval 0.5.2", ] -[[package]] -name = "gimli" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4075386626662786ddb0ec9081e7c7eeb1ba31951f447ca780ef9f5d568189" - [[package]] name = "git2" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7339329bfa14a00223244311560d11f8f489b453fb90092af97f267a6090ab0" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "libc", "libgit2-sys", "log 0.4.14", @@ -1788,11 +1742,11 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726" +checksum = "d7f3675cfef6a30c8031cf9e6493ebdc3bb3272a3fea3923c4210d1830e6a472" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -1800,7 +1754,7 @@ dependencies = [ "http", "indexmap", "slab", - "tokio 1.9.0", + "tokio 1.10.1", "tokio-util 0.6.7", "tracing", ] @@ -1837,9 +1791,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hex-literal" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76505e26b6ca3bbdbbb360b68472abbb80998c5fa5dc43672eca34f28258e138" +checksum = "21e4590e13640f19f249fe3e4eca5113bc4289f2497710378190e7f4bd96f45b" [[package]] name = "http" @@ -1847,7 +1801,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -1864,20 +1818,20 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" +checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] [[package]] name = "httparse" -version = "1.4.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68" +checksum = "acd94fdbe1d4ff688b67b04eee2e17bd50995534a61539e45adfefb45e5e5503" [[package]] name = "httpdate" @@ -1941,7 +1895,7 @@ dependencies = [ "httparse", "httpdate 0.3.2", "itoa", - "pin-project 1.0.7", + "pin-project 1.0.8", "socket2 0.3.19", "tokio 0.2.25", "tower-service", @@ -1951,28 +1905,40 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.11" +version = "0.14.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b61cf2d1aebcf6e6352c97b81dc2244ca29194be1b276f5d8ad5c6330fffb11" +checksum = "13f67199e765030fa08fe0bd581af683f0d5bc04ea09c2b1102012c5fb90e7fd" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", - "h2 0.3.3", + "h2 0.3.4", "http", - "http-body 0.4.2", + "http-body 0.4.3", "httparse", "httpdate 1.0.1", "itoa", "pin-project-lite 0.2.7", - "socket2 0.4.0", - "tokio 1.9.0", + "socket2 0.4.1", + "tokio 1.10.1", "tower-service", "tracing", "want", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.12", + "pin-project-lite 0.2.7", + "tokio 1.10.1", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.4.3" @@ -1992,10 +1958,10 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ - "bytes 1.0.1", - "hyper 0.14.11", + "bytes 1.1.0", + "hyper 0.14.12", "native-tls", - "tokio 1.9.0", + "tokio 1.10.1", "tokio-native-tls", ] @@ -2084,17 +2050,26 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" +checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" [[package]] name = "js-sys" -version = "0.3.51" +version = "0.3.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83bdfbace3a0e81a4253f73b49e960b053e396a11012cbd49b9b74d6a2b67062" +checksum = "e4bf49d50e2961077d9c99f4b7997d770a1114f087c3c2e0069b36c13fc2979d" dependencies = [ "wasm-bindgen", ] @@ -2106,7 +2081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "436f3455a8a4e9c7b14de9f1206198ee5d0bdc2db1b560339d2141093d7dd389" dependencies = [ "hyper 0.10.16", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "serde_json", ] @@ -2156,9 +2131,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.97" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b8adadd720df158f4d70dfe7ccc6adb0472d7c55ca83445f6a5ab3e36f8fb6" +checksum = "a1fa8cddc8fbbee11227ef194b5317ed014b8acbf15139bd716a18ad3fe99ec5" [[package]] name = "libgit2-sys" @@ -2306,7 +2281,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -2329,7 +2304,7 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "serde 1.0.126", + "serde 1.0.129", "serde-value 0.5.3", "serde_derive", "serde_yaml", @@ -2355,7 +2330,7 @@ dependencies = [ "log-mdc", "parking_lot 0.11.1", "regex", - "serde 1.0.126", + "serde 1.0.129", "serde-value 0.7.0", "serde_json", "serde_yaml", @@ -2376,9 +2351,9 @@ dependencies = [ [[package]] name = "matches" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" [[package]] name = "md-5" @@ -2393,9 +2368,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" [[package]] name = "memoffset" @@ -2434,9 +2409,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9753f12909fd8d923f75ae5c3258cae1ed3c8ec052e1b38c93c21a6d157f789c" dependencies = [ "migrations_internals", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -2506,17 +2481,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -2549,21 +2513,39 @@ dependencies = [ "fixed-hash", "hex", "hex-literal", - "serde 1.0.126", + "serde 1.0.129", "serde-big-array", "thiserror", "tiny-keccak", ] +[[package]] +name = "multiaddr" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48ee4ea82141951ac6379f964f71b20876d43712bea8faf6dd1a375e08a46499" +dependencies = [ + "arrayref", + "bs58", + "byteorder", + "data-encoding", + "multihash", + "percent-encoding 2.1.0", + "serde 1.0.129", + "static_assertions", + "unsigned-varint", + "url 2.2.2", +] + [[package]] name = "multihash" -version = "0.13.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac63698b887d2d929306ea48b63760431ff8a24fac40ddb22f9c7f49fb7cab" +checksum = "752a61cd890ff691b4411423d23816d5866dd5621e4d1c5687a53b94b5a979d8" dependencies = [ "generic-array", "multihash-derive", - "unsigned-varint 0.5.1", + "unsigned-varint", ] [[package]] @@ -2574,9 +2556,9 @@ checksum = "424f6e86263cd5294cbd7f1e95746b95aca0e0d66bff31e5a40d6baa87b4aa99" dependencies = [ "proc-macro-crate", "proc-macro-error", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "synstructure", ] @@ -2588,9 +2570,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" [[package]] name = "native-tls" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4" +checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" dependencies = [ "lazy_static 1.4.0", "libc", @@ -2623,9 +2605,12 @@ checksum = "d36047f46c69ef97b60e7b069a26ce9a15cd8a7852eddb6991ea94a83ba36a78" [[package]] name = "nibble_vec" -version = "0.0.4" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d77f3db4bce033f4d04db08079b2ef1c3d02b44e86f25d08886fafa7756ffa" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] [[package]] name = "nix" @@ -2633,7 +2618,7 @@ version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83450fe6a6142ddd95fb064b746083fc4ef1705fe81f64a64e1d4b39f54a1055" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "cc", "cfg-if 0.1.10", "libc", @@ -2730,7 +2715,7 @@ dependencies = [ "num-iter", "num-traits 0.2.14", "rand 0.7.3", - "serde 1.0.126", + "serde 1.0.129", "smallvec", "zeroize", ] @@ -2750,9 +2735,9 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -2826,15 +2811,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a38f2be3697a57b4060074ff41b44c16870d916ad7877c17696e063257482bc7" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.8.0" @@ -2849,11 +2825,11 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.35" +version = "0.10.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "549430950c79ae24e6d02e0b7404534ecf311d94cc9f861e9e4020187d13d885" +checksum = "8d9facdb76fec0b73c406f125d44d86fdad818d66fef0531eec9233ca425ff4a" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "cfg-if 1.0.0", "foreign-types", "libc", @@ -2878,9 +2854,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.65" +version = "0.9.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a7907e3bfa08bb85105209cdfcb6c63d109f8f6c1ed6ca318fff5c1853fbc1d" +checksum = "1996d2d305e561b70d1ee0c53f1542833f4e1ac6ce9a6708b6ff2738ca67dc82" dependencies = [ "autocfg 1.0.1", "cc", @@ -2918,24 +2894,6 @@ dependencies = [ "libm 0.1.4", ] -[[package]] -name = "parity-multiaddr" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bfda2e46fc5e14122649e2645645a81ee5844e0fb2e727ef560cc71a8b2d801" -dependencies = [ - "arrayref", - "bs58", - "byteorder", - "data-encoding", - "multihash", - "percent-encoding 2.1.0", - "serde 1.0.126", - "static_assertions", - "unsigned-varint 0.6.0", - "url 2.2.2", -] - [[package]] name = "parking_lot" version = "0.10.2" @@ -2980,7 +2938,7 @@ dependencies = [ "cfg-if 1.0.0", "instant", "libc", - "redox_syscall 0.2.9", + "redox_syscall 0.2.10", "smallvec", "winapi 0.3.9", ] @@ -3099,11 +3057,11 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4" +checksum = "576bc800220cc65dac09e99e97b08b358cfab6e17078de8dc5fee223bd2d0c08" dependencies = [ - "pin-project-internal 1.0.7", + "pin-project-internal 1.0.8", ] [[package]] @@ -3112,20 +3070,20 @@ version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be26700300be6d9d23264c73211d8190e755b6b5ca7a1b28230025511b52a5e" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "pin-project-internal" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f" +checksum = "6e8fe8163d14ce7f0cdac2e040116f22eac817edabff0be91e8aff7e9accf389" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -3154,9 +3112,9 @@ checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "poly1305" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe800695325da85083cd23b56826fccb2e2dc29b218e7811a6f33bc93f414be" +checksum = "9fcffab1f78ebbdf4b93b68c1ffebc24037eedf271edaca795732b24e5e4e349" dependencies = [ "cpufeatures", "opaque-debug", @@ -3176,9 +3134,9 @@ dependencies = [ [[package]] name = "polyval" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e597450cbf209787f0e6de80bf3795c6b2356a380ee87837b545aded8dbc1823" +checksum = "a6ba6a405ef63530d6cb12802014b22f9c5751bd17cdcddbe9e46d5c8ae83287" dependencies = [ "cfg-if 1.0.0", "cpufeatures", @@ -3209,9 +3167,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" dependencies = [ "proc-macro-error-attr", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "version_check 0.9.3", ] @@ -3221,7 +3179,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", "version_check 0.9.3", ] @@ -3249,61 +3207,61 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" +checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" dependencies = [ "unicode-xid 0.2.2", ] [[package]] name = "prost" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce49aefe0a6144a45de32927c77bd2859a5f7677b55f220ae5b744e87389c212" +checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost-derive", ] [[package]] name = "prost-build" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b10678c913ecbd69350e8535c3aef91a8676c0773fc1d7b95cdd196d7f2f26" +checksum = "355f634b43cdd80724ee7848f95770e7e70eefa6dcf14fea676216573b8fd603" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "heck", - "itertools", + "itertools 0.10.1", "log 0.4.14", "multimap", "petgraph", "prost", "prost-types", "tempfile", - "which", + "which 4.2.2", ] [[package]] name = "prost-derive" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537aa19b95acde10a12fec4301466386f757403de4cd4e5b4fa78fb5ecb18f72" +checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", - "proc-macro2 1.0.27", + "itertools 0.10.1", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "prost-types" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1834f67c0697c001304b75be76f67add9c89742eda3a085ad8ee0bb38c3417aa" +checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost", ] @@ -3344,14 +3302,14 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", ] [[package]] name = "radix_trie" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3681b28cd95acfb0560ea9441f82d6a4504fa3b15b97bd7b6e952131820e95" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" dependencies = [ "endian-type", "nibble_vec", @@ -3368,7 +3326,6 @@ dependencies = [ "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc 0.2.0", - "rand_pcg", ] [[package]] @@ -3468,15 +3425,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "rand_pcg" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "rand_xoshiro" version = "0.1.0" @@ -3493,7 +3441,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da103db7f8022a2646a11e8f58de98d137089f90c3eb0bb54ed18f12ecb73b7" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "git2", "libc", "thiserror", @@ -3501,9 +3449,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674" +checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" dependencies = [ "autocfg 1.0.1", "crossbeam-deque", @@ -3513,9 +3461,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.9.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a" +checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" dependencies = [ "crossbeam-channel 0.5.1", "crossbeam-deque", @@ -3541,11 +3489,11 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "redox_syscall" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab49abadf3f9e1c4bc499e8845e152ad87d2ad2d30371841171169e9d75feee" +checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", ] [[package]] @@ -3555,7 +3503,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" dependencies = [ "getrandom 0.2.3", - "redox_syscall 0.2.9", + "redox_syscall 0.2.10", ] [[package]] @@ -3617,7 +3565,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "serde_urlencoded", "tokio 0.2.25", @@ -3636,13 +3584,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246e9f61b9bb77df069a947682be06e31ac43ea37862e244a69f177694ea6d22" dependencies = [ "base64 0.13.0", - "bytes 1.0.1", + "bytes 1.1.0", "encoding_rs", "futures-core", "futures-util", "http", - "http-body 0.4.2", - "hyper 0.14.11", + "http-body 0.4.3", + "hyper 0.14.12", "hyper-tls 0.5.0", "ipnet", "js-sys", @@ -3652,10 +3600,10 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "serde_urlencoded", - "tokio 1.9.0", + "tokio 1.10.1", "tokio-native-tls", "url 2.2.2", "wasm-bindgen", @@ -3708,7 +3656,7 @@ checksum = "011e1d58446e9fa3af7cdc1fb91295b10621d3ac4cb3a85cc86385ee9ca50cd3" dependencies = [ "byteorder", "rmp", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -3749,12 +3697,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e52c148ef37f8c375d49d5a73aa70713125b7f19095948a923f80afdeb22ec2" -[[package]] -name = "rustc-demangle" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dead70b0b5e03e9c814bcb6b01e03e68f7c57a80aa48c72ec92152ab3e818d49" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -3769,29 +3711,29 @@ checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" [[package]] name = "rustc_version" -version = "0.2.3" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" dependencies = [ - "semver 0.9.0", + "semver 0.11.0", ] [[package]] name = "rustc_version" -version = "0.3.3" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 0.11.0", + "semver 1.0.4", ] [[package]] name = "rustls" -version = "0.17.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d4a31f5d68413404705d6982529b0e11a9aacd4839d1d6222ee3b8cb4015e1" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ - "base64 0.11.0", + "base64 0.13.0", "log 0.4.14", "ring", "sct", @@ -3824,7 +3766,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54a50e29610a5be68d4a586a5cce3bfb572ed2c2a74227e4168444b7bf4e5235" dependencies = [ "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -3886,7 +3828,7 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -3903,35 +3845,20 @@ dependencies = [ "libc", ] -[[package]] -name = "semver" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" -dependencies = [ - "semver-parser 0.7.0", -] - [[package]] name = "semver" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" dependencies = [ - "semver-parser 0.10.2", + "semver-parser", ] [[package]] name = "semver" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f3aac57ee7f3272d8395c6e4f502f434f0e289fcd62876f70daa008c20dcabe" - -[[package]] -name = "semver-parser" -version = "0.7.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" +checksum = "568a8e6258aa33c13358f81fd834adb854c6f7c9468520910a9b1e8fac068012" [[package]] name = "semver-parser" @@ -3950,9 +3877,9 @@ checksum = "9dad3f759919b92c3068c696c15c3d17238234498bbdcc80f2c469606f948ac8" [[package]] name = "serde" -version = "1.0.126" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03" +checksum = "d1f72836d2aa753853178eda473a3b9d8e4eefdaf20523b919677e6de489f8f1" dependencies = [ "serde_derive", ] @@ -3963,7 +3890,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", "serde_derive", ] @@ -3987,7 +3914,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a663f873dedc4eac1a559d4c6bc0d0b2c34dc5ac4702e105014b8281489e44f" dependencies = [ "ordered-float 1.1.1", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -3997,29 +3924,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" dependencies = [ "ordered-float 2.7.0", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] name = "serde_derive" -version = "1.0.126" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "963a7dbc9895aeac7ac90e74f34a5d5261828f79df35cbed41e10189d3804d43" +checksum = "e57ae87ad533d9a56427558b516d0adac283614e347abf85b0dc0cbbf0a249f3" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "serde_json" -version = "1.0.64" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" +checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" dependencies = [ "itoa", "ryu", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -4028,9 +3955,9 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98d0516900518c29efa217c298fa1f4e6c6ffc85ae29fd7f4ee48f176e1a9ed5" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -4051,26 +3978,26 @@ dependencies = [ "form_urlencoded", "itoa", "ryu", - "serde 1.0.126", + "serde 1.0.129", ] [[package]] name = "serde_yaml" -version = "0.8.17" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15654ed4ab61726bf918a39cb8d98a2e2995b002387807fa6ba58fdf7f59bb23" +checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" dependencies = [ "dtoa", - "linked-hash-map 0.5.4", - "serde 1.0.126", + "indexmap", + "serde 1.0.129", "yaml-rust", ] [[package]] name = "sha-1" -version = "0.9.6" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c4cfa741c5832d0ef7fab46cabed29c2aae926db0b11bb2069edd8db5e64e16" +checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4106,9 +4033,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79c719719ee05df97490f80a45acfc99e5a30ce98a1e4fb67aee422745ae14e3" +checksum = "740223c51853f3145fe7c90360d2d4232f2b62e3449489c207eccde818979982" dependencies = [ "lazy_static 1.4.0", ] @@ -4158,9 +4085,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f173ac3d1a7e3b28003f40de0b5ce7fe2710f9b9dc3fc38664cebee46b3b6527" +checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" [[package]] name = "smallvec" @@ -4174,7 +4101,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6142f7c25e94f6fd25a32c3348ec230df9109b463f59c8c7acc4bd34936babb7" dependencies = [ - "aes-gcm 0.9.2", + "aes-gcm 0.9.3", "blake2", "chacha20poly1305", "rand 0.8.4", @@ -4198,9 +4125,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dfc207c526015c632472a77be09cf1b6e46866581aecae5cc38fb4235dea2" +checksum = "765f090f0e423d2b55843402a07915add955e7d60657db13707a159727326cad" dependencies = [ "libc", "winapi 0.3.9", @@ -4268,9 +4195,9 @@ checksum = "7813934aecf5f51a54775e00068c237de98489463968231a51746bbbc03f9c10" dependencies = [ "heck", "proc-macro-error", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -4286,9 +4213,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6e163a520367c465f59e0a61a23cfae3b10b6546d78b6f672a382be79f7110" dependencies = [ "heck", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -4298,9 +4225,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87c85aa3f8ea653bfd3ddf25f7ee357ee4d204731f6aa9ad04002306f6e2774c" dependencies = [ "heck", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -4310,16 +4237,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e61bb0be289045cb80bfce000512e32d09f8337e54c186725da381377ad1f8d5" dependencies = [ "heck", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "subtle" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "subtle-ng" @@ -4357,11 +4284,11 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.73" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7" +checksum = "b7f58f7e8eaa0009c5fec437aabf511bd9933e4b2d7407bd05273c01a8906ea7" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", "unicode-xid 0.2.2", ] @@ -4377,13 +4304,13 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701" +checksum = "474aaa926faa1603c40b7885a9eaea29b444d1cb2850cb7c0e37bb1a4182f4fa" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "unicode-xid 0.2.2", ] @@ -4397,7 +4324,6 @@ dependencies = [ "tari_common_types", "tari_comms", "tari_core", - "tari_crypto", "tari_wallet", "tonic", "tonic-build", @@ -4409,7 +4335,7 @@ version = "0.9.5" dependencies = [ "config", "dirs-next", - "futures 0.3.15", + "futures 0.3.16", "log 0.4.14", "qrcode", "rand 0.8.4", @@ -4424,7 +4350,7 @@ dependencies = [ "tari_p2p", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4436,7 +4362,7 @@ dependencies = [ "bincode", "chrono", "config", - "futures 0.3.15", + "futures 0.3.16", "log 0.4.14", "regex", "rustyline", @@ -4454,9 +4380,8 @@ dependencies = [ "tari_p2p", "tari_service_framework", "tari_shutdown", - "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4473,7 +4398,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rand_core 0.6.3", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "sha3", "subtle-ng", @@ -4491,10 +4416,10 @@ dependencies = [ "git2", "log 0.4.14", "log4rs 1.0.0", - "parity-multiaddr", + "multiaddr", "path-clean", "prost-build", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "sha2", "structopt", @@ -4508,11 +4433,12 @@ dependencies = [ name = "tari_common_types" version = "0.9.5" dependencies = [ - "futures 0.3.15", + "futures 0.3.16", + "lazy_static 1.4.0", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "tari_crypto", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4521,26 +4447,26 @@ version = "0.9.5" dependencies = [ "anyhow", "async-trait", - "bitflags 1.2.1", + "bitflags 1.3.2", "blake2", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "cidr", "clear_on_drop", "data-encoding", "digest", "env_logger 0.7.1", - "futures 0.3.15", + "futures 0.3.16", "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", + "multiaddr", "nom 5.1.2", "openssl", - "parity-multiaddr", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "serde_json", "snow", @@ -4552,10 +4478,10 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-util 0.2.0", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.3.1", "tower-make", "yamux", ] @@ -4565,8 +4491,8 @@ name = "tari_comms_dht" version = "0.9.5" dependencies = [ "anyhow", - "bitflags 1.2.1", - "bytes 0.4.12", + "bitflags 1.3.2", + "bytes 0.5.6", "chacha20", "chrono", "clap", @@ -4574,7 +4500,7 @@ dependencies = [ "diesel_migrations", "digest", "env_logger 0.7.1", - "futures 0.3.15", + "futures 0.3.16", "futures-test-preview", "futures-util", "lazy_static 1.4.0", @@ -4586,7 +4512,7 @@ dependencies = [ "prost", "prost-types", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "serde_repr", "tari_common", @@ -4599,10 +4525,10 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-test", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-test 0.4.2", + "tower 0.3.1", "tower-test", "ttl_cache", ] @@ -4611,15 +4537,14 @@ dependencies = [ name = "tari_comms_rpc_macros" version = "0.9.5" dependencies = [ - "futures 0.3.15", - "proc-macro2 1.0.27", + "futures 0.3.16", + "proc-macro2 1.0.28", "prost", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "tari_comms", "tari_test_utils", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tower-service", ] @@ -4627,11 +4552,11 @@ dependencies = [ name = "tari_console_wallet" version = "0.9.5" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "chrono", "chrono-english", "crossterm", - "futures 0.3.15", + "futures 0.3.16", "log 0.4.14", "qrcode", "rand 0.8.4", @@ -4652,7 +4577,7 @@ dependencies = [ "tari_shutdown", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tui", "unicode-segmentation", @@ -4664,17 +4589,18 @@ name = "tari_core" version = "0.9.5" dependencies = [ "bincode", - "bitflags 1.2.1", + "bitflags 1.3.2", "blake2", - "bytes 0.4.12", + "bytes 0.5.6", "chrono", "config", "croaring", "digest", "env_logger 0.7.1", "fs2", - "futures 0.3.15", + "futures 0.3.16", "hex", + "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", "monero", @@ -4685,7 +4611,7 @@ dependencies = [ "prost-types", "rand 0.8.4", "randomx-rs", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "sha3", "strum_macros 0.17.1", @@ -4703,8 +4629,7 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "ttl_cache", "uint", ] @@ -4725,7 +4650,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rmp-serde", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "sha2", "sha3", @@ -4750,7 +4675,7 @@ version = "0.9.5" dependencies = [ "digest", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "serde_json", "sha2", @@ -4764,20 +4689,20 @@ version = "0.9.5" dependencies = [ "anyhow", "bincode", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "config", "derive-error", "env_logger 0.7.1", - "futures 0.3.15", + "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.8.4", - "reqwest 0.10.10", - "serde 1.0.126", + "reqwest 0.11.4", + "serde 1.0.129", "serde_json", "structopt", "tari_app_grpc", @@ -4787,8 +4712,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tracing", "tracing-futures", @@ -4803,7 +4727,7 @@ dependencies = [ "bufstream", "chrono", "crossbeam", - "futures 0.3.15", + "futures 0.3.16", "hex", "jsonrpc", "log 0.4.14", @@ -4812,7 +4736,7 @@ dependencies = [ "prost-types", "rand 0.8.4", "reqwest 0.11.4", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "sha3", "tari_app_grpc", @@ -4822,7 +4746,7 @@ dependencies = [ "tari_crypto", "thiserror", "time", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4837,7 +4761,7 @@ dependencies = [ "digest", "log 0.4.14", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "tari_crypto", "tari_infra_derive", @@ -4855,7 +4779,7 @@ dependencies = [ "clap", "env_logger 0.6.2", "fs2", - "futures 0.3.15", + "futures 0.3.16", "futures-timer", "lazy_static 1.4.0", "lmdb-zero", @@ -4865,8 +4789,8 @@ dependencies = [ "prost", "rand 0.8.4", "reqwest 0.10.10", - "semver 1.0.3", - "serde 1.0.126", + "semver 1.0.4", + "serde 1.0.129", "serde_derive", "stream-cancel", "tari_common", @@ -4880,9 +4804,9 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tower 0.3.1", "tower-service", "trust-dns-client", ] @@ -4893,15 +4817,14 @@ version = "0.9.5" dependencies = [ "anyhow", "async-trait", - "futures 0.3.15", + "futures 0.3.16", "futures-test", "log 0.4.14", "tari_shutdown", "tari_test_utils", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", "tower-service", ] @@ -4909,8 +4832,8 @@ dependencies = [ name = "tari_shutdown" version = "0.9.5" dependencies = [ - "futures 0.3.15", - "tokio 0.2.25", + "futures 0.3.16", + "tokio 1.10.1", ] [[package]] @@ -4918,14 +4841,14 @@ name = "tari_storage" version = "0.9.5" dependencies = [ "bincode", - "bytes 0.4.12", + "bytes 0.5.6", "env_logger 0.6.2", "lmdb-zero", "log 0.4.14", "rand 0.8.4", "rmp", "rmp-serde", - "serde 1.0.126", + "serde 1.0.129", "serde_derive", "tari_utilities", "thiserror", @@ -4937,7 +4860,7 @@ version = "0.0.1" dependencies = [ "hex", "libc", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "tari_app_grpc", "tari_common", @@ -4958,7 +4881,7 @@ dependencies = [ "config", "derive-error", "env_logger 0.7.1", - "futures 0.3.15", + "futures 0.3.16", "futures-test", "hex", "hyper 0.13.10", @@ -4966,7 +4889,7 @@ dependencies = [ "log 0.4.14", "rand 0.7.3", "reqwest 0.10.10", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "structopt", "tari_app_grpc", @@ -4975,8 +4898,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tonic-build", "tracing", @@ -4989,13 +4911,13 @@ dependencies = [ name = "tari_test_utils" version = "0.9.5" dependencies = [ - "futures 0.3.15", + "futures 0.3.16", "futures-test", "lazy_static 1.4.0", "rand 0.8.4", "tari_shutdown", "tempfile", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5006,12 +4928,12 @@ checksum = "22966aea452f806a83b75d59d54d34f638e48b94a1ea2b2e0efce9aacf532635" dependencies = [ "base64 0.10.1", "bincode", - "bitflags 1.2.1", + "bitflags 1.3.2", "chrono", "clear_on_drop", "newtype-ops", "rand 0.7.3", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "thiserror", ] @@ -5030,7 +4952,7 @@ dependencies = [ "digest", "env_logger 0.7.1", "fs2", - "futures 0.3.15", + "futures 0.3.16", "lazy_static 1.4.0", "libsqlite3-sys", "lmdb-zero", @@ -5038,7 +4960,7 @@ dependencies = [ "log4rs 1.0.0", "prost", "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "tari_common_types", "tari_comms", @@ -5054,9 +4976,8 @@ dependencies = [ "tempfile", "thiserror", "time", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", ] [[package]] @@ -5065,7 +4986,7 @@ version = "0.17.4" dependencies = [ "chrono", "env_logger 0.7.1", - "futures 0.3.15", + "futures 0.3.16", "lazy_static 1.4.0", "libc", "log 0.4.14", @@ -5084,7 +5005,7 @@ dependencies = [ "tari_wallet", "tempfile", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5096,7 +5017,7 @@ dependencies = [ "cfg-if 1.0.0", "libc", "rand 0.8.4", - "redox_syscall 0.2.9", + "redox_syscall 0.2.10", "remove_dir_all", "winapi 0.3.9", ] @@ -5115,12 +5036,12 @@ name = "test_faucet" version = "0.9.5" dependencies = [ "rand 0.8.4", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "tari_core", "tari_crypto", "tari_utilities", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5147,9 +5068,9 @@ version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -5198,15 +5119,15 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", "serde_json", ] [[package]] name = "tinyvec" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b5220f05bb7de7f3f53c7c065e1199b3172696fe2db9f9c4d8ad9b4ee74c342" +checksum = "848a1e1181b9f6753b5e96a092749e29b11d19ede67dfbbd6c7dc7e0f49b5338" dependencies = [ "tinyvec_macros", ] @@ -5228,43 +5149,50 @@ dependencies = [ "futures-core", "iovec", "lazy_static 1.4.0", - "libc", "memchr", "mio 0.6.23", - "mio-uds", - "num_cpus", "pin-project-lite 0.1.12", - "signal-hook-registry", "slab", - "tokio-macros", - "winapi 0.3.9", ] [[package]] name = "tokio" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b7b349f11a7047e6d1276853e612d152f5e8a352c61917887cc2169e2366b4c" +checksum = "92036be488bb6594459f2e03b60e42df6f937fe6ca5c5ffdcb539c6b84dc40f5" dependencies = [ "autocfg 1.0.1", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", "mio 0.7.13", "num_cpus", + "once_cell", "pin-project-lite 0.2.7", + "signal-hook-registry", + "tokio-macros", "winapi 0.3.9", ] +[[package]] +name = "tokio-io-timeout" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90c49f106be240de154571dd31fbe48acb10ba6c6dd6f6517ad603abffa42de9" +dependencies = [ + "pin-project-lite 0.2.7", + "tokio 1.10.1", +] + [[package]] name = "tokio-macros" -version = "0.2.6" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -5274,7 +5202,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" dependencies = [ "native-tls", - "tokio 1.9.0", + "tokio 1.10.1", +] + +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio 1.10.1", + "webpki", +] + +[[package]] +name = "tokio-stream" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" +dependencies = [ + "futures-core", + "pin-project-lite 0.2.7", + "tokio 1.10.1", + "tokio-util 0.6.7", ] [[package]] @@ -5289,26 +5240,25 @@ dependencies = [ ] [[package]] -name = "tokio-tls" -version = "0.3.1" +name = "tokio-test" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a70f4fcd7b3b24fb194f837560168208f669ca8cb70d0c4b862944452396343" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" dependencies = [ - "native-tls", - "tokio 0.2.25", + "async-stream", + "bytes 1.1.0", + "futures-core", + "tokio 1.10.1", + "tokio-stream", ] [[package]] -name = "tokio-util" -version = "0.2.0" +name = "tokio-tls" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "571da51182ec208780505a32528fc5512a8fe1443ab960b3f2f3ef093cd16930" +checksum = "9a70f4fcd7b3b24fb194f837560168208f669ca8cb70d0c4b862944452396343" dependencies = [ - "bytes 0.5.6", - "futures-core", - "futures-sink", - "log 0.4.14", - "pin-project-lite 0.1.12", + "native-tls", "tokio 0.2.25", ] @@ -5332,12 +5282,13 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", + "futures-io", "futures-sink", "log 0.4.14", "pin-project-lite 0.2.7", - "tokio 1.9.0", + "tokio 1.10.1", ] [[package]] @@ -5346,7 +5297,7 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", ] [[package]] @@ -5355,34 +5306,35 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", ] [[package]] name = "tonic" -version = "0.2.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4afef9ce97ea39593992cf3fa00ff33b1ad5eb07665b31355df63a690e38c736" +checksum = "796c5e1cd49905e65dd8e700d4cb1dffcbfdb4fc9d017de08c1a537afd83627c" dependencies = [ "async-stream", "async-trait", - "base64 0.11.0", - "bytes 0.5.6", + "base64 0.13.0", + "bytes 1.1.0", "futures-core", "futures-util", + "h2 0.3.4", "http", - "http-body 0.3.1", - "hyper 0.13.10", + "http-body 0.4.3", + "hyper 0.14.12", + "hyper-timeout", "percent-encoding 2.1.0", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "prost-derive", - "tokio 0.2.25", - "tokio-util 0.3.1", - "tower", - "tower-balance", - "tower-load", - "tower-make", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.4.8", + "tower-layer", "tower-service", "tracing", "tracing-futures", @@ -5390,14 +5342,14 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d8d21cb568e802d77055ab7fcd43f0992206de5028de95c8d3a41118d32e8e" +checksum = "12b52d07035516c2b74337d2ac7746075e7dcae7643816c1b12c5ff8a7484c08" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "prost-build", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] @@ -5419,23 +5371,21 @@ dependencies = [ ] [[package]] -name = "tower-balance" -version = "0.3.0" +name = "tower" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a792277613b7052448851efcf98a2c433e6f1d01460832dc60bef676bc275d4c" +checksum = "f60422bc7fefa2f3ec70359b8ff1caff59d785877eb70595904605bcc412470f" dependencies = [ "futures-core", "futures-util", "indexmap", - "pin-project 0.4.28", - "rand 0.7.3", + "pin-project 1.0.8", + "rand 0.8.4", "slab", - "tokio 0.2.25", - "tower-discover", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", "tower-layer", - "tower-load", - "tower-make", - "tower-ready-cache", "tower-service", "tracing", ] @@ -5517,21 +5467,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" dependencies = [ - "tokio 0.2.25", - "tower-service", -] - -[[package]] -name = "tower-ready-cache" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eabb6620e5481267e2ec832c780b31cad0c15dcb14ed825df5076b26b591e1f" -dependencies = [ - "futures-core", - "futures-util", - "indexmap", - "log 0.4.14", - "tokio 0.2.25", "tower-service", ] @@ -5563,7 +5498,7 @@ dependencies = [ "futures-util", "pin-project 0.4.28", "tokio 0.2.25", - "tokio-test", + "tokio-test 0.2.1", "tower-layer", "tower-service", ] @@ -5611,16 +5546,16 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c42e6fa53307c8a17e4ccd4dc81cf5ec38db9209f59b222210375b54ee40d1e2" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", ] [[package]] name = "tracing-core" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9ff14f98b1a4b289c6248a023c1c2fa1491062964e9fed67ab29c4e4da4a052" +checksum = "2ca517f43f0fb96e0c3072ed5c275fe5eece87e8cb52f4a77b69226d3b1c9df8" dependencies = [ "lazy_static 1.4.0", ] @@ -5631,7 +5566,7 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" dependencies = [ - "pin-project 1.0.7", + "pin-project 1.0.8", "tracing", ] @@ -5652,22 +5587,22 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb65ea441fbb84f9f6748fd496cf7f63ec9af5bca94dd86456978d055e8eb28b" dependencies = [ - "serde 1.0.126", + "serde 1.0.129", "tracing-core", ] [[package]] name = "tracing-subscriber" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab69019741fca4d98be3c62d2b75254528b5432233fd8a4d2739fec20278de48" +checksum = "b9cbe87a2fa7e35900ce5de20220a582a9483a7063811defce79d7cbd59d4cfe" dependencies = [ "ansi_term 0.12.1", "chrono", "lazy_static 1.4.0", "matchers", "regex", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "sharded-slab", "smallvec", @@ -5686,47 +5621,54 @@ checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" [[package]] name = "trust-dns-client" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e935ae5a26a2745fb5a6b95f0e206e1cfb7f00066892d2cf78a8fee87bc2e0c6" +checksum = "37532ce92c75c6174b1d51ed612e26c5fde66ef3f29aa10dbd84e7c5d9a0c27b" dependencies = [ "cfg-if 1.0.0", "chrono", "data-encoding", - "futures 0.3.15", + "futures-channel", + "futures-util", "lazy_static 1.4.0", "log 0.4.14", "radix_trie", - "rand 0.7.3", + "rand 0.8.4", "ring", "rustls", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "trust-dns-proto", "webpki", ] [[package]] name = "trust-dns-proto" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cad71a0c0d68ab9941d2fb6e82f8fb2e86d9945b94e1661dd0aaea2b88215a9" +checksum = "4cd23117e93ea0e776abfd8a07c9e389d7ecd3377827858f21bd795ebdfefa36" dependencies = [ "async-trait", - "backtrace", "cfg-if 1.0.0", "data-encoding", "enum-as-inner", - "futures 0.3.15", + "futures-channel", + "futures-io", + "futures-util", "idna 0.2.3", + "ipnet", "lazy_static 1.4.0", "log 0.4.14", - "rand 0.7.3", + "rand 0.8.4", "ring", + "rustls", "smallvec", "thiserror", - "tokio 0.2.25", + "tinyvec", + "tokio 1.10.1", + "tokio-rustls", "url 2.2.2", + "webpki", ] [[package]] @@ -5759,7 +5701,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2eaeee894a1e9b90f80aa466fe59154fdb471980b5e104d8836fcea309ae17e" dependencies = [ - "bitflags 1.2.1", + "bitflags 1.3.2", "cassowary", "crossterm", "unicode-segmentation", @@ -5836,12 +5778,9 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeb8be209bb1c96b7c177c7420d26e04eccacb0eeae6b980e35fcb74678107e0" -dependencies = [ - "matches", -] +checksum = "246f4c42e67e7a4e3c6106ff716a5d067d4132a642840b242e357e468a2a0085" [[package]] name = "unicode-normalization" @@ -5884,9 +5823,9 @@ checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" [[package]] name = "universal-hash" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8326b2c654932e3e4f9196e69d08fdf7cfd718e1dc6f66b347e6024a0c961402" +checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05" dependencies = [ "generic-array", "subtle", @@ -5903,15 +5842,9 @@ dependencies = [ [[package]] name = "unsigned-varint" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fdeedbf205afadfe39ae559b75c3240f24e257d0ca27e85f85cb82aa19ac35" - -[[package]] -name = "unsigned-varint" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35581ff83d4101e58b582e607120c7f5ffb17e632a980b1f38334d76b36908b2" +checksum = "5f8d425fafb8cd76bc3f22aace4af471d3156301d7508f2107e98fbeae10bc7f" [[package]] name = "untrusted" @@ -6007,36 +5940,36 @@ checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" [[package]] name = "wasm-bindgen" -version = "0.2.74" +version = "0.2.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54ee1d4ed486f78874278e63e4069fc1ab9f6a18ca492076ffb90c5eb2997fd" +checksum = "8ce9b1b516211d33767048e5d47fa2a381ed8b76fc48d2ce4aa39877f9f183e0" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.126", + "serde 1.0.129", "serde_json", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.74" +version = "0.2.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b33f6a0694ccfea53d94db8b2ed1c3a8a4c86dd936b13b9f0a15ec4a451b900" +checksum = "cfe8dc78e2326ba5f845f4b5bf548401604fa20b1dd1d365fb73b6c1d6364041" dependencies = [ "bumpalo", "lazy_static 1.4.0", "log 0.4.14", - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.24" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fba7978c679d53ce2d0ac80c8c175840feb849a161664365d1287b41f2e67f1" +checksum = "95fded345a6559c2cfee778d562300c581f7d4ff3edb9b0d230d69800d213972" dependencies = [ "cfg-if 1.0.0", "js-sys", @@ -6046,9 +5979,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.74" +version = "0.2.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "088169ca61430fe1e58b8096c24975251700e7b1f6fd91cc9d59b04fb9b18bd4" +checksum = "44468aa53335841d9d6b6c023eaab07c0cd4bddbcfdee3e2bb1e8d2cb8069fef" dependencies = [ "quote 1.0.9", "wasm-bindgen-macro-support", @@ -6056,28 +5989,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.74" +version = "0.2.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be2241542ff3d9f241f5e2cb6dd09b37efe786df8851c54957683a49f0987a97" +checksum = "0195807922713af1e67dc66132c7328206ed9766af3858164fb583eedc25fbad" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.74" +version = "0.2.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7cff876b8f18eed75a66cf49b65e7f967cb354a7aa16003fb55dbfd25b44b4f" +checksum = "acdb075a845574a1fa5f09fd77e43f7747599301ea3417a9fbffdeedfc1f4a29" [[package]] name = "web-sys" -version = "0.3.51" +version = "0.3.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e828417b379f3df7111d3a2a9e5753706cae29c41f7c4029ee9fd77f3e09e582" +checksum = "224b2f6b67919060055ef1a67807367c2066ed520c3862cc013d26cf893a783c" dependencies = [ "js-sys", "wasm-bindgen", @@ -6102,6 +6035,17 @@ dependencies = [ "libc", ] +[[package]] +name = "which" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea187a8ef279bc014ec368c27a920da2024d2a711109bfbe3440585d5cf27ad9" +dependencies = [ + "either", + "lazy_static 1.4.0", + "libc", +] + [[package]] name = "winapi" version = "0.2.8" @@ -6190,7 +6134,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7d9028f208dd5e63c614be69f115c1b53cacc1111437d4c765185856666c107" dependencies = [ - "futures 0.3.15", + "futures 0.3.16", "log 0.4.14", "nohash-hasher", "parking_lot 0.11.1", @@ -6213,8 +6157,8 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2c1e130bebaeab2f23886bf9acbaca14b092408c452543c857f66399cd6dab1" dependencies = [ - "proc-macro2 1.0.27", + "proc-macro2 1.0.28", "quote 1.0.9", - "syn 1.0.73", + "syn 1.0.75", "synstructure", ] diff --git a/applications/tari_app_grpc/Cargo.toml b/applications/tari_app_grpc/Cargo.toml index bec7e05720..fac860a5f1 100644 --- a/applications/tari_app_grpc/Cargo.toml +++ b/applications/tari_app_grpc/Cargo.toml @@ -10,14 +10,16 @@ edition = "2018" [dependencies] tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_core = { path = "../../base_layer/core"} -tari_wallet = { path = "../../base_layer/wallet"} -tari_crypto = "0.11.1" +tari_wallet = { path = "../../base_layer/wallet", optional = true} tari_comms = { path = "../../comms"} chrono = "0.4.6" -prost = "0.6" -prost-types = "0.6.1" -tonic = "0.2" +prost = "0.8" +prost-types = "0.8" +tonic = "0.5.2" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" + +[features] +wallet = ["tari_wallet"] diff --git a/applications/tari_app_grpc/src/conversions/block_header.rs b/applications/tari_app_grpc/src/conversions/block_header.rs index 3b660f21e1..c89bae5b78 100644 --- a/applications/tari_app_grpc/src/conversions/block_header.rs +++ b/applications/tari_app_grpc/src/conversions/block_header.rs @@ -25,8 +25,12 @@ use crate::{ tari_rpc as grpc, }; use std::convert::TryFrom; -use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork, transactions::types::BlindingFactor}; -use tari_crypto::tari_utilities::{ByteArray, Hashable}; +use tari_core::{ + blocks::BlockHeader, + crypto::tari_utilities::{ByteArray, Hashable}, + proof_of_work::ProofOfWork, + transactions::types::BlindingFactor, +}; impl From<BlockHeader> for grpc::BlockHeader { fn from(h: BlockHeader) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/com_signature.rs b/applications/tari_app_grpc/src/conversions/com_signature.rs index e10e48ffe8..1ffb26f08e 100644 --- a/applications/tari_app_grpc/src/conversions/com_signature.rs +++ b/applications/tari_app_grpc/src/conversions/com_signature.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; use tari_core::transactions::types::{ComSignature, Commitment, PrivateKey}; diff --git a/applications/tari_app_grpc/src/conversions/mod.rs b/applications/tari_app_grpc/src/conversions/mod.rs index e404d52d1a..f48a1b876d 100644 --- a/applications/tari_app_grpc/src/conversions/mod.rs +++ b/applications/tari_app_grpc/src/conversions/mod.rs @@ -59,7 +59,7 @@ pub use self::{ use crate::{tari_rpc as grpc, tari_rpc::BlockGroupRequest}; use prost_types::Timestamp; -use tari_crypto::tari_utilities::epoch_time::EpochTime; +use tari_core::crypto::tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `EpochTime` to a `prost::Timestamp` pub fn datetime_to_timestamp(datetime: EpochTime) -> Timestamp { diff --git a/applications/tari_app_grpc/src/conversions/new_block_template.rs b/applications/tari_app_grpc/src/conversions/new_block_template.rs index 7d87900cd1..f4e01677f1 100644 --- a/applications/tari_app_grpc/src/conversions/new_block_template.rs +++ b/applications/tari_app_grpc/src/conversions/new_block_template.rs @@ -24,10 +24,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; use tari_core::{ blocks::{NewBlockHeaderTemplate, NewBlockTemplate}, + crypto::tari_utilities::ByteArray, proof_of_work::ProofOfWork, transactions::types::BlindingFactor, }; -use tari_crypto::tari_utilities::ByteArray; impl From<NewBlockTemplate> for grpc::NewBlockTemplate { fn from(block: NewBlockTemplate) -> Self { let header = grpc::NewBlockHeaderTemplate { diff --git a/applications/tari_app_grpc/src/conversions/peer.rs b/applications/tari_app_grpc/src/conversions/peer.rs index f04d3fd8dd..f3bd151d0d 100644 --- a/applications/tari_app_grpc/src/conversions/peer.rs +++ b/applications/tari_app_grpc/src/conversions/peer.rs @@ -22,7 +22,7 @@ use crate::{conversions::datetime_to_timestamp, tari_rpc as grpc}; use tari_comms::{connectivity::ConnectivityStatus, net_address::MutliaddrWithStats, peer_manager::Peer}; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; impl From<Peer> for grpc::Peer { fn from(peer: Peer) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/signature.rs b/applications/tari_app_grpc/src/conversions/signature.rs index d9883a338e..8468491a25 100644 --- a/applications/tari_app_grpc/src/conversions/signature.rs +++ b/applications/tari_app_grpc/src/conversions/signature.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; use tari_core::transactions::types::{PrivateKey, PublicKey, Signature}; diff --git a/applications/tari_app_grpc/src/conversions/transaction.rs b/applications/tari_app_grpc/src/conversions/transaction.rs index 97dcaf8795..cb9e91f63a 100644 --- a/applications/tari_app_grpc/src/conversions/transaction.rs +++ b/applications/tari_app_grpc/src/conversions/transaction.rs @@ -22,9 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::transaction::Transaction; -use tari_crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}; -use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; +use tari_core::{ + crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}, + transactions::transaction::Transaction, +}; impl From<Transaction> for grpc::Transaction { fn from(source: Transaction) -> Self { @@ -53,38 +54,44 @@ impl TryFrom<grpc::Transaction> for Transaction { } } -impl From<models::TransactionStatus> for grpc::TransactionStatus { - fn from(status: models::TransactionStatus) -> Self { - use models::TransactionStatus::*; - match status { - Completed => grpc::TransactionStatus::Completed, - Broadcast => grpc::TransactionStatus::Broadcast, - MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, - MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, - Imported => grpc::TransactionStatus::Imported, - Pending => grpc::TransactionStatus::Pending, - Coinbase => grpc::TransactionStatus::Coinbase, +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; + + impl From<models::TransactionStatus> for grpc::TransactionStatus { + fn from(status: models::TransactionStatus) -> Self { + use models::TransactionStatus::*; + match status { + Completed => grpc::TransactionStatus::Completed, + Broadcast => grpc::TransactionStatus::Broadcast, + MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, + MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, + Imported => grpc::TransactionStatus::Imported, + Pending => grpc::TransactionStatus::Pending, + Coinbase => grpc::TransactionStatus::Coinbase, + } } } -} -impl From<models::TransactionDirection> for grpc::TransactionDirection { - fn from(status: models::TransactionDirection) -> Self { - use models::TransactionDirection::*; - match status { - Unknown => grpc::TransactionDirection::Unknown, - Inbound => grpc::TransactionDirection::Inbound, - Outbound => grpc::TransactionDirection::Outbound, + impl From<models::TransactionDirection> for grpc::TransactionDirection { + fn from(status: models::TransactionDirection) -> Self { + use models::TransactionDirection::*; + match status { + Unknown => grpc::TransactionDirection::Unknown, + Inbound => grpc::TransactionDirection::Inbound, + Outbound => grpc::TransactionDirection::Outbound, + } } } -} -impl grpc::TransactionInfo { - pub fn not_found(tx_id: TxId) -> Self { - Self { - tx_id, - status: grpc::TransactionStatus::NotFound as i32, - ..Default::default() + impl grpc::TransactionInfo { + pub fn not_found(tx_id: TxId) -> Self { + Self { + tx_id, + status: grpc::TransactionStatus::NotFound as i32, + ..Default::default() + } } } } diff --git a/applications/tari_app_grpc/src/conversions/transaction_input.rs b/applications/tari_app_grpc/src/conversions/transaction_input.rs index ed5793a2f2..692af44856 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_input.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_input.rs @@ -22,13 +22,15 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - transaction::TransactionInput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::{ByteArray, Hashable}, +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::{ + transaction::TransactionInput, + types::{Commitment, PublicKey}, + }, }; impl TryFrom<grpc::TransactionInput> for TransactionInput { diff --git a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs index e394a6bce5..ac7b4c540d 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs @@ -22,12 +22,14 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::{KernelFeatures, TransactionKernel}, - types::Commitment, +use tari_core::{ + crypto::tari_utilities::{ByteArray, Hashable}, + transactions::{ + tari_amount::MicroTari, + transaction::{KernelFeatures, TransactionKernel}, + types::Commitment, + }, }; -use tari_crypto::tari_utilities::{ByteArray, Hashable}; impl TryFrom<grpc::TransactionKernel> for TransactionKernel { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/transaction_output.rs b/applications/tari_app_grpc/src/conversions/transaction_output.rs index b9556b2940..05b13f41e3 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_output.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_output.rs @@ -22,14 +22,16 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - bullet_rangeproofs::BulletRangeProof, - transaction::TransactionOutput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::TariScript, - tari_utilities::{ByteArray, Hashable}, +use tari_core::{ + crypto::{ + script::TariScript, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::{ + bullet_rangeproofs::BulletRangeProof, + transaction::TransactionOutput, + types::{Commitment, PublicKey}, + }, }; impl TryFrom<grpc::TransactionOutput> for TransactionOutput { diff --git a/applications/tari_app_grpc/src/conversions/unblinded_output.rs b/applications/tari_app_grpc/src/conversions/unblinded_output.rs index 94ac4c178d..fbf1276fd3 100644 --- a/applications/tari_app_grpc/src/conversions/unblinded_output.rs +++ b/applications/tari_app_grpc/src/conversions/unblinded_output.rs @@ -22,14 +22,16 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::UnblindedOutput, - types::{PrivateKey, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::ByteArray, +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::ByteArray, + }, + transactions::{ + tari_amount::MicroTari, + transaction::UnblindedOutput, + types::{PrivateKey, PublicKey}, + }, }; impl From<UnblindedOutput> for grpc::UnblindedOutput { diff --git a/applications/tari_app_utilities/Cargo.toml b/applications/tari_app_utilities/Cargo.toml index 333af5f959..7a6bd8a7de 100644 --- a/applications/tari_app_utilities/Cargo.toml +++ b/applications/tari_app_utilities/Cargo.toml @@ -9,21 +9,21 @@ tari_comms = { path = "../../comms"} tari_crypto = "0.11.1" tari_common = { path = "../../common" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", optional = true } config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} qrcode = { version = "0.12" } dirs-next = "1.0.2" serde_json = "1.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version="^1.10", features = ["signal"] } structopt = { version = "0.3.13", default_features = false } strum = "^0.19" strum_macros = "^0.19" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" [dependencies.tari_core] path = "../../base_layer/core" @@ -33,3 +33,6 @@ features = ["transactions"] [build-dependencies] tari_common = { path = "../../common", features = ["build", "static-application-info"] } + +[features] +wallet = ["tari_wallet"] diff --git a/applications/tari_app_utilities/src/utilities.rs b/applications/tari_app_utilities/src/utilities.rs index d963f212a7..52655eae8d 100644 --- a/applications/tari_app_utilities/src/utilities.rs +++ b/applications/tari_app_utilities/src/utilities.rs @@ -35,13 +35,8 @@ use tari_comms::{ types::CommsPublicKey, utils::multiaddr::multiaddr_to_socketaddr, }; -use tari_core::tari_utilities::hex::Hex; +use tari_core::{tari_utilities::hex::Hex, transactions::emoji::EmojiId}; use tari_p2p::transport::{TorConfig, TransportType}; -use tari_wallet::{ - error::{WalletError, WalletStorageError}, - output_manager_service::error::OutputManagerError, - util::emoji::EmojiId, -}; use thiserror::Error; use tokio::{runtime, runtime::Runtime}; @@ -107,20 +102,6 @@ impl From<tari_common::ConfigError> for ExitCodes { } } -impl From<WalletError> for ExitCodes { - fn from(err: WalletError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - -impl From<OutputManagerError> for ExitCodes { - fn from(err: OutputManagerError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - impl From<ConnectivityError> for ExitCodes { fn from(err: ConnectivityError) -> Self { error!(target: LOG_TARGET, "{}", err); @@ -135,13 +116,36 @@ impl From<RpcError> for ExitCodes { } } -impl From<WalletStorageError> for ExitCodes { - fn from(err: WalletStorageError) -> Self { - use WalletStorageError::*; - match err { - NoPasswordError => ExitCodes::NoPassword, - IncorrectPassword => ExitCodes::IncorrectPassword, - e => ExitCodes::WalletError(e.to_string()), +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{ + error::{WalletError, WalletStorageError}, + output_manager_service::error::OutputManagerError, + }; + + impl From<WalletError> for ExitCodes { + fn from(err: WalletError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From<OutputManagerError> for ExitCodes { + fn from(err: OutputManagerError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From<WalletStorageError> for ExitCodes { + fn from(err: WalletStorageError) -> Self { + use WalletStorageError::*; + match err { + NoPasswordError => ExitCodes::NoPassword, + IncorrectPassword => ExitCodes::IncorrectPassword, + e => ExitCodes::WalletError(e.to_string()), + } } } } @@ -259,26 +263,22 @@ pub fn convert_socks_authentication(auth: SocksAuthentication) -> socks::Authent /// ## Returns /// A result containing the runtime on success, string indicating the error on failure pub fn setup_runtime(config: &GlobalConfig) -> Result<Runtime, String> { - info!( - target: LOG_TARGET, - "Configuring the node to run on up to {} core threads and {} mining threads.", - config.max_threads.unwrap_or(512), - config.num_mining_threads - ); + let mut builder = runtime::Builder::new_multi_thread(); - let mut builder = runtime::Builder::new(); - - if let Some(max_threads) = config.max_threads { - // Ensure that there are always enough threads for mining. - // e.g if the user sets max_threads = 2, mining_threads = 5 then 7 threads are available in total - builder.max_threads(max_threads + config.num_mining_threads); - } if let Some(core_threads) = config.core_threads { - builder.core_threads(core_threads); + info!( + target: LOG_TARGET, + "Configuring the node to run on up to {} core threads.", + config + .core_threads + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "<num cores>".to_string()), + ); + builder.worker_threads(core_threads); } builder - .threaded_scheduler() .enable_all() .build() .map_err(|e| format!("There was an error while building the node runtime. {}", e.to_string())) diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index 7ac4914c50..af94c84643 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -11,33 +11,32 @@ edition = "2018" tari_app_grpc = { path = "../tari_app_grpc" } tari_app_utilities = { path = "../tari_app_utilities" } tari_common = { path = "../../common" } -tari_comms = { path = "../../comms", features = ["rpc"]} -tari_comms_dht = { path = "../../comms/dht"} -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_comms = { path = "../../comms", features = ["rpc"] } +tari_comms_dht = { path = "../../comms/dht" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_mmr = { path = "../../base_layer/mmr" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_service_framework = { path = "../../base_layer/service_framework"} -tari_shutdown = { path = "../../infrastructure/shutdown"} -tari_wallet = { path = "../../base_layer/wallet" } +tari_service_framework = { path = "../../base_layer/service_framework" } +tari_shutdown = { path = "../../infrastructure/shutdown" } anyhow = "1.0.32" bincode = "1.3.1" chrono = "0.4" config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"] } log = { version = "0.4.8", features = ["std"] } regex = "1" rustyline = "6.0" rustyline-derive = "0.3" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version = "^1.10", features = ["signal"] } strum = "^0.19" strum_macros = "0.18.0" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" [features] -avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_wallet/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] +avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] safe = [] diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index 9953483b4b..951c8ec5db 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -59,7 +59,6 @@ use tari_p2p::{ }; use tari_service_framework::{ServiceHandles, StackBuilder}; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "c::bn::initialization"; /// The minimum buffer size for the base node pubsub_connector channel @@ -84,8 +83,7 @@ where B: BlockchainBackend + 'static fs::create_dir_all(&config.peer_db_path)?; let buf_size = cmp::max(BASE_NODE_BUFFER_MIN_SIZE, config.buffer_size_base_node); - let (publisher, peer_message_subscriptions) = - pubsub_connector(runtime::Handle::current(), buf_size, config.buffer_rate_limit_base_node); + let (publisher, peer_message_subscriptions) = pubsub_connector(buf_size, config.buffer_rate_limit_base_node); let peer_message_subscriptions = Arc::new(peer_message_subscriptions); let node_config = BaseNodeServiceConfig::default(); // TODO - make this configurable diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index 87efdaf9bf..f3136f57e5 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -70,10 +70,8 @@ impl BaseNodeContext { pub async fn run(self) { info!(target: LOG_TARGET, "Tari base node has STARTED"); - if let Err(e) = self.state_machine().shutdown_signal().await { - warn!(target: LOG_TARGET, "Error shutting down Base Node State Machine: {}", e); - } - info!(target: LOG_TARGET, "Initiating communications stack shutdown"); + self.state_machine().shutdown_signal().wait().await; + info!(target: LOG_TARGET, "Waiting for communications stack shutdown"); self.base_node_comms.wait_until_shutdown().await; info!(target: LOG_TARGET, "Communications stack has shutdown"); diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 114cee7ad2..255adb9fec 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -53,11 +53,13 @@ use tari_core::{ mempool::service::LocalMempoolService, proof_of_work::PowAlgorithm, tari_utilities::{hex::Hex, message_format::MessageFormat}, - transactions::types::{Commitment, HashOutput, Signature}, + transactions::{ + emoji::EmojiId, + types::{Commitment, HashOutput, Signature}, + }, }; use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::Hashable}; use tari_p2p::auto_update::SoftwareUpdaterHandle; -use tari_wallet::util::emoji::EmojiId; use tokio::{runtime, sync::watch}; pub enum StatusOutput { @@ -101,7 +103,7 @@ impl CommandHandler { } pub fn status(&self, output: StatusOutput) { - let mut state_info = self.state_machine_info.clone(); + let state_info = self.state_machine_info.clone(); let mut node = self.node_service.clone(); let mut mempool = self.mempool_service.clone(); let peer_manager = self.peer_manager.clone(); @@ -115,8 +117,7 @@ impl CommandHandler { let version = format!("v{}", consts::APP_VERSION_NUMBER); status_line.add_field("", version); - let state = state_info.recv().await.unwrap(); - status_line.add_field("State", state.state_info.short_desc()); + status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); @@ -189,18 +190,8 @@ impl CommandHandler { /// Function to process the get-state-info command pub fn state_info(&self) { - let mut channel = self.state_machine_info.clone(); - self.executor.spawn(async move { - match channel.recv().await { - None => { - info!( - target: LOG_TARGET, - "Error communicating with state machine, channel could have been closed" - ); - }, - Some(data) => println!("Current state machine state:\n{}", data), - }; - }); + let watch = self.state_machine_info.clone(); + println!("Current state machine state:\n{}", *watch.borrow()); } /// Check for updates diff --git a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs index 00474a5cb4..d4c8181fe9 100644 --- a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs +++ b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs @@ -26,6 +26,7 @@ use crate::{ helpers::{mean, median}, }, }; +use futures::{channel::mpsc, SinkExt}; use log::*; use std::{ cmp, @@ -40,7 +41,6 @@ use tari_comms::{Bytes, CommsNode}; use tari_core::{ base_node::{ comms_interface::{Broadcast, CommsInterfaceError}, - state_machine_service::states::BlockSyncInfo, LocalNodeCommsInterface, StateMachineHandle, }, @@ -54,7 +54,7 @@ use tari_core::{ }; use tari_crypto::tari_utilities::{message_format::MessageFormat, Hashable}; use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "tari::base_node::grpc"; @@ -995,32 +995,25 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer { ) -> Result<Response<tari_rpc::SyncInfoResponse>, Status> { debug!(target: LOG_TARGET, "Incoming GRPC request for BN sync data"); - let mut channel = self.state_machine_handle.get_status_info_watch(); - - let mut sync_info: Option<BlockSyncInfo> = None; - - if let Some(info) = channel.recv().await { - sync_info = info.state_info.get_block_sync_info(); - } - - let mut response = tari_rpc::SyncInfoResponse { - tip_height: 0, - local_height: 0, - peer_node_id: vec![], - }; - - if let Some(info) = sync_info { - let node_ids = info - .sync_peers - .iter() - .map(|x| x.to_string().as_bytes().to_vec()) - .collect(); - response = tari_rpc::SyncInfoResponse { - tip_height: info.tip_height, - local_height: info.local_height, - peer_node_id: node_ids, - }; - } + let response = self + .state_machine_handle + .get_status_info_watch() + .borrow() + .state_info + .get_block_sync_info() + .map(|info| { + let node_ids = info + .sync_peers + .iter() + .map(|x| x.to_string().as_bytes().to_vec()) + .collect(); + tari_rpc::SyncInfoResponse { + tip_height: info.tip_height.unwrap_or_default(), + local_height: info.local_height.unwrap_or_default(), + peer_node_id: node_ids, + } + }) + .unwrap_or_default(); debug!(target: LOG_TARGET, "Sending SyncData response to client"); Ok(Response::new(response)) diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index e40f3f9cc4..5a2fc1bbc6 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -27,6 +27,7 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] +#![allow(dead_code)] /// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⣶⣿⣿⣿⣿⣶⣦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ /// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣾⣿⡿⠋⠀⠀⠀⠀⠉⠛⠿⣿⣿⣶⣤⣀⠀⠀⠀⠀⠀⠀⢰⣿⣾⣾⣾⣾⣾⣾⣾⣾⣾⣿⠀⠀⠀⣾⣾⣾⡀⠀⠀⠀⠀⢰⣾⣾⣾⣾⣿⣶⣶⡀⠀⠀⠀⢸⣾⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ @@ -96,7 +97,7 @@ mod status_line; mod utils; use crate::command_handler::{CommandHandler, StatusOutput}; -use futures::{future::Fuse, pin_mut, FutureExt}; +use futures::{pin_mut, FutureExt}; use log::*; use parser::Parser; use rustyline::{config::OutputStreamType, error::ReadlineError, CompletionType, Config, EditMode, Editor}; @@ -117,7 +118,7 @@ use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ runtime, task, - time::{self, Delay}, + time::{self}, }; use tonic::transport::Server; @@ -142,7 +143,7 @@ fn main_inner() -> Result<(), ExitCodes> { debug!(target: LOG_TARGET, "Using configuration: {:?}", node_config); // Set up the Tokio runtime - let mut rt = setup_runtime(&node_config).map_err(|e| { + let rt = setup_runtime(&node_config).map_err(|e| { error!(target: LOG_TARGET, "{}", e); ExitCodes::UnknownError })?; @@ -296,26 +297,28 @@ async fn read_command(mut rustyline: Editor<Parser>) -> Result<(String, Editor<P .expect("Could not spawn rustyline task") } -fn status_interval(start_time: Instant) -> Fuse<Delay> { +fn status_interval(start_time: Instant) -> time::Sleep { let duration = match start_time.elapsed().as_secs() { 0..=120 => Duration::from_secs(5), _ => Duration::from_secs(30), }; - time::delay_for(duration).fuse() + time::sleep(duration) } async fn status_loop(command_handler: Arc<CommandHandler>, shutdown: Shutdown) { let start_time = Instant::now(); let mut shutdown_signal = shutdown.to_signal(); loop { - let mut interval = status_interval(start_time); - futures::select! { + let interval = status_interval(start_time); + tokio::select! { + biased; + _ = shutdown_signal.wait() => { + break; + } + _ = interval => { command_handler.status(StatusOutput::Log); }, - _ = shutdown_signal => { - break; - } } } } @@ -344,9 +347,9 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { let start_time = Instant::now(); let mut software_update_notif = command_handler.get_software_updater().new_update_notifier().clone(); loop { - let mut interval = status_interval(start_time); - futures::select! { - res = read_command_fut => { + let interval = status_interval(start_time); + tokio::select! { + res = &mut read_command_fut => { match res { Ok((line, mut rustyline)) => { if let Some(p) = rustyline.helper_mut().as_deref_mut() { @@ -363,8 +366,8 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { } } }, - resp = software_update_notif.recv().fuse() => { - if let Some(Some(update)) = resp { + Ok(_) = software_update_notif.changed() => { + if let Some(ref update) = *software_update_notif.borrow() { println!( "Version {} of the {} is available: {} (sha: {})", update.version(), @@ -377,7 +380,7 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { _ = interval => { command_handler.status(StatusOutput::Full); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { break; } } diff --git a/applications/tari_console_wallet/Cargo.toml b/applications/tari_console_wallet/Cargo.toml index 5f43a75189..687ee1de94 100644 --- a/applications/tari_console_wallet/Cargo.toml +++ b/applications/tari_console_wallet/Cargo.toml @@ -8,18 +8,18 @@ edition = "2018" tari_wallet = { path = "../../base_layer/wallet" } tari_crypto = "0.11.1" tari_common = { path = "../../common" } -tari_app_utilities = { path = "../tari_app_utilities"} +tari_app_utilities = { path = "../tari_app_utilities", features = ["wallet"]} tari_comms = { path = "../../comms"} tari_comms_dht = { path = "../../comms/dht"} tari_p2p = { path = "../../base_layer/p2p" } -tari_app_grpc = { path = "../tari_app_grpc" } +tari_app_grpc = { path = "../tari_app_grpc", features = ["wallet"] } tari_shutdown = { path = "../../infrastructure/shutdown" } tari_key_manager = { path = "../../base_layer/key_manager" } bitflags = "1.2.1" chrono = { version = "0.4.6", features = ["serde"]} chrono-english = "0.1" -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} crossterm = { version = "0.17"} rand = "0.8" unicode-width = "0.1" @@ -31,9 +31,9 @@ rpassword = "5.0" rustyline = "6.0" strum = "^0.19" strum_macros = "^0.19" -tokio = { version="0.2.10", features = ["signal"] } -thiserror = "1.0.20" -tonic = "0.2" +tokio = { version="^1.10", features = ["signal"] } +thiserror = "1.0.26" +tonic = "0.5.2" [dependencies.tari_core] path = "../../base_layer/core" diff --git a/applications/tari_console_wallet/src/automation/commands.rs b/applications/tari_console_wallet/src/automation/commands.rs index 608cc8a675..0818d9782c 100644 --- a/applications/tari_console_wallet/src/automation/commands.rs +++ b/applications/tari_console_wallet/src/automation/commands.rs @@ -26,7 +26,7 @@ use crate::{ utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, }; use chrono::{DateTime, Utc}; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{ fs::File, @@ -45,6 +45,7 @@ use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; use tari_core::{ tari_utilities::hex::Hex, transactions::{ + emoji::EmojiId, tari_amount::{uT, MicroTari, Tari}, transaction::UnblindedOutput, types::PublicKey, @@ -54,12 +55,11 @@ use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; use tari_wallet::{ output_manager_service::{handle::OutputManagerHandle, TxId}, transaction_service::handle::{TransactionEvent, TransactionServiceHandle}, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ - sync::mpsc, - time::{delay_for, timeout}, + sync::{broadcast, mpsc}, + time::{sleep, timeout}, }; pub const LOG_TARGET: &str = "wallet::automation::commands"; @@ -175,21 +175,24 @@ pub async fn coin_split( Ok(tx_id) } -async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result<bool, CommandError> { - let mut connectivity = connectivity_requester.get_event_subscription().fuse(); +async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result<(), CommandError> { + let mut connectivity = connectivity_requester.get_event_subscription(); print!("Waiting for connectivity... "); - let mut timeout = delay_for(Duration::from_secs(30)).fuse(); + let timeout = sleep(Duration::from_secs(30)); + tokio::pin!(timeout); + let mut timeout = timeout.fuse(); loop { - futures::select! { - result = connectivity.select_next_some() => { - if let Ok(msg) = result { - if let ConnectivityEvent::PeerConnected(_) = (*msg).clone() { - println!("✅"); - return Ok(true); - } + tokio::select! { + biased; + + // Wait for the first base node connection + Ok(ConnectivityEvent::PeerConnected(conn)) = connectivity.recv() => { + if conn.peer_features().is_node() { + println!("✅"); + return Ok(()); } }, - () = timeout => { + () = &mut timeout => { println!("❌"); return Err(CommandError::Comms("Timed out".to_string())); } @@ -311,7 +314,7 @@ pub async fn make_it_rain( target: LOG_TARGET, "make-it-rain delaying for {:?} ms - scheduled to start at {}", delay_ms, start_time ); - delay_for(Duration::from_millis(delay_ms)).await; + sleep(Duration::from_millis(delay_ms)).await; let num_txs = (txps * duration as f64) as usize; let started_at = Utc::now(); @@ -352,10 +355,10 @@ pub async fn make_it_rain( let target_ms = (i as f64 / (txps / 1000.0)) as i64; if target_ms - actual_ms > 0 { // Maximum delay between Txs set to 120 s - delay_for(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; + sleep(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; } let delayed_for = Instant::now(); - let mut sender_clone = sender.clone(); + let sender_clone = sender.clone(); tokio::task::spawn(async move { let spawn_start = Instant::now(); // Send transaction @@ -432,7 +435,7 @@ pub async fn monitor_transactions( tx_ids: Vec<TxId>, wait_stage: TransactionStage, ) -> Vec<SentTransaction> { - let mut event_stream = transaction_service.get_event_stream_fused(); + let mut event_stream = transaction_service.get_event_stream(); let mut results = Vec::new(); debug!(target: LOG_TARGET, "monitor transactions wait_stage: {:?}", wait_stage); println!( @@ -442,104 +445,102 @@ pub async fn monitor_transactions( ); loop { - match event_stream.next().await { - Some(event_result) => match event_result { - Ok(event) => match &*event { - TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx direct send event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } - } - }, - TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx store and forward event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } + match event_stream.recv().await { + Ok(event) => match &*event { + TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx direct send event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); - if wait_stage == TransactionStage::Negotiated { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Negotiated, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx store and forward event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); - if wait_stage == TransactionStage::Broadcast { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Broadcast, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); + if wait_stage == TransactionStage::Negotiated { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Negotiated, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations - ); - if wait_stage == TransactionStage::MinedUnconfirmed { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::MinedUnconfirmed, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); + if wait_stage == TransactionStage::Broadcast { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Broadcast, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); - if wait_stage == TransactionStage::Mined { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Mined, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations + ); + if wait_stage == TransactionStage::MinedUnconfirmed { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::MinedUnconfirmed, + }); + if results.len() == tx_ids.len() { + break; } - }, - _ => {}, + } }, - Err(e) => { - eprintln!("RecvError in monitor_transactions: {:?}", e); - break; + TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); + if wait_stage == TransactionStage::Mined { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Mined, + }); + if results.len() == tx_ids.len() { + break; + } + } }, + _ => {}, }, - None => { - warn!( + // All event senders have gone (i.e. we take it that the node is shutting down) + Err(broadcast::error::RecvError::Closed) => { + debug!( target: LOG_TARGET, - "`None` result in event in monitor_transactions loop" + "All Transaction event senders have gone. Exiting `monitor_transactions` loop." ); break; }, + Err(err) => { + warn!(target: LOG_TARGET, "monitor_transactions: {}", err); + }, } } @@ -578,7 +579,8 @@ pub async fn command_runner( }, DiscoverPeer => { if !online { - online = wait_for_comms(&connectivity_requester).await?; + wait_for_comms(&connectivity_requester).await?; + online = true; } discover_peer(dht_service.clone(), parsed.args).await? }, diff --git a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs index b7e53ae6a2..1e98102851 100644 --- a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -1,4 +1,4 @@ -use futures::future; +use futures::{channel::mpsc, future, SinkExt}; use log::*; use std::convert::TryFrom; use tari_app_grpc::{ @@ -41,7 +41,7 @@ use tari_wallet::{ transaction_service::{handle::TransactionServiceHandle, storage::models}, WalletSqlite, }; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "wallet::ui::grpc"; diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index eeae75cd92..6c1ee279b6 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -128,9 +128,15 @@ pub async fn change_password( return Err(ExitCodes::InputError("Passwords don't match!".to_string())); } - wallet.remove_encryption().await?; + wallet + .remove_encryption() + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; - wallet.apply_encryption(passphrase).await?; + wallet + .apply_encryption(passphrase) + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; println!("Wallet password changed successfully."); diff --git a/applications/tari_console_wallet/src/main.rs b/applications/tari_console_wallet/src/main.rs index 25f749034e..4952aa27db 100644 --- a/applications/tari_console_wallet/src/main.rs +++ b/applications/tari_console_wallet/src/main.rs @@ -56,8 +56,7 @@ fn main() { } fn main_inner() -> Result<(), ExitCodes> { - let mut runtime = tokio::runtime::Builder::new() - .threaded_scheduler() + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Failed to build a runtime!"); @@ -153,11 +152,8 @@ fn main_inner() -> Result<(), ExitCodes> { }; print!("\nShutting down wallet... "); - if shutdown.trigger().is_ok() { - runtime.block_on(wallet.wait_until_shutdown()); - } else { - error!(target: LOG_TARGET, "No listeners for the shutdown signal!"); - } + shutdown.trigger(); + runtime.block_on(wallet.wait_until_shutdown()); println!("Done."); result diff --git a/applications/tari_console_wallet/src/recovery.rs b/applications/tari_console_wallet/src/recovery.rs index 887995f0a5..e9eca033bc 100644 --- a/applications/tari_console_wallet/src/recovery.rs +++ b/applications/tari_console_wallet/src/recovery.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use chrono::offset::Local; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use rustyline::Editor; use tari_app_utilities::utilities::ExitCodes; @@ -35,6 +35,7 @@ use tari_wallet::{ }; use crate::wallet_modes::PeerConfig; +use tokio::sync::broadcast; pub const LOG_TARGET: &str = "wallet::recovery"; @@ -97,13 +98,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi .with_retry_limit(10) .build_with_wallet(wallet, shutdown_signal); - let mut event_stream = recovery_task.get_event_receiver().fuse(); + let mut event_stream = recovery_task.get_event_receiver(); let recovery_join_handle = tokio::spawn(recovery_task.run()).fuse(); // Read recovery task events. The event stream will end once recovery has completed. - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { print!("Connecting to base node {}... ", peer); }, @@ -170,11 +171,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi info!(target: LOG_TARGET, "{}", stats); println!("{}", stats); }, - Err(e) => { - // Can occur if we read events too slowly (lagging/slow subscriber) + Err(e @ broadcast::error::RecvError::Lagged(_)) => { debug!(target: LOG_TARGET, "Error receiving Wallet recovery events: {}", e); continue; }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Ok(UtxoScannerEvent::ScanningFailed) => { error!(target: LOG_TARGET, "Wallet Recovery process failed and is exiting"); }, diff --git a/applications/tari_console_wallet/src/ui/components/base_node.rs b/applications/tari_console_wallet/src/ui/components/base_node.rs index d9a271e291..c51421d89f 100644 --- a/applications/tari_console_wallet/src/ui/components/base_node.rs +++ b/applications/tari_console_wallet/src/ui/components/base_node.rs @@ -42,9 +42,9 @@ impl BaseNode { impl<B: Backend> Component<B> for BaseNode { fn draw(&mut self, f: &mut Frame<B>, area: Rect, app_state: &AppState) where B: Backend { - let base_node_state = app_state.get_base_node_state(); + let current_online_status = app_state.get_wallet_connectivity().get_connectivity_status(); - let chain_info = match base_node_state.online { + let chain_info = match current_online_status { OnlineStatus::Connecting => Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), @@ -56,6 +56,7 @@ impl<B: Backend> Component<B> for BaseNode { Span::styled("Offline", Style::default().fg(Color::Red)), ]), OnlineStatus::Online => { + let base_node_state = app_state.get_base_node_state(); if let Some(metadata) = base_node_state.clone().chain_metadata { let tip = metadata.height_of_longest_chain(); @@ -92,7 +93,7 @@ impl<B: Backend> Component<B> for BaseNode { Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), - Span::styled("Error", Style::default().fg(Color::Red)), + Span::styled("Waiting for data...", Style::default().fg(Color::White)), ]) } }, diff --git a/applications/tari_console_wallet/src/ui/state/app_state.rs b/applications/tari_console_wallet/src/ui/state/app_state.rs index 6d9293ef47..be85727e07 100644 --- a/applications/tari_console_wallet/src/ui/state/app_state.rs +++ b/applications/tari_console_wallet/src/ui/state/app_state.rs @@ -34,7 +34,6 @@ use crate::{ wallet_modes::PeerConfig, }; use bitflags::bitflags; -use futures::{stream::Fuse, StreamExt}; use log::*; use qrcode::{render::unicode, QrCode}; use std::{ @@ -51,6 +50,7 @@ use tari_comms::{ NodeIdentity, }; use tari_core::transactions::{ + emoji::EmojiId, tari_amount::{uT, MicroTari}, types::PublicKey, }; @@ -66,7 +66,6 @@ use tari_wallet::{ storage::models::{CompletedTransaction, TransactionStatus}, }, types::ValidationRetryStrategy, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ @@ -84,6 +83,7 @@ pub struct AppState { completed_tx_filter: TransactionFilter, node_config: GlobalConfig, config: AppStateConfig, + wallet_connectivity: WalletConnectivityHandle, } impl AppState { @@ -95,6 +95,7 @@ impl AppState { base_node_config: PeerConfig, node_config: GlobalConfig, ) -> Self { + let wallet_connectivity = wallet.wallet_connectivity.clone(); let inner = AppStateInner::new(node_identity, network, wallet, base_node_selected, base_node_config); let cached_data = inner.data.clone(); @@ -105,6 +106,7 @@ impl AppState { completed_tx_filter: TransactionFilter::ABANDONED_COINBASES, node_config, config: AppStateConfig::default(), + wallet_connectivity, } } @@ -352,6 +354,10 @@ impl AppState { &self.cached_data.base_node_state } + pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { + self.wallet_connectivity.clone() + } + pub fn get_selected_base_node(&self) -> &Peer { &self.cached_data.base_node_selected } @@ -641,24 +647,24 @@ impl AppStateInner { self.wallet.comms.shutdown_signal() } - pub fn get_transaction_service_event_stream(&self) -> Fuse<TransactionEventReceiver> { - self.wallet.transaction_service.get_event_stream_fused() + pub fn get_transaction_service_event_stream(&self) -> TransactionEventReceiver { + self.wallet.transaction_service.get_event_stream() } - pub fn get_output_manager_service_event_stream(&self) -> Fuse<OutputManagerEventReceiver> { - self.wallet.output_manager_service.get_event_stream_fused() + pub fn get_output_manager_service_event_stream(&self) -> OutputManagerEventReceiver { + self.wallet.output_manager_service.get_event_stream() } - pub fn get_connectivity_event_stream(&self) -> Fuse<ConnectivityEventRx> { - self.wallet.comms.connectivity().get_event_subscription().fuse() + pub fn get_connectivity_event_stream(&self) -> ConnectivityEventRx { + self.wallet.comms.connectivity().get_event_subscription() } pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { self.wallet.wallet_connectivity.clone() } - pub fn get_base_node_event_stream(&self) -> Fuse<BaseNodeEventReceiver> { - self.wallet.base_node_service.clone().get_event_stream_fused() + pub fn get_base_node_event_stream(&self) -> BaseNodeEventReceiver { + self.wallet.base_node_service.get_event_stream() } pub async fn set_base_node_peer(&mut self, peer: Peer) -> Result<(), UiError> { diff --git a/applications/tari_console_wallet/src/ui/state/tasks.rs b/applications/tari_console_wallet/src/ui/state/tasks.rs index caf8073f56..85243660e6 100644 --- a/applications/tari_console_wallet/src/ui/state/tasks.rs +++ b/applications/tari_console_wallet/src/ui/state/tasks.rs @@ -21,11 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::ui::{state::UiTransactionSendStatus, UiError}; -use futures::StreamExt; use tari_comms::types::CommsPublicKey; use tari_core::transactions::tari_amount::MicroTari; use tari_wallet::transaction_service::handle::{TransactionEvent, TransactionServiceHandle}; -use tokio::sync::watch; +use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "wallet::console_wallet::tasks "; @@ -37,8 +36,8 @@ pub async fn send_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender<UiTransactionSendStatus>, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); let mut send_direct_received_result = (false, false); let mut send_saf_received_result = (false, false); match transaction_service_handle @@ -46,15 +45,15 @@ pub async fn send_transaction_task( .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => match &*event { TransactionEvent::TransactionDiscoveryInProgress(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::DiscoveryInProgress); + let _ = result_tx.send(UiTransactionSendStatus::DiscoveryInProgress); } }, TransactionEvent::TransactionDirectSendResult(tx_id, result) => { @@ -75,25 +74,28 @@ pub async fn send_transaction_task( }, TransactionEvent::TransactionCompletedImmediately(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } }, _ => (), }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } if send_direct_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentDirect); + let _ = result_tx.send(UiTransactionSendStatus::SentDirect); } else if send_saf_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentViaSaf); + let _ = result_tx.send(UiTransactionSendStatus::SentViaSaf); } else { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "Transaction could not be sent".to_string(), )); } @@ -109,34 +111,37 @@ pub async fn send_one_sided_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender<UiTransactionSendStatus>, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); match transaction_service_handle .send_one_sided_transaction(public_key, amount, fee_per_gram, message) .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => { if let TransactionEvent::TransactionCompletedImmediately(tx_id) = &*event { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } } }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "One-sided transaction could not be sent".to_string(), )); }, diff --git a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs index 2e20999667..e7df30b653 100644 --- a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs +++ b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{notifier::Notifier, ui::state::AppStateInner}; -use futures::stream::StreamExt; use log::*; use std::sync::Arc; use tari_comms::{connectivity::ConnectivityEvent, peer_manager::Peer}; @@ -30,7 +29,7 @@ use tari_wallet::{ output_manager_service::{handle::OutputManagerEvent, TxId}, transaction_service::handle::TransactionEvent, }; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, RwLock}; const LOG_TARGET: &str = "wallet::console_wallet::wallet_event_monitor"; @@ -55,14 +54,14 @@ impl WalletEventMonitor { let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream(); let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity(); - let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse(); + let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch(); let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream(); info!(target: LOG_TARGET, "Wallet Event Monitor starting"); loop { - futures::select! { - result = transaction_service_events.select_next_some() => { + tokio::select! { + result = transaction_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet transaction service event {:?}", msg); @@ -104,18 +103,21 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Transaction events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - status = connectivity_status.select_next_some() => { - trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status); + Ok(_) = connectivity_status.changed() => { + trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status changed"); self.trigger_peer_state_refresh().await; }, - result = connectivity_events.select_next_some() => { + result = connectivity_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg); - match &*msg { + match msg { ConnectivityEvent::PeerDisconnected(_) | ConnectivityEvent::ManagedPeerDisconnected(_) | ConnectivityEvent::PeerConnected(_) => { @@ -125,10 +127,13 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Connectivity event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Connectivity events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = base_node_events.select_next_some() => { + result = base_node_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received base node event {:?}", msg); @@ -141,10 +146,13 @@ impl WalletEventMonitor { } } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on base node event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Base node Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = output_manager_service_events.select_next_some() => { + result = output_manager_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -152,14 +160,13 @@ impl WalletEventMonitor { self.trigger_balance_refresh().await; } }, - Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Output Manager Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } - }, - complete => { - info!(target: LOG_TARGET, "Wallet Event Monitor is exiting because all tasks have completed"); - break; }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Wallet Event Monitor shutting down because the shutdown signal was received"); break; }, diff --git a/applications/tari_console_wallet/src/ui/ui_contact.rs b/applications/tari_console_wallet/src/ui/ui_contact.rs index 2d8be4a182..d87fc0de57 100644 --- a/applications/tari_console_wallet/src/ui/ui_contact.rs +++ b/applications/tari_console_wallet/src/ui/ui_contact.rs @@ -1,4 +1,5 @@ -use tari_wallet::{contacts_service::storage::database::Contact, util::emoji::EmojiId}; +use tari_core::transactions::emoji::EmojiId; +use tari_wallet::contacts_service::storage::database::Contact; #[derive(Debug, Clone)] pub struct UiContact { diff --git a/applications/tari_console_wallet/src/wallet_modes.rs b/applications/tari_console_wallet/src/wallet_modes.rs index a205523a00..23b1ee6330 100644 --- a/applications/tari_console_wallet/src/wallet_modes.rs +++ b/applications/tari_console_wallet/src/wallet_modes.rs @@ -239,7 +239,10 @@ pub fn tui_mode(config: WalletModeConfig, mut wallet: WalletSqlite) -> Result<() info!(target: LOG_TARGET, "Starting app"); - handle.enter(|| ui::run(app))?; + { + let _enter = handle.enter(); + ui::run(app)?; + } info!( target: LOG_TARGET, diff --git a/applications/tari_merge_mining_proxy/Cargo.toml b/applications/tari_merge_mining_proxy/Cargo.toml index 626d0502b6..2cd31b8b3d 100644 --- a/applications/tari_merge_mining_proxy/Cargo.toml +++ b/applications/tari_merge_mining_proxy/Cargo.toml @@ -13,33 +13,32 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} -tari_app_utilities = { path = "../tari_app_utilities"} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } +tari_app_utilities = { path = "../tari_app_utilities" } tari_crypto = "0.11.1" tari_utilities = "^0.3" anyhow = "1.0.40" bincode = "1.3.1" -bytes = "0.5.6" +bytes = "1.1" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11.4", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" diff --git a/applications/tari_merge_mining_proxy/src/main.rs b/applications/tari_merge_mining_proxy/src/main.rs index 7e15777977..0df27490bb 100644 --- a/applications/tari_merge_mining_proxy/src/main.rs +++ b/applications/tari_merge_mining_proxy/src/main.rs @@ -46,7 +46,7 @@ use tari_app_utilities::initialization::init_configuration; use tari_common::configuration::bootstrap::ApplicationType; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), anyhow::Error> { let (_, config, _) = init_configuration(ApplicationType::MergeMiningProxy)?; diff --git a/applications/tari_mining_node/Cargo.toml b/applications/tari_mining_node/Cargo.toml index a04938c048..87f72c0e92 100644 --- a/applications/tari_mining_node/Cargo.toml +++ b/applications/tari_mining_node/Cargo.toml @@ -17,15 +17,15 @@ crossbeam = "0.8" futures = "0.3" log = { version = "0.4", features = ["std"] } num_cpus = "1.13" -prost-types = "0.6" +prost-types = "0.8" rand = "0.8" sha3 = "0.9" serde = { version = "1.0", default_features = false, features = ["derive"] } -tonic = { version = "0.2", features = ["transport"] } -tokio = { version = "0.2", default_features = false, features = ["rt-core"] } +tonic = { version = "0.5.2", features = ["transport"] } +tokio = { version = "1.10", default_features = false, features = ["rt"] } thiserror = "1.0" jsonrpc = "0.11.0" -reqwest = { version = "0.11", features = ["blocking", "json"] } +reqwest = { version = "0.11", features = [ "json"] } serde_json = "1.0.57" native-tls = "0.2" bufstream = "0.1" @@ -35,5 +35,5 @@ hex = "0.4.2" [dev-dependencies] tari_crypto = "0.11.1" -prost-types = "0.6.1" +prost-types = "0.8" chrono = "0.4" diff --git a/applications/tari_mining_node/src/main.rs b/applications/tari_mining_node/src/main.rs index b1f538c1e4..ddf717070f 100644 --- a/applications/tari_mining_node/src/main.rs +++ b/applications/tari_mining_node/src/main.rs @@ -27,7 +27,7 @@ use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::W use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; use tari_core::blocks::BlockHeader; -use tokio::{runtime::Runtime, time::delay_for}; +use tokio::{runtime::Runtime, time::sleep}; use tonic::transport::Channel; use utils::{coinbase_request, extract_outputs_and_kernels}; @@ -144,7 +144,7 @@ async fn main_inner() -> Result<(), ExitCodes> { error!("Connection error: {:?}", err); loop { debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; match connect(&config, &global).await { Ok((nc, wc)) => { node_conn = nc; @@ -168,7 +168,7 @@ async fn main_inner() -> Result<(), ExitCodes> { Err(err) => { error!("Error: {:?}", err); debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; }, Ok(submitted) => { if submitted { diff --git a/applications/tari_stratum_transcoder/Cargo.toml b/applications/tari_stratum_transcoder/Cargo.toml index 29f95c82da..4cb899dddd 100644 --- a/applications/tari_stratum_transcoder/Cargo.toml +++ b/applications/tari_stratum_transcoder/Cargo.toml @@ -13,12 +13,12 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_utilities = "^0.3" bincode = "1.3.1" -bytes = "0.5.6" +bytes = "0.5" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" @@ -29,21 +29,20 @@ hyper = "0.13.7" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.7.2" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.10.8", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "^1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" url = "2.1.1" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" [dev-dependencies] futures-test = "0.3.5" diff --git a/applications/test_faucet/Cargo.toml b/applications/test_faucet/Cargo.toml index 3ef3c8a4c1..a6cfc25bec 100644 --- a/applications/test_faucet/Cargo.toml +++ b/applications/test_faucet/Cargo.toml @@ -20,6 +20,6 @@ default-features = false features = ["transactions", "avx2"] [dependencies.tokio] -version = "^0.2.10" +version = "^1.10" default-features = false -features = ["fs", "blocking", "stream", "rt-threaded", "macros", "io-util", "sync"] +features = ["fs", "rt-multi-thread", "macros", "io-util", "sync"] diff --git a/base_layer/common_types/Cargo.toml b/base_layer/common_types/Cargo.toml index 2b85d4e9a9..5042a1ddfc 100644 --- a/base_layer/common_types/Cargo.toml +++ b/base_layer/common_types/Cargo.toml @@ -11,4 +11,5 @@ futures = {version = "^0.3.1", features = ["async-await"] } rand = "0.8" tari_crypto = "0.11.1" serde = { version = "1.0.106", features = ["derive"] } -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +tokio = { version="^1.10", features = [ "time", "sync"] } +lazy_static = "1.4.0" diff --git a/base_layer/common_types/src/waiting_requests.rs b/base_layer/common_types/src/waiting_requests.rs index a26119a5cb..67e6eed6ef 100644 --- a/base_layer/common_types/src/waiting_requests.rs +++ b/base_layer/common_types/src/waiting_requests.rs @@ -20,10 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::oneshot::Sender as OneshotSender; use rand::RngCore; use std::{collections::HashMap, sync::Arc, time::Instant}; -use tokio::sync::RwLock; +use tokio::sync::{oneshot::Sender as OneshotSender, RwLock}; pub type RequestKey = u64; diff --git a/base_layer/core/Cargo.toml b/base_layer/core/Cargo.toml index 6e802b7fbc..167963bf24 100644 --- a/base_layer/core/Cargo.toml +++ b/base_layer/core/Cargo.toml @@ -35,27 +35,28 @@ bincode = "1.1.4" bitflags = "1.0.4" blake2 = "^0.9.0" sha3 = "0.9" -bytes = "0.4.12" +bytes = "0.5" chrono = { version = "0.4.6", features = ["serde"]} croaring = { version = "=0.4.5", optional = true } digest = "0.9.0" -futures = {version = "^0.3.1", features = ["async-await"] } +futures = {version = "^0.3.16", features = ["async-await"] } fs2 = "0.3.0" hex = "0.4.2" +lazy_static = "1.4.0" lmdb-zero = "0.4.4" log = "0.4" monero = { version = "^0.13.0", features= ["serde_support"], optional = true } newtype-ops = "0.1.4" num = "0.3" -prost = "0.6.1" -prost-types = "0.6.1" +prost = "0.8.0" +prost-types = "0.8.0" rand = "0.8" randomx-rs = { version = "0.5.0", optional = true } serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0" strum_macros = "0.17.1" -thiserror = "1.0.20" -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +thiserror = "1.0.26" +tokio = { version="^1.10", features = [ "time", "sync", "macros"] } ttl_cache = "0.5.1" uint = { version = "0.9", default-features = false } num-format = "0.4.0" @@ -67,7 +68,6 @@ tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } config = { version = "0.9.3" } env_logger = "0.7.0" tempfile = "3.1.0" -tokio-macros = "0.2.4" [build-dependencies] tari_common = { version = "^0.9", path="../../common", features = ["build"]} diff --git a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs index 1310f22702..2700dc9d01 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs @@ -20,10 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{service::ChainMetadataService, LOG_TARGET}; +use super::service::ChainMetadataService; use crate::base_node::{chain_metadata_service::handle::ChainMetadataHandle, comms_interface::LocalNodeCommsInterface}; -use futures::{future, pin_mut}; -use log::*; use tari_comms::connectivity::ConnectivityRequester; use tari_p2p::services::liveness::LivenessHandle; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; @@ -40,15 +38,12 @@ impl ServiceInitializer for ChainMetadataServiceInitializer { let handle = ChainMetadataHandle::new(publisher.clone()); context.register_handle(handle); - context.spawn_when_ready(|handles| async move { + context.spawn_until_shutdown(|handles| { let liveness = handles.expect_handle::<LivenessHandle>(); let base_node = handles.expect_handle::<LocalNodeCommsInterface>(); let connectivity = handles.expect_handle::<ConnectivityRequester>(); - let service_run = ChainMetadataService::new(liveness, base_node, connectivity, publisher).run(); - pin_mut!(service_run); - future::select(service_run, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "ChainMetadataService has shut down"); + ChainMetadataService::new(liveness, base_node, connectivity, publisher).run() }); Ok(()) diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index 6c7da96719..332ade2350 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -29,7 +29,6 @@ use crate::{ chain_storage::BlockAddResult, proto::base_node as proto, }; -use futures::stream::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use prost::Message; @@ -75,9 +74,9 @@ impl ChainMetadataService { /// Run the service pub async fn run(mut self) { - let mut liveness_event_stream = self.liveness.get_event_stream().fuse(); - let mut block_event_stream = self.base_node.get_block_event_stream().fuse(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut liveness_event_stream = self.liveness.get_event_stream(); + let mut block_event_stream = self.base_node.get_block_event_stream(); + let mut connectivity_events = self.connectivity.get_event_subscription(); log_if_error!( target: LOG_TARGET, @@ -86,47 +85,36 @@ impl ChainMetadataService { ); loop { - futures::select! { - block_event = block_event_stream.select_next_some() => { - if let Ok(block_event) = block_event { - log_if_error!( - level: debug, - target: LOG_TARGET, - "Failed to handle block event because '{}'", - self.handle_block_event(&block_event).await - ); - } + tokio::select! { + Ok(block_event) = block_event_stream.recv() => { + log_if_error!( + level: debug, + target: LOG_TARGET, + "Failed to handle block event because '{}'", + self.handle_block_event(&block_event).await + ); }, - liveness_event = liveness_event_stream.select_next_some() => { - if let Ok(event) = liveness_event { - log_if_error!( - target: LOG_TARGET, - "Failed to handle liveness event because '{}'", - self.handle_liveness_event(&*event).await - ); - } + Ok(event) = liveness_event_stream.recv() => { + log_if_error!( + target: LOG_TARGET, + "Failed to handle liveness event because '{}'", + self.handle_liveness_event(&*event).await + ); }, - event = connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event); - } - } - - complete => { - info!(target: LOG_TARGET, "ChainStateSyncService is exiting because all tasks have completed"); - break; + Ok(event) = connectivity_events.recv() => { + self.handle_connectivity_event(event); } } } } - fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { use ConnectivityEvent::*; match event { PeerDisconnected(node_id) | ManagedPeerDisconnected(node_id) | PeerBanned(node_id) => { - if let Some(pos) = self.peer_chain_metadata.iter().position(|p| &p.node_id == node_id) { + if let Some(pos) = self.peer_chain_metadata.iter().position(|p| p.node_id == node_id) { debug!( target: LOG_TARGET, "Removing disconnected/banned peer `{}` from chain metadata list ", node_id @@ -298,6 +286,7 @@ impl ChainMetadataService { mod test { use super::*; use crate::base_node::comms_interface::{CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}; + use futures::StreamExt; use std::convert::TryInto; use tari_comms::test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, @@ -361,7 +350,7 @@ mod test { ) } - #[tokio_macros::test] + #[tokio::test] async fn update_liveness_chain_metadata() { let (mut service, liveness_mock_state, _, mut base_node_receiver) = setup(); @@ -370,11 +359,11 @@ mod test { let chain_metadata = proto_chain_metadata.clone().try_into().unwrap(); task::spawn(async move { - let base_node_req = base_node_receiver.select_next_some().await; - let (_req, reply_tx) = base_node_req.split(); - reply_tx - .send(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) - .unwrap(); + if let Some(base_node_req) = base_node_receiver.next().await { + base_node_req + .reply(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) + .unwrap(); + } }); service.update_liveness_chain_metadata().await.unwrap(); @@ -387,7 +376,7 @@ mod test { let chain_metadata = proto::ChainMetadata::decode(data.as_slice()).unwrap(); assert_eq!(chain_metadata.height_of_longest_chain, Some(123)); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_ok() { let (mut service, _, _, _) = setup(); @@ -416,7 +405,7 @@ mod test { ); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_banned_peer() { let (mut service, _, _, _) = setup(); @@ -450,7 +439,7 @@ mod test { .all(|p| &p.node_id != nodes[0].node_id())); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_no_metadata() { let (mut service, _, _, _) = setup(); @@ -468,7 +457,7 @@ mod test { assert_eq!(service.peer_chain_metadata.len(), 0); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_bad_metadata() { let (mut service, _, _, _) = setup(); diff --git a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs index 753fe802d8..f3f08409b2 100644 --- a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs @@ -26,11 +26,11 @@ use crate::{ chain_storage::HistoricalBlock, transactions::{transaction::TransactionOutput, types::HashOutput}, }; -use futures::channel::mpsc::UnboundedSender; use log::*; use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::bn::comms_interface::outbound_interface"; @@ -234,10 +234,8 @@ impl OutboundNodeCommsInterface { new_block: NewBlock, exclude_peers: Vec<NodeId>, ) -> Result<(), CommsInterfaceError> { - self.block_sender - .unbounded_send((new_block, exclude_peers)) - .map_err(|err| { - CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) - }) + self.block_sender.send((new_block, exclude_peers)).map_err(|err| { + CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) + }) } } diff --git a/base_layer/core/src/base_node/rpc/service.rs b/base_layer/core/src/base_node/rpc/service.rs index c50600ea9c..5adefae0e6 100644 --- a/base_layer/core/src/base_node/rpc/service.rs +++ b/base_layer/core/src/base_node/rpc/service.rs @@ -230,7 +230,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeWalletService for BaseNodeWalletRpc // Determine if we are synced let status_watch = state_machine.get_status_info_watch(); - let is_synced = match (*status_watch.borrow()).state_info { + let is_synced = match status_watch.borrow().state_info { StateInfo::Listening(li) => li.is_synced(), _ => false, }; diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index ae6be0b519..11132db138 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -33,7 +33,7 @@ use crate::{ proto as shared_protos, proto::base_node as proto, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -50,7 +50,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::service::initializer"; const SUBSCRIPTION_LABEL: &str = "Base Node"; @@ -151,7 +151,7 @@ where T: BlockchainBackend + 'static let inbound_block_stream = self.inbound_block_stream(); // Connect InboundNodeCommsInterface and OutboundNodeCommsInterface to BaseNodeService let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); - let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded(); + let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded_channel(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (local_block_sender_service, local_block_stream) = reply_channel::unbounded(); let outbound_nci = diff --git a/base_layer/core/src/base_node/service/service.rs b/base_layer/core/src/base_node/service/service.rs index db96b6ed9e..1d66cbf1b1 100644 --- a/base_layer/core/src/base_node/service/service.rs +++ b/base_layer/core/src/base_node/service/service.rs @@ -38,16 +38,7 @@ use crate::{ proto as shared_protos, proto::{base_node as proto, base_node::base_node_service_request::Request}, }; -use futures::{ - channel::{ - mpsc::{channel, Receiver, Sender, UnboundedReceiver}, - oneshot::Sender as OneshotSender, - }, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -64,7 +55,14 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::reply_channel::RequestContext; -use tokio::task; +use tokio::{ + sync::{ + mpsc, + mpsc::{Receiver, Sender, UnboundedReceiver}, + oneshot::Sender as OneshotSender, + }, + task, +}; const LOG_TARGET: &str = "c::bn::base_node_service::service"; @@ -134,7 +132,7 @@ where B: BlockchainBackend + 'static config: BaseNodeServiceConfig, state_machine_handle: StateMachineHandle, ) -> Self { - let (timeout_sender, timeout_receiver) = channel(100); + let (timeout_sender, timeout_receiver) = mpsc::channel(100); Self { outbound_message_service, inbound_nch, @@ -162,7 +160,7 @@ where B: BlockchainBackend + 'static { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let outbound_block_stream = streams.outbound_block_stream.fuse(); + let outbound_block_stream = streams.outbound_block_stream; pin_mut!(outbound_block_stream); let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); @@ -177,53 +175,52 @@ where B: BlockchainBackend + 'static let timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Base Node Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Base Node Service initialized without timeout_receiver_stream"); pin_mut!(timeout_receiver_stream); loop { - futures::select! { + tokio::select! { // Outbound request messages from the OutboundNodeCommsInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound block messages from the OutboundNodeCommsInterface - (block, excluded_peers) = outbound_block_stream.select_next_some() => { + Some((block, excluded_peers)) = outbound_block_stream.recv() => { self.spawn_handle_outbound_block(block, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, // Incoming block messages from the Comms layer - block_msg = inbound_block_stream.select_next_some() => { + Some(block_msg) = inbound_block_stream.next() => { self.spawn_handle_incoming_block(block_msg).await; } // Incoming local request messages from the LocalNodeCommsInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Incoming local block messages from the LocalNodeCommsInterface e.g. miner and block sync - local_block_context = local_block_stream.select_next_some() => { + Some(local_block_context) = local_block_stream.next() => { self.spawn_handle_local_block(local_block_context); }, - complete => { - info!(target: LOG_TARGET, "Base Node service shutting down"); + else => { + info!(target: LOG_TARGET, "Base Node service shutting down because all streams ended"); break; } } @@ -646,9 +643,9 @@ async fn handle_request_timeout( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: Sender<RequestKey>, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: Sender<RequestKey>, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/base_node/state_machine_service/state_machine.rs b/base_layer/core/src/base_node/state_machine_service/state_machine.rs index 3c8e33ee06..383bc82d03 100644 --- a/base_layer/core/src/base_node/state_machine_service/state_machine.rs +++ b/base_layer/core/src/base_node/state_machine_service/state_machine.rs @@ -158,7 +158,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeStateMachine<B> { state_info: self.info.clone(), }; - if let Err(e) = self.status_event_sender.broadcast(status) { + if let Err(e) = self.status_event_sender.send(status) { debug!(target: LOG_TARGET, "Error broadcasting a StatusEvent update: {}", e); } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs index cc596d61d4..971c1ddd9a 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs @@ -76,11 +76,11 @@ impl BlockSync { false.into(), )); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSync(BlockSyncInfo { - tip_height: remote_tip_height, - local_height, + tip_height: Some(remote_tip_height), + local_height: Some(local_height), sync_peers: sync_peers.to_vec(), }), }); diff --git a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs index 8c65886792..2202637744 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs @@ -170,12 +170,7 @@ impl StateInfo { pub fn short_desc(&self) -> String { match self { Self::StartUp => "Starting up".to_string(), - Self::HeaderSync(info) => format!( - "Syncing headers: {}/{} ({:.0}%)", - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 - ), + Self::HeaderSync(info) => format!("Syncing headers: {}", info.sync_progress_string()), Self::HorizonSync(info) => match info.status { HorizonSyncStatus::Starting => "Starting horizon sync".to_string(), HorizonSyncStatus::Kernels(current, total) => format!( @@ -192,12 +187,7 @@ impl StateInfo { ), HorizonSyncStatus::Finalizing => "Finalizing horizon sync".to_string(), }, - Self::BlockSync(info) => format!( - "Syncing blocks: {}/{} ({:.0}%)", - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 - ), + Self::BlockSync(info) => format!("Syncing blocks: {}", info.sync_progress_string()), Self::Listening(_) => "Listening".to_string(), } } @@ -261,8 +251,8 @@ impl Display for StatusInfo { #[derive(Clone, Debug, PartialEq)] /// This struct contains info that is use full for external viewing of state info pub struct BlockSyncInfo { - pub tip_height: u64, - pub local_height: u64, + pub tip_height: Option<u64>, + pub local_height: Option<u64>, pub sync_peers: Vec<NodeId>, } @@ -270,20 +260,27 @@ impl BlockSyncInfo { /// Creates a new blockSyncInfo pub fn new(tip_height: u64, local_height: u64, sync_peers: Vec<NodeId>) -> BlockSyncInfo { BlockSyncInfo { - tip_height, - local_height, + tip_height: Some(tip_height), + local_height: Some(local_height), sync_peers, } } + + pub fn sync_progress_string(&self) -> String { + self.local_height + .and_then(|h| self.tip_height.map(|t| (h, t))) + .map(|(h, t)| format!("{}/{} ({:.0}%)", h, t, (h as f64 / t as f64 * 100.0))) + .unwrap_or_else(|| "--".to_string()) + } } impl Display for BlockSyncInfo { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - fmt.write_str("Syncing from the following peers: \n")?; + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Syncing from the following peers:")?; for peer in &self.sync_peers { - fmt.write_str(&format!("{}\n", peer))?; + writeln!(f, "{}", peer)?; } - fmt.write_str(&format!("Syncing {}/{}\n", self.local_height, self.tip_height)) + writeln!(f, "Syncing {}", self.sync_progress_string()) } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs index 2acfd53206..9f1ae87673 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs @@ -74,15 +74,27 @@ impl HeaderSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - synchronizer.on_progress(move |current_height, remote_tip_height, sync_peers| { - let _ = status_event_sender.broadcast(StatusInfo { - bootstrapped, - state_info: StateInfo::HeaderSync(BlockSyncInfo { - tip_height: remote_tip_height, - local_height: current_height, - sync_peers: sync_peers.to_vec(), - }), - }); + synchronizer.on_progress(move |details, sync_peers| { + let status_info = match details { + Some((current_height, remote_tip_height)) => StatusInfo { + bootstrapped, + state_info: StateInfo::HeaderSync(BlockSyncInfo { + tip_height: Some(remote_tip_height), + local_height: Some(current_height), + sync_peers: sync_peers.to_vec(), + }), + }, + None => StatusInfo { + bootstrapped, + state_info: StateInfo::HeaderSync(BlockSyncInfo { + tip_height: None, + local_height: None, + sync_peers: sync_peers.to_vec(), + }), + }, + }; + + let _ = status_event_sender.send(status_info); }); let local_nci = shared.local_node_interface.clone(); diff --git a/base_layer/core/src/base_node/state_machine_service/states/listening.rs b/base_layer/core/src/base_node/state_machine_service/states/listening.rs index 0ea8157568..92349c269a 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/listening.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/listening.rs @@ -31,7 +31,6 @@ use crate::{ }, chain_storage::BlockchainBackend, }; -use futures::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use serde::{Deserialize, Serialize}; @@ -118,7 +117,8 @@ impl Listening { info!(target: LOG_TARGET, "Listening for chain metadata updates"); shared.set_state_info(StateInfo::Listening(ListeningInfo::new(self.is_synced))); - while let Some(metadata_event) = shared.metadata_event_stream.next().await { + loop { + let metadata_event = shared.metadata_event_stream.recv().await; match metadata_event.as_ref().map(|v| v.deref()) { Ok(ChainMetadataEvent::PeerChainMetadataReceived(peer_metadata_list)) => { let mut peer_metadata_list = peer_metadata_list.clone(); @@ -199,16 +199,16 @@ impl Listening { if !self.is_synced { self.is_synced = true; + shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); debug!(target: LOG_TARGET, "Initial sync achieved"); } - shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { debug!(target: LOG_TARGET, "Metadata event subscriber lagged by {} item(s)", n); }, - Err(broadcast::RecvError::Closed) => { - // This should never happen because the while loop exits when the stream ends + Err(broadcast::error::RecvError::Closed) => { debug!(target: LOG_TARGET, "Metadata event subscriber closed"); + break; }, } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs index 7ea2e7e2b0..aeaa5ab430 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs @@ -23,7 +23,7 @@ use crate::base_node::state_machine_service::states::{BlockSync, HeaderSync, HorizonStateSync, StateEvent}; use log::info; use std::time::Duration; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "c::bn::state_machine_service::states::waiting"; @@ -41,7 +41,7 @@ impl Waiting { "The base node has started a WAITING state for {} seconds", self.timeout.as_secs() ); - delay_for(self.timeout).await; + sleep(self.timeout).await; info!( target: LOG_TARGET, "The base node waiting state has completed. Resuming normal operations" diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index 8ff487a8e5..d52050d228 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -82,7 +82,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } pub fn on_progress<H>(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.hooks.add_on_progress_header_hook(hook); } @@ -93,6 +93,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { pub async fn synchronize(&mut self) -> Result<PeerConnection, BlockHeaderSyncError> { debug!(target: LOG_TARGET, "Starting header sync.",); + self.hooks.call_on_progress_header_hooks(None, self.sync_peers); let sync_peers = self.select_sync_peers().await?; info!( target: LOG_TARGET, @@ -259,7 +260,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { let latency = client.get_last_request_latency().await?; debug!( target: LOG_TARGET, - "Initiating header sync with peer `{}` (latency = {}ms)", + "Initiating header sync with peer `{}` (sync latency = {}ms)", conn.peer_node_id(), latency.unwrap_or_default().as_millis() ); @@ -270,6 +271,10 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { // We're ahead of this peer, try another peer if possible SyncStatus::Ahead => Err(BlockHeaderSyncError::NotInSync), SyncStatus::Lagging(split_info) => { + self.hooks.call_on_progress_header_hooks( + Some((split_info.local_tip_header.height(), split_info.remote_tip_height)), + self.sync_peers, + ); self.synchronize_headers(&peer, &mut client, *split_info).await?; Ok(()) }, @@ -481,7 +486,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { const COMMIT_EVERY_N_HEADERS: usize = 1000; // Peer returned no more than the max headers. This indicates that there are no further headers to request. - if self.header_validator.valid_headers().len() <= NUM_INITIAL_HEADERS_TO_REQUEST as usize { + if self.header_validator.valid_headers().len() < NUM_INITIAL_HEADERS_TO_REQUEST as usize { debug!(target: LOG_TARGET, "No further headers to download"); if !self.pending_chain_has_higher_pow(&split_info.local_tip_header)? { return Err(BlockHeaderSyncError::WeakerChain); @@ -561,7 +566,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } self.hooks - .call_on_progress_header_hooks(current_height, split_info.remote_tip_height, self.sync_peers); + .call_on_progress_header_hooks(Some((current_height, split_info.remote_tip_height)), self.sync_peers); } if !has_switched_to_new_chain { diff --git a/base_layer/core/src/base_node/sync/header_sync/validator.rs b/base_layer/core/src/base_node/sync/header_sync/validator.rs index aff52b80fd..4a25f28aa7 100644 --- a/base_layer/core/src/base_node/sync/header_sync/validator.rs +++ b/base_layer/core/src/base_node/sync/header_sync/validator.rs @@ -283,7 +283,7 @@ mod test { mod initialize_state { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_initializes_state_to_given_header() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(&tip.header().hash()).await.unwrap(); @@ -295,7 +295,7 @@ mod test { assert_eq!(state.current_height, 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_if_hash_does_not_exist() { let (mut validator, _) = setup(); let start_hash = vec![0; 32]; @@ -308,7 +308,7 @@ mod test { mod validate { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_passes_if_headers_are_valid() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(tip.hash()).await.unwrap(); @@ -322,7 +322,7 @@ mod test { assert_eq!(validator.valid_headers().len(), 2); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_fails_if_height_is_not_serial() { let (mut validator, _, tip) = setup_with_headers(2).await; validator.initialize_state(tip.hash()).await.unwrap(); diff --git a/base_layer/core/src/base_node/sync/hooks.rs b/base_layer/core/src/base_node/sync/hooks.rs index 71f0802926..d1e2628822 100644 --- a/base_layer/core/src/base_node/sync/hooks.rs +++ b/base_layer/core/src/base_node/sync/hooks.rs @@ -28,7 +28,7 @@ use tari_comms::peer_manager::NodeId; #[derive(Default)] pub(super) struct Hooks { - on_progress_header: Vec<Box<dyn FnMut(u64, u64, &[NodeId]) + Send + Sync>>, + on_progress_header: Vec<Box<dyn FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync>>, on_progress_block: Vec<Box<dyn FnMut(Arc<ChainBlock>, u64, &[NodeId]) + Send + Sync>>, on_complete: Vec<Box<dyn FnMut(Arc<ChainBlock>) + Send + Sync>>, on_rewind: Vec<Box<dyn FnMut(Vec<Arc<ChainBlock>>) + Send + Sync>>, @@ -36,14 +36,14 @@ pub(super) struct Hooks { impl Hooks { pub fn add_on_progress_header_hook<H>(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.on_progress_header.push(Box::new(hook)); } - pub fn call_on_progress_header_hooks(&mut self, height: u64, remote_tip_height: u64, sync_peers: &[NodeId]) { + pub fn call_on_progress_header_hooks(&mut self, height_vs_remote: Option<(u64, u64)>, sync_peers: &[NodeId]) { self.on_progress_header .iter_mut() - .for_each(|f| (*f)(height, remote_tip_height, sync_peers)); + .for_each(|f| (*f)(height_vs_remote, sync_peers)); } pub fn add_on_progress_block_hook<H>(&mut self, hook: H) diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index fc2d621113..1c0b3eb4e1 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -35,12 +35,14 @@ use crate::{ SyncUtxosResponse, }, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::cmp; -use tari_comms::protocol::rpc::{Request, Response, RpcStatus, Streaming}; +use tari_comms::{ + protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, +}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::task; +use tokio::{sync::mpsc, task}; const LOG_TARGET: &str = "c::base_node::sync_rpc"; @@ -114,7 +116,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ // Number of blocks to load and push to the stream before loading the next batch const BATCH_SIZE: usize = 4; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); task::spawn(async move { let iter = NonOverlappingIntegerPairIter::new(start, end + 1, BATCH_SIZE); @@ -134,19 +136,16 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(blocks) => { - let mut blocks = stream::iter( - blocks - .into_iter() - .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) - .map(|block| match block { - Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), - Err(err) => Err(err), - }) - .map(Ok), - ); + let blocks = blocks + .into_iter() + .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) + .map(|block| match block { + Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), + Err(err) => Err(err), + }); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut blocks).await.is_err() { + if utils::mpsc::send_all(&tx, blocks).await.is_err() { break; } }, @@ -202,7 +201,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ chunk_size ); - let (mut tx, rx) = mpsc::channel(chunk_size); + let (tx, rx) = mpsc::channel(chunk_size); task::spawn(async move { let iter = NonOverlappingIntegerPairIter::new( start_header.height + 1, @@ -224,10 +223,9 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(headers) => { - let mut headers = - stream::iter(headers.into_iter().map(proto::core::BlockHeader::from).map(Ok).map(Ok)); + let headers = headers.into_iter().map(proto::core::BlockHeader::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut headers).await.is_err() { + if utils::mpsc::send_all(&tx, headers).await.is_err() { break; } }, @@ -339,7 +337,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ ) -> Result<Streaming<proto::types::TransactionKernel>, RpcStatus> { let req = request.into_message(); const BATCH_SIZE: usize = 1000; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let db = self.db(); task::spawn(async move { @@ -379,15 +377,9 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(kernels) => { - let mut kernels = stream::iter( - kernels - .into_iter() - .map(proto::types::TransactionKernel::from) - .map(Ok) - .map(Ok), - ); + let kernels = kernels.into_iter().map(proto::types::TransactionKernel::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut kernels).await.is_err() { + if utils::mpsc::send_all(&tx, kernels).await.is_err() { break; } }, diff --git a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs index ef10b41c2f..8064aaf458 100644 --- a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs +++ b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs @@ -25,11 +25,11 @@ use crate::{ proto, proto::base_node::{SyncUtxo, SyncUtxosRequest, SyncUtxosResponse}, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc, time::Instant}; -use tari_comms::protocol::rpc::RpcStatus; +use tari_comms::{protocol::rpc::RpcStatus, utils}; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tokio::sync::mpsc; const LOG_TARGET: &str = "c::base_node::sync_rpc::sync_utxo_task"; @@ -147,8 +147,7 @@ where B: BlockchainBackend + 'static utxos.len(), deleted_diff.cardinality(), ); - let mut utxos = stream::iter( - utxos + let utxos = utxos .into_iter() .enumerate() // Only include pruned UTXOs if include_pruned_utxos is true @@ -161,12 +160,10 @@ where B: BlockchainBackend + 'static mmr_index: start + i as u64, } }) - .map(Ok) - .map(Ok), - ); + .map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut utxos).await.is_err() { + if utils::mpsc::send_all(&tx, utxos).await.is_err() { break; } diff --git a/base_layer/core/src/base_node/sync/rpc/tests.rs b/base_layer/core/src/base_node/sync/rpc/tests.rs index 35611a9dda..61adc1aa3c 100644 --- a/base_layer/core/src/base_node/sync/rpc/tests.rs +++ b/base_layer/core/src/base_node/sync/rpc/tests.rs @@ -89,7 +89,7 @@ mod sync_blocks { use tari_comms::protocol::rpc::RpcStatusCode; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_not_found_if_unknown_hash() { let mut backend = create_mock_backend(); backend.expect_fetch().times(1).returning(|_| Ok(None)); @@ -103,7 +103,7 @@ mod sync_blocks { unpack_enum!(RpcStatusCode::NotFound = err.status_code()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_sends_an_empty_response() { let mut backend = create_mock_backend(); @@ -136,7 +136,7 @@ mod sync_blocks { assert!(streaming.next().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_streams_blocks_until_end() { let mut backend = create_mock_backend(); diff --git a/base_layer/core/src/lib.rs b/base_layer/core/src/lib.rs index 2e4bdc2f49..5a93b73576 100644 --- a/base_layer/core/src/lib.rs +++ b/base_layer/core/src/lib.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work +// Needed to make tokio::select! work #![recursion_limit = "512"] #![feature(shrink_to)] // #![cfg_attr(not(debug_assertions), deny(unused_variables))] diff --git a/base_layer/core/src/mempool/rpc/test.rs b/base_layer/core/src/mempool/rpc/test.rs index 64ad84f122..a9cbb2ee49 100644 --- a/base_layer/core/src/mempool/rpc/test.rs +++ b/base_layer/core/src/mempool/rpc/test.rs @@ -43,7 +43,7 @@ mod get_stats { use super::*; use crate::mempool::{MempoolService, StatsResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_stats() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_stats = StatsResponse { @@ -66,7 +66,7 @@ mod get_state { use super::*; use crate::mempool::{MempoolService, StateResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_state() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_state = StateResponse { @@ -94,7 +94,7 @@ mod get_tx_state_by_excess_sig { use tari_crypto::ristretto::{RistrettoPublicKey, RistrettoSecretKey}; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_storage_status() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -116,7 +116,7 @@ mod get_tx_state_by_excess_sig { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_signature() { let (service, _, req_mock, _tmpdir) = setup(); let status = service @@ -139,7 +139,7 @@ mod submit_transaction { use tari_crypto::ristretto::RistrettoSecretKey; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_submits_transaction() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -166,7 +166,7 @@ mod submit_transaction { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_transaction() { let (service, _, req_mock, _tmpdir) = setup(); let status = service diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index cb57d58e94..a295daf96a 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -37,7 +37,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -54,7 +54,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::mempool_service::initializer"; const SUBSCRIPTION_LABEL: &str = "Mempool"; @@ -148,7 +148,7 @@ impl ServiceInitializer for MempoolServiceInitializer { let mempool_handle = MempoolHandle::new(request_sender); context.register_handle(mempool_handle); - let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded(); + let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded_channel(); let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (mempool_state_event_publisher, _) = broadcast::channel(100); @@ -167,7 +167,7 @@ impl ServiceInitializer for MempoolServiceInitializer { context.register_handle(outbound_mp_interface); context.register_handle(local_mp_interface); - context.spawn_when_ready(move |handles| async move { + context.spawn_until_shutdown(move |handles| { let outbound_message_service = handles.expect_handle::<Dht>().outbound_requester(); let state_machine = handles.expect_handle::<StateMachineHandle>(); let base_node = handles.expect_handle::<LocalNodeCommsInterface>(); @@ -182,11 +182,7 @@ impl ServiceInitializer for MempoolServiceInitializer { block_event_stream: base_node.get_block_event_stream(), request_receiver, }; - let service = - MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams); - futures::pin_mut!(service); - future::select(service, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "Mempool Service shutdown"); + MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams) }); Ok(()) diff --git a/base_layer/core/src/mempool/service/local_service.rs b/base_layer/core/src/mempool/service/local_service.rs index c58af9912d..58e4bef4d3 100644 --- a/base_layer/core/src/mempool/service/local_service.rs +++ b/base_layer/core/src/mempool/service/local_service.rs @@ -146,7 +146,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); @@ -157,7 +157,7 @@ mod test { assert_eq!(stats, request_stats()); } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats_from_multiple() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); diff --git a/base_layer/core/src/mempool/service/outbound_interface.rs b/base_layer/core/src/mempool/service/outbound_interface.rs index 87cda226f3..f84a90a4a8 100644 --- a/base_layer/core/src/mempool/service/outbound_interface.rs +++ b/base_layer/core/src/mempool/service/outbound_interface.rs @@ -28,10 +28,10 @@ use crate::{ }, transactions::{transaction::Transaction, types::Signature}, }; -use futures::channel::mpsc::UnboundedSender; use log::*; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::mp::service::outbound_interface"; @@ -71,15 +71,13 @@ impl OutboundMempoolServiceInterface { transaction: Transaction, exclude_peers: Vec<NodeId>, ) -> Result<(), MempoolServiceError> { - self.tx_sender - .unbounded_send((transaction, exclude_peers)) - .or_else(|e| { - { - error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); - Err(e) - } - .map_err(|_| MempoolServiceError::BroadcastFailed) - }) + self.tx_sender.send((transaction, exclude_peers)).or_else(|e| { + { + error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); + Err(e) + } + .map_err(|_| MempoolServiceError::BroadcastFailed) + }) } /// Check if the specified transaction is stored in the mempool of a remote base node. diff --git a/base_layer/core/src/mempool/service/service.rs b/base_layer/core/src/mempool/service/service.rs index b8ee487b9c..e82e8afa2d 100644 --- a/base_layer/core/src/mempool/service/service.rs +++ b/base_layer/core/src/mempool/service/service.rs @@ -38,13 +38,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{ - channel::{mpsc, oneshot::Sender as OneshotSender}, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -58,7 +52,10 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot::Sender as OneshotSender}, + task, +}; const LOG_TARGET: &str = "c::mempool::service::service"; @@ -118,7 +115,7 @@ impl MempoolService { { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let mut outbound_tx_stream = streams.outbound_tx_stream.fuse(); + let mut outbound_tx_stream = streams.outbound_tx_stream; let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); let inbound_response_stream = streams.inbound_response_stream.fuse(); @@ -127,70 +124,70 @@ impl MempoolService { pin_mut!(inbound_transaction_stream); let local_request_stream = streams.local_request_stream.fuse(); pin_mut!(local_request_stream); - let mut block_event_stream = streams.block_event_stream.fuse(); + let mut block_event_stream = streams.block_event_stream; let mut timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Mempool Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Mempool Service initialized without timeout_receiver_stream"); let mut request_receiver = streams.request_receiver; loop { - futures::select! { + tokio::select! { // Requests sent from the handle - request = request_receiver.select_next_some() => { + Some(request) = request_receiver.next() => { let (request, reply) = request.split(); let _ = reply.send(self.handle_request(request).await); }, // Outbound request messages from the OutboundMempoolServiceInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound tx messages from the OutboundMempoolServiceInterface - (txn, excluded_peers) = outbound_tx_stream.select_next_some() => { + Some((txn, excluded_peers)) = outbound_tx_stream.recv() => { self.spawn_handle_outbound_tx(txn, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Incoming transaction messages from the Comms layer - transaction_msg = inbound_transaction_stream.select_next_some() => { + Some(transaction_msg) = inbound_transaction_stream.next() => { self.spawn_handle_incoming_tx(transaction_msg).await; } // Incoming local request messages from the LocalMempoolServiceInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Block events from local Base Node. - block_event = block_event_stream.select_next_some() => { + block_event = block_event_stream.recv() => { if let Ok(block_event) = block_event { self.spawn_handle_block_event(block_event); } }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, - complete => { + else => { info!(target: LOG_TARGET, "Mempool service shutting down"); break; } } } + Ok(()) } @@ -506,9 +503,9 @@ async fn handle_outbound_tx( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: mpsc::Sender<RequestKey>, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: mpsc::Sender<RequestKey>, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/mempool/sync_protocol/initializer.rs b/base_layer/core/src/mempool/sync_protocol/initializer.rs index df40de648a..535c32c77a 100644 --- a/base_layer/core/src/mempool/sync_protocol/initializer.rs +++ b/base_layer/core/src/mempool/sync_protocol/initializer.rs @@ -28,13 +28,17 @@ use crate::{ MempoolServiceConfig, }, }; -use futures::channel::mpsc; +use log::*; +use std::time::Duration; use tari_comms::{ connectivity::ConnectivityRequester, protocol::{ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolNotification}, Substream, }; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::{sync::mpsc, time::sleep}; + +const LOG_TARGET: &str = "c::mempool::sync_protocol"; pub struct MempoolSyncInitializer { config: MempoolServiceConfig, @@ -70,17 +74,29 @@ impl ServiceInitializer for MempoolSyncInitializer { let mempool = self.mempool.clone(); let notif_rx = self.notif_rx.take().unwrap(); - context.spawn_when_ready(move |handles| { + context.spawn_until_shutdown(move |handles| async move { let state_machine = handles.expect_handle::<StateMachineHandle>(); let connectivity = handles.expect_handle::<ConnectivityRequester>(); - MempoolSyncProtocol::new( - config, - notif_rx, - connectivity.get_event_subscription(), - mempool, - Some(state_machine), - ) - .run() + + let mut status_watch = state_machine.get_status_info_watch(); + if !status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Waiting for node to bootstrap..."); + while status_watch.changed().await.is_ok() { + if status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Node bootstrapped. Starting mempool sync protocol"); + break; + } + trace!( + target: LOG_TARGET, + "Mempool sync still on hold, waiting for bootstrap to finish", + ); + sleep(Duration::from_secs(1)).await; + } + } + + MempoolSyncProtocol::new(config, notif_rx, connectivity.get_event_subscription(), mempool) + .run() + .await; }); Ok(()) diff --git a/base_layer/core/src/mempool/sync_protocol/mod.rs b/base_layer/core/src/mempool/sync_protocol/mod.rs index 08219f2a8d..dc133f744d 100644 --- a/base_layer/core/src/mempool/sync_protocol/mod.rs +++ b/base_layer/core/src/mempool/sync_protocol/mod.rs @@ -73,12 +73,11 @@ mod initializer; pub use initializer::MempoolSyncInitializer; use crate::{ - base_node::StateMachineHandle, mempool::{async_mempool, proto, Mempool, MempoolServiceConfig}, proto as shared_proto, transactions::transaction::Transaction, }; -use futures::{stream, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{stream, SinkExt, Stream, StreamExt}; use log::*; use prost::Message; use std::{ @@ -88,7 +87,6 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx}, @@ -101,7 +99,11 @@ use tari_comms::{ PeerConnection, }; use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tokio::{sync::Semaphore, task}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Semaphore, + task, +}; const MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB const LOG_TARGET: &str = "c::mempool::sync_protocol"; @@ -111,11 +113,10 @@ pub static MEMPOOL_SYNC_PROTOCOL: Bytes = Bytes::from_static(b"t/mempool-sync/1" pub struct MempoolSyncProtocol<TSubstream> { config: MempoolServiceConfig, protocol_notifier: ProtocolNotificationRx<TSubstream>, - connectivity_events: Fuse<ConnectivityEventRx>, + connectivity_events: ConnectivityEventRx, mempool: Mempool, num_synched: Arc<AtomicUsize>, permits: Arc<Semaphore>, - state_machine: Option<StateMachineHandle>, } impl<TSubstream> MempoolSyncProtocol<TSubstream> @@ -126,54 +127,34 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static protocol_notifier: ProtocolNotificationRx<TSubstream>, connectivity_events: ConnectivityEventRx, mempool: Mempool, - state_machine: Option<StateMachineHandle>, ) -> Self { Self { config, protocol_notifier, - connectivity_events: connectivity_events.fuse(), + connectivity_events, mempool, num_synched: Arc::new(AtomicUsize::new(0)), permits: Arc::new(Semaphore::new(1)), - state_machine, } } pub async fn run(mut self) { info!(target: LOG_TARGET, "Mempool protocol handler has started"); - while let Some(ref v) = self.state_machine { - let status_watch = v.get_status_info_watch(); - if (*status_watch.borrow()).bootstrapped { - break; - } - trace!( - target: LOG_TARGET, - "Mempool sync still on hold, waiting for bootstrap to finish", - ); - tokio::time::delay_for(Duration::from_secs(1)).await; - } + loop { - futures::select! { - event = self.connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event).await; - } + tokio::select! { + Ok(event) = self.connectivity_events.recv() => { + self.handle_connectivity_event(event).await; }, - notif = self.protocol_notifier.select_next_some() => { + Some(notif) = self.protocol_notifier.recv() => { self.handle_protocol_notification(notif); } - - // protocol_notifier and connectivity_events are closed - complete => { - info!(target: LOG_TARGET, "Mempool protocol handler is shutting down"); - break; - } } } } - async fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + async fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { match event { // If this node is connecting to a peer ConnectivityEvent::PeerConnected(conn) if conn.direction().is_outbound() => { diff --git a/base_layer/core/src/mempool/sync_protocol/test.rs b/base_layer/core/src/mempool/sync_protocol/test.rs index dd77fe3c70..e68275170b 100644 --- a/base_layer/core/src/mempool/sync_protocol/test.rs +++ b/base_layer/core/src/mempool/sync_protocol/test.rs @@ -30,7 +30,7 @@ use crate::{ transactions::{helpers::create_tx, tari_amount::uT, transaction::Transaction}, validation::mocks::MockValidator, }; -use futures::{channel::mpsc, Sink, SinkExt, Stream, StreamExt}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use std::{fmt, io, iter::repeat_with, sync::Arc}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventTx}, @@ -44,7 +44,10 @@ use tari_comms::{ BytesMut, }; use tari_crypto::tari_utilities::ByteArray; -use tokio::{sync::broadcast, task}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; pub fn create_transactions(n: usize) -> Vec<Transaction> { repeat_with(|| { @@ -82,7 +85,6 @@ fn setup( protocol_notif_rx, connectivity_events_rx, mempool.clone(), - None, ); task::spawn(protocol.run()); @@ -90,7 +92,7 @@ fn setup( (protocol_notif_tx, connectivity_events_tx, mempool, transactions) } -#[tokio_macros::test_basic] +#[tokio::test] async fn empty_set() { let (_, connectivity_events_tx, mempool1, _) = setup(0); @@ -120,7 +122,7 @@ async fn empty_set() { assert_eq!(transactions.len(), 0); } -#[tokio_macros::test_basic] +#[tokio::test] async fn synchronise() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(5); @@ -154,7 +156,7 @@ async fn synchronise() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn duplicate_set() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(2); @@ -189,9 +191,9 @@ async fn duplicate_set() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -225,9 +227,9 @@ async fn responder() { // this. } -#[tokio_macros::test_basic] +#[tokio::test] async fn initiator_messages() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -260,7 +262,7 @@ async fn initiator_messages() { assert_eq!(indexes.indexes, [0, 1]); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder_messages() { let (_, connectivity_events_tx, _, transactions1) = setup(1); diff --git a/base_layer/wallet/src/util/emoji.rs b/base_layer/core/src/transactions/emoji/emoji_id.rs similarity index 97% rename from base_layer/wallet/src/util/emoji.rs rename to base_layer/core/src/transactions/emoji/emoji_id.rs index 18ecdc174c..8234e41fad 100644 --- a/base_layer/wallet/src/util/emoji.rs +++ b/base_layer/core/src/transactions/emoji/emoji_id.rs @@ -20,12 +20,13 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::util::luhn::{checksum, is_valid}; +use super::luhn::{checksum, is_valid}; +use crate::transactions::types::PublicKey; +use lazy_static::lazy_static; use std::{ collections::HashMap, fmt::{Display, Error, Formatter}, }; -use tari_core::transactions::types::PublicKey; use tari_crypto::tari_utilities::{ hex::{Hex, HexError}, ByteArray, @@ -70,7 +71,7 @@ lazy_static! { /// # Example /// /// ``` -/// use tari_wallet::util::emoji::EmojiId; +/// use tari_core::transactions::emoji::EmojiId; /// /// assert!(EmojiId::is_valid("🐎🍴🌷🌟💻🐖🐩🐾🌟🐬🎧🐌🏦🐳🐎🐝🐢🔋👕🎸👿🍒🐓🎉💔🌹🏆🐬💡🎳🚦🍹🎒")); /// let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); @@ -170,9 +171,7 @@ pub struct EmojiIdError; #[cfg(test)] mod test { - use crate::util::emoji::EmojiId; - use tari_core::transactions::types::PublicKey; - use tari_crypto::tari_utilities::hex::Hex; + use super::*; #[test] fn convert_key() { diff --git a/base_layer/wallet/src/util/luhn.rs b/base_layer/core/src/transactions/emoji/luhn.rs similarity index 100% rename from base_layer/wallet/src/util/luhn.rs rename to base_layer/core/src/transactions/emoji/luhn.rs diff --git a/base_layer/core/src/transactions/emoji/mod.rs b/base_layer/core/src/transactions/emoji/mod.rs new file mode 100644 index 0000000000..a9b0b1add0 --- /dev/null +++ b/base_layer/core/src/transactions/emoji/mod.rs @@ -0,0 +1,26 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod emoji_id; +pub use emoji_id::{emoji_set, EmojiId, EmojiIdError}; + +mod luhn; diff --git a/base_layer/core/src/transactions/mod.rs b/base_layer/core/src/transactions/mod.rs index 9d98bd394b..a027c4c4cf 100644 --- a/base_layer/core/src/transactions/mod.rs +++ b/base_layer/core/src/transactions/mod.rs @@ -11,10 +11,7 @@ pub use transaction_protocol::{recipient::ReceiverTransactionProtocol, sender::S #[macro_use] pub mod helpers; -#[cfg(any(feature = "base_node", feature = "transactions"))] -mod coinbase_builder; +pub mod emoji; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuildError; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuilder; +mod coinbase_builder; +pub use crate::transactions::coinbase_builder::{CoinbaseBuildError, CoinbaseBuilder}; diff --git a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs index 0d5beb738d..46e8267c18 100644 --- a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs +++ b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs @@ -763,6 +763,7 @@ mod test { .with_output(output, p.sender_offset_private_key) .unwrap() .with_fee_per_gram(MicroTari(2)); + for _ in 0..MAX_TRANSACTION_INPUTS + 1 { let (utxo, input) = create_test_input(MicroTari(50), 0, &factories.commitment); builder.with_input(utxo, input); diff --git a/base_layer/core/tests/base_node_rpc.rs b/base_layer/core/tests/base_node_rpc.rs index 9b96512d47..453b9eef30 100644 --- a/base_layer/core/tests/base_node_rpc.rs +++ b/base_layer/core/tests/base_node_rpc.rs @@ -81,22 +81,19 @@ use tari_core::{ txn_schema, }; use tempfile::{tempdir, TempDir}; -use tokio::runtime::Runtime; -fn setup() -> ( +async fn setup() -> ( BaseNodeWalletRpcService<TempDatabase>, NodeInterfaces, RpcRequestMock, ConsensusManager, ChainBlock, UnblindedOutput, - Runtime, TempDir, ) { let network = NetworkConsensus::from(Network::LocalNet); let consensus_constants = network.create_consensus_constants(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxo0) = @@ -107,13 +104,14 @@ fn setup() -> ( let (mut base_node, _consensus_manager) = BaseNodeBuilder::new(network) .with_consensus_manager(consensus_manager.clone()) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; base_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), }); - let request_mock = runtime.enter(|| RpcRequestMock::new(base_node.comms.peer_manager())); + let request_mock = RpcRequestMock::new(base_node.comms.peer_manager()); let service = BaseNodeWalletRpcService::new( base_node.blockchain_db.clone().into(), base_node.mempool_handle.clone(), @@ -126,16 +124,15 @@ fn setup() -> ( consensus_manager, block0, utxo0, - runtime, temp_dir, ) } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_base_node_wallet_rpc() { +async fn test_base_node_wallet_rpc() { // Testing the submit_transaction() and transaction_query() rpc calls - let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, mut runtime, _temp_dir) = setup(); + let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, _temp_dir) = setup().await; let (txs1, utxos1) = schema_to_transaction(&[txn_schema!(from: vec![utxo0.clone()], to: vec![1 * T, 1 * T])]); let tx1 = (*txs1[0]).clone(); @@ -151,8 +148,8 @@ fn test_base_node_wallet_rpc() { // Query Tx1 let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = service.transaction_query(req).await.unwrap().into_message(); + let resp = TxQueryResponse::try_from(resp).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -162,13 +159,7 @@ fn test_base_node_wallet_rpc() { let msg = TransactionProto::from(tx2.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::Orphan); @@ -176,8 +167,7 @@ fn test_base_node_wallet_rpc() { // Query Tx2 to confirm it wasn't accepted let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -189,24 +179,22 @@ fn test_base_node_wallet_rpc() { .prepare_block_merkle_roots(chain_block(&block0.block(), vec![tx1.clone()], &consensus_manager)) .unwrap(); - assert!(runtime - .block_on(base_node.local_nci.submit_block(block1.clone(), Broadcast::from(true))) - .is_ok()); + base_node + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); // Check that subitting Tx2 will now be accepted let msg = TransactionProto::from(tx2); let req = request_mock.request_with_context(Default::default(), msg); - let resp = runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(); + let resp = service.submit_transaction(req).await.unwrap().into_message(); assert!(resp.accepted); // Query Tx2 which should now be in the mempool let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -215,13 +203,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined as Tx1's kernel is present let msg = TransactionProto::from(tx1); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::AlreadyMined); @@ -233,13 +215,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined let msg = TransactionProto::from(tx1b); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::DoubleSpend); @@ -253,15 +229,16 @@ fn test_base_node_wallet_rpc() { block2.header.output_mmr_size += 1; block2.header.kernel_mmr_size += 1; - runtime - .block_on(base_node.local_nci.submit_block(block2, Broadcast::from(true))) + base_node + .local_nci + .submit_block(block2, Broadcast::from(true)) + .await .unwrap(); // Query Tx1 which should be in block 1 with 1 confirmation let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 1); assert_eq!(resp.block_hash, Some(block1.hash())); @@ -271,10 +248,7 @@ fn test_base_node_wallet_rpc() { sigs: vec![SignatureProto::from(tx1_sig.clone()), SignatureProto::from(tx2_sig)], }; let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.transaction_batch_query(req)) - .unwrap() - .into_message(); + let response = service.transaction_batch_query(req).await.unwrap().into_message(); for r in response.responses { let response = TxQueryBatchResponse::try_from(r).unwrap(); @@ -299,10 +273,7 @@ fn test_base_node_wallet_rpc() { let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.fetch_matching_utxos(req)) - .unwrap() - .into_message(); + let response = service.fetch_matching_utxos(req).await.unwrap().into_message(); assert_eq!(response.outputs.len(), utxos1.len()); for output_proto in response.outputs.iter() { diff --git a/base_layer/core/tests/helpers/event_stream.rs b/base_layer/core/tests/helpers/event_stream.rs index b79b494900..5485467f4c 100644 --- a/base_layer/core/tests/helpers/event_stream.rs +++ b/base_layer/core/tests/helpers/event_stream.rs @@ -20,16 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{future, future::Either, FutureExt, Stream, StreamExt}; use std::time::Duration; +use tokio::{sync::broadcast, time}; #[allow(dead_code)] -pub async fn event_stream_next<TStream>(stream: &mut TStream, timeout: Duration) -> Option<TStream::Item> -where TStream: Stream + Unpin { - let either = future::select(stream.next(), tokio::time::delay_for(timeout).fuse()).await; - - match either { - Either::Left((v, _)) => v, - Either::Right(_) => None, +pub async fn event_stream_next<T: Clone>(stream: &mut broadcast::Receiver<T>, timeout: Duration) -> Option<T> { + tokio::select! { + item = stream.recv() => match item { + Ok(item) => Some(item), + Err(broadcast::error::RecvError::Closed) => None, + Err(broadcast::error::RecvError::Lagged(n)) => panic!("Lagged events channel {}", n), + }, + _ = time::sleep(timeout) => None } } diff --git a/base_layer/core/tests/helpers/mock_state_machine.rs b/base_layer/core/tests/helpers/mock_state_machine.rs index 7d49f93e85..0d4b6ce512 100644 --- a/base_layer/core/tests/helpers/mock_state_machine.rs +++ b/base_layer/core/tests/helpers/mock_state_machine.rs @@ -40,7 +40,7 @@ impl MockBaseNodeStateMachine { } pub fn publish_status(&mut self, status: StatusInfo) { - let _ = self.status_sender.broadcast(status); + let _ = self.status_sender.send(status); } pub fn get_initializer(&self) -> MockBaseNodeStateMachineInitializer { diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index ffe69c8034..06f2d5f8e0 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -21,9 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::helpers::mock_state_machine::MockBaseNodeStateMachine; -use futures::Sink; use rand::rngs::OsRng; -use std::{error::Error, path::Path, sync::Arc, time::Duration}; +use std::{path::Path, sync::Arc, time::Duration}; use tari_common::configuration::Network; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, @@ -60,13 +59,12 @@ use tari_core::{ }, }; use tari_p2p::{ - comms_connector::{pubsub_connector, InboundDomainConnector, PeerMessage}, + comms_connector::{pubsub_connector, InboundDomainConnector}, initialization::initialize_local_test_comms, services::liveness::{LivenessConfig, LivenessHandle, LivenessInitializer}, }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tokio::runtime::Runtime; /// The NodeInterfaces is used as a container for providing access to all the services and interfaces of a single node. pub struct NodeInterfaces { @@ -91,7 +89,7 @@ pub struct NodeInterfaces { #[allow(dead_code)] impl NodeInterfaces { pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -182,7 +180,7 @@ impl BaseNodeBuilder { /// Build the test base node and start its services. #[allow(clippy::redundant_closure)] - pub fn start(self, runtime: &mut Runtime, data_path: &str) -> (NodeInterfaces, ConsensusManager) { + pub async fn start(self, data_path: &str) -> (NodeInterfaces, ConsensusManager) { let validators = self.validators.unwrap_or_else(|| { Validators::new( MockValidator::new(true), @@ -199,7 +197,6 @@ impl BaseNodeBuilder { let mempool = Mempool::new(self.mempool_config.unwrap_or_default(), Arc::new(mempool_validator)); let node_identity = self.node_identity.unwrap_or_else(|| random_node_identity()); let node_interfaces = setup_base_node_services( - runtime, node_identity, self.peers.unwrap_or_default(), blockchain_db, @@ -209,17 +206,19 @@ impl BaseNodeBuilder { self.mempool_service_config.unwrap_or_default(), self.liveness_service_config.unwrap_or_default(), data_path, - ); + ) + .await; (node_interfaces, consensus_manager) } } -#[allow(dead_code)] -pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { +pub async fn wait_until_online(nodes: &[&NodeInterfaces]) { for node in nodes { - runtime - .block_on(node.comms.connectivity().wait_for_connectivity(Duration::from_secs(10))) + node.comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await .map_err(|err| format!("Node '{}' failed to go online {:?}", node.node_identity.node_id(), err)) .unwrap(); } @@ -227,10 +226,7 @@ pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes( - runtime: &mut Runtime, - data_path: &str, -) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { +pub async fn create_network_with_2_base_nodes(data_path: &str) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { let alice_node_identity = random_node_identity(); let bob_node_identity = random_node_identity(); @@ -238,22 +234,23 @@ pub fn create_network_with_2_base_nodes( let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) - .start(runtime, data_path); + .start(data_path) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) .with_consensus_manager(consensus_manager) - .start(runtime, data_path); + .start(data_path) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes_with_config<P: AsRef<Path>>( - runtime: &mut Runtime, +pub async fn create_network_with_2_base_nodes_with_config<P: AsRef<Path>>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -269,7 +266,8 @@ pub fn create_network_with_2_base_nodes_with_config<P: AsRef<Path>>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) @@ -277,35 +275,34 @@ pub fn create_network_with_2_base_nodes_with_config<P: AsRef<Path>>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes( data_path: &str, ) -> (NodeInterfaces, NodeInterfaces, NodeInterfaces, ConsensusManager) { let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); create_network_with_3_base_nodes_with_config( - runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, data_path, ) + .await } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes_with_config<P: AsRef<Path>>( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes_with_config<P: AsRef<Path>>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -329,7 +326,8 @@ pub fn create_network_with_3_base_nodes_with_config<P: AsRef<Path>>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("carol").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("carol").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![carol_node_identity.clone()]) @@ -337,7 +335,8 @@ pub fn create_network_with_3_base_nodes_with_config<P: AsRef<Path>>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) @@ -345,9 +344,10 @@ pub fn create_network_with_3_base_nodes_with_config<P: AsRef<Path>>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; (alice_node, bob_node, carol_node, consensus_manager) } @@ -365,16 +365,12 @@ pub fn random_node_identity() -> Arc<NodeIdentity> { // Helper function for starting the comms stack. #[allow(dead_code)] -async fn setup_comms_services<TSink>( +async fn setup_comms_services( node_identity: Arc<NodeIdentity>, peers: Vec<Arc<NodeIdentity>>, - publisher: InboundDomainConnector<TSink>, + publisher: InboundDomainConnector, data_path: &str, -) -> (CommsNode, Dht, MessagingEventSender, Shutdown) -where - TSink: Sink<Arc<PeerMessage>> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender, Shutdown) { let peers = peers.into_iter().map(|p| p.to_peer()).collect(); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = initialize_local_test_comms( @@ -393,8 +389,7 @@ where // Helper function for starting the services of the Base node. #[allow(clippy::too_many_arguments)] -fn setup_base_node_services( - runtime: &mut Runtime, +async fn setup_base_node_services( node_identity: Arc<NodeIdentity>, peers: Vec<Arc<NodeIdentity>>, blockchain_db: BlockchainDatabase<TempDatabase>, @@ -405,14 +400,14 @@ fn setup_base_node_services( liveness_service_config: LivenessConfig, data_path: &str, ) -> NodeInterfaces { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht, messaging_events, shutdown) = - runtime.block_on(setup_comms_services(node_identity.clone(), peers, publisher, data_path)); + setup_comms_services(node_identity.clone(), peers, publisher, data_path).await; let mock_state_machine = MockBaseNodeStateMachine::new(); - let fut = StackBuilder::new(shutdown.to_signal()) + let handles = StackBuilder::new(shutdown.to_signal()) .add_initializer(RegisterHandle::new(dht)) .add_initializer(RegisterHandle::new(comms.connectivity())) .add_initializer(LivenessInitializer::new( @@ -433,9 +428,9 @@ fn setup_base_node_services( )) .add_initializer(mock_state_machine.get_initializer()) .add_initializer(ChainMetadataServiceInitializer) - .build(); - - let handles = runtime.block_on(fut).expect("Service initialization failed"); + .build() + .await + .unwrap(); let outbound_nci = handles.expect_handle::<OutboundNodeCommsInterface>(); let local_nci = handles.expect_handle::<LocalNodeCommsInterface>(); diff --git a/base_layer/core/tests/mempool.rs b/base_layer/core/tests/mempool.rs index 80e187a99a..973769f664 100644 --- a/base_layer/core/tests/mempool.rs +++ b/base_layer/core/tests/mempool.rs @@ -66,11 +66,10 @@ use tari_crypto::script; use tari_p2p::{services::liveness::LivenessConfig, tari_message::TariMessageType}; use tari_test_utils::async_assert_eventually; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_insert_and_process_published_block() { +async fn test_insert_and_process_published_block() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -201,9 +200,9 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.total_weight, 30); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_time_locked() { +async fn test_time_locked() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -245,9 +244,9 @@ fn test_time_locked() { assert_eq!(mempool.insert(tx2).unwrap(), TxStorageResponse::UnconfirmedPool); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_retrieve() { +async fn test_retrieve() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -331,9 +330,9 @@ fn test_retrieve() { assert!(retrieved_txs.contains(&tx2[1])); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_zero_conf() { +async fn test_zero_conf() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -631,9 +630,9 @@ fn test_zero_conf() { assert!(retrieved_txs.contains(&Arc::new(tx34))); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_reorg() { +async fn test_reorg() { let network = Network::LocalNet; let (mut db, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(db.clone()); @@ -712,13 +711,12 @@ fn test_reorg() { mempool.process_reorg(vec![], vec![reorg_block4.into()]).unwrap(); } -#[test] // TODO: This test returns 0 in the unconfirmed pool, so might not catch errors. It should be updated to return better // data #[allow(clippy::identity_op)] -fn request_response_get_stats() { +#[tokio::test] +async fn request_response_get_stats() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -731,13 +729,13 @@ fn request_response_get_stats() { .with_block(block0) .build(); let (mut alice, bob, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path(), - ); + ) + .await; // Create a tx spending the genesis output. Then create 2 orphan txs let (tx1, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![2 * T, 2 * T, 2 * T])); @@ -759,21 +757,18 @@ fn request_response_get_stats() { assert_eq!(stats.reorg_txs, 0); assert_eq!(stats.total_weight, 0); - runtime.block_on(async { - // Alice will request mempool stats from Bob, and thus should be identical - let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); - assert_eq!(received_stats.total_txs, 0); - assert_eq!(received_stats.unconfirmed_txs, 0); - assert_eq!(received_stats.reorg_txs, 0); - assert_eq!(received_stats.total_weight, 0); - }); + // Alice will request mempool stats from Bob, and thus should be identical + let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); + assert_eq!(received_stats.total_txs, 0); + assert_eq!(received_stats.unconfirmed_txs, 0); + assert_eq!(received_stats.reorg_txs, 0); + assert_eq!(received_stats.total_weight, 0); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn request_response_get_tx_state_by_excess_sig() { +async fn request_response_get_tx_state_by_excess_sig() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -786,13 +781,13 @@ fn request_response_get_tx_state_by_excess_sig() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let (tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo.clone()], to: vec![2 * T, 2 * T, 2 * T])); let (unpublished_tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![3 * T])); @@ -807,43 +802,40 @@ fn request_response_get_tx_state_by_excess_sig() { // Check that the transactions are in the expected pools. // Spending the coinbase utxo will be in the pending pool, because cb utxos have a maturity. // The orphan tx will be in the orphan pool, while the unadded tx won't be found - runtime.block_on(async { - let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); - let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); - let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(orphan_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - }); + let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); + let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); + let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(orphan_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); } static EMISSION: [u64; 2] = [10, 10]; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn receive_and_propagate_transaction() { +async fn receive_and_propagate_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -857,13 +849,13 @@ fn receive_and_propagate_transaction() { .build(); let (mut alice_node, mut bob_node, mut carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -884,63 +876,61 @@ fn receive_and_propagate_transaction() { assert!(alice_node.mempool.insert(Arc::new(tx.clone())).is_ok()); assert!(alice_node.mempool.insert(Arc::new(orphan.clone())).is_ok()); - runtime.block_on(async { - alice_node - .outbound_message_service - .send_direct( - bob_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), - ) - .await - .unwrap(); - alice_node - .outbound_message_service - .send_direct( - carol_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), - ) - .await - .unwrap(); + alice_node + .outbound_message_service + .send_direct( + bob_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), + ) + .await + .unwrap(); + alice_node + .outbound_message_service + .send_direct( + carol_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), + ) + .await + .unwrap(); - async_assert_eventually!( - bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(tx_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // Carol got sent the orphan tx directly, so it will be in her mempool - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated - // by the time we check it - async_assert_eventually!( - bob_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - ); - }); + async_assert_eventually!( + bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(tx_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // Carol got sent the orphan tx directly, so it will be in her mempool + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated + // by the time we check it + async_assert_eventually!( + bob_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + ); } -#[test] -fn consensus_validation_large_tx() { +#[tokio::test] +async fn consensus_validation_large_tx() { let network = Network::LocalNet; // We dont want to compute the 19500 limit of local net, so we create smaller blocks let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -1042,9 +1032,8 @@ fn consensus_validation_large_tx() { assert!(matches!(response, TxStorageResponse::NotStored)); } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let mempool_service_config = MempoolServiceConfig { @@ -1053,27 +1042,25 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), mempool_service_config, LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - bob_node.shutdown().await; + bob_node.shutdown().await; - match alice_node.outbound_mp_interface.get_stats().await { - Err(MempoolServiceError::RequestTimedOut) => {}, - _ => panic!(), - } - }); + match alice_node.outbound_mp_interface.get_stats().await { + Err(MempoolServiceError::RequestTimedOut) => {}, + _ => panic!(), + } } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn block_event_and_reorg_event_handling() { +async fn block_event_and_reorg_event_handling() { // This test creates 2 nodes Alice and Bob // Then creates 2 chains B1 -> B2A (diff 1) and B1 -> B2B (diff 10) // There are 5 transactions created @@ -1086,7 +1073,6 @@ fn block_event_and_reorg_event_handling() { let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxos0) = create_genesis_block_with_coinbase_value(&factories, 100_000_000.into(), &consensus_constants[0]); @@ -1095,13 +1081,13 @@ fn block_event_and_reorg_event_handling() { .with_block(block0.clone()) .build(); let (mut alice, mut bob, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -1135,88 +1121,86 @@ fn block_event_and_reorg_event_handling() { .prepare_block_merkle_roots(chain_block(block0.block(), vec![], &consensus_manager)) .unwrap(); - runtime.block_on(async { - // Add one empty block, so the coinbase UTXO is no longer time-locked. - assert!(bob - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - assert!(alice - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); - let mut block1 = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); - // Add Block1 - tx1 will be moved to the ReorgPool. - assert!(bob - .local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .is_ok()); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - - let mut block2a = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); - // Block2b also builds on Block1 but has a stronger PoW - let mut block2b = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); - - // Add Block2a - tx2b and tx3b will be discarded as double spends. - assert!(bob - .local_nci - .submit_block(block2a.clone(), Broadcast::from(true)) - .await - .is_ok()); - - async_assert_eventually!( - bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - }); + // Add one empty block, so the coinbase UTXO is no longer time-locked. + assert!(bob + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + assert!(alice + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); + let mut block1 = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); + // Add Block1 - tx1 will be moved to the ReorgPool. + assert!(bob + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .is_ok()); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + + let mut block2a = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); + // Block2b also builds on Block1 but has a stronger PoW + let mut block2b = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); + + // Add Block2a - tx2b and tx3b will be discarded as double spends. + assert!(bob + .local_nci + .submit_block(block2a.clone(), Broadcast::from(true)) + .await + .is_ok()); + + async_assert_eventually!( + bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); } diff --git a/base_layer/core/tests/node_comms_interface.rs b/base_layer/core/tests/node_comms_interface.rs index 532d102bbe..97851f0d17 100644 --- a/base_layer/core/tests/node_comms_interface.rs +++ b/base_layer/core/tests/node_comms_interface.rs @@ -22,7 +22,7 @@ #[allow(dead_code)] mod helpers; -use futures::{channel::mpsc, StreamExt}; +use futures::StreamExt; use helpers::block_builders::append_block; use std::sync::Arc; use tari_common::configuration::Network; @@ -55,7 +55,7 @@ use tari_crypto::{ tari_utilities::hash::Hashable, }; use tari_service_framework::{reply_channel, reply_channel::Receiver}; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; // use crate::helpers::database::create_test_db; async fn test_request_responder( @@ -71,10 +71,10 @@ fn new_mempool() -> Mempool { Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)) } -#[tokio_macros::test] +#[tokio::test] async fn outbound_get_metadata() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let metadata = ChainMetadata::new(5, vec![0u8], 3, 0, 5); @@ -86,7 +86,7 @@ async fn outbound_get_metadata() { assert_eq!(received_metadata.unwrap(), metadata); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_get_metadata() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -95,7 +95,7 @@ async fn inbound_get_metadata() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -117,7 +117,7 @@ async fn inbound_get_metadata() { } } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_kernel_by_excess_sig() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -126,7 +126,7 @@ async fn inbound_fetch_kernel_by_excess_sig() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -149,10 +149,10 @@ async fn inbound_fetch_kernel_by_excess_sig() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_headers() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let mut header = BlockHeader::new(0); @@ -167,7 +167,7 @@ async fn outbound_fetch_headers() { assert_eq!(received_headers[0], header); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_headers() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -175,7 +175,7 @@ async fn inbound_fetch_headers() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -197,11 +197,11 @@ async fn inbound_fetch_headers() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_utxos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (utxo, _, _) = create_utxo( @@ -221,7 +221,7 @@ async fn outbound_fetch_utxos() { assert_eq!(received_utxos[0], utxo); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_utxos() { let factories = CryptoFactories::default(); @@ -231,7 +231,7 @@ async fn inbound_fetch_utxos() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -264,11 +264,11 @@ async fn inbound_fetch_utxos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_txos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (txo1, _, _) = create_utxo( @@ -296,7 +296,7 @@ async fn outbound_fetch_txos() { assert_eq!(received_txos[1], txo2); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_txos() { let factories = CryptoFactories::default(); let store = create_test_blockchain_db(); @@ -305,7 +305,7 @@ async fn inbound_fetch_txos() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -366,10 +366,10 @@ async fn inbound_fetch_txos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_blocks() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -385,7 +385,7 @@ async fn outbound_fetch_blocks() { assert_eq!(received_blocks[0], block); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_blocks() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -393,7 +393,7 @@ async fn inbound_fetch_blocks() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -415,7 +415,7 @@ async fn inbound_fetch_blocks() { } } -#[tokio_macros::test] +#[tokio::test] // Test needs to be updated to new pruned structure. async fn inbound_fetch_blocks_before_horizon_height() { let factories = CryptoFactories::default(); @@ -437,7 +437,7 @@ async fn inbound_fetch_blocks_before_horizon_height() { let mempool = Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index 14f277f19b..d806b9b646 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -23,7 +23,6 @@ #[allow(dead_code)] mod helpers; use crate::helpers::block_builders::{construct_chained_blocks, create_coinbase}; -use futures::join; use helpers::{ block_builders::{ append_block, @@ -72,11 +71,9 @@ use tari_crypto::tari_utilities::hash::Hashable; use tari_p2p::services::liveness::LivenessConfig; use tari_test_utils::unpack_enum; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] -fn request_response_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_response_get_metadata() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -89,27 +86,24 @@ fn request_response_get_metadata() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); - assert_eq!(received_metadata.height_of_longest_chain(), 0); + let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); + assert_eq!(received_metadata.height_of_longest_chain(), 0); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -122,13 +116,13 @@ fn request_and_response_fetch_blocks() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -147,26 +141,23 @@ fn request_and_response_fetch_blocks() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks_with_hashes() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks_with_hashes() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -179,13 +170,13 @@ fn request_and_response_fetch_blocks_with_hashes() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -206,34 +197,31 @@ fn request_and_response_fetch_blocks_with_hashes() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(received_blocks[0], received_blocks[1]); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(received_blocks[0], received_blocks[1]); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_many_valid_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_many_valid_blocks() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate a number of block hashes to bob, bob will receive it, request the full block, verify and @@ -261,24 +249,28 @@ fn propagate_and_forward_many_valid_blocks() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity) .with_peers(vec![carol_node_identity, bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -302,56 +294,49 @@ fn propagate_and_forward_many_valid_blocks() { let blocks = construct_chained_blocks(&alice_node.blockchain_db, block0, &rules, 5); - runtime.block_on(async { - for block in &blocks { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block.block()), vec![]) - .await - .unwrap(); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - let block_hash = block.hash(); - - if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Bob's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*carol_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Carol's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*dan_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Dan's node did not receive and validate the expected block"); - } + for block in &blocks { + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block.block()), vec![]) + .await + .unwrap(); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + let block_hash = block.hash(); + + if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Bob's node did not receive and validate the expected block"); } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Carol's node did not receive and validate the expected block"); + } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*dan_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Dan's node did not receive and validate the expected block"); + } + } - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn propagate_and_forward_invalid_block_hash() { +#[tokio::test] +async fn propagate_and_forward_invalid_block_hash() { // Alice will propagate a "made up" block hash to Bob, Bob will request the block from Alice. Alice will not be able // to provide the block and so Bob will not propagate the hash further to Carol. // alice -> bob -> carol - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); @@ -370,19 +355,22 @@ fn propagate_and_forward_invalid_block_hash() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity) .with_peers(vec![bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -409,42 +397,37 @@ fn propagate_and_forward_invalid_block_hash() { let mut bob_message_events = bob_node.messaging_events.subscribe(); let mut carol_message_events = carol_node.messaging_events.subscribe(); - runtime.block_on(async { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .unwrap(); + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .unwrap(); - // Alice propagated to Bob - // Bob received the invalid hash - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); - // Sent the request for the block to Alice - // Bob received a response from Alice - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); - assert_eq!(&*node_id, alice_node.node_identity.node_id()); - // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be - // flaky. - let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; - assert!(msg_event.is_none()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + // Alice propagated to Bob + // Bob received the invalid hash + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + // Sent the request for the block to Alice + // Bob received a response from Alice + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + assert_eq!(&*node_id, alice_node.node_identity.node_id()); + // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be + // flaky. + let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; + assert!(msg_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_invalid_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_invalid_block() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate an invalid block to Carol and Bob, they will check the received block and not propagate the @@ -473,7 +456,8 @@ fn propagate_and_forward_invalid_block() { let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![dan_node_identity.clone()]) @@ -483,20 +467,23 @@ fn propagate_and_forward_invalid_block() { mock_validator.clone(), stateless_block_validator.clone(), ) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![dan_node_identity]) .with_consensus_manager(rules) .with_validators(mock_validator.clone(), mock_validator, stateless_block_validator) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, @@ -520,45 +507,42 @@ fn propagate_and_forward_invalid_block() { let block1 = append_block(&alice_node.blockchain_db, &block0, vec![], &rules, 1.into()).unwrap(); let block1_hash = block1.hash(); - runtime.block_on(async { - let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); - let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); - let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - - assert!(alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .is_ok()); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - - if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Bob's node should have detected an invalid block"); - } - if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Carol's node should have detected an invalid block"); - } - assert!(dan_block_event.is_none()); + let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); + let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); + let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + assert!(alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .is_ok()); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + + if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Bob's node should have detected an invalid block"); + } + if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Carol's node should have detected an invalid block"); + } + assert!(dan_block_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let base_node_service_config = BaseNodeServiceConfig { @@ -569,47 +553,42 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, base_node_service_config, MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - // Bob should not be reachable - bob_node.shutdown().await; - unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); - alice_node.shutdown().await; - }); + // Bob should not be reachable + bob_node.shutdown().await; + unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); + alice_node.shutdown().await; } -#[test] -fn local_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_get_metadata() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let block0 = db.fetch_block(0).unwrap().try_into_chain_block().unwrap(); let block1 = append_block(db, &block0, vec![], &consensus_manager, 1.into()).unwrap(); let block2 = append_block(db, &block1, vec![], &consensus_manager, 1.into()).unwrap(); - runtime.block_on(async { - let metadata = node.local_nci.get_metadata().await.unwrap(); - assert_eq!(metadata.height_of_longest_chain(), 2); - assert_eq!(metadata.best_block(), block2.hash()); + let metadata = node.local_nci.get_metadata().await.unwrap(); + assert_eq!(metadata.height_of_longest_chain(), 2); + assert_eq!(metadata.best_block(), block2.hash()); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_template_and_get_new_block() { +#[tokio::test] +async fn local_get_new_block_template_and_get_new_block() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -620,7 +599,8 @@ fn local_get_new_block_template_and_get_new_block() { .build(); let (mut node, _rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let schema = [ txn_schema!(from: vec![outputs[1].clone()], to: vec![10_000 * uT, 20_000 * uT]), @@ -630,29 +610,26 @@ fn local_get_new_block_template_and_get_new_block() { assert!(node.mempool.insert(txs[0].clone()).is_ok()); assert!(node.mempool.insert(txs[1].clone()).is_ok()); - runtime.block_on(async { - let block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 2); + let block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 2); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); - node.blockchain_db.add_block(block.clone().into()).unwrap(); + node.blockchain_db.add_block(block.clone().into()).unwrap(); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_with_zero_conf() { +#[tokio::test] +async fn local_get_new_block_with_zero_conf() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -668,7 +645,8 @@ fn local_get_new_block_with_zero_conf() { HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -700,38 +678,35 @@ fn local_get_new_block_with_zero_conf() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_get_new_block_with_combined_transaction() { +#[tokio::test] +async fn local_get_new_block_with_combined_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -747,7 +722,8 @@ fn local_get_new_block_with_combined_transaction() { HeaderValidator::new(rules.clone()), OrphanBlockValidator::new(rules, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -774,41 +750,39 @@ fn local_get_new_block_with_combined_transaction() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_submit_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_submit_block() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let mut event_stream = node.local_nci.get_block_event_stream(); @@ -818,20 +792,18 @@ fn local_submit_block() { .unwrap(); block1.header.kernel_mmr_size += 1; block1.header.output_mmr_size += 1; - runtime.block_on(async { - node.local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .unwrap(); + node.local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); - let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; - if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap().unwrap() { - assert_eq!(received_block.hash(), block1.hash()); - result.assert_added(); - } else { - panic!("Block validation failed"); - } + let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; + if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap() { + assert_eq!(received_block.hash(), block1.hash()); + result.assert_added(); + } else { + panic!("Block validation failed"); + } - node.shutdown().await; - }); + node.shutdown().await; } diff --git a/base_layer/core/tests/node_state_machine.rs b/base_layer/core/tests/node_state_machine.rs index bcbeeea436..bbe9baa5c5 100644 --- a/base_layer/core/tests/node_state_machine.rs +++ b/base_layer/core/tests/node_state_machine.rs @@ -23,13 +23,12 @@ #[allow(dead_code)] mod helpers; -use futures::StreamExt; use helpers::{ block_builders::{append_block, chain_block, create_genesis_block}, chain_metadata::{random_peer_metadata, MockChainMetadata}, nodes::{create_network_with_2_base_nodes_with_config, wait_until_online, BaseNodeBuilder}, }; -use std::{thread, time::Duration}; +use std::time::Duration; use tari_common::configuration::Network; use tari_core::{ base_node::{ @@ -54,15 +53,14 @@ use tari_p2p::services::liveness::LivenessConfig; use tari_shutdown::Shutdown; use tempfile::tempdir; use tokio::{ - runtime::Runtime, sync::{broadcast, watch}, + task, time, }; static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn test_listening_lagging() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_listening_lagging() { let factories = CryptoFactories::default(); let network = Network::LocalNet; let temp_dir = tempdir().unwrap(); @@ -75,7 +73,6 @@ fn test_listening_lagging() { .with_block(prev_block.clone()) .build(); let (alice_node, bob_node, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig { @@ -84,7 +81,8 @@ fn test_listening_lagging() { }, consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let shutdown = Shutdown::new(); let (state_change_event_publisher, _) = broadcast::channel(10); let (status_event_sender, _status_event_receiver) = watch::channel(StatusInfo::new()); @@ -103,46 +101,44 @@ fn test_listening_lagging() { consensus_manager.clone(), shutdown.to_signal(), ); - wait_until_online(&mut runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; - let await_event_task = runtime.spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); + let await_event_task = task::spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); - runtime.block_on(async move { - let bob_db = bob_node.blockchain_db; - let mut bob_local_nci = bob_node.local_nci; + let bob_db = bob_node.blockchain_db; + let mut bob_local_nci = bob_node.local_nci; - // Bob Block 1 - no block event - let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); - // Bob Block 2 - with block event and liveness service metadata update - let mut prev_block = bob_db - .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) - .unwrap(); - prev_block.header.output_mmr_size += 1; - prev_block.header.kernel_mmr_size += 1; - bob_local_nci - .submit_block(prev_block, Broadcast::from(true)) - .await - .unwrap(); - assert_eq!(bob_db.get_height().unwrap(), 2); + // Bob Block 1 - no block event + let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); + // Bob Block 2 - with block event and liveness service metadata update + let mut prev_block = bob_db + .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) + .unwrap(); + prev_block.header.output_mmr_size += 1; + prev_block.header.kernel_mmr_size += 1; + bob_local_nci + .submit_block(prev_block, Broadcast::from(true)) + .await + .unwrap(); + assert_eq!(bob_db.get_height().unwrap(), 2); - let next_event = time::timeout(Duration::from_secs(10), await_event_task) - .await - .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") - .unwrap(); + let next_event = time::timeout(Duration::from_secs(10), await_event_task) + .await + .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") + .unwrap(); - match next_event { - StateEvent::InitialSync => {}, - _ => panic!(), - } - }); + match next_event { + StateEvent::InitialSync => {}, + _ => panic!(), + } } -#[test] -fn test_event_channel() { +#[tokio::test] +async fn test_event_channel() { let temp_dir = tempdir().unwrap(); - let mut runtime = Runtime::new().unwrap(); - let (node, consensus_manager) = - BaseNodeBuilder::new(Network::Weatherwax.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (node, consensus_manager) = BaseNodeBuilder::new(Network::Weatherwax.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; // let shutdown = Shutdown::new(); let db = create_test_blockchain_db(); let shutdown = Shutdown::new(); @@ -165,24 +161,21 @@ fn test_event_channel() { shutdown.to_signal(), ); - runtime.spawn(state_machine.run()); + task::spawn(state_machine.run()); let PeerChainMetadata { node_id, chain_metadata, } = random_peer_metadata(10, 5_000); - runtime - .block_on(mock.publish_chain_metadata(&node_id, &chain_metadata)) + mock.publish_chain_metadata(&node_id, &chain_metadata) + .await .expect("Could not publish metadata"); - thread::sleep(Duration::from_millis(50)); - runtime.block_on(async { - let event = state_change_event_subscriber.next().await; - assert_eq!(*event.unwrap().unwrap(), StateEvent::Initialized); - let event = state_change_event_subscriber.next().await; - let event = event.unwrap().unwrap(); - match event.as_ref() { - StateEvent::InitialSync => (), - _ => panic!("Unexpected state was found:{:?}", event), - } - }); + let event = state_change_event_subscriber.recv().await; + assert_eq!(*event.unwrap(), StateEvent::Initialized); + let event = state_change_event_subscriber.recv().await; + let event = event.unwrap(); + match event.as_ref() { + StateEvent::InitialSync => (), + _ => panic!("Unexpected state was found:{:?}", event), + } } diff --git a/base_layer/key_manager/Cargo.toml b/base_layer/key_manager/Cargo.toml index e57d0bf8b4..a40415ce17 100644 --- a/base_layer/key_manager/Cargo.toml +++ b/base_layer/key_manager/Cargo.toml @@ -15,7 +15,7 @@ sha2 = "0.9.5" serde = "1.0.89" serde_derive = "1.0.89" serde_json = "1.0.39" -thiserror = "1.0.20" +thiserror = "1.0.26" [features] avx2 = ["tari_crypto/avx2"] diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index 38b723a5f2..0d725874fb 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -14,7 +14,7 @@ benches = ["criterion"] [dependencies] tari_utilities = "^0.3" -thiserror = "1.0.20" +thiserror = "1.0.26" digest = "0.9.0" log = "0.4" serde = { version = "1.0.97", features = ["derive"] } diff --git a/base_layer/p2p/Cargo.toml b/base_layer/p2p/Cargo.toml index e9d3a0291c..0d7aae7390 100644 --- a/base_layer/p2p/Cargo.toml +++ b/base_layer/p2p/Cargo.toml @@ -10,37 +10,38 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../../comms"} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht"} -tari_common = { version= "^0.9", path = "../../common" } +tari_comms = { version = "^0.9", path = "../../comms" } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } +tari_common = { version = "^0.9", path = "../../common" } tari_crypto = "0.11.1" -tari_service_framework = { version = "^0.9", path = "../service_framework"} -tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } tari_utilities = "^0.3" anyhow = "1.0.32" bytes = "0.5" -chrono = {version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } fs2 = "0.3.0" -futures = {version = "^0.3.1"} +futures = { version = "^0.3.1" } lmdb-zero = "0.4.4" log = "0.4.6" -pgp = {version = "0.7.1", optional = true} -prost = "0.6.1" +pgp = { version = "0.7.1", optional = true } +prost = "=0.8.0" rand = "0.8" -reqwest = {version = "0.10", optional = true, default-features = false} +reqwest = { version = "0.10", optional = true, default-features = false } semver = "1.0.1" serde = "1.0.90" serde_derive = "1.0.90" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["blocking"]} +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tokio-stream = { version = "0.1.7", default-features = false, features = ["time"] } tower = "0.3.0-alpha.2" -tower-service = { version="0.3.0-alpha.2" } -trust-dns-client = {version="0.19.5", features=["dns-over-rustls"]} +tower-service = { version = "0.3.0-alpha.2" } +trust-dns-client = { version = "0.21.0-alpha.1", features = ["dns-over-rustls"] } [dev-dependencies] -tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } clap = "2.33.0" env_logger = "0.6.2" @@ -48,7 +49,6 @@ futures-timer = "0.3.0" lazy_static = "1.3.0" stream-cancel = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.4" [dev-dependencies.log4rs] version = "^0.8" @@ -56,7 +56,7 @@ features = ["console_appender", "file_appender", "file", "yaml_format"] default-features = false [build-dependencies] -tari_common = { version = "^0.9", path="../../common", features = ["build"] } +tari_common = { version = "^0.9", path = "../../common", features = ["build"] } [features] test-mocks = [] diff --git a/base_layer/p2p/src/auto_update/dns.rs b/base_layer/p2p/src/auto_update/dns.rs index 7042b19919..64c3be7f5a 100644 --- a/base_layer/p2p/src/auto_update/dns.rs +++ b/base_layer/p2p/src/auto_update/dns.rs @@ -189,19 +189,26 @@ impl Display for UpdateSpec { #[cfg(test)] mod test { use super::*; + use crate::dns::mock; use trust_dns_client::{ - proto::rr::{rdata, RData, RecordType}, + op::Query, + proto::{ + rr::{rdata, Name, RData, RecordType}, + xfer::DnsResponse, + }, rr::Record, }; - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let resp_query = Query::query(Name::from_str("test.local.").unwrap(), RecordType::A); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } mod update_spec { @@ -220,7 +227,6 @@ mod test { mod dns_software_update { use super::*; use crate::DEFAULT_DNS_NAME_SERVER; - use std::{collections::HashMap, iter::FromIterator}; impl AutoUpdateConfig { fn get_test_defaults() -> Self { @@ -238,15 +244,15 @@ mod test { } } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_ignores_non_conforming_txt_entries() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec![":::"]), - create_txt_record(vec!["base-node:::"]), - create_txt_record(vec!["base-node::1.0:"]), - create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec![":::"])), + Ok(create_txt_record(vec!["base-node:::"])), + Ok(create_txt_record(vec!["base-node::1.0:"])), + Ok(create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), @@ -258,12 +264,12 @@ mod test { assert!(spec.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_best_update() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), diff --git a/base_layer/p2p/src/auto_update/service.rs b/base_layer/p2p/src/auto_update/service.rs index a786d84ad3..eebd9ec555 100644 --- a/base_layer/p2p/src/auto_update/service.rs +++ b/base_layer/p2p/src/auto_update/service.rs @@ -24,17 +24,15 @@ use crate::{ auto_update, auto_update::{AutoUpdateConfig, SoftwareUpdate, Version}, }; -use futures::{ - channel::{mpsc, oneshot}, - future::Either, - stream, - SinkExt, - StreamExt, -}; +use futures::{future::Either, stream, StreamExt}; use std::{env::consts, time::Duration}; use tari_common::configuration::bootstrap::ApplicationType; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::{sync::watch, time}; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "app:auto-update"; @@ -94,19 +92,20 @@ impl SoftwareUpdaterService { new_update_notification: watch::Receiver<Option<SoftwareUpdate>>, ) { let mut interval_or_never = match self.check_interval { - Some(interval) => Either::Left(time::interval(interval)).fuse(), - None => Either::Right(stream::empty()).fuse(), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval(interval))), + None => Either::Right(stream::empty()), }; loop { let last_version = new_update_notification.borrow().clone(); - let maybe_update = futures::select! { - reply = request_rx.select_next_some() => { + let maybe_update = tokio::select! { + Some(reply) = request_rx.recv() => { let maybe_update = self.check_for_updates().await; let _ = reply.send(maybe_update.clone()); maybe_update }, + _ = interval_or_never.next() => { // Periodically, check for updates if configured to do so. // If an update is found the new update notifier will be triggered and any listeners notified @@ -121,7 +120,7 @@ impl SoftwareUpdaterService { .map(|up| up.version() < update.version()) .unwrap_or(true) { - let _ = notifier.broadcast(Some(update.clone())); + let _ = notifier.send(Some(update.clone())); } } } diff --git a/base_layer/p2p/src/comms_connector/inbound_connector.rs b/base_layer/p2p/src/comms_connector/inbound_connector.rs index ed16cd578d..6230d7a78e 100644 --- a/base_layer/p2p/src/comms_connector/inbound_connector.rs +++ b/base_layer/p2p/src/comms_connector/inbound_connector.rs @@ -22,45 +22,42 @@ use super::peer_message::PeerMessage; use anyhow::anyhow; -use futures::{task::Context, Future, Sink, SinkExt}; +use futures::{task::Context, Future}; use log::*; use std::{pin::Pin, sync::Arc, task::Poll}; use tari_comms::pipeline::PipelineError; use tari_comms_dht::{domain_message::MessageHeader, inbound::DecryptedDhtMessage}; +use tokio::sync::mpsc; use tower::Service; const LOG_TARGET: &str = "comms::middleware::inbound_connector"; /// This service receives DecryptedDhtMessage, deserializes the MessageHeader and /// sends a `PeerMessage` on the given sink. #[derive(Clone)] -pub struct InboundDomainConnector<TSink> { - sink: TSink, +pub struct InboundDomainConnector { + sink: mpsc::Sender<Arc<PeerMessage>>, } -impl<TSink> InboundDomainConnector<TSink> { - pub fn new(sink: TSink) -> Self { +impl InboundDomainConnector { + pub fn new(sink: mpsc::Sender<Arc<PeerMessage>>) -> Self { Self { sink } } } -impl<TSink> Service<DecryptedDhtMessage> for InboundDomainConnector<TSink> -where - TSink: Sink<Arc<PeerMessage>> + Unpin + Clone + 'static, - TSink::Error: std::error::Error + Send + Sync + 'static, -{ +impl Service<DecryptedDhtMessage> for InboundDomainConnector { type Error = PipelineError; - type Future = Pin<Box<dyn Future<Output = Result<(), PipelineError>>>>; + type Future = Pin<Box<dyn Future<Output = Result<(), PipelineError>> + Send>>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - let mut sink = self.sink.clone(); + let sink = self.sink.clone(); let future = async move { let peer_message = Self::construct_peer_message(msg)?; - // If this fails there is something wrong with the sink and the pubsub middleware should not + // If this fails the channel has closed and the pubsub middleware should not // continue sink.send(Arc::new(peer_message)).await?; @@ -70,7 +67,7 @@ where } } -impl<TSink> InboundDomainConnector<TSink> { +impl InboundDomainConnector { fn construct_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result<PeerMessage, PipelineError> { let envelope_body = inbound_message .success_mut() @@ -107,41 +104,42 @@ impl<TSink> InboundDomainConnector<TSink> { } } -impl<TSink> Sink<DecryptedDhtMessage> for InboundDomainConnector<TSink> -where - TSink: Sink<Arc<PeerMessage>> + Unpin, - TSink::Error: Into<PipelineError> + Send + Sync + 'static, -{ - type Error = PipelineError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) - } - - fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { - let item = Self::construct_peer_message(item)?; - Pin::new(&mut self.sink).start_send(Arc::new(item)).map_err(Into::into) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Pin::new(&mut self.sink).poll_flush(cx).map_err(Into::into) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Pin::new(&mut self.sink).poll_close(cx).map_err(Into::into) - } -} +// impl<TSink> Sink<DecryptedDhtMessage> for InboundDomainConnector<TSink> +// where +// TSink: Sink<Arc<PeerMessage>> + Unpin, +// TSink::Error: Into<PipelineError> + Send + Sync + 'static, +// { +// type Error = PipelineError; +// +// fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +// Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) +// } +// +// fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { +// let item = Self::construct_peer_message(item)?; +// Pin::new(&mut self.sink).start_send(Arc::new(item)).map_err(Into::into) +// } +// +// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +// Pin::new(&mut self.sink).poll_flush(cx).map_err(Into::into) +// } +// +// fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +// Pin::new(&mut self.sink).poll_close(cx).map_err(Into::into) +// } +// } #[cfg(test)] mod test { use super::*; use crate::test_utils::{make_dht_inbound_message, make_node_identity}; - use futures::{channel::mpsc, executor::block_on, StreamExt}; + use futures::executor::block_on; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_comms_dht::domain_message::MessageHeader; + use tokio::sync::mpsc; use tower::ServiceExt; - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -151,12 +149,12 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::<String>().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -165,14 +163,14 @@ mod test { let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); - InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); + InboundDomainConnector::new(tx).call(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::<String>().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); @@ -182,10 +180,11 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); - assert!(rx.try_next().unwrap().is_none()); + rx.close(); + assert!(rx.recv().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_send() { // Drop the receiver of the channel, this is the only reason this middleware should return an error // from it's call function diff --git a/base_layer/p2p/src/comms_connector/pubsub.rs b/base_layer/p2p/src/comms_connector/pubsub.rs index ae01e8ced8..5051c3ae8d 100644 --- a/base_layer/p2p/src/comms_connector/pubsub.rs +++ b/base_layer/p2p/src/comms_connector/pubsub.rs @@ -22,11 +22,15 @@ use super::peer_message::PeerMessage; use crate::{comms_connector::InboundDomainConnector, tari_message::TariMessageType}; -use futures::{channel::mpsc, future, stream::Fuse, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{cmp, fmt::Debug, sync::Arc, time::Duration}; use tari_comms::rate_limit::RateLimit; -use tokio::{runtime::Handle, sync::broadcast}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "comms::middleware::pubsub"; @@ -35,16 +39,11 @@ const RATE_LIMIT_MIN_CAPACITY: usize = 5; const RATE_LIMIT_RESTOCK_INTERVAL: Duration = Duration::from_millis(1000); /// Alias for a pubsub-type domain connector -pub type PubsubDomainConnector = InboundDomainConnector<mpsc::Sender<Arc<PeerMessage>>>; +pub type PubsubDomainConnector = InboundDomainConnector; pub type SubscriptionFactory = TopicSubscriptionFactory<TariMessageType, Arc<PeerMessage>>; /// Connects `InboundDomainConnector` to a `tari_pubsub::TopicPublisher` through a buffered broadcast channel -pub fn pubsub_connector( - // TODO: Remove this arg in favor of task::spawn - executor: Handle, - buf_size: usize, - rate_limit: usize, -) -> (PubsubDomainConnector, SubscriptionFactory) { +pub fn pubsub_connector(buf_size: usize, rate_limit: usize) -> (PubsubDomainConnector, SubscriptionFactory) { let (publisher, subscription_factory) = pubsub_channel(buf_size); let (sender, receiver) = mpsc::channel(buf_size); trace!( @@ -55,8 +54,8 @@ pub fn pubsub_connector( ); // Spawn a task which forwards messages from the pubsub service to the TopicPublisher - executor.spawn(async move { - let forwarder = receiver + task::spawn(async move { + wrappers::ReceiverStream::new(receiver) // Rate limit the receiver; the sender will adhere to the limit .rate_limit(cmp::max(rate_limit, RATE_LIMIT_MIN_CAPACITY), RATE_LIMIT_RESTOCK_INTERVAL) // Map DomainMessage into a TopicPayload @@ -89,8 +88,7 @@ pub fn pubsub_connector( ); } future::ready(()) - }); - forwarder.await; + }).await; }); (InboundDomainConnector::new(sender), subscription_factory) } @@ -98,8 +96,8 @@ pub fn pubsub_connector( /// Create a topic-based pub-sub channel fn pubsub_channel<T, M>(size: usize) -> (TopicPublisher<T, M>, TopicSubscriptionFactory<T, M>) where - T: Clone + Debug + Send + Eq, - M: Send + Clone, + T: Clone + Debug + Send + Eq + 'static, + M: Send + Clone + 'static, { let (publisher, _) = broadcast::channel(size); (publisher.clone(), TopicSubscriptionFactory::new(publisher)) @@ -138,8 +136,8 @@ pub struct TopicSubscriptionFactory<T, M> { impl<T, M> TopicSubscriptionFactory<T, M> where - T: Clone + Eq + Debug + Send, - M: Clone + Send, + T: Clone + Eq + Debug + Send + 'static, + M: Clone + Send + 'static, { pub fn new(sender: broadcast::Sender<TopicPayload<T, M>>) -> Self { TopicSubscriptionFactory { sender } @@ -148,38 +146,23 @@ where /// Create a subscription stream to a particular topic. The provided label is used to identify which consumer is /// lagging. pub fn get_subscription(&self, topic: T, label: &'static str) -> impl Stream<Item = M> { - self.sender - .subscribe() - .filter_map({ - let topic = topic.clone(); - move |result| { - let opt = match result { - Ok(payload) => Some(payload), - Err(broadcast::RecvError::Closed) => None, - Err(broadcast::RecvError::Lagged(n)) => { - warn!( - target: LOG_TARGET, - "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n - ); - None - }, - }; - future::ready(opt) - } - }) - .filter_map(move |item| { - let opt = if item.topic() == &topic { - Some(item.message) - } else { - None + wrappers::BroadcastStream::new(self.sender.subscribe()).filter_map({ + let topic = topic.clone(); + move |result| { + let opt = match result { + Ok(payload) if *payload.topic() == topic => Some(payload.message), + Ok(_) => None, + Err(wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => { + warn!( + target: LOG_TARGET, + "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n + ); + None + }, }; future::ready(opt) - }) - } - - /// Convenience function that returns a fused (`stream::Fuse`) version of the subscription stream. - pub fn get_subscription_fused(&self, topic: T, label: &'static str) -> Fuse<impl Stream<Item = M>> { - self.get_subscription(topic, label).fuse() + } + }) } } @@ -190,7 +173,7 @@ mod test { use std::time::Duration; use tari_test_utils::collect_stream; - #[tokio_macros::test_basic] + #[tokio::test] async fn topic_pub_sub() { let (publisher, subscriber_factory) = pubsub_channel(10); diff --git a/base_layer/p2p/src/dns/client.rs b/base_layer/p2p/src/dns/client.rs index 85e03f71e4..78093186ef 100644 --- a/base_layer/p2p/src/dns/client.rs +++ b/base_layer/p2p/src/dns/client.rs @@ -20,35 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#[cfg(test)] +use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use super::DnsClientError; -use futures::future; +use futures::{future, FutureExt}; use std::{net::SocketAddr, sync::Arc}; use tari_shutdown::Shutdown; use tokio::{net::UdpSocket, task}; use trust_dns_client::{ - client::{AsyncClient, AsyncDnssecClient}, - op::{DnsResponse, Query}, - proto::{ - rr::dnssec::TrustAnchor, - udp::{UdpClientStream, UdpResponse}, - xfer::DnsRequestOptions, - DnsHandle, - }, + client::{AsyncClient, AsyncDnssecClient, ClientHandle}, + op::Query, + proto::{error::ProtoError, rr::dnssec::TrustAnchor, udp::UdpClientStream, xfer::DnsResponse, DnsHandle}, rr::{DNSClass, IntoName, RecordType}, serialize::binary::BinEncoder, }; -#[cfg(test)] -use std::collections::HashMap; -#[cfg(test)] -use trust_dns_client::{proto::xfer::DnsMultiplexerSerialResponse, rr::Record}; - #[derive(Clone)] pub enum DnsClient { - Secure(Client<AsyncDnssecClient<UdpResponse>>), - Normal(Client<AsyncClient<UdpResponse>>), + Secure(Client<AsyncDnssecClient>), + Normal(Client<AsyncClient>), #[cfg(test)] - Mock(Client<AsyncClient<DnsMultiplexerSerialResponse>>), + Mock(Client<MockClientHandle<DefaultOnSend, ProtoError>>), } impl DnsClient { @@ -63,18 +56,18 @@ impl DnsClient { } #[cfg(test)] - pub async fn connect_mock(records: HashMap<&'static str, Vec<Record>>) -> Result<Self, DnsClientError> { - let client = Client::connect_mock(records).await?; + pub async fn connect_mock(messages: Vec<Result<DnsResponse, ProtoError>>) -> Result<Self, DnsClientError> { + let client = Client::connect_mock(messages).await?; Ok(DnsClient::Mock(client)) } - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result<DnsResponse, DnsClientError> { + pub async fn lookup(&mut self, query: Query) -> Result<DnsResponse, DnsClientError> { use DnsClient::*; match self { - Secure(ref mut client) => client.lookup(query, options).await, - Normal(ref mut client) => client.lookup(query, options).await, + Secure(ref mut client) => client.lookup(query).await, + Normal(ref mut client) => client.lookup(query).await, #[cfg(test)] - Mock(ref mut client) => client.lookup(query, options).await, + Mock(ref mut client) => client.lookup(query).await, } } @@ -85,11 +78,11 @@ impl DnsClient { .set_query_class(DNSClass::IN) .set_query_type(RecordType::TXT); - let response = self.lookup(query, Default::default()).await?; + let responses = self.lookup(query).await?; - let records = response - .messages() - .flat_map(|msg| msg.answers()) + let records = responses + .answers() + .iter() .map(|answer| { let data = answer.rdata(); let mut buf = Vec::new(); @@ -116,7 +109,7 @@ pub struct Client<C> { shutdown: Arc<Shutdown>, } -impl Client<AsyncDnssecClient<UdpResponse>> { +impl Client<AsyncDnssecClient> { pub async fn connect_secure(name_server: SocketAddr, trust_anchor: TrustAnchor) -> Result<Self, DnsClientError> { let shutdown = Shutdown::new(); let stream = UdpClientStream::<UdpSocket>::new(name_server); @@ -124,7 +117,7 @@ impl Client<AsyncDnssecClient<UdpResponse>> { .trust_anchor(trust_anchor) .build() .await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -133,12 +126,12 @@ impl Client<AsyncDnssecClient<UdpResponse>> { } } -impl Client<AsyncClient<UdpResponse>> { +impl Client<AsyncClient> { pub async fn connect(name_server: SocketAddr) -> Result<Self, DnsClientError> { let shutdown = Shutdown::new(); let stream = UdpClientStream::<UdpSocket>::new(name_server); let (client, background) = AsyncClient::connect(stream).await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -148,87 +141,31 @@ impl Client<AsyncClient<UdpResponse>> { } impl<C> Client<C> -where C: DnsHandle +where C: DnsHandle<Error = ProtoError> { - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result<DnsResponse, DnsClientError> { - let resp = self.inner.lookup(query, options).await?; - Ok(resp) + pub async fn lookup(&mut self, query: Query) -> Result<DnsResponse, DnsClientError> { + let client_resp = self + .inner + .query(query.name().clone(), query.query_class(), query.query_type()) + .await?; + Ok(client_resp) } } #[cfg(test)] mod mock { use super::*; - use futures::{channel::mpsc, future, Stream, StreamExt}; - use std::{ - fmt, - fmt::Display, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; + use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use std::sync::Arc; use tari_shutdown::Shutdown; - use tokio::task; - use trust_dns_client::{ - client::AsyncClient, - op::Message, - proto::{ - error::ProtoError, - xfer::{DnsClientStream, DnsMultiplexerSerialResponse, SerialMessage}, - StreamHandle, - }, - rr::Record, - }; - - pub struct MockStream { - receiver: mpsc::UnboundedReceiver<Vec<u8>>, - answers: HashMap<&'static str, Vec<Record>>, - } - - impl DnsClientStream for MockStream { - fn name_server_addr(&self) -> SocketAddr { - ([0u8, 0, 0, 0], 53).into() - } - } - - impl Display for MockStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MockStream") - } - } - - impl Stream for MockStream { - type Item = Result<SerialMessage, ProtoError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let req = match futures::ready!(self.receiver.poll_next_unpin(cx)) { - Some(r) => r, - None => return Poll::Ready(None), - }; - let req = Message::from_vec(&req).unwrap(); - let name = req.queries()[0].name().to_string(); - let mut msg = Message::new(); - let answers = self.answers.get(name.as_str()).into_iter().flatten().cloned(); - msg.set_id(req.id()).add_answers(answers); - Poll::Ready(Some(Ok(SerialMessage::new( - msg.to_vec().unwrap(), - self.name_server_addr(), - )))) - } - } - - impl Client<AsyncClient<DnsMultiplexerSerialResponse>> { - pub async fn connect_mock(answers: HashMap<&'static str, Vec<Record>>) -> Result<Self, ProtoError> { - let (tx, rx) = mpsc::unbounded(); - let stream = future::ready(Ok(MockStream { receiver: rx, answers })); - let (client, background) = AsyncClient::new(stream, Box::new(StreamHandle::new(tx)), None).await?; + use trust_dns_client::proto::error::ProtoError; - let shutdown = Shutdown::new(); - task::spawn(future::select(shutdown.to_signal(), background)); + impl Client<MockClientHandle<DefaultOnSend, ProtoError>> { + pub async fn connect_mock(messages: Vec<Result<DnsResponse, ProtoError>>) -> Result<Self, ProtoError> { + let client = MockClientHandle::mock(messages); Ok(Self { inner: client, - shutdown: Arc::new(shutdown), + shutdown: Arc::new(Shutdown::new()), }) } } diff --git a/base_layer/p2p/src/dns/mock.rs b/base_layer/p2p/src/dns/mock.rs new file mode 100644 index 0000000000..7f181dad90 --- /dev/null +++ b/base_layer/p2p/src/dns/mock.rs @@ -0,0 +1,106 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{future, stream, Future}; +use std::{error::Error, pin::Pin, sync::Arc}; +use trust_dns_client::{ + op::{Message, Query}, + proto::{ + error::ProtoError, + xfer::{DnsHandle, DnsRequest, DnsResponse}, + }, + rr::Record, +}; + +#[derive(Clone)] +pub struct MockClientHandle<O: OnSend, E> { + messages: Arc<Vec<Result<DnsResponse, E>>>, + on_send: O, +} + +impl<E> MockClientHandle<DefaultOnSend, E> { + /// constructs a new MockClient which returns each Message one after the other + pub fn mock(messages: Vec<Result<DnsResponse, E>>) -> Self { + println!("MockClientHandle::mock message count: {}", messages.len()); + + MockClientHandle { + messages: Arc::new(messages), + on_send: DefaultOnSend, + } + } +} + +impl<O: OnSend + Unpin, E> DnsHandle for MockClientHandle<O, E> +where E: From<ProtoError> + Error + Clone + Send + Sync + Unpin + 'static +{ + type Error = E; + type Response = stream::Once<future::Ready<Result<DnsResponse, E>>>; + + fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response { + let responses = (*self.messages) + .clone() + .into_iter() + .fold(Result::<_, E>::Ok(Message::new()), |msg, resp| { + msg.and_then(|mut msg| { + resp.map(move |resp| { + msg.add_answers(resp.answers().iter().cloned()); + msg + }) + }) + }) + .map(DnsResponse::from); + + let stream = stream::once(future::ready(responses)); + + // let stream = stream::unfold(messages, |mut msgs| async move { + // let msg = msgs.pop()?; + // Some((msg, msgs)) + // }); + stream + } +} + +pub fn message(query: Query, answers: Vec<Record>, name_servers: Vec<Record>, additionals: Vec<Record>) -> Message { + let mut message = Message::new(); + message.add_query(query); + message.insert_answers(answers); + message.insert_name_servers(name_servers); + message.insert_additionals(additionals); + message +} + +pub trait OnSend: Clone + Send + Sync + 'static { + fn on_send<E>( + &mut self, + response: Result<DnsResponse, E>, + ) -> Pin<Box<dyn Future<Output = Result<DnsResponse, E>> + Send>> + where + E: From<ProtoError> + Send + 'static, + { + Box::pin(future::ready(response)) + } +} + +#[derive(Clone)] +pub struct DefaultOnSend; + +impl OnSend for DefaultOnSend {} diff --git a/base_layer/p2p/src/dns/mod.rs b/base_layer/p2p/src/dns/mod.rs index 197b788236..8c49de1993 100644 --- a/base_layer/p2p/src/dns/mod.rs +++ b/base_layer/p2p/src/dns/mod.rs @@ -4,6 +4,9 @@ pub use client::DnsClient; mod error; pub use error::DnsClientError; +#[cfg(test)] +pub(crate) mod mock; + use trust_dns_client::proto::rr::dnssec::{public_key::Rsa, TrustAnchor}; #[inline] diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9264a5033c..0915479612 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -19,20 +19,20 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] use crate::{ - comms_connector::{InboundDomainConnector, PeerMessage, PubsubDomainConnector}, + comms_connector::{InboundDomainConnector, PubsubDomainConnector}, peer_seeds::{DnsSeedResolver, SeedPeer}, transport::{TorConfig, TransportType}, MAJOR_NETWORK_VERSION, MINOR_NETWORK_VERSION, }; use fs2::FileExt; -use futures::{channel::mpsc, future, Sink}; +use futures::future; use log::*; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::{ - error::Error, fs::File, iter, net::SocketAddr, @@ -47,7 +47,6 @@ use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerManagerError}, pipeline, - pipeline::SinkService, protocol::{ messaging::{MessagingEventSender, MessagingProtocolExtension}, rpc::RpcServer, @@ -71,7 +70,7 @@ use tari_storage::{ LMDBWrapper, }; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::ServiceBuilder; const LOG_TARGET: &str = "p2p::initialization"; @@ -158,18 +157,14 @@ pub struct CommsConfig { } /// Initialize Tari Comms configured for tests -pub async fn initialize_local_test_comms<TSink>( +pub async fn initialize_local_test_comms( node_identity: Arc<NodeIdentity>, - connector: InboundDomainConnector<TSink>, + connector: InboundDomainConnector, data_path: &str, discovery_request_timeout: Duration, seed_peers: Vec<Peer>, shutdown_signal: ShutdownSignal, -) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> -where - TSink: Sink<Arc<PeerMessage>> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> { let peer_database_name = { let mut rng = thread_rng(); iter::repeat(()) @@ -230,7 +225,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); @@ -319,15 +314,11 @@ async fn initialize_hidden_service( builder.build().await } -async fn configure_comms_and_dht<TSink>( +async fn configure_comms_and_dht( builder: CommsBuilder, config: &CommsConfig, - connector: InboundDomainConnector<TSink>, -) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> -where - TSink: Sink<Arc<PeerMessage>> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ + connector: InboundDomainConnector, +) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> { let file_lock = acquire_exclusive_file_lock(&config.datastore_path)?; let datastore = LMDBBuilder::new() @@ -391,7 +382,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); diff --git a/base_layer/p2p/src/lib.rs b/base_layer/p2p/src/lib.rs index c21dace083..63ead30b11 100644 --- a/base_layer/p2p/src/lib.rs +++ b/base_layer/p2p/src/lib.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work +// Needed to make tokio::select! work #![recursion_limit = "256"] #![cfg_attr(not(debug_assertions), deny(unused_variables))] #![cfg_attr(not(debug_assertions), deny(unused_imports))] diff --git a/base_layer/p2p/src/peer_seeds.rs b/base_layer/p2p/src/peer_seeds.rs index 24ce3c4dda..4698770fa1 100644 --- a/base_layer/p2p/src/peer_seeds.rs +++ b/base_layer/p2p/src/peer_seeds.rs @@ -120,6 +120,8 @@ mod test { use super::*; use tari_utilities::hex::Hex; + const TEST_NAME: &str = "test.local."; + mod peer_seed { use super::*; @@ -182,76 +184,88 @@ mod test { mod peer_seed_resolver { use super::*; - use std::{collections::HashMap, iter::FromIterator}; - use trust_dns_client::rr::{rdata, RData, Record, RecordType}; + use crate::dns::mock; + use trust_dns_client::{ + proto::{ + op::Query, + rr::{DNSClass, Name}, + xfer::DnsResponse, + }, + rr::{rdata, RData, Record, RecordType}, + }; #[ignore = "This test requires network IO and is mostly useful during development"] - #[tokio_macros::test] - async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { + #[tokio::test] + async fn it_returns_seeds_from_real_address() { let mut resolver = DnsSeedResolver { client: DnsClient::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(), }; - let seeds = resolver.resolve("tari.com").await.unwrap(); - assert!(seeds.is_empty()); + let seeds = resolver.resolve("seeds.weatherwax.tari.com").await.unwrap(); + assert!(!seeds.is_empty()); } - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let mut resp_query = Query::query(Name::from_str(TEST_NAME).unwrap(), RecordType::TXT); + resp_query.set_query_class(DNSClass::IN); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } - #[tokio_macros::test] + #[tokio::test] async fn it_returns_peer_seeds() { - let records = HashMap::from_iter([("test.local.", vec![ + let records = vec![ // Multiple addresses(works) - create_txt_record(vec![ - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::/\ + Ok(create_txt_record(vec![ + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c::/ip4/127.0.0.1/tcp/8000::/\ onion3/bsmuof2cn4y2ysz253gzsvg3s72fcgh4f3qcm3hdlxdtcwe6al2dicyd:1234", - ]), + ])), // Misc - create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"]), + Ok(create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"])), // Single address (works) - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // Single address trailing delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::", - ]), + ])), // Invalid public key - create_txt_record(vec![ + Ok(create_txt_record(vec![ "07e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // No Address with delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::", - ]), + ])), // No Address no delim - create_txt_record(vec!["06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a"]), + Ok(create_txt_record(vec![ + "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a", + ])), // Invalid address - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/onion3/invalid:1234", - ]), - ])]); + ])), + ]; let mut resolver = DnsSeedResolver { client: DnsClient::connect_mock(records).await.unwrap(), }; - let seeds = resolver.resolve("test.local.").await.unwrap(); + let seeds = resolver.resolve(TEST_NAME).await.unwrap(); assert_eq!(seeds.len(), 2); assert_eq!( seeds[0].public_key.to_hex(), - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c" ); + assert_eq!(seeds[0].addresses.len(), 2); assert_eq!( seeds[1].public_key.to_hex(), "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" ); - assert_eq!(seeds[0].addresses.len(), 2); assert_eq!(seeds[1].addresses.len(), 1); } } diff --git a/base_layer/p2p/src/services/liveness/mock.rs b/base_layer/p2p/src/services/liveness/mock.rs index fa5f2f72c1..470cca84c2 100644 --- a/base_layer/p2p/src/services/liveness/mock.rs +++ b/base_layer/p2p/src/services/liveness/mock.rs @@ -36,9 +36,9 @@ use std::sync::{ RwLock, }; -use tari_crypto::tari_utilities::acquire_write_lock; +use tari_crypto::tari_utilities::{acquire_read_lock, acquire_write_lock}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::sync::{broadcast, broadcast::SendError}; +use tokio::sync::{broadcast, broadcast::error::SendError}; const LOG_TARGET: &str = "p2p::liveness_mock"; @@ -69,7 +69,8 @@ impl LivenessMockState { } pub async fn publish_event(&self, event: LivenessEvent) -> Result<(), SendError<Arc<LivenessEvent>>> { - acquire_write_lock!(self.event_publisher).send(Arc::new(event))?; + let lock = acquire_read_lock!(self.event_publisher); + lock.send(Arc::new(event))?; Ok(()) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 65e85e0ea8..8e373e9abb 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -49,6 +49,7 @@ use tari_comms_dht::{ use tari_service_framework::reply_channel::RequestContext; use tari_shutdown::ShutdownSignal; use tokio::time; +use tokio_stream::wrappers; /// Service responsible for testing Liveness of Peers. pub struct LivenessService<THandleStream, TPingStream> { @@ -59,7 +60,7 @@ pub struct LivenessService<THandleStream, TPingStream> { connectivity: ConnectivityRequester, outbound_messaging: OutboundMessageRequester, event_publisher: LivenessEventSender, - shutdown_signal: Option<ShutdownSignal>, + shutdown_signal: ShutdownSignal, } impl<TRequestStream, TPingStream> LivenessService<TRequestStream, TPingStream> @@ -85,7 +86,7 @@ where connectivity, outbound_messaging, event_publisher, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, config, } } @@ -100,39 +101,36 @@ where pin_mut!(request_stream); let mut ping_tick = match self.config.auto_ping_interval { - Some(interval) => Either::Left(time::interval_at((Instant::now() + interval).into(), interval)), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval_at( + (Instant::now() + interval).into(), + interval, + ))), None => Either::Right(futures::stream::iter(iter::empty())), - } - .fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("Liveness service initialized without shutdown signal"); + }; loop { - futures::select! { + tokio::select! { // Requests from the handle - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let _ = reply_tx.send(self.handle_request(request).await); }, // Tick events - _ = ping_tick.select_next_some() => { + Some(_) = ping_tick.next() => { if let Err(err) = self.start_ping_round().await { warn!(target: LOG_TARGET, "Error when pinging peers: {}", err); } }, // Incoming messages from the Comms layer - msg = ping_stream.select_next_some() => { + Some(msg) = ping_stream.next() => { if let Err(err) = self.handle_incoming_message(msg).await { warn!(target: LOG_TARGET, "Failed to handle incoming PingPong message: {}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Liveness service shutting down because the shutdown signal was received"); break; } @@ -306,10 +304,7 @@ mod test { proto::liveness::MetadataKey, services::liveness::{handle::LivenessHandle, state::Metadata}, }; - use futures::{ - channel::{mpsc, oneshot}, - stream, - }; + use futures::stream; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ @@ -325,9 +320,12 @@ mod test { use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; use tari_shutdown::Shutdown; - use tokio::{sync::broadcast, task}; + use tokio::{ + sync::{broadcast, mpsc, oneshot}, + task, + }; - #[tokio_macros::test_basic] + #[tokio::test] async fn get_ping_pong_count() { let mut state = LivenessState::new(); state.inc_pings_received(); @@ -369,7 +367,7 @@ mod test { assert_eq!(res, 2); } - #[tokio_macros::test] + #[tokio::test] async fn send_ping() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); @@ -401,8 +399,8 @@ mod test { let node_id = NodeId::from_key(&pk); // Receive outbound request task::spawn(async move { - match outbound_rx.select_next_some().await { - DhtOutboundRequest::SendMessage(_, _, reply_tx) => { + match outbound_rx.recv().await { + Some(DhtOutboundRequest::SendMessage(_, _, reply_tx)) => { let (_, rx) = oneshot::channel(); reply_tx .send(SendMessageResponse::Queued( @@ -410,6 +408,7 @@ mod test { )) .unwrap(); }, + None => {}, } }); @@ -445,7 +444,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn handle_message_ping() { let state = LivenessState::new(); @@ -478,10 +477,10 @@ mod test { task::spawn(service.run()); // Test oms got request to send message - unwrap_oms_send_msg!(outbound_rx.select_next_some().await); + unwrap_oms_send_msg!(outbound_rx.recv().await.unwrap()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_pong() { let mut state = LivenessState::new(); @@ -516,9 +515,9 @@ mod test { task::spawn(service.run()); // Listen for the pong event - let subscriber = publisher.subscribe(); + let mut subscriber = publisher.subscribe(); - let event = time::timeout(Duration::from_secs(10), subscriber.fuse().select_next_some()) + let event = time::timeout(Duration::from_secs(10), subscriber.recv()) .await .unwrap() .unwrap(); @@ -530,12 +529,12 @@ mod test { _ => panic!("Unexpected event"), } - shutdown.trigger().unwrap(); + shutdown.trigger(); // No further events (malicious_msg was ignored) - let mut subscriber = publisher.subscribe().fuse(); + let mut subscriber = publisher.subscribe(); drop(publisher); - let msg = subscriber.next().await; - assert!(msg.is_none()); + let msg = subscriber.recv().await; + assert!(msg.is_err()); } } diff --git a/base_layer/p2p/tests/services/liveness.rs b/base_layer/p2p/tests/services/liveness.rs index ab9a66cf59..2505c15543 100644 --- a/base_layer/p2p/tests/services/liveness.rs +++ b/base_layer/p2p/tests/services/liveness.rs @@ -35,16 +35,15 @@ use tari_p2p::{ }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tari_test_utils::collect_stream; +use tari_test_utils::collect_try_recv; use tempfile::tempdir; -use tokio::runtime; pub async fn setup_liveness_service( node_identity: Arc<NodeIdentity>, peers: Vec<Arc<NodeIdentity>>, data_path: &str, ) -> (LivenessHandle, CommsNode, Dht, Shutdown) { - let (publisher, subscription_factory) = pubsub_connector(runtime::Handle::current(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let shutdown = Shutdown::new(); let (comms, dht, _) = @@ -75,7 +74,7 @@ fn make_node_identity() -> Arc<NodeIdentity> { )) } -#[tokio_macros::test_basic] +#[tokio::test] async fn end_to_end() { let node_1_identity = make_node_identity(); let node_2_identity = make_node_identity(); @@ -114,34 +113,34 @@ async fn end_to_end() { liveness1.send_ping(node_2_identity.node_id().clone()).await.unwrap(); } - let events = collect_stream!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20),); + let events = collect_try_recv!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 10); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 8); - let events = collect_stream!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10),); + let events = collect_try_recv!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 8); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 10); diff --git a/base_layer/p2p/tests/support/comms_and_services.rs b/base_layer/p2p/tests/support/comms_and_services.rs index 40cc85a710..33bc8fdef7 100644 --- a/base_layer/p2p/tests/support/comms_and_services.rs +++ b/base_layer/p2p/tests/support/comms_and_services.rs @@ -20,27 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeIdentity, protocol::messaging::MessagingEventSender, CommsNode}; use tari_comms_dht::Dht; -use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, - initialization::initialize_local_test_comms, -}; +use tari_p2p::{comms_connector::InboundDomainConnector, initialization::initialize_local_test_comms}; use tari_shutdown::ShutdownSignal; -pub async fn setup_comms_services<TSink>( +pub async fn setup_comms_services( node_identity: Arc<NodeIdentity>, peers: Vec<Arc<NodeIdentity>>, - publisher: InboundDomainConnector<TSink>, + publisher: InboundDomainConnector, data_path: &str, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht, MessagingEventSender) -where - TSink: Sink<Arc<PeerMessage>> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, messaging_events) = initialize_local_test_comms( node_identity, diff --git a/base_layer/service_framework/Cargo.toml b/base_layer/service_framework/Cargo.toml index 49e829c949..ead5e311f2 100644 --- a/base_layer/service_framework/Cargo.toml +++ b/base_layer/service_framework/Cargo.toml @@ -14,15 +14,15 @@ tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } anyhow = "1.0.32" async-trait = "0.1.50" -futures = { version = "^0.3.1", features=["async-await"]} +futures = { version = "^0.3.16", features=["async-await"]} log = "0.4.8" -thiserror = "1.0.20" -tokio = { version = "0.2.10" } +thiserror = "1.0.26" +tokio = {version="1.10", features=["rt"]} tower-service = { version="0.3.0" } [dev-dependencies] tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tokio = {version="1.10", features=["rt-multi-thread", "macros", "time"]} futures-test = { version = "0.3.3" } -tokio-macros = "0.2.5" tower = "0.3.1" diff --git a/base_layer/service_framework/examples/services/service_a.rs b/base_layer/service_framework/examples/services/service_a.rs index c898696415..d7adcfdffb 100644 --- a/base_layer/service_framework/examples/services/service_a.rs +++ b/base_layer/service_framework/examples/services/service_a.rs @@ -69,7 +69,7 @@ impl ServiceA { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service A API Request"); diff --git a/base_layer/service_framework/examples/services/service_b.rs b/base_layer/service_framework/examples/services/service_b.rs index decf53ab14..7b483860af 100644 --- a/base_layer/service_framework/examples/services/service_b.rs +++ b/base_layer/service_framework/examples/services/service_b.rs @@ -31,7 +31,7 @@ use tari_service_framework::{ ServiceInitializerContext, }; use tari_shutdown::ShutdownSignal; -use tokio::time::delay_for; +use tokio::time::sleep; use tower::Service; pub struct ServiceB { @@ -67,7 +67,7 @@ impl ServiceB { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service B API Request"); @@ -134,7 +134,7 @@ impl ServiceInitializer for ServiceBInitializer { println!("Service B has shutdown and initializer spawned task is now ending"); }); - delay_for(Duration::from_secs(10)).await; + sleep(Duration::from_secs(10)).await; Ok(()) } } diff --git a/base_layer/service_framework/examples/stack_builder_example.rs b/base_layer/service_framework/examples/stack_builder_example.rs index 35fd785ed7..6f150796e7 100644 --- a/base_layer/service_framework/examples/stack_builder_example.rs +++ b/base_layer/service_framework/examples/stack_builder_example.rs @@ -25,9 +25,9 @@ use crate::services::{ServiceAHandle, ServiceAInitializer, ServiceBHandle, Servi use std::time::Duration; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; -use tokio::time::delay_for; +use tokio::time::sleep; -#[tokio_macros::main] +#[tokio::main] async fn main() { let mut shutdown = Shutdown::new(); let fut = StackBuilder::new(shutdown.to_signal()) @@ -40,7 +40,7 @@ async fn main() { let mut service_a_handle = handles.expect_handle::<ServiceAHandle>(); let mut service_b_handle = handles.expect_handle::<ServiceBHandle>(); - delay_for(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; println!("----------------------------------------------------"); let response_b = service_b_handle.send_msg("Hello B".to_string()).await; println!("Response from Service B: {}", response_b); @@ -51,5 +51,5 @@ async fn main() { let _ = shutdown.trigger(); - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } diff --git a/base_layer/service_framework/src/reply_channel.rs b/base_layer/service_framework/src/reply_channel.rs index 54cef8d90f..b0deffb190 100644 --- a/base_layer/service_framework/src/reply_channel.rs +++ b/base_layer/service_framework/src/reply_channel.rs @@ -20,26 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{ - mpsc::{self, SendError}, - oneshot, - }, - ready, - stream::FusedStream, - task::Context, - Future, - FutureExt, - Stream, - StreamExt, -}; +use futures::{ready, stream::FusedStream, task::Context, Future, FutureExt, Stream}; use std::{pin::Pin, task::Poll}; use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; use tower_service::Service; /// Create a new Requester/Responder pair which wraps and calls the given service pub fn unbounded<TReq, TResp>() -> (SenderService<TReq, TResp>, Receiver<TReq, TResp>) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (SenderService::new(tx), Receiver::new(rx)) } @@ -81,20 +70,15 @@ impl<TReq, TRes> Service<TReq> for SenderService<TReq, TRes> { type Future = TransportResponseFuture<TRes>; type Response = TRes; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - self.tx.poll_ready(cx).map_err(|err| { - if err.is_disconnected() { - return TransportChannelError::ChannelClosed; - } - - unreachable!("unbounded channels can never be full"); - }) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + // An unbounded sender is always ready (i.e. will never wait to send) + Poll::Ready(Ok(())) } fn call(&mut self, request: TReq) -> Self::Future { let (tx, rx) = oneshot::channel(); - if self.tx.unbounded_send((request, tx)).is_ok() { + if self.tx.send((request, tx)).is_ok() { TransportResponseFuture::new(rx) } else { // We're not able to send (rx closed) so return a future which resolves to @@ -106,8 +90,6 @@ impl<TReq, TRes> Service<TReq> for SenderService<TReq, TRes> { #[derive(Debug, Error, Eq, PartialEq, Clone)] pub enum TransportChannelError { - #[error("Error occurred when sending: `{0}`")] - SendError(#[from] SendError), #[error("Request was canceled")] Canceled, #[error("The response channel has closed")] @@ -188,23 +170,21 @@ impl<TReq, TResp> RequestContext<TReq, TResp> { } /// Receiver side of the reply channel. -/// This is functionally equivalent to `rx.map(|(req, reply_tx)| RequestContext::new(req, reply_tx))` -/// but is ergonomically better to use with the `futures::select` macro (implements FusedStream) -/// and has a short type signature. pub struct Receiver<TReq, TResp> { rx: Rx<TReq, TResp>, + is_closed: bool, } impl<TReq, TResp> FusedStream for Receiver<TReq, TResp> { fn is_terminated(&self) -> bool { - self.rx.is_terminated() + self.is_closed } } impl<TReq, TResp> Receiver<TReq, TResp> { // Create a new Responder pub fn new(rx: Rx<TReq, TResp>) -> Self { - Self { rx } + Self { rx, is_closed: false } } pub fn close(&mut self) { @@ -216,10 +196,17 @@ impl<TReq, TResp> Stream for Receiver<TReq, TResp> { type Item = RequestContext<TReq, TResp>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match ready!(self.rx.poll_next_unpin(cx)) { + if self.is_terminated() { + return Poll::Ready(None); + } + + match ready!(self.rx.poll_recv(cx)) { Some((req, tx)) => Poll::Ready(Some(RequestContext::new(req, tx))), // Stream has closed, so we're done - None => Poll::Ready(None), + None => { + self.is_closed = true; + Poll::Ready(None) + }, } } } diff --git a/base_layer/service_framework/src/stack.rs b/base_layer/service_framework/src/stack.rs index 2489006d56..ad38a94be3 100644 --- a/base_layer/service_framework/src/stack.rs +++ b/base_layer/service_framework/src/stack.rs @@ -103,7 +103,7 @@ mod test { use tari_shutdown::Shutdown; use tower::service_fn; - #[tokio_macros::test] + #[tokio::test] async fn service_defn_simple() { // This is less of a test and more of a demo of using the short-hand implementation of ServiceInitializer let simple_initializer = |_: ServiceInitializerContext| Ok(()); @@ -155,7 +155,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn service_stack_new() { let shared_state = Arc::new(AtomicUsize::new(0)); diff --git a/base_layer/tari_stratum_ffi/Cargo.toml b/base_layer/tari_stratum_ffi/Cargo.toml index 6598df9f06..9ed773d89f 100644 --- a/base_layer/tari_stratum_ffi/Cargo.toml +++ b/base_layer/tari_stratum_ffi/Cargo.toml @@ -14,7 +14,7 @@ tari_app_grpc = { path = "../../applications/tari_app_grpc" } tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} tari_utilities = "^0.3" libc = "0.2.65" -thiserror = "1.0.20" +thiserror = "1.0.26" hex = "0.4.2" serde = { version="1.0.106", features = ["derive"] } serde_json = "1.0.57" diff --git a/base_layer/wallet/Cargo.toml b/base_layer/wallet/Cargo.toml index c2cf12ca22..b64cb34ab2 100644 --- a/base_layer/wallet/Cargo.toml +++ b/base_layer/wallet/Cargo.toml @@ -7,54 +7,53 @@ version = "0.9.5" edition = "2018" [dependencies] -tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} -tari_comms = { version = "^0.9", path = "../../comms"} +tari_common_types = { version = "^0.9", path = "../../base_layer/common_types" } +tari_comms = { version = "^0.9", path = "../../comms" } tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } -tari_service_framework = { version = "^0.9", path = "../service_framework"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } aes-gcm = "^0.8" blake2 = "0.9.0" -chrono = { version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } crossbeam-channel = "0.3.8" digest = "0.9.0" -diesel = { version="1.4.7", features = ["sqlite", "serde_json", "chrono"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } fs2 = "0.3.0" -futures = { version = "^0.3.1", features =["compat", "std"]} +futures = { version = "^0.3.1", features = ["compat", "std"] } lazy_static = "1.4.0" log = "0.4.6" -log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} +log4rs = { version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"] } lmdb-zero = "0.4.4" rand = "0.8" -serde = {version = "1.0.89", features = ["derive"] } +serde = { version = "1.0.89", features = ["derive"] } serde_json = "1.0.39" -tokio = { version = "0.2.10", features = ["blocking", "sync"]} +tokio = { version = "1.10", features = ["sync", "macros"] } tower = "0.3.0-alpha.2" tempfile = "3.1.0" -time = {version = "0.1.39"} -thiserror = "1.0.20" +time = { version = "0.1.39" } +thiserror = "1.0.26" bincode = "1.3.1" [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" default-features = false -features = ["transactions", "mempool_proto", "base_node_proto",] +features = ["transactions", "mempool_proto", "base_node_proto", ] [dev-dependencies] -tari_p2p = { version = "^0.9", path = "../p2p", features=["test-mocks"]} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features=["test-mocks"]} +tari_p2p = { version = "^0.9", path = "../p2p", features = ["test-mocks"] } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features = ["test-mocks"] } tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } lazy_static = "1.3.0" env_logger = "0.7.1" -prost = "0.6.1" -tokio-macros = "0.2.4" +prost = "0.8.0" [features] c_integration = [] diff --git a/base_layer/wallet/src/base_node_service/handle.rs b/base_layer/wallet/src/base_node_service/handle.rs index 4957823c72..f495479778 100644 --- a/base_layer/wallet/src/base_node_service/handle.rs +++ b/base_layer/wallet/src/base_node_service/handle.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::BaseNodeServiceError, service::BaseNodeState}; -use futures::{stream::Fuse, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::peer_manager::Peer; @@ -72,8 +71,8 @@ impl BaseNodeServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse<BaseNodeEventReceiver> { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> BaseNodeEventReceiver { + self.event_stream_sender.subscribe() } pub async fn get_chain_metadata(&mut self) -> Result<Option<ChainMetadata>, BaseNodeServiceError> { diff --git a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs index 1bc57ed9d2..9aa981150d 100644 --- a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs +++ b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs @@ -20,13 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::{ - error::BaseNodeServiceError, - handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, - service::BaseNodeState, - }, - connectivity_service::OnlineStatus, +use crate::base_node_service::{ + error::BaseNodeServiceError, + handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, + service::BaseNodeState, }; use futures::StreamExt; use tari_common_types::chain_metadata::ChainMetadata; @@ -81,30 +78,28 @@ impl MockBaseNodeService { /// Set the mock server state, either online and synced to a specific height, or offline with None pub fn set_base_node_state(&mut self, height: Option<u64>) { - let (chain_metadata, is_synced, online) = match height { + let (chain_metadata, is_synced) = match height { Some(height) => { let metadata = ChainMetadata::new(height, Vec::new(), 0, 0, 0); - (Some(metadata), Some(true), OnlineStatus::Online) + (Some(metadata), Some(true)) }, - None => (None, None, OnlineStatus::Offline), + None => (None, None), }; self.state = BaseNodeState { chain_metadata, is_synced, updated: None, latency: None, - online, } } pub fn set_default_base_node_state(&mut self) { - let metadata = ChainMetadata::new(std::u64::MAX, Vec::new(), 0, 0, 0); + let metadata = ChainMetadata::new(u64::MAX, Vec::new(), 0, 0, 0); self.state = BaseNodeState { chain_metadata: Some(metadata), is_synced: Some(true), updated: None, latency: None, - online: OnlineStatus::Online, } } diff --git a/base_layer/wallet/src/base_node_service/monitor.rs b/base_layer/wallet/src/base_node_service/monitor.rs index 5a2c3a7e76..8e0298ca27 100644 --- a/base_layer/wallet/src/base_node_service/monitor.rs +++ b/base_layer/wallet/src/base_node_service/monitor.rs @@ -25,7 +25,7 @@ use crate::{ handle::{BaseNodeEvent, BaseNodeEventSender}, service::BaseNodeState, }, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, error::WalletStorageError, storage::database::{WalletBackend, WalletDatabase}, }; @@ -33,7 +33,7 @@ use chrono::Utc; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; -use tari_comms::{peer_manager::NodeId, protocol::rpc::RpcError}; +use tari_comms::protocol::rpc::RpcError; use tokio::{sync::RwLock, time}; const LOG_TARGET: &str = "wallet::base_node_service::chain_metadata_monitor"; @@ -78,9 +78,6 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> { }, Err(e @ BaseNodeMonitorError::RpcFailed(_)) => { warn!(target: LOG_TARGET, "Connectivity failure to base node: {}", e); - debug!(target: LOG_TARGET, "Setting as OFFLINE and retrying...",); - - self.set_offline().await; continue; }, Err(e @ BaseNodeMonitorError::InvalidBaseNodeResponse(_)) | @@ -96,34 +93,19 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> { ); } - async fn update_connectivity_status(&self) -> NodeId { - let mut watcher = self.wallet_connectivity.get_connectivity_status_watch(); - loop { - use OnlineStatus::*; - match watcher.recv().await.unwrap_or(Offline) { - Online => match self.wallet_connectivity.get_current_base_node_id() { - Some(node_id) => return node_id, - _ => continue, - }, - Connecting => { - self.set_connecting().await; - }, - Offline => { - self.set_offline().await; - }, - } - } - } - async fn monitor_node(&mut self) -> Result<(), BaseNodeMonitorError> { loop { - let peer_node_id = self.update_connectivity_status().await; let mut client = self .wallet_connectivity .obtain_base_node_wallet_rpc_client() .await .ok_or(BaseNodeMonitorError::NodeShuttingDown)?; + let base_node_id = match self.wallet_connectivity.get_current_base_node_id() { + Some(n) => n, + None => continue, + }; + let tip_info = client.get_tip_info().await?; let chain_metadata = tip_info @@ -138,7 +120,7 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> { debug!( target: LOG_TARGET, "Base node {} Tip: {} ({}) Latency: {} ms", - peer_node_id, + base_node_id, chain_metadata.height_of_longest_chain(), if is_synced { "Synced" } else { "Syncing..." }, latency.as_millis() @@ -151,11 +133,10 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> { is_synced: Some(is_synced), updated: Some(Utc::now().naive_utc()), latency: Some(latency), - online: OnlineStatus::Online, }) .await; - time::delay_for(self.interval).await + time::sleep(self.interval).await } // loop only exits on shutdown/error @@ -163,28 +144,6 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> { Ok(()) } - async fn set_connecting(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Connecting, - }) - .await; - } - - async fn set_offline(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Offline, - }) - .await; - } - async fn map_state<F>(&self, transform: F) where F: FnOnce(&BaseNodeState) -> BaseNodeState { let new_state = { diff --git a/base_layer/wallet/src/base_node_service/service.rs b/base_layer/wallet/src/base_node_service/service.rs index 3da987c8b1..eb2b91ebda 100644 --- a/base_layer/wallet/src/base_node_service/service.rs +++ b/base_layer/wallet/src/base_node_service/service.rs @@ -27,7 +27,7 @@ use super::{ }; use crate::{ base_node_service::monitor::BaseNodeMonitor, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, storage::database::{WalletBackend, WalletDatabase}, }; use chrono::NaiveDateTime; @@ -49,8 +49,6 @@ pub struct BaseNodeState { pub is_synced: Option<bool>, pub updated: Option<NaiveDateTime>, pub latency: Option<Duration>, - pub online: OnlineStatus, - // pub base_node_peer: Option<Peer>, } impl Default for BaseNodeState { @@ -60,7 +58,6 @@ impl Default for BaseNodeState { is_synced: None, updated: None, latency: None, - online: OnlineStatus::Connecting, } } } diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index ac218edc5e..5a35696e14 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -22,16 +22,12 @@ use super::service::OnlineStatus; use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use tari_comms::{ peer_manager::{NodeId, Peer}, protocol::rpc::RpcClientLease, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::sync::watch; +use tokio::sync::{mpsc, oneshot, watch}; pub enum WalletConnectivityRequest { ObtainBaseNodeWalletRpcClient(oneshot::Sender<RpcClientLease<BaseNodeWalletRpcClient>>), @@ -102,8 +98,8 @@ impl WalletConnectivityHandle { reply_rx.await.ok() } - pub async fn get_connectivity_status(&mut self) -> OnlineStatus { - self.online_status_rx.recv().await.unwrap_or(OnlineStatus::Offline) + pub fn get_connectivity_status(&mut self) -> OnlineStatus { + *self.online_status_rx.borrow() } pub fn get_connectivity_status_watch(&self) -> watch::Receiver<OnlineStatus> { diff --git a/base_layer/wallet/src/connectivity_service/initializer.rs b/base_layer/wallet/src/connectivity_service/initializer.rs index d0c2b94126..1610a834e3 100644 --- a/base_layer/wallet/src/connectivity_service/initializer.rs +++ b/base_layer/wallet/src/connectivity_service/initializer.rs @@ -30,8 +30,8 @@ use super::{handle::WalletConnectivityHandle, service::WalletConnectivityService, watch::Watch}; use crate::{base_node_service::config::BaseNodeServiceConfig, connectivity_service::service::OnlineStatus}; -use futures::channel::mpsc; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::sync::mpsc; pub struct WalletConnectivityInitializer { config: BaseNodeServiceConfig, @@ -59,8 +59,13 @@ impl ServiceInitializer for WalletConnectivityInitializer { context.spawn_until_shutdown(move |handles| { let connectivity = handles.expect_handle(); - let service = - WalletConnectivityService::new(config, receiver, base_node_watch, online_status_watch, connectivity); + let service = WalletConnectivityService::new( + config, + receiver, + base_node_watch.get_receiver(), + online_status_watch, + connectivity, + ); service.start() }); diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index c0cf474b96..ba3ff94c8b 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -24,15 +24,8 @@ use crate::{ base_node_service::config::BaseNodeServiceConfig, connectivity_service::{error::WalletConnectivityError, handle::WalletConnectivityRequest, watch::Watch}, }; -use core::mem; -use futures::{ - channel::{mpsc, oneshot}, - future, - future::Either, - stream::Fuse, - StreamExt, -}; use log::*; +use std::{mem, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer}, @@ -40,7 +33,11 @@ use tari_comms::{ PeerConnection, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; const LOG_TARGET: &str = "wallet::connectivity"; @@ -54,9 +51,9 @@ pub enum OnlineStatus { pub struct WalletConnectivityService { config: BaseNodeServiceConfig, - request_stream: Fuse<mpsc::Receiver<WalletConnectivityRequest>>, + request_stream: mpsc::Receiver<WalletConnectivityRequest>, connectivity: ConnectivityRequester, - base_node_watch: Watch<Option<Peer>>, + base_node_watch: watch::Receiver<Option<Peer>>, pools: Option<ClientPoolContainer>, online_status_watch: Watch<OnlineStatus>, pending_requests: Vec<ReplyOneshot>, @@ -71,13 +68,13 @@ impl WalletConnectivityService { pub(super) fn new( config: BaseNodeServiceConfig, request_stream: mpsc::Receiver<WalletConnectivityRequest>, - base_node_watch: Watch<Option<Peer>>, + base_node_watch: watch::Receiver<Option<Peer>>, online_status_watch: Watch<OnlineStatus>, connectivity: ConnectivityRequester, ) -> Self { Self { config, - request_stream: request_stream.fuse(), + request_stream, connectivity, base_node_watch, pools: None, @@ -88,22 +85,40 @@ impl WalletConnectivityService { pub async fn start(mut self) { debug!(target: LOG_TARGET, "Wallet connectivity service has started."); - let mut base_node_watch_rx = self.base_node_watch.get_receiver().fuse(); + let mut check_connection = + time::interval_at(time::Instant::now() + Duration::from_secs(5), Duration::from_secs(5)); + check_connection.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - req = self.request_stream.select_next_some() => { - self.handle_request(req).await; - }, - maybe_peer = base_node_watch_rx.select_next_some() => { - if maybe_peer.is_some() { + tokio::select! { + biased; + + Ok(_) = self.base_node_watch.changed() => { + if self.base_node_watch.borrow().is_some() { // This will block the rest until the connection is established. This is what we want. self.setup_base_node_connection().await; } + }, + + Some(req) = self.request_stream.recv() => { + self.handle_request(req).await; + }, + + _ = check_connection.tick() => { + self.check_connection().await; } } } } + async fn check_connection(&mut self) { + if let Some(pool) = self.pools.as_ref() { + if !pool.base_node_wallet_rpc_client.is_connected().await { + debug!(target: LOG_TARGET, "Peer connection lost. Attempting to reconnect..."); + self.setup_base_node_connection().await; + } + } + } + async fn handle_request(&mut self, request: WalletConnectivityRequest) { use WalletConnectivityRequest::*; match request { @@ -138,7 +153,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -169,7 +183,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -186,21 +199,6 @@ impl WalletConnectivityService { } } - fn trigger_reconnect(&mut self) { - let peer = self - .base_node_watch - .borrow() - .clone() - .expect("trigger_reconnect called before base node is set"); - // Trigger the watch so that a peer connection is reinitiated - self.set_base_node_peer(peer); - } - - fn set_base_node_peer(&mut self, peer: Peer) { - self.pools = None; - self.base_node_watch.broadcast(Some(peer)); - } - fn current_base_node(&self) -> Option<NodeId> { self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone()) } @@ -236,8 +234,8 @@ impl WalletConnectivityService { } else { self.set_online_status(OnlineStatus::Offline); } - error!(target: LOG_TARGET, "{}", e); - time::delay_for(self.config.base_node_monitor_refresh_interval).await; + warn!(target: LOG_TARGET, "{}", e); + time::sleep(self.config.base_node_monitor_refresh_interval).await; continue; }, } @@ -275,13 +273,15 @@ impl WalletConnectivityService { } async fn try_dial_peer(&mut self, peer: NodeId) -> Result<Option<PeerConnection>, WalletConnectivityError> { - let recv_fut = self.base_node_watch.recv(); - futures::pin_mut!(recv_fut); - let dial_fut = self.connectivity.dial_peer(peer); - futures::pin_mut!(dial_fut); - match future::select(recv_fut, dial_fut).await { - Either::Left(_) => Ok(None), - Either::Right((conn, _)) => Ok(Some(conn?)), + tokio::select! { + biased; + + _ = self.base_node_watch.changed() => { + Ok(None) + } + result = self.connectivity.dial_peer(peer) => { + Ok(Some(result?)) + } } } @@ -307,8 +307,8 @@ impl ReplyOneshot { pub fn is_canceled(&self) -> bool { use ReplyOneshot::*; match self { - WalletRpc(tx) => tx.is_canceled(), - SyncRpc(tx) => tx.is_canceled(), + WalletRpc(tx) => tx.is_closed(), + SyncRpc(tx) => tx.is_closed(), } } } diff --git a/base_layer/wallet/src/connectivity_service/test.rs b/base_layer/wallet/src/connectivity_service/test.rs index 7c24ef5b46..36495e2be9 100644 --- a/base_layer/wallet/src/connectivity_service/test.rs +++ b/base_layer/wallet/src/connectivity_service/test.rs @@ -70,7 +70,7 @@ async fn setup() -> ( (handle, mock_server, mock_state, shutdown) } -#[tokio_macros::test] +#[tokio::test] async fn it_dials_peer_when_base_node_is_set() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -92,7 +92,7 @@ async fn it_dials_peer_when_base_node_is_set() { assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_resolves_many_pending_rpc_session_requests() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -122,7 +122,7 @@ async fn it_resolves_many_pending_rpc_session_requests() { assert!(results.into_iter().map(Result::unwrap).all(convert::identity)); } -#[tokio_macros::test] +#[tokio::test] async fn it_changes_to_a_new_base_node() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -155,7 +155,7 @@ async fn it_changes_to_a_new_base_node() { assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_connect_fail_reconnect() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -198,7 +198,7 @@ async fn it_gracefully_handles_connect_fail_reconnect() { pending_request.await.unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_multiple_connection_failures() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/src/connectivity_service/watch.rs b/base_layer/wallet/src/connectivity_service/watch.rs index 1f1e868d47..4669b355f6 100644 --- a/base_layer/wallet/src/connectivity_service/watch.rs +++ b/base_layer/wallet/src/connectivity_service/watch.rs @@ -26,24 +26,19 @@ use tokio::sync::watch; #[derive(Clone)] pub struct Watch<T>(Arc<watch::Sender<T>>, watch::Receiver<T>); -impl<T: Clone> Watch<T> { +impl<T> Watch<T> { pub fn new(initial: T) -> Self { let (tx, rx) = watch::channel(initial); Self(Arc::new(tx), rx) } - #[allow(dead_code)] - pub async fn recv(&mut self) -> Option<T> { - self.receiver_mut().recv().await - } - pub fn borrow(&self) -> watch::Ref<'_, T> { self.receiver().borrow() } pub fn broadcast(&self, item: T) { - // SAFETY: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime - if self.sender().broadcast(item).is_err() { + // PANIC: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime + if self.sender().send(item).is_err() { // Result::expect requires E: fmt::Debug and `watch::SendError<T>` is not, this is equivalent panic!("watch internal receiver is dropped"); } @@ -53,10 +48,6 @@ impl<T: Clone> Watch<T> { &self.0 } - fn receiver_mut(&mut self) -> &mut watch::Receiver<T> { - &mut self.1 - } - pub fn receiver(&self) -> &watch::Receiver<T> { &self.1 } diff --git a/base_layer/wallet/src/contacts_service/service.rs b/base_layer/wallet/src/contacts_service/service.rs index f86f6b49cc..cfc8473202 100644 --- a/base_layer/wallet/src/contacts_service/service.rs +++ b/base_layer/wallet/src/contacts_service/service.rs @@ -76,8 +76,8 @@ where T: ContactsBackend + 'static info!(target: LOG_TARGET, "Contacts Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { + tokio::select! { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { error!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -88,14 +88,10 @@ where T: ContactsBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Contacts service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Contacts service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Contacts Service ended"); diff --git a/base_layer/wallet/src/lib.rs b/base_layer/wallet/src/lib.rs index bc0b4c1a04..22bce8bfdb 100644 --- a/base_layer/wallet/src/lib.rs +++ b/base_layer/wallet/src/lib.rs @@ -25,8 +25,6 @@ pub mod wallet; extern crate diesel; #[macro_use] extern crate diesel_migrations; -#[macro_use] -extern crate lazy_static; mod config; pub mod schema; diff --git a/base_layer/wallet/src/output_manager_service/handle.rs b/base_layer/wallet/src/output_manager_service/handle.rs index 659fab4a42..183f46ab4e 100644 --- a/base_layer/wallet/src/output_manager_service/handle.rs +++ b/base_layer/wallet/src/output_manager_service/handle.rs @@ -31,7 +31,6 @@ use crate::{ types::ValidationRetryStrategy, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc, time::Duration}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ @@ -191,8 +190,8 @@ impl OutputManagerHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse<OutputManagerEventReceiver> { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> OutputManagerEventReceiver { + self.event_stream_sender.subscribe() } pub async fn add_output(&mut self, output: UnblindedOutput) -> Result<(), OutputManagerError> { diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index c8946b9ff9..77687c8845 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -166,9 +166,9 @@ where TBackend: OutputManagerBackend + 'static info!(target: LOG_TARGET, "Output Manager Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Service API Request"); + tokio::select! { + Some(request_context) = request_stream.next() => { + trace!(target: LOG_TARGET, "Handling Service API Request"); let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { warn!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -179,14 +179,10 @@ where TBackend: OutputManagerBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Output manager service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Output manager service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Output Manager Service ended"); diff --git a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs index e3e022e4fb..de0e0e3c17 100644 --- a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs +++ b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs @@ -30,7 +30,7 @@ use crate::{ transaction_service::storage::models::TransactionStatus, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, collections::HashMap, convert::TryFrom, fmt, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -40,7 +40,7 @@ use tari_core::{ transactions::{transaction::TransactionOutput, types::Signature}, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::output_manager_service::utxo_validation_task"; @@ -87,21 +87,17 @@ where TBackend: OutputManagerBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result<u64, OutputManagerProtocolError> { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - OutputManagerProtocolError::new( - self.id, - OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), - ) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + OutputManagerProtocolError::new( + self.id, + OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), + ) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); let total_retries_str = match self.retry_strategy { - ValidationRetryStrategy::Limited(n) => format!("{}", n), + ValidationRetryStrategy::Limited(n) => n.to_string(), ValidationRetryStrategy::UntilSuccess => "∞".to_string(), }; @@ -180,14 +176,14 @@ where TBackend: OutputManagerBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key.clone()); let mut connection: Option<PeerConnection> = None; - let delay = delay_for(self.resources.config.peer_dial_retry_timeout); + let delay = sleep(self.resources.config.peer_dial_retry_timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -197,7 +193,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!( @@ -228,7 +224,7 @@ where TBackend: OutputManagerBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -236,7 +232,7 @@ where TBackend: OutputManagerBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -253,7 +249,7 @@ where TBackend: OutputManagerBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -294,9 +290,9 @@ where TBackend: OutputManagerBackend + 'static batch_num, batch_total ); - let delay = delay_for(self.retry_delay); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.retry_delay); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_bn) => { info!(target: LOG_TARGET, "TXO Validation protocol aborted due to Base Node Public key change" ); @@ -323,7 +319,7 @@ where TBackend: OutputManagerBackend + 'static } } }, - result = self.send_query_batch(batch.clone(), &mut client).fuse() => { + result = self.send_query_batch(batch.clone(), &mut client) => { match result { Ok(synced) => { self.base_node_synced = synced; @@ -374,7 +370,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index c197dd2024..05e9e4af2b 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -34,7 +34,7 @@ use tari_p2p::services::liveness::error::LivenessError; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; use time::OutOfRangeError; -use tokio::sync::broadcast::RecvError; +use tokio::sync::broadcast::error::RecvError; #[derive(Debug, Error)] pub enum TransactionServiceError { diff --git a/base_layer/wallet/src/transaction_service/handle.rs b/base_layer/wallet/src/transaction_service/handle.rs index f34a5f667f..7a6ab48649 100644 --- a/base_layer/wallet/src/transaction_service/handle.rs +++ b/base_layer/wallet/src/transaction_service/handle.rs @@ -28,7 +28,6 @@ use crate::{ }, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; @@ -187,8 +186,8 @@ impl TransactionServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse<TransactionEventReceiver> { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> TransactionEventReceiver { + self.event_stream_sender.subscribe() } pub async fn send_transaction( diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs index 05191c8f8d..e2b0d7b871 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -44,7 +44,7 @@ use tari_core::{ transactions::{transaction::Transaction, types::Signature}, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::broadcast_protocol"; @@ -86,21 +86,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result<u64, TransactionServiceProtocolError> { - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); // Main protocol loop @@ -108,14 +100,14 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option<PeerConnection> = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -139,7 +131,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -158,7 +150,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -179,7 +171,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -187,11 +179,11 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -243,10 +235,10 @@ where TBackend: TransactionBackend + 'static }, }; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -315,7 +307,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -332,7 +324,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs index 65fb58f601..d072eb77a8 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs @@ -29,7 +29,7 @@ use crate::{ storage::{database::TransactionBackend, models::CompletedTransaction}, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -41,7 +41,7 @@ use tari_core::{ transactions::types::Signature, }; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::coinbase_monitoring"; @@ -86,21 +86,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result<u64, TransactionServiceProtocolError> { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; trace!( target: LOG_TARGET, @@ -173,7 +165,7 @@ where TBackend: TransactionBackend + 'static self.base_node_public_key, self.tx_id, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -203,7 +195,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -225,7 +217,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -248,7 +240,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring protocol (TxId: {}) shutting down because it received the shutdown \ @@ -259,14 +251,14 @@ where TBackend: TransactionBackend + 'static }, } - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the \ @@ -314,10 +306,10 @@ where TBackend: TransactionBackend + 'static TransactionServiceError::InvalidCompletedTransaction, )); } - let delay = delay_for(self.timeout).fuse(); + let delay = sleep(self.timeout).fuse(); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -392,7 +384,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -411,7 +403,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the shutdown \ diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs index 744bb6f1fc..058d2c4f74 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs @@ -34,21 +34,18 @@ use crate::{ }, }; use chrono::Utc; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; +use futures::future::FutureExt; use log::*; use std::sync::Arc; use tari_comms::types::CommsPublicKey; +use tokio::sync::{mpsc, oneshot}; use tari_core::transactions::{ transaction::Transaction, transaction_protocol::{recipient::RecipientState, sender::TransactionSenderMessage}, }; use tari_crypto::tari_utilities::Hashable; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "wallet::transaction_service::protocols::receive_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::receive_protocol"; @@ -263,7 +260,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match inbound_tx.last_send_timestamp { @@ -310,9 +308,9 @@ where TBackend: TransactionBackend + 'static let mut incoming_finalized_transaction = None; loop { loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, tx_id, tx) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, tx_id, tx)) = receiver.recv() => { incoming_finalized_transaction = Some(tx); if inbound_tx.source_public_key != spk { warn!( @@ -325,16 +323,14 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { - if result.is_ok() { - info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); - return Err(TransactionServiceProtocolError::new( - self.id, - TransactionServiceError::TransactionCancelled, - )); - } + Ok(_) = &mut cancellation_receiver => { + info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + )); }, - () = resend_timeout => { + _ = resend_timeout => { match send_transaction_reply( inbound_tx.clone(), self.resources.outbound_message_service.clone(), @@ -353,10 +349,10 @@ where TBackend: TransactionBackend + 'static ), } }, - () = timeout_delay => { + _ = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Receive Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index cf30a7f928..e88c9013e6 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -20,12 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; - -use chrono::Utc; -use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; -use log::*; - use crate::transaction_service::{ config::TransactionRoutingMechanism, error::{TransactionServiceError, TransactionServiceProtocolError}, @@ -41,7 +35,10 @@ use crate::transaction_service::{ wait_on_dial::wait_on_dial, }, }; -use futures::channel::oneshot; +use chrono::Utc; +use futures::FutureExt; +use log::*; +use std::sync::Arc; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_comms_dht::{ domain_message::OutboundDomainMessage, @@ -55,7 +52,10 @@ use tari_core::transactions::{ }; use tari_crypto::script; use tari_p2p::tari_message::TariMessageType; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc::Receiver, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::send_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::send_protocol"; @@ -344,7 +344,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match outbound_tx.last_send_timestamp { @@ -390,9 +391,9 @@ where TBackend: TransactionBackend + 'static #[allow(unused_assignments)] let mut reply = None; loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, rr) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, rr)) = receiver.recv() => { let rr_tx_id = rr.tx_id; reply = Some(rr); @@ -407,7 +408,7 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { + result = &mut cancellation_receiver => { if result.is_ok() { info!(target: LOG_TARGET, "Cancelling Transaction Send Protocol (TxId: {})", self.id); let _ = send_transaction_cancelled_message(self.id,self.dest_pubkey.clone(), self.resources.outbound_message_service.clone(), ).await.map_err(|e| { @@ -441,10 +442,10 @@ where TBackend: TransactionBackend + 'static .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; } }, - () = timeout_delay => { + () = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Send Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs index d0b2f7f6ac..dcf072c272 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -43,7 +43,7 @@ use tari_core::{ }, proto::{base_node::Signatures as SignaturesProto, types::Signature as SignatureProto}, }; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::validation_protocol"; @@ -94,14 +94,12 @@ where TBackend: TransactionBackend + 'static let mut timeout_update_receiver = self .timeout_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut base_node_update_receiver = self .base_node_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut shutdown = self.resources.shutdown_signal.clone(); @@ -158,13 +156,13 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option<PeerConnection> = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -175,7 +173,7 @@ where TBackend: TransactionBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { @@ -204,7 +202,7 @@ where TBackend: TransactionBackend + 'static } } } - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -223,7 +221,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -231,7 +229,7 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -248,7 +246,7 @@ where TBackend: TransactionBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -282,9 +280,9 @@ where TBackend: TransactionBackend + 'static } else { break 'main; }; - let delay = delay_for(self.timeout); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.timeout); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!(target: LOG_TARGET, "Aborting Transaction Validation Protocol as new Base node is set"); @@ -372,7 +370,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -391,7 +389,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index c30fb7f412..c768a085ed 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -47,14 +47,7 @@ use crate::{ }; use chrono::{NaiveDateTime, Utc}; use digest::Digest; -use futures::{ - channel::{mpsc, mpsc::Sender, oneshot}, - pin_mut, - stream::FuturesUnordered, - SinkExt, - Stream, - StreamExt, -}; +use futures::{pin_mut, stream::FuturesUnordered, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -85,7 +78,10 @@ use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteA use tari_p2p::domain_message::DomainMessage; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle}; +use tokio::{ + sync::{broadcast, mpsc, mpsc::Sender, oneshot}, + task::JoinHandle, +}; const LOG_TARGET: &str = "wallet::transaction_service::service"; @@ -276,9 +272,9 @@ where info!(target: LOG_TARGET, "Transaction Service started"); loop { - futures::select! { + tokio::select! { //Incoming request - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (request, reply_tx) = request_context.split(); @@ -303,7 +299,7 @@ where ); }, // Incoming Transaction messages from the Comms layer - msg = transaction_stream.select_next_some() => { + Some(msg) = transaction_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -333,7 +329,7 @@ where ); }, // Incoming Transaction Reply messages from the Comms layer - msg = transaction_reply_stream.select_next_some() => { + Some(msg) = transaction_reply_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -364,7 +360,7 @@ where ); }, // Incoming Finalized Transaction messages from the Comms layer - msg = transaction_finalized_stream.select_next_some() => { + Some(msg) = transaction_finalized_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -402,7 +398,7 @@ where ); }, // Incoming messages from the Comms layer - msg = base_node_response_stream.select_next_some() => { + Some(msg) = base_node_response_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -421,7 +417,7 @@ where ); } // Incoming messages from the Comms layer - msg = transaction_cancelled_stream.select_next_some() => { + Some(msg) = transaction_cancelled_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -436,7 +432,7 @@ where finish.duration_since(start).as_millis(), ); } - join_result = send_transaction_protocol_handles.select_next_some() => { + Some(join_result) = send_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Send Protocol for Transaction has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_send_transaction_protocol( @@ -446,7 +442,7 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = receive_transaction_protocol_handles.select_next_some() => { + Some(join_result) = receive_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Receive Transaction Protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_receive_transaction_protocol( @@ -456,14 +452,14 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = transaction_broadcast_protocol_handles.select_next_some() => { + Some(join_result) = transaction_broadcast_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Broadcast protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_broadcast_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Broadcast Protocol: {:?}", e), }; } - join_result = coinbase_transaction_monitoring_protocol_handles.select_next_some() => { + Some(join_result) = coinbase_transaction_monitoring_protocol_handles.next() => { trace!(target: LOG_TARGET, "Coinbase transaction monitoring protocol has ended with result {:?}", join_result); match join_result { @@ -471,21 +467,17 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Coinbase Monitoring protocol: {:?}", e), }; } - join_result = transaction_validation_protocol_handles.select_next_some() => { + Some(join_result) = transaction_validation_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Validation protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_validation_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Transaction Validation protocol: {:?}", e), }; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Transaction service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Transaction service shut down"); diff --git a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs index 61b3ee7d75..522bacdbb9 100644 --- a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs +++ b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs @@ -27,8 +27,8 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::StreamExt; use log::*; +use tokio::sync::broadcast; const LOG_TARGET: &str = "wallet::transaction_service::tasks::start_tx_validation_and_broadcast"; @@ -36,16 +36,16 @@ pub async fn start_transaction_validation_and_broadcast_protocols( mut handle: TransactionServiceHandle, retry_strategy: ValidationRetryStrategy, ) -> Result<(), TransactionServiceError> { - let mut event_stream = handle.get_event_stream_fused(); + let mut event_stream = handle.get_event_stream(); let our_id = handle.validate_transactions(retry_strategy).await?; // Now that its started we will spawn an task to monitor the event bus and when its successful we will start the // Broadcast protocols tokio::spawn(async move { - while let Some(event_item) = event_stream.next().await { - if let Ok(event) = event_item { - match (*event).clone() { + loop { + match event_stream.recv().await { + Ok(event) => match &*event { TransactionEvent::TransactionValidationSuccess(_id) => { info!( target: LOG_TARGET, @@ -59,19 +59,28 @@ pub async fn start_transaction_validation_and_broadcast_protocols( } }, TransactionEvent::TransactionValidationFailure(id) => { - if our_id == id { + if our_id == *id { error!(target: LOG_TARGET, "Transaction Validation failed!"); break; } }, _ => (), - } - } else { - warn!( - target: LOG_TARGET, - "Error reading from Transaction Service Event Stream" - ); - break; + }, + Err(e @ broadcast::error::RecvError::Lagged(_)) => { + warn!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols: {}", e + ); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { + debug!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols is exiting because the event stream \ + closed", + ); + break; + }, } } }); diff --git a/base_layer/wallet/src/util/mod.rs b/base_layer/wallet/src/util/mod.rs index 9664a0e376..7217ac5056 100644 --- a/base_layer/wallet/src/util/mod.rs +++ b/base_layer/wallet/src/util/mod.rs @@ -20,6 +20,4 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod emoji; pub mod encryption; -pub mod luhn; diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs index 6aaa38363f..341d853ad5 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs @@ -69,7 +69,7 @@ use tari_core::{ }; use tari_service_framework::{reply_channel, reply_channel::SenderService}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task, time}; +use tokio::{sync::broadcast, task, time, time::MissedTickBehavior}; pub const LOG_TARGET: &str = "wallet::utxo_scanning"; @@ -715,35 +715,23 @@ where TBackend: WalletBackend + 'static let mut shutdown = self.shutdown_signal.clone(); let start_at = Instant::now() + Duration::from_secs(1); - let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval).fuse(); - let mut previous = Instant::now(); + let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval); + work_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - _ = work_interval.select_next_some() => { - // This bit of code prevents bottled up tokio interval events to be fired successively for the edge - // case where a computer wakes up from sleep. - if start_at.elapsed() > self.scan_for_utxo_interval && - previous.elapsed() < self.scan_for_utxo_interval.mul_f32(0.9) - { - debug!( - target: LOG_TARGET, - "UTXO scanning work interval event fired too quickly, not running the task" - ); - } else { - let running_flag = self.is_running.clone(); - if !running_flag.load(Ordering::SeqCst) { - let task = self.create_task(); - debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); - task::spawn(async move { - if let Err(err) = task.run().await { - error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); - } - //we make sure the flag is set to false here - running_flag.store(false, Ordering::Relaxed); - }); - } + tokio::select! { + _ = work_interval.tick() => { + let running_flag = self.is_running.clone(); + if !running_flag.load(Ordering::SeqCst) { + let task = self.create_task(); + debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); + task::spawn(async move { + if let Err(err) = task.run().await { + error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); + } + //we make sure the flag is set to false here + running_flag.store(false, Ordering::Relaxed); + }); } - previous = Instant::now(); }, request_context = request_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Service API Request"); @@ -757,7 +745,7 @@ where TBackend: WalletBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { // this will stop the task if its running, and let that thread exit gracefully self.is_running.store(false, Ordering::Relaxed); info!(target: LOG_TARGET, "UTXO scanning service shutting down because it received the shutdown signal"); diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index 1f91f3d625..a02ea1129d 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -77,7 +77,6 @@ use tari_key_manager::key_manager::KeyManager; use tari_p2p::{comms_connector::pubsub_connector, initialization, initialization::P2pInitializer}; use tari_service_framework::StackBuilder; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "wallet"; @@ -139,8 +138,7 @@ where let bn_service_db = wallet_database.clone(); let factories = config.clone().factories; - let (publisher, subscription_factory) = - pubsub_connector(runtime::Handle::current(), config.buffer_size, config.rate_limit); + let (publisher, subscription_factory) = pubsub_connector(config.buffer_size, config.rate_limit); let peer_message_subscription_factory = Arc::new(subscription_factory); let transport_type = config.comms_config.transport_type.clone(); diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index c6c23da53e..b2b647cef6 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -30,7 +30,7 @@ use rand::{rngs::OsRng, RngCore}; use std::{sync::Arc, thread, time::Duration}; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, - protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, + protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcClientConfig, RpcStatus}, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -74,7 +74,7 @@ use tari_wallet::{ service::OutputManagerService, storage::{ database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, - models::DbUnblindedOutput, + models::{DbUnblindedOutput, OutputStatus}, sqlite_db::OutputManagerSqliteDatabase, }, TxId, @@ -83,13 +83,10 @@ use tari_wallet::{ transaction_service::handle::TransactionServiceHandle, types::ValidationRetryStrategy, }; - -use tari_comms::protocol::rpc::RpcClientConfig; -use tari_wallet::output_manager_service::storage::models::OutputStatus; use tokio::{ runtime::Runtime, sync::{broadcast, broadcast::channel}, - time::delay_for, + time, }; #[allow(clippy::type_complexity)] @@ -1375,7 +1372,7 @@ fn test_utxo_stxo_invalid_txo_validation() { let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1435,10 +1432,10 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut success = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = (*msg).clone() { @@ -1475,10 +1472,10 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut success = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = (*msg).clone() { @@ -1514,10 +1511,10 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut success = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { @@ -1557,7 +1554,7 @@ fn test_base_node_switch_during_validation() { mut rpc_service_state, _connectivity_mock_state, ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1615,10 +1612,10 @@ fn test_base_node_switch_during_validation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut abort = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { @@ -1644,7 +1641,7 @@ fn test_txo_validation_connection_timeout_retries() { let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, _rpc_service_state, _connectivity_mock_state) = setup_output_manager_service(&mut runtime, backend, false); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1675,11 +1672,11 @@ fn test_txo_validation_connection_timeout_retries() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut timeout = 0; let mut failed = 0; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { match (*msg).clone() { @@ -1714,7 +1711,7 @@ fn test_txo_validation_rpc_error_retries() { let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_rpc_status_error(Some(RpcStatus::bad_request("blah".to_string()))); let unspent_value1 = 500; @@ -1746,10 +1743,10 @@ fn test_txo_validation_rpc_error_retries() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut failed = 0; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { @@ -1785,7 +1782,7 @@ fn test_txo_validation_rpc_timeout() { mut rpc_service_state, _connectivity_mock_state, ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(120))); let unspent_value1 = 500; @@ -1817,7 +1814,7 @@ fn test_txo_validation_rpc_timeout() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for( + let mut delay = sleep( RpcClientConfig::default().deadline.unwrap() + RpcClientConfig::default().deadline_grace_period + Duration::from_secs(30), @@ -1825,7 +1822,7 @@ fn test_txo_validation_rpc_timeout() { .fuse(); let mut failed = 0; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { @@ -1856,7 +1853,7 @@ fn test_txo_validation_base_node_not_synced() { let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_is_synced(false); let unspent_value1 = 500; @@ -1889,10 +1886,10 @@ fn test_txo_validation_base_node_not_synced() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut delayed = 0; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationDelayed(_,_) = (*msg).clone() { @@ -1915,10 +1912,10 @@ fn test_txo_validation_base_node_not_synced() { rpc_service_state.set_utxos(vec![unspent_tx_output1]); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = time::sleep(Duration::from_secs(60)).fuse(); let mut success = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let Ok(msg) = event { if let OutputManagerEvent::TxoValidationSuccess(_,_) = (*msg).clone() { diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index c0609da64c..00309569cc 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -429,7 +429,7 @@ pub fn test_key_manager_crud() { assert_eq!(read_state3.primary_key_index, 2); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_short_term_encumberance() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); @@ -510,7 +510,7 @@ pub async fn test_short_term_encumberance() { ); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_no_duplicate_outputs() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); diff --git a/base_layer/wallet/tests/support/rpc.rs b/base_layer/wallet/tests/support/rpc.rs index 0f7009bdd0..fd1aa0cdbd 100644 --- a/base_layer/wallet/tests/support/rpc.rs +++ b/base_layer/wallet/tests/support/rpc.rs @@ -57,7 +57,7 @@ use tari_core::{ types::Signature, }, }; -use tokio::time::delay_for; +use tokio::time::sleep; /// This macro unlocks a Mutex or RwLock. If the lock is /// poisoned (i.e. panic while unlocked) the last value @@ -212,7 +212,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -234,7 +234,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -256,7 +256,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -276,7 +276,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err("Did not receive enough calls within the timeout period".to_string()) } @@ -318,7 +318,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result<Response<TxSubmissionResponseProto>, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -345,7 +345,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result<Response<TxQueryResponseProto>, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -371,7 +371,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result<Response<TxQueryBatchResponsesProto>, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -415,7 +415,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result<Response<FetchUtxosResponse>, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -448,7 +448,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { async fn get_tip_info(&self, _request: Request<()>) -> Result<Response<TipInfoResponse>, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } log::info!("Get tip info call received"); @@ -493,7 +493,7 @@ mod test { }; use tokio::time::Duration; - #[tokio_macros::test] + #[tokio::test] async fn test_wallet_rpc_mock() { let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let client_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index dc508f7321..51f1de392c 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -142,7 +142,7 @@ use tokio::{ runtime, runtime::{Builder, Runtime}, sync::{broadcast, broadcast::channel}, - time::delay_for, + time::sleep, }; fn create_runtime() -> Runtime { @@ -172,7 +172,7 @@ pub fn setup_transaction_service< discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, ) -> (TransactionServiceHandle, OutputManagerHandle, CommsNode) { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht) = runtime.block_on(setup_comms_services( node_identity, @@ -504,9 +504,9 @@ fn manage_single_transaction() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( &mut runtime, @@ -524,7 +524,7 @@ fn manage_single_transaction() { .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); let _ = runtime.block_on( bob_comms @@ -556,10 +556,10 @@ fn manage_single_transaction() { .expect("Alice sending tx"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut delay = sleep(Duration::from_secs(90)).fuse(); let mut count = 0; loop { - futures::select! { + tokio::select! { _event = alice_event_stream.select_next_some() => { println!("alice: {:?}", &*_event.as_ref().unwrap()); count+=1; @@ -576,10 +576,10 @@ fn manage_single_transaction() { let mut tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut delay = sleep(Duration::from_secs(90)).fuse(); let mut finalized = 0; loop { - futures::select! { + tokio::select! { event = bob_event_stream.select_next_some() => { println!("bob: {:?}", &*event.as_ref().unwrap()); if let TransactionEvent::ReceivedFinalizedTransaction(id) = &*event.unwrap() { @@ -747,7 +747,7 @@ fn send_one_sided_transaction_to_other() { shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -792,10 +792,10 @@ fn send_one_sided_transaction_to_other() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut found = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionCompletedImmediately(id) = &*event.unwrap() { if id == &tx_id { @@ -1071,9 +1071,9 @@ fn manage_multiple_transactions() { Duration::from_secs(60), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); // Spin up Bob and Carol let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( @@ -1088,8 +1088,8 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + let mut bob_event_stream = bob_ts.get_event_stream(); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( &mut runtime, @@ -1103,18 +1103,18 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); // Establish some connections beforehand, to reduce the amount of work done concurrently in tests // Connect Bob and Alice - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); let _ = runtime.block_on( bob_comms .connectivity() .dial_peer(alice_node_identity.node_id().clone()), ); - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); // Connect alice to carol let _ = runtime.block_on( @@ -1182,11 +1182,11 @@ fn manage_multiple_transactions() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut delay = sleep(Duration::from_secs(90)).fuse(); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, @@ -1210,11 +1210,11 @@ fn manage_multiple_transactions() { log::trace!("Alice received all Tx messages"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut delay = sleep(Duration::from_secs(90)).fuse(); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { + tokio::select! { event = bob_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, @@ -1235,10 +1235,10 @@ fn manage_multiple_transactions() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut delay = sleep(Duration::from_secs(90)).fuse(); let mut finalized = 0; loop { - futures::select! { + tokio::select! { event = carol_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event.unwrap() { finalized+=1 } }, @@ -1264,7 +1264,7 @@ fn manage_multiple_transactions() { assert_eq!(carol_pending_inbound.len(), 0); assert_eq!(carol_completed_tx.len(), 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; bob_comms.wait_until_shutdown().await; @@ -1303,7 +1303,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); @@ -1355,10 +1355,10 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut errors = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { log::error!("ERROR: {:?}", event); if let TransactionEvent::Error(s) = &*event.unwrap() { @@ -1415,7 +1415,7 @@ fn finalize_tx_with_incorrect_pubkey() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1488,9 +1488,9 @@ fn finalize_tx_with_incorrect_pubkey() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let mut delay = sleep(Duration::from_secs(15)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event!"); @@ -1542,7 +1542,7 @@ fn finalize_tx_with_missing_output() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1623,9 +1623,9 @@ fn finalize_tx_with_missing_output() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let mut delay = sleep(Duration::from_secs(15)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event"); @@ -1714,7 +1714,7 @@ fn discovery_async_return_test() { Duration::from_secs(20), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo1a) = make_input(&mut OsRng, MicroTari(5500), &factories.commitment); runtime.block_on(alice_oms.add_output(uo1a)).unwrap(); @@ -1741,9 +1741,9 @@ fn discovery_async_return_test() { let mut txid = 0; let mut is_success = true; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, result) = (*event.unwrap()).clone() { txid = tx_id; @@ -1772,10 +1772,10 @@ fn discovery_async_return_test() { let mut success_result = false; let mut success_tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, success) = &*event.unwrap() { success_result = *success; @@ -1794,9 +1794,9 @@ fn discovery_async_return_test() { assert!(success_result); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap() { if tx_id == &tx_id2 { @@ -1811,7 +1811,7 @@ fn discovery_async_return_test() { } }); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; carol_comms.wait_until_shutdown().await; @@ -2012,7 +2012,7 @@ fn test_transaction_cancellation() { ..Default::default() }), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let alice_total_available = 250000 * uT; let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); @@ -2030,9 +2030,9 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionStoreForwardSendResult(_,_) = &*event.unwrap() { break; @@ -2054,7 +2054,7 @@ fn test_transaction_cancellation() { None => (), Some(_) => break, } - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); if i >= 12 { panic!("Pending outbound transaction should have been added by now"); } @@ -2066,10 +2066,10 @@ fn test_transaction_cancellation() { // Wait for cancellation event, in an effort to nail down where the issue is for the flakey CI test runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut cancelled = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; @@ -2143,9 +2143,9 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; @@ -2213,9 +2213,9 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; @@ -2243,7 +2243,7 @@ fn test_transaction_cancellation() { ))) .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime .block_on(alice_ts.get_pending_inbound_transactions()) @@ -2257,10 +2257,10 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut cancelled = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; @@ -2386,7 +2386,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob_outbound_service.call_count(), 0, "Should be no more calls"); let (_wallet_backend, backend, oms_backend, _, _temp_dir) = make_wallet_databases(None); @@ -2428,7 +2428,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob2_outbound_service.call_count(), 0, "Should be no more calls"); // Test finalize is sent Direct Only. @@ -2449,7 +2449,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { let _ = alice_outbound_service.pop_call().unwrap(); let _ = alice_outbound_service.pop_call().unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls"); // Now to repeat sending so we can test the SAF send of the finalize message @@ -2520,7 +2520,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { assert_eq!(alice_outbound_service.call_count(), 1); let _ = alice_outbound_service.pop_call(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls2"); } @@ -2548,7 +2548,7 @@ fn test_tx_direct_send_behaviour() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, 1000000 * uT, &factories.commitment); runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); @@ -2576,11 +2576,11 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, @@ -2619,11 +2619,11 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, @@ -2663,10 +2663,10 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut direct_count = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if *result { direct_count+=1 }, @@ -2705,10 +2705,10 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut saf_count = 0; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 @@ -2852,7 +2852,7 @@ fn test_restarting_transaction_protocols() { // Test that Bob's node restarts the send protocol let (mut bob_ts, _bob_oms, _bob_outbound_service, _, _, mut bob_tx_reply, _, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, bob_oms_backend, None); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); runtime .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2864,10 +2864,10 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let mut delay = sleep(Duration::from_secs(15)).fuse(); let mut received_reply = false; loop { - futures::select! { + tokio::select! { event = bob_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransactionReply(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); @@ -2886,7 +2886,7 @@ fn test_restarting_transaction_protocols() { // Test Alice's node restarts the receive protocol let (mut alice_ts, _alice_oms, _alice_outbound_service, _, _, _, mut alice_tx_finalized, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories, alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2906,10 +2906,10 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let mut delay = sleep(Duration::from_secs(15)).fuse(); let mut received_finalized = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); @@ -3046,7 +3046,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3131,10 +3131,10 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3170,10 +3170,10 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3215,7 +3215,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3301,10 +3301,10 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionMinedUnconfirmed(tx_id, _) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3368,10 +3368,10 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3413,7 +3413,7 @@ fn test_coinbase_monitoring_mined_not_synced() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3499,10 +3499,10 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3538,10 +3538,10 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = sleep(Duration::from_secs(30)).fuse(); let mut count = 0usize; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { @@ -3760,7 +3760,7 @@ fn test_transaction_resending() { assert_eq!(bob_reply_message.tx_id, tx_id); } - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); // See if sending a second message too soon is ignored runtime .block_on(bob_tx_sender.send(create_dummy_message( @@ -3772,7 +3772,7 @@ fn test_transaction_resending() { assert!(bob_outbound_service.wait_call_count(1, Duration::from_secs(2)).is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(bob_tx_sender.send(create_dummy_message( alice_sender_message.into(), @@ -3819,7 +3819,7 @@ fn test_transaction_resending() { .is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(alice_tx_reply_sender.send(create_dummy_message( @@ -4143,7 +4143,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(data.tx_id, tx_id); } // Need a moment for Alice's wallet to finish writing to its database before cancelling - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime.block_on(alice_ts.cancel_transaction(tx_id)).unwrap(); @@ -4193,7 +4193,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(bob_reply_message.tx_id, tx_id); // Wait for cooldown to expire - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let _ = alice_outbound_service.take_calls(); @@ -4406,7 +4406,7 @@ fn test_transaction_timeout_cancellation() { ..Default::default() }), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); runtime .block_on(carol_tx_sender.send(create_dummy_message( @@ -4431,10 +4431,10 @@ fn test_transaction_timeout_cancellation() { assert_eq!(carol_reply_message.tx_id, tx_id); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut transaction_cancelled = false; loop { - futures::select! { + tokio::select! { event = carol_event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(t) = &*event.unwrap() { if t == &tx_id { @@ -4481,7 +4481,7 @@ fn transaction_service_tx_broadcast() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4609,10 +4609,10 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx1_received = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { @@ -4655,10 +4655,10 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx1_mined = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { @@ -4683,10 +4683,10 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx2_received = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { @@ -4733,10 +4733,10 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx2_cancelled = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { @@ -4851,14 +4851,14 @@ fn broadcast_all_completed_transactions_on_startup() { assert!(runtime.block_on(alice_ts.restart_broadcast_protocols()).is_ok()); - let mut event_stream = alice_ts.get_event_stream_fused(); + let mut event_stream = alice_ts.get_event_stream(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut found1 = false; let mut found2 = false; let mut found3 = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let TransactionEvent::TransactionBroadcast(tx_id) = (*event.unwrap()).clone() { if tx_id == 1u64 { @@ -4916,7 +4916,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4995,10 +4995,10 @@ fn transaction_service_tx_broadcast_with_base_node_change() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx1_received = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { @@ -5075,10 +5075,10 @@ fn transaction_service_tx_broadcast_with_base_node_change() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx_mined = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => { if let TransactionEvent::TransactionMined(_) = &*event.unwrap(){ tx_mined = true; diff --git a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs index 66079613ca..5d6d2290de 100644 --- a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs +++ b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs @@ -80,7 +80,7 @@ use tari_wallet::{ types::ValidationRetryStrategy, }; use tempfile::{tempdir, TempDir}; -use tokio::{sync::broadcast, task, time::delay_for}; +use tokio::{sync::broadcast, task, time::sleep}; // Just in case other options become apparent in later testing #[derive(PartialEq)] @@ -230,7 +230,7 @@ pub async fn oms_reply_channel_task( } /// A happy path test by submitting a transaction into the mempool, have it mined but unconfirmed and then confirmed. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_i() { let ( @@ -392,12 +392,12 @@ async fn tx_broadcast_protocol_submit_success_i() { ); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(5)).fuse(); + let mut delay = sleep(Duration::from_secs(5)).fuse(); let mut broadcast = false; let mut unconfirmed = false; let mut confirmed = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionMinedUnconfirmed(_, confirmations) => if *confirmations == 1 { @@ -426,7 +426,7 @@ async fn tx_broadcast_protocol_submit_success_i() { } /// Test submitting a transaction that is immediately rejected -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_rejection() { let ( @@ -478,10 +478,10 @@ async fn tx_broadcast_protocol_submit_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let mut delay = sleep(Duration::from_secs(1)).fuse(); let mut cancelled = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; @@ -498,7 +498,7 @@ async fn tx_broadcast_protocol_submit_rejection() { /// Test restarting a protocol which means the first step is a query not a submission, detecting the Tx is not in the /// mempool, resubmit the tx and then have it mined -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_restart_protocol_as_query() { let ( @@ -585,7 +585,7 @@ async fn tx_broadcast_protocol_restart_protocol_as_query() { /// This test will submit a Tx which will be accepted and then dropped from the mempool, resulting in a resubmit which /// will be rejected and result in a cancelled transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { let ( @@ -666,10 +666,10 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let mut delay = sleep(Duration::from_secs(1)).fuse(); let mut cancelled = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; @@ -686,7 +686,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { /// This test will submit a tx which is accepted and mined but unconfirmed, then the next query it will not exist /// resulting in a resubmission which we will let run to being mined with success -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { let ( @@ -806,7 +806,7 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { } /// Test being unable to connect and then connection becoming available. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_connection_problem() { let ( @@ -839,10 +839,10 @@ async fn tx_broadcast_protocol_connection_problem() { let join_handle = task::spawn(protocol.execute()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let mut delay = sleep(Duration::from_secs(10)).fuse(); let mut connection_issues = 0; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let TransactionEvent::TransactionBaseNodeConnectionProblem(_) = &*event.unwrap() { connection_issues+=1; @@ -878,7 +878,7 @@ async fn tx_broadcast_protocol_connection_problem() { } /// Submit a transaction that is Already Mined for the submission, the subsequent query should confirm the transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_already_mined() { let ( @@ -948,7 +948,7 @@ async fn tx_broadcast_protocol_submit_already_mined() { } /// A test to see that the broadcast protocol can handle a change to the base node address while it runs. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { let ( @@ -1050,7 +1050,7 @@ async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_valid() { let ( @@ -1148,7 +1148,7 @@ async fn tx_validation_protocol_tx_becomes_valid() { } /// Validate completed transaction, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_invalid() { let ( @@ -1213,7 +1213,7 @@ async fn tx_validation_protocol_tx_becomes_invalid() { } /// Validate completed transactions, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_unconfirmed() { let ( @@ -1285,7 +1285,7 @@ async fn tx_validation_protocol_tx_becomes_unconfirmed() { /// Test the validation protocol reacts correctly to a change in base node and redoes the full validation based on the /// new base node -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_ends_on_base_node_end() { let ( @@ -1398,10 +1398,10 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { let result = join_handle.await.unwrap(); assert!(result.is_ok()); - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let mut delay = sleep(Duration::from_secs(1)).fuse(); let mut aborted = false; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { if let TransactionEvent::TransactionValidationAborted(_) = &*event.unwrap() { aborted = true; @@ -1416,7 +1416,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { } /// Test the validation protocol reacts correctly when the RPC client returns an error between calls. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_between_calls() { let ( @@ -1540,7 +1540,7 @@ async fn tx_validation_protocol_rpc_client_broken_between_calls() { /// Test the validation protocol reacts correctly when the RPC client returns an error between calls and only retry /// finite amount of times -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_finite_retries() { let ( @@ -1610,11 +1610,11 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { assert!(result.is_err()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let mut delay = sleep(Duration::from_secs(10)).fuse(); let mut timeouts = 0i32; let mut failures = 0i32; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { log::error!("EVENT: {:?}", event); match &*event.unwrap() { @@ -1641,7 +1641,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_base_node_not_synced() { let ( @@ -1711,11 +1711,11 @@ async fn tx_validation_protocol_base_node_not_synced() { let result = join_handle.await.unwrap(); assert!(result.is_err()); - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let mut delay = sleep(Duration::from_secs(10)).fuse(); let mut delayed = 0i32; let mut failures = 0i32; loop { - futures::select! { + tokio::select! { event = event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::TransactionValidationDelayed(_) => { diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index 6aa231b09c..2e362d399f 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -71,7 +71,7 @@ use tari_wallet::{ WalletSqlite, }; use tempfile::tempdir; -use tokio::{runtime::Runtime, time::delay_for}; +use tokio::{runtime::Runtime, time::sleep}; fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { Peer::new( @@ -163,7 +163,7 @@ async fn create_wallet( .await } -#[tokio_macros::test] +#[tokio::test] async fn test_wallet() { let mut shutdown_a = Shutdown::new(); let mut shutdown_b = Shutdown::new(); @@ -227,7 +227,7 @@ async fn test_wallet() { .await .unwrap(); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream(); let value = MicroTari::from(1000); let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); @@ -245,10 +245,10 @@ async fn test_wallet() { .await .unwrap(); - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut reply_count = false; loop { - futures::select! { + tokio::select! { event = alice_event_stream.select_next_some() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { reply_count = true; break; @@ -405,7 +405,7 @@ async fn test_wallet() { bob_wallet.wait_until_shutdown().await; } -#[tokio_macros::test] +#[tokio::test] async fn test_do_not_overwrite_master_key() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -423,7 +423,7 @@ async fn test_do_not_overwrite_master_key() { ) .await .unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); wallet.wait_until_shutdown().await; // try to use a new master key to create a wallet using the existing wallet database @@ -457,7 +457,7 @@ async fn test_do_not_overwrite_master_key() { .unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn test_sign_message() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -591,13 +591,13 @@ fn test_store_and_forward_send_tx() { .unwrap(); // Waiting here for a while to make sure the discovery retry is over - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); alice_runtime .block_on(alice_wallet.transaction_service.cancel_transaction(tx_id)) .unwrap(); - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); let carol_wallet = carol_runtime .block_on(create_wallet( @@ -610,7 +610,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); - let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream_fused(); + let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream(); carol_runtime .block_on(carol_wallet.comms.peer_manager().add_peer(create_peer( @@ -623,12 +623,12 @@ fn test_store_and_forward_send_tx() { .unwrap(); carol_runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut delay = sleep(Duration::from_secs(60)).fuse(); let mut tx_recv = false; let mut tx_cancelled = false; loop { - futures::select! { + tokio::select! { event = carol_event_stream.select_next_some() => { match &*event.unwrap() { TransactionEvent::ReceivedTransaction(_) => tx_recv = true, @@ -655,7 +655,7 @@ fn test_store_and_forward_send_tx() { carol_runtime.block_on(carol_wallet.wait_until_shutdown()); } -#[tokio_macros::test] +#[tokio::test] async fn test_import_utxo() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); diff --git a/base_layer/wallet_ffi/Cargo.toml b/base_layer/wallet_ffi/Cargo.toml index d7381c0d8e..868d301e97 100644 --- a/base_layer/wallet_ffi/Cargo.toml +++ b/base_layer/wallet_ffi/Cargo.toml @@ -17,11 +17,11 @@ tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } tari_utilities = "^0.3" futures = { version = "^0.3.1", features =["compat", "std"]} -tokio = "0.2.10" +tokio = "1.10.1" libc = "0.2.65" rand = "0.8" chrono = { version = "0.4.6", features = ["serde"]} -thiserror = "1.0.20" +thiserror = "1.0.26" log = "0.4.6" log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} @@ -41,4 +41,3 @@ env_logger = "0.7.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} -tokio = { version="0.2.10" } diff --git a/base_layer/wallet_ffi/src/callback_handler.rs b/base_layer/wallet_ffi/src/callback_handler.rs index 6d00799f23..ea9856dd64 100644 --- a/base_layer/wallet_ffi/src/callback_handler.rs +++ b/base_layer/wallet_ffi/src/callback_handler.rs @@ -219,7 +219,7 @@ where TBackend: TransactionBackend + 'static info!(target: LOG_TARGET, "Transaction Service Callback Handler starting"); loop { - futures::select! { + tokio::select! { result = self.transaction_service_event_stream.select_next_some() => { match result { Ok(msg) => { diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 57065f7f1f..8a644144f1 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -2943,8 +2943,8 @@ pub unsafe extern "C" fn wallet_create( // Start Callback Handler let callback_handler = CallbackHandler::new( TransactionDatabase::new(transaction_backend), - w.transaction_service.get_event_stream_fused(), - w.output_manager_service.get_event_stream_fused(), + w.transaction_service.get_event_stream(), + w.output_manager_service.get_event_stream(), w.dht_service.subscribe_dht_events().fuse(), w.comms.shutdown_signal(), w.comms.node_identity().public_key().clone(), diff --git a/common/Cargo.toml b/common/Cargo.toml index 87d98a6163..ac7b1aeff6 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,14 +22,14 @@ dirs-next = "1.0.2" get_if_addrs = "0.5.3" log = "0.4.8" log4rs = { version = "1.0.0", default_features= false, features = ["config_parsing", "threshold_filter"]} -multiaddr={package="parity-multiaddr", version = "0.11.0"} +multiaddr={version = "0.13.0"} sha2 = "0.9.5" path-clean = "0.1.0" tari_storage = { version = "^0.9", path = "../infrastructure/storage"} anyhow = { version = "1.0", optional = true } git2 = { version = "0.8", optional = true } -prost-build = { version = "0.6.1", optional = true } +prost-build = { version = "0.8.0", optional = true } toml = { version = "0.5", optional = true } [dev-dependencies] diff --git a/common/src/configuration/global.rs b/common/src/configuration/global.rs index 880e111cd1..941a7ba7fa 100644 --- a/common/src/configuration/global.rs +++ b/common/src/configuration/global.rs @@ -71,7 +71,6 @@ pub struct GlobalConfig { pub pruning_horizon: u64, pub pruned_mode_cleanup_interval: u64, pub core_threads: Option<usize>, - pub max_threads: Option<usize>, pub base_node_identity_file: PathBuf, pub public_address: Multiaddr, pub grpc_enabled: bool, @@ -270,10 +269,6 @@ fn convert_node_config( let core_threads = optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - let key = config_string("base_node", &net_str, "max_threads"); - let max_threads = - optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - // Max RandomX VMs let key = config_string("base_node", &net_str, "max_randomx_vms"); let max_randomx_vms = optional(cfg.get_int(&key).map(|n| n as usize)) @@ -712,7 +707,6 @@ fn convert_node_config( pruning_horizon, pruned_mode_cleanup_interval, core_threads, - max_threads, base_node_identity_file, public_address, grpc_enabled, diff --git a/common/src/dns/tests.rs b/common/src/dns/tests.rs index 955f22cf97..b7dc087517 100644 --- a/common/src/dns/tests.rs +++ b/common/src/dns/tests.rs @@ -48,7 +48,7 @@ use trust_dns_client::rr::{rdata, RData, Record, RecordType}; // Ignore as this test requires network IO #[ignore] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { let mut resolver = PeerSeedResolver::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(); let seeds = resolver.resolve("tari.com").await.unwrap(); @@ -64,7 +64,7 @@ fn create_txt_record(contents: Vec<String>) -> Record { } #[allow(clippy::vec_init_then_push)] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_peer_seeds() { let mut records = Vec::new(); // Multiple addresses(works) diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 4ae5e3ae2e..4e94c7e008 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -12,51 +12,52 @@ edition = "2018" [dependencies] tari_crypto = "0.11.1" tari_storage = { version = "^0.9", path = "../infrastructure/storage" } -tari_shutdown = { version="^0.9", path = "../infrastructure/shutdown" } +tari_shutdown = { version = "^0.9", path = "../infrastructure/shutdown" } +anyhow = "1.0.32" async-trait = "0.1.36" bitflags = "1.0.4" blake2 = "0.9.0" -bytes = { version = "0.5.x", features=["serde"] } +bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4.6", features = ["serde"] } cidr = "0.1.0" clear_on_drop = "=0.2.4" data-encoding = "2.2.0" digest = "0.9.0" -futures = { version = "^0.3", features = ["async-await"]} +futures = { version = "^0.3", features = ["async-await"] } lazy_static = "1.3.0" lmdb-zero = "0.4.4" log = { version = "0.4.0", features = ["std"] } -multiaddr = {version = "=0.11.0", package = "parity-multiaddr"} -nom = {version = "5.1.0", features=["std"], default-features=false} +multiaddr = { version = "0.13.0" } +nom = { version = "5.1.0", features = ["std"], default-features = false } openssl = { version = "0.10", features = ["vendored"] } -pin-project = "0.4.17" -prost = "=0.6.1" +pin-project = "1.0.8" +prost = "=0.8.0" rand = "0.8" serde = "1.0.119" serde_derive = "1.0.119" -snow = {version="=0.8.0", features=["default-resolver"]} -thiserror = "1.0.20" -tokio = {version="~0.2.19", features=["blocking", "time", "tcp", "dns", "sync", "stream", "signal"]} -tokio-util = {version="0.2.0", features=["codec"]} -tower= "0.3.1" +snow = { version = "=0.8.0", features = ["default-resolver"] } +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } +tokio-stream = { version = "0.1.7", features = ["sync"] } +tokio-util = { version = "0.6.7", features = ["codec", "compat"] } +tower = "0.3.1" yamux = "=0.9.0" # RPC dependencies -tower-make = {version="0.3.0", optional=true} -anyhow = "1.0.32" +tower-make = { version = "0.3.0", optional = true } [dev-dependencies] -tari_test_utils = {version="^0.9", path="../infrastructure/test_utils"} -tari_comms_rpc_macros = {version="*", path="./rpc_macros"} +tari_test_utils = { version = "^0.9", path = "../infrastructure/test_utils" } +tari_comms_rpc_macros = { version = "*", path = "./rpc_macros" } env_logger = "0.7.0" serde_json = "1.0.39" -tokio-macros = "0.2.3" +#tokio = {version="1.8", features=["macros"]} tempfile = "3.1.0" [build-dependencies] -tari_common = { version = "^0.9", path="../common", features = ["build"]} +tari_common = { version = "^0.9", path = "../common", features = ["build"] } [features] avx2 = ["tari_crypto/avx2"] diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index c75e423543..d9aa0480b7 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -10,58 +10,57 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} -tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros"} +tari_comms = { version = "^0.9", path = "../", features = ["rpc"] } +tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros" } tari_crypto = "0.11.1" -tari_utilities = { version = "^0.3" } -tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown"} -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_utilities = { version = "^0.3" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } anyhow = "1.0.32" bitflags = "1.2.0" -bytes = "0.4.12" +bytes = "0.5" chacha20 = "0.7.1" chrono = "0.4.9" -diesel = {version="1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } digest = "0.9.0" -futures= {version= "^0.3.1"} +futures = { version = "^0.3.1" } log = "0.4.8" -prost = "=0.6.1" -prost-types = "=0.6.1" +prost = "=0.8.0" +prost-types = "=0.8.0" rand = "0.8" serde = "1.0.90" serde_derive = "1.0.90" serde_repr = "0.1.5" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} -tower= "0.3.1" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt", "macros"] } +tower = "0.3.1" ttl_cache = "0.5.1" # tower-filter dependencies pin-project = "0.4" [dev-dependencies] -tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } env_logger = "0.7.0" futures-test = { version = "0.3.0-alpha.19", package = "futures-test-preview" } lmdb-zero = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.3" +tokio-stream = { version = "0.1.7", features = ["sync"] } petgraph = "0.5.1" clap = "2.33.0" # tower-filter dependencies tower-test = { version = "^0.3" } -tokio-test = "^0.2" -tokio = "^0.2" +tokio-test = "^0.4.2" futures-util = "^0.3.1" lazy_static = "1.4.0" [build-dependencies] -tari_common = { version = "^0.9", path="../../common"} +tari_common = { version = "^0.9", path = "../../common" } [features] test-mocks = [] diff --git a/comms/dht/examples/graphing_utilities/utilities.rs b/comms/dht/examples/graphing_utilities/utilities.rs index bfb081eb21..3a4fea7ae9 100644 --- a/comms/dht/examples/graphing_utilities/utilities.rs +++ b/comms/dht/examples/graphing_utilities/utilities.rs @@ -32,6 +32,7 @@ use petgraph::{ }; use std::{collections::HashMap, convert::TryFrom, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; use tari_comms::{connectivity::ConnectivitySelection, peer_manager::NodeId}; +use tari_test_utils::streams::convert_unbounded_mpsc_to_stream; const TEMP_GRAPH_OUTPUT_DIR: &str = "/tmp/memorynet_temp"; @@ -277,7 +278,9 @@ pub enum PythonRenderType { /// This function will drain the message event queue and then build a message propagation tree assuming the first sender /// is the starting node pub async fn track_join_message_drain_messaging_events(messaging_rx: &mut NodeEventRx) -> StableGraph<NodeId, String> { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); let messages = drain_fut.await; let num_messages = messages.len(); diff --git a/comms/dht/examples/memory_net/drain_burst.rs b/comms/dht/examples/memory_net/drain_burst.rs index d2f5bce2be..f4374328d5 100644 --- a/comms/dht/examples/memory_net/drain_burst.rs +++ b/comms/dht/examples/memory_net/drain_burst.rs @@ -42,7 +42,7 @@ where St: ?Sized + Stream + Unpin let (lower_bound, upper_bound) = stream.size_hint(); Self { inner: stream, - collection: Vec::with_capacity(upper_bound.or(Some(lower_bound)).unwrap()), + collection: Vec::with_capacity(upper_bound.unwrap_or(lower_bound)), } } } @@ -71,14 +71,14 @@ mod test { use super::*; use futures::stream; - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_terminating_stream() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; assert_eq!(burst, (1..10u8).into_iter().collect::<Vec<_>>()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_stream_with_pending() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bf1bf03ae4..88022977c7 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -22,7 +22,7 @@ #![allow(clippy::mutex_atomic)] use crate::memory_net::DrainBurst; -use futures::{channel::mpsc, future, StreamExt}; +use futures::future; use lazy_static::lazy_static; use rand::{rngs::OsRng, Rng}; use std::{ @@ -62,8 +62,13 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tari_test_utils::{paths::create_temporary_data_path, random}; -use tokio::{runtime, sync::broadcast, task, time}; +use tari_test_utils::{paths::create_temporary_data_path, random, streams::convert_unbounded_mpsc_to_stream}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, + task, + time, +}; use tower::ServiceBuilder; pub type NodeEventRx = mpsc::UnboundedReceiver<(NodeId, NodeId)>; @@ -154,7 +159,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent start.elapsed() ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, Err(err) => { @@ -166,7 +171,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent err ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, } @@ -298,7 +303,7 @@ pub async fn do_network_wide_propagation(nodes: &mut [TestNode], origin_node_ind let node_name = node.name.clone(); task::spawn(async move { - let result = time::timeout(Duration::from_secs(30), ims_rx.next()).await; + let result = time::timeout(Duration::from_secs(30), ims_rx.recv()).await; let mut is_success = false; match result { Ok(Some(msg)) => { @@ -450,21 +455,23 @@ pub async fn do_store_and_forward_message_propagation( for (idx, mut s) in neighbour_subs.into_iter().enumerate() { let neighbour = neighbours[idx].name.clone(); task::spawn(async move { - let msg = time::timeout(Duration::from_secs(2), s.next()).await; + let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { - Ok(Some(Ok(evt))) => { + Ok(Ok(evt)) => { if let MessagingEvent::MessageReceived(_, tag) = &*evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, - Ok(_) => {}, + Ok(Err(err)) => { + println!("{}", err); + }, Err(_) => println!("{} did not receive the SAF message", neighbour), } }); } banner!("⏰ Waiting a few seconds for messages to propagate around the network..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; let mut total_messages = drain_messaging_events(messaging_rx, false).await; @@ -515,7 +522,7 @@ pub async fn do_store_and_forward_message_propagation( let mut num_msgs = 0; let mut succeeded = 0; loop { - let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().next()).await; + let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().recv()).await; num_msgs += 1; match result { Ok(msg) => { @@ -554,7 +561,9 @@ pub async fn do_store_and_forward_message_propagation( } pub async fn drain_messaging_events(messaging_rx: &mut NodeEventRx, show_logs: bool) -> usize { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); if show_logs { let messages = drain_fut.await; let num_messages = messages.len(); @@ -694,42 +703,43 @@ impl TestNode { fn spawn_event_monitor( comms: &CommsNode, - messaging_events: MessagingEventReceiver, + mut messaging_events: MessagingEventReceiver, events_tx: mpsc::Sender<Arc<ConnectionManagerEvent>>, messaging_events_tx: NodeEventTx, quiet_mode: bool, ) { - let conn_man_event_sub = comms.subscribe_connection_manager_events(); + let mut conn_man_event_sub = comms.subscribe_connection_manager_events(); let executor = runtime::Handle::current(); - executor.spawn( - conn_man_event_sub - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .map(connection_manager_logger( - comms.node_identity().node_id().clone(), - quiet_mode, - )) - .map(Ok) - .forward(events_tx), - ); - let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + let mut logger = connection_manager_logger(node_id, quiet_mode); + loop { + match conn_man_event_sub.recv().await { + Ok(event) => { + events_tx.send(logger(event)).await.unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => break, + Err(err) => log::error!("{}", err), + } + } + }); - executor.spawn( - messaging_events - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .filter_map(move |event| { - use MessagingEvent::*; - future::ready(match &*event { - MessageReceived(peer_node_id, _) => Some((Clone::clone(&*peer_node_id), node_id.clone())), - _ => None, - }) - }) - .map(Ok) - .forward(messaging_events_tx), - ); + let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + loop { + let event = messaging_events.recv().await; + use MessagingEvent::*; + match event.as_deref() { + Ok(MessageReceived(peer_node_id, _)) => { + messaging_events_tx + .send((Clone::clone(&*peer_node_id), node_id.clone())) + .unwrap(); + }, + _ => {}, + } + } + }); } #[inline] @@ -749,7 +759,7 @@ impl TestNode { } use ConnectionManagerEvent::*; loop { - let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.next()) + let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.recv()) .await .ok()??; @@ -763,7 +773,7 @@ impl TestNode { } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -946,5 +956,5 @@ async fn setup_comms_dht( pub async fn take_a_break(num_nodes: usize) { banner!("Taking a break for a few seconds to let things settle..."); - time::delay_for(Duration::from_millis(num_nodes as u64 * 100)).await; + time::sleep(Duration::from_millis(num_nodes as u64 * 100)).await; } diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 9cc28551bc..24bce254c4 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -49,15 +49,16 @@ use crate::memory_net::utilities::{ shutdown_all, take_a_break, }; -use futures::{channel::mpsc, future}; +use futures::future; use rand::{rngs::OsRng, Rng}; use std::{iter::repeat_with, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; // Size of network -const NUM_NODES: usize = 6; +const NUM_NODES: usize = 40; // Must be at least 2 -const NUM_WALLETS: usize = 50; +const NUM_WALLETS: usize = 5; const QUIET_MODE: bool = true; /// Number of neighbouring nodes each node should include in the connection pool const NUM_NEIGHBOURING_NODES: usize = 8; @@ -77,7 +78,7 @@ async fn main() { NUM_WALLETS ); - let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let seed_node = vec![ make_node( diff --git a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs index a0a1bddc1e..c4401b6e22 100644 --- a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs +++ b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs @@ -65,9 +65,9 @@ use crate::{ }, }; use clap::{App, Arg}; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; #[tokio_macros::main] #[allow(clippy::same_item_push)] @@ -96,7 +96,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_join.rs b/comms/dht/examples/memorynet_graph_network_track_join.rs index 259358e1cb..5326b0263e 100644 --- a/comms/dht/examples/memorynet_graph_network_track_join.rs +++ b/comms/dht/examples/memorynet_graph_network_track_join.rs @@ -73,9 +73,9 @@ use crate::{ }; use clap::{App, Arg}; use env_logger::Env; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; #[tokio_macros::main] #[allow(clippy::same_item_push)] @@ -106,7 +106,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_propagation.rs b/comms/dht/examples/memorynet_graph_network_track_propagation.rs index fcc5debff3..a68220ddf7 100644 --- a/comms/dht/examples/memorynet_graph_network_track_propagation.rs +++ b/comms/dht/examples/memorynet_graph_network_track_propagation.rs @@ -73,8 +73,8 @@ use crate::{ }, }; use env_logger::Env; -use futures::channel::mpsc; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; #[tokio_macros::main] #[allow(clippy::same_item_push)] @@ -105,7 +105,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index c2c2d4e52a..97c56f2134 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -37,13 +37,7 @@ use crate::{ DhtConfig, }; use chrono::{DateTime, Utc}; -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot}, - future::BoxFuture, - stream::{Fuse, FuturesUnordered}, - SinkExt, - StreamExt, -}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use log::*; use std::{cmp, fmt, fmt::Display, sync::Arc}; use tari_comms::{ @@ -54,7 +48,11 @@ use tari_comms::{ use tari_shutdown::ShutdownSignal; use tari_utilities::message_format::{MessageFormat, MessageFormatError}; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::actor"; @@ -62,8 +60,6 @@ const LOG_TARGET: &str = "comms::dht::actor"; pub enum DhtActorError { #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("Reply sender canceled the request")] ReplyCanceled, #[error("PeerManagerError: {0}")] @@ -84,15 +80,9 @@ pub enum DhtActorError { ConnectivityEventStreamClosed, } -impl From<SendError> for DhtActorError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtActorError::ChannelDisconnected - } else if err.is_full() { - DhtActorError::SendBufferFull - } else { - unreachable!(); - } +impl<T> From<mpsc::error::SendError<T>> for DhtActorError { + fn from(_: mpsc::error::SendError<T>) -> Self { + DhtActorError::ChannelDisconnected } } @@ -186,8 +176,8 @@ pub struct DhtActor { outbound_requester: OutboundMessageRequester, connectivity: ConnectivityRequester, config: DhtConfig, - shutdown_signal: Option<ShutdownSignal>, - request_rx: Fuse<mpsc::Receiver<DhtRequest>>, + shutdown_signal: ShutdownSignal, + request_rx: mpsc::Receiver<DhtRequest>, msg_hash_dedup_cache: DedupCacheDatabase, } @@ -217,8 +207,8 @@ impl DhtActor { peer_manager, connectivity, node_identity, - shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + shutdown_signal, + request_rx, } } @@ -247,33 +237,28 @@ impl DhtActor { let mut pending_jobs = FuturesUnordered::new(); - let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval).fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtActor initialized without shutdown_signal"); + let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { trace!(target: LOG_TARGET, "DhtActor received request: {}", request); pending_jobs.push(self.request_handler(request)); }, - result = pending_jobs.select_next_some() => { + Some(result) = pending_jobs.next() => { if let Err(err) = result { debug!(target: LOG_TARGET, "Error when handling DHT request message. {}", err); } }, - _ = dedup_cache_trim_ticker.select_next_some() => { + _ = dedup_cache_trim_ticker.tick() => { if let Err(err) = self.msg_hash_dedup_cache.truncate().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); self.mark_shutdown_time().await; break Ok(()); @@ -691,10 +676,13 @@ mod test { }; use chrono::{DateTime, Utc}; use std::time::Duration; - use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; + use tari_comms::{ + runtime, + test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}, + }; use tari_shutdown::Shutdown; use tari_test_utils::random; - use tokio::time::delay_for; + use tokio::time::sleep; async fn db_connection() -> DbConnection { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); @@ -702,7 +690,7 @@ mod test { conn } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_join_request() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -727,11 +715,11 @@ mod test { actor.spawn(); requester.send_join().await.unwrap(); - let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap()); + let (params, _) = unwrap_oms_send_msg!(out_rx.recv().await.unwrap()); assert_eq!(params.dht_message_type, DhtMessageType::Join); } - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_message_signature() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -773,7 +761,7 @@ mod test { assert!(!is_dup); } - #[tokio_macros::test_basic] + #[runtime::test] async fn dedup_cache_cleanup() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -848,7 +836,7 @@ mod test { } // Let the trim period expire; this will trim the dedup cache to capacity - delay_for(Duration::from_millis(trim_interval_ms * 2)).await; + sleep(Duration::from_millis(trim_interval_ms * 2)).await; // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { @@ -859,10 +847,10 @@ mod test { assert!(!is_dup); } - shutdown.trigger().unwrap(); + shutdown.trigger(); } - #[tokio_macros::test_basic] + #[runtime::test] async fn select_peers() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -973,7 +961,7 @@ mod test { assert_eq!(peers.len(), 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn get_and_set_metadata() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -1029,6 +1017,6 @@ mod test { .unwrap(); assert_eq!(got_ts, ts); - shutdown.trigger().unwrap(); + shutdown.trigger(); } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index 249ed3d369..a88fff0a3c 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; -use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeIdentity, PeerManager}, }; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; pub struct DhtBuilder { node_identity: Arc<NodeIdentity>, diff --git a/comms/dht/src/connectivity/metrics.rs b/comms/dht/src/connectivity/metrics.rs index b7ee546d4e..ba456d9aa9 100644 --- a/comms/dht/src/connectivity/metrics.rs +++ b/comms/dht/src/connectivity/metrics.rs @@ -20,20 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot, oneshot::Canceled}, - future, - SinkExt, - StreamExt, -}; use log::*; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, - future::Future, time::{Duration, Instant}, }; use tari_comms::peer_manager::NodeId; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot}, + task, +}; const LOG_TARGET: &str = "comms::dht::metrics"; @@ -124,7 +120,7 @@ impl MetricsState { } pub struct MetricsCollector { - stream: Option<mpsc::Receiver<MetricOp>>, + stream: mpsc::Receiver<MetricOp>, state: MetricsState, } @@ -133,18 +129,17 @@ impl MetricsCollector { let (metrics_tx, metrics_rx) = mpsc::channel(500); let metrics_collector = MetricsCollectorHandle::new(metrics_tx); let collector = Self { - stream: Some(metrics_rx), + stream: metrics_rx, state: Default::default(), }; task::spawn(collector.run()); metrics_collector } - fn run(mut self) -> impl Future<Output = ()> { - self.stream.take().unwrap().for_each(move |op| { + async fn run(mut self) { + while let Some(op) = self.stream.recv().await { self.handle(op); - future::ready(()) - }) + } } fn handle(&mut self, op: MetricOp) { @@ -286,7 +281,7 @@ impl MetricsCollectorHandle { match self.inner.try_send(MetricOp::Write(write)) { Ok(_) => true, Err(err) => { - warn!(target: LOG_TARGET, "Failed to write metric: {}", err.into_send_error()); + warn!(target: LOG_TARGET, "Failed to write metric: {:?}", err); false }, } @@ -338,14 +333,14 @@ pub enum MetricsError { ReplyCancelled, } -impl From<mpsc::SendError> for MetricsError { - fn from(_: SendError) -> Self { +impl<T> From<mpsc::error::SendError<T>> for MetricsError { + fn from(_: mpsc::error::SendError<T>) -> Self { MetricsError::ChannelClosedUnexpectedly } } -impl From<oneshot::Canceled> for MetricsError { - fn from(_: Canceled) -> Self { +impl From<oneshot::error::RecvError> for MetricsError { + fn from(_: oneshot::error::RecvError) -> Self { MetricsError::ReplyCancelled } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index c9c855b6c1..a3c4471451 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -27,7 +27,6 @@ mod metrics; pub use metrics::{MetricsCollector, MetricsCollectorHandle}; use crate::{connectivity::metrics::MetricsError, event::DhtEvent, DhtActorError, DhtConfig, DhtRequester}; -use futures::{stream::Fuse, StreamExt}; use log::*; use std::{sync::Arc, time::Instant}; use tari_comms::{ @@ -78,11 +77,11 @@ pub struct DhtConnectivity { /// Used to track when the random peer pool was last refreshed random_pool_last_refresh: Option<Instant>, stats: Stats, - dht_events: Fuse<broadcast::Receiver<Arc<DhtEvent>>>, + dht_events: broadcast::Receiver<Arc<DhtEvent>>, metrics_collector: MetricsCollectorHandle, - shutdown_signal: Option<ShutdownSignal>, + shutdown_signal: ShutdownSignal, } impl DhtConnectivity { @@ -108,8 +107,8 @@ impl DhtConnectivity { metrics_collector, random_pool_last_refresh: None, stats: Stats::new(), - dht_events: dht_events.fuse(), - shutdown_signal: Some(shutdown_signal), + dht_events, + shutdown_signal, } } @@ -131,21 +130,15 @@ impl DhtConnectivity { }) } - pub async fn run(mut self, connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { - let mut connectivity_events = connectivity_events.fuse(); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtConnectivity initialized without a shutdown_signal"); - + pub async fn run(mut self, mut connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { debug!(target: LOG_TARGET, "DHT connectivity starting"); self.refresh_neighbour_pool().await?; - let mut ticker = time::interval(self.config.connectivity_update_interval).fuse(); + let mut ticker = time::interval(self.config.connectivity_update_interval); loop { - futures::select! { - event = connectivity_events.select_next_some() => { + tokio::select! { + event = connectivity_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { debug!(target: LOG_TARGET, "Error handling connectivity event: {:?}", err); @@ -153,15 +146,13 @@ impl DhtConnectivity { } }, - event = self.dht_events.select_next_some() => { - if let Ok(event) = event { - if let Err(err) = self.handle_dht_event(&event).await { - debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); - } - } + Ok(event) = self.dht_events.recv() => { + if let Err(err) = self.handle_dht_event(&event).await { + debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); + } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_random_pool_if_required().await { debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); } @@ -170,7 +161,7 @@ impl DhtConnectivity { } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtConnectivity shutting down because the shutdown signal was received"); break; } diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 65dba8c235..d0e83c0aa5 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -30,6 +30,7 @@ use std::{iter::repeat_with, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{Peer, PeerFeatures}, + runtime, test_utils::{ count_string_occurrences, mocks::{create_connectivity_mock, create_dummy_peer_connection, ConnectivityManagerMockState}, @@ -89,7 +90,7 @@ async fn setup( ) } -#[tokio_macros::test_basic] +#[runtime::test] async fn initialize() { let config = DhtConfig { num_neighbouring_nodes: 4, @@ -127,7 +128,7 @@ async fn initialize() { assert!(managed.iter().all(|n| !neighbours.contains(n))); } -#[tokio_macros::test_basic] +#[runtime::test] async fn added_neighbours() { let node_identity = make_node_identity(); let mut node_identities = @@ -173,7 +174,7 @@ async fn added_neighbours() { assert!(managed.contains(closer_peer.node_id())); } -#[tokio_macros::test_basic] +#[runtime::test] #[allow(clippy::redundant_closure)] async fn reinitialize_pools_when_offline() { let node_identity = make_node_identity(); @@ -215,7 +216,7 @@ async fn reinitialize_pools_when_offline() { assert_eq!(managed.len(), 5); } -#[tokio_macros::test_basic] +#[runtime::test] async fn insert_neighbour() { let node_identity = make_node_identity(); let node_identities = @@ -254,11 +255,13 @@ async fn insert_neighbour() { } mod metrics { + use super::*; mod collector { + use super::*; use crate::connectivity::MetricsCollector; use tari_comms::peer_manager::NodeId; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_adds_message_received() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); @@ -273,7 +276,7 @@ mod metrics { assert_eq!(ts.count(), 100); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_clears_the_metrics() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 5428277af0..2a8bcc8aaa 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -138,7 +138,7 @@ mod test { #[test] fn process_message() { - let mut rt = Runtime::new().unwrap(); + let rt = Runtime::new().unwrap(); let spy = service_spy(); let (dht_requester, mock) = create_dht_actor_mock(1); diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index dcdeea5730..165e3d2f66 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -42,7 +42,7 @@ use crate::{ DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, future, Future}; +use futures::{future, Future}; use log::*; use std::sync::Arc; use tari_comms::{ @@ -53,7 +53,7 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::{layer::Layer, Service, ServiceBuilder}; const LOG_TARGET: &str = "comms::dht"; @@ -404,22 +404,23 @@ mod test { make_comms_inbound_message, make_dht_envelope, make_node_identity, + service_spy, }, DhtBuilder, }; - use futures::{channel::mpsc, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_comms::{ message::{MessageExt, MessageTag}, pipeline::SinkService, + runtime, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body, }; use tari_shutdown::Shutdown; - use tokio::{task, time}; + use tokio::{sync::mpsc, task, time}; use tower::{layer::Layer, Service}; - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_unencrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -459,7 +460,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -469,7 +470,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_encrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -509,7 +510,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -519,7 +520,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_forward() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -528,7 +529,6 @@ mod test { peer_manager.add_peer(node_identity.to_peer()).await.unwrap(); let (connectivity, _) = create_connectivity_mock(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); let (oms_requester, oms_mock) = create_outbound_service_mock(1); // Send all outbound requests to the mock @@ -545,7 +545,8 @@ mod test { let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); @@ -574,10 +575,10 @@ mod test { assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called - assert!(next_service_rx.try_next().is_err()); + assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_filter_saf_message() { let node_identity = make_client_identity(); let peer_manager = build_peer_manager(); @@ -600,9 +601,8 @@ mod test { .await .unwrap(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); - - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"secret".to_vec()); let mut dht_envelope = make_dht_envelope( @@ -619,10 +619,6 @@ mod test { let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); - // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely - assert_eq!( - format!("{}", next_service_rx.try_next().unwrap_err()), - "receiver channel is empty" - ); + assert_eq!(spy.call_count(), 0); } } diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index cf98e42d9d..c2a77f0c9a 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::outbound::{message::SendFailure, DhtOutboundError}; -use futures::channel::mpsc::SendError; use tari_comms::peer_manager::PeerManagerError; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtDiscoveryError { @@ -37,8 +37,6 @@ pub enum DhtDiscoveryError { InvalidNodeId, #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("The discovery request timed out")] DiscoveryTimeout, #[error("Failed to send discovery message: {0}")] @@ -56,14 +54,8 @@ impl DhtDiscoveryError { } } -impl From<SendError> for DhtDiscoveryError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtDiscoveryError::ChannelDisconnected - } else if err.is_full() { - DhtDiscoveryError::SendBufferFull - } else { - unreachable!(); - } +impl<T> From<SendError<T>> for DhtDiscoveryError { + fn from(_: SendError<T>) -> Self { + DhtDiscoveryError::ChannelDisconnected } } diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index a286bcc6a5..a7317f79a2 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -21,16 +21,15 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; #[derive(Debug)] pub enum DhtDiscoveryRequest { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 258eb67cea..2cebf571b4 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -27,11 +27,6 @@ use crate::{ proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, }; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -47,7 +42,11 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use tari_utilities::{hex::Hex, ByteArray}; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -72,8 +71,8 @@ pub struct DhtDiscoveryService { node_identity: Arc<NodeIdentity>, outbound_requester: OutboundMessageRequester, peer_manager: Arc<PeerManager>, - request_rx: Option<mpsc::Receiver<DhtDiscoveryRequest>>, - shutdown_signal: Option<ShutdownSignal>, + request_rx: mpsc::Receiver<DhtDiscoveryRequest>, + shutdown_signal: ShutdownSignal, inflight_discoveries: HashMap<u64, DiscoveryRequestState>, } @@ -91,8 +90,8 @@ impl DhtDiscoveryService { outbound_requester, node_identity, peer_manager, - shutdown_signal: Some(shutdown_signal), - request_rx: Some(request_rx), + shutdown_signal, + request_rx, inflight_discoveries: HashMap::new(), } } @@ -106,29 +105,19 @@ impl DhtDiscoveryService { pub async fn run(mut self) { info!(target: LOG_TARGET, "Dht discovery service started"); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DiscoveryService initialized without shutdown_signal") - .fuse(); - - let mut request_rx = self - .request_rx - .take() - .expect("DiscoveryService initialized without request_rx") - .fuse(); - loop { - futures::select! { - request = request_rx.select_next_some() => { - trace!(target: LOG_TARGET, "Received request '{}'", request); - self.handle_request(request).await; - }, + tokio::select! { + biased; - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Discovery service is shutting down because the shutdown signal was received"); break; } + + Some(request) = self.request_rx.recv() => { + trace!(target: LOG_TARGET, "Received request '{}'", request); + self.handle_request(request).await; + }, } } } @@ -153,7 +142,7 @@ impl DhtDiscoveryService { let mut remaining_requests = HashMap::new(); for (nonce, request) in self.inflight_discoveries.drain() { // Exclude canceled requests - if request.reply_tx.is_canceled() { + if request.reply_tx.is_closed() { continue; } @@ -199,7 +188,7 @@ impl DhtDiscoveryService { ); for request in self.collect_all_discovery_requests(&public_key) { - if !reply_tx.is_canceled() { + if !reply_tx.is_closed() { let _ = request.reply_tx.send(Ok(peer.clone())); } } @@ -299,7 +288,7 @@ impl DhtDiscoveryService { self.inflight_discoveries = self .inflight_discoveries .drain() - .filter(|(_, state)| !state.reply_tx.is_canceled()) + .filter(|(_, state)| !state.reply_tx.is_closed()) .collect(); trace!( @@ -393,9 +382,10 @@ mod test { test_utils::{build_peer_manager, make_node_identity}, }; use std::time::Duration; + use tari_comms::runtime; use tari_shutdown::Shutdown; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_discovery() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 0b93546dbb..b22d86b92c 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -22,6 +22,8 @@ use bitflags::bitflags; use bytes::Bytes; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost_types::Timestamp; use serde::{Deserialize, Serialize}; use std::{ cmp, @@ -30,14 +32,11 @@ use std::{ fmt::Display, }; use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey, NodeIdentity}; -use tari_utilities::{ByteArray, ByteArrayError}; +use tari_utilities::{epoch_time::EpochTime, ByteArray, ByteArrayError}; use thiserror::Error; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType}; -use chrono::{DateTime, NaiveDateTime, Utc}; -use prost_types::Timestamp; -use tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `chrono::DateTime<Utc>` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime<Utc>) -> Timestamp { diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 9e3a6bbd3e..46fc3daa47 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -397,7 +397,12 @@ mod test { }; use futures::{executor::block_on, future}; use std::sync::Mutex; - use tari_comms::{message::MessageExt, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body}; + use tari_comms::{ + message::MessageExt, + runtime, + test_utils::mocks::create_connectivity_mock, + wrap_in_envelope_body, + }; use tari_test_utils::{counter_context, unpack_enum}; use tower::service_fn; @@ -469,7 +474,7 @@ mod test { assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index b28a057cb5..a73c3b3cfb 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -137,9 +137,12 @@ mod test { service_spy, }, }; - use tari_comms::message::{MessageExt, MessageTag}; + use tari_comms::{ + message::{MessageExt, MessageTag}, + runtime, + }; - #[tokio_macros::test_basic] + #[runtime::test] async fn deserialize() { let spy = service_spy(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index cab2f8ab6f..e3b730022e 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -71,7 +71,7 @@ //! #use std::sync::Arc; //! #use tari_comms::CommsBuilder; //! #use tokio::runtime::Runtime; -//! #use futures::channel::mpsc; +//! #use tokio::sync::mpsc; //! //! let runtime = Runtime::new().unwrap(); //! // Channel from comms to inbound dht diff --git a/comms/dht/src/network_discovery/on_connect.rs b/comms/dht/src/network_discovery/on_connect.rs index b93657f061..cd162b3903 100644 --- a/comms/dht/src/network_discovery/on_connect.rs +++ b/comms/dht/src/network_discovery/on_connect.rs @@ -33,7 +33,7 @@ use crate::{ }; use futures::StreamExt; use log::*; -use std::{convert::TryInto, ops::Deref}; +use std::convert::TryInto; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{NodeId, Peer}, @@ -62,8 +62,9 @@ impl OnConnect { pub async fn next_event(&mut self) -> StateEvent { let mut connectivity_events = self.context.connectivity.get_event_subscription(); - while let Some(event) = connectivity_events.next().await { - match event.as_ref().map(|e| e.deref()) { + loop { + let event = connectivity_events.recv().await; + match event { Ok(ConnectivityEvent::PeerConnected(conn)) => { if conn.peer_features().is_client() { continue; @@ -96,10 +97,10 @@ impl OnConnect { self.prev_synced.push(conn.peer_node_id().clone()); }, Ok(_) => { /* Nothing to do */ }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagged behind on {} connectivity event(s)", n) }, - Err(broadcast::RecvError::Closed) => { + Err(broadcast::error::RecvError::Closed) => { break; }, } diff --git a/comms/dht/src/network_discovery/test.rs b/comms/dht/src/network_discovery/test.rs index 54f596ee26..2f854627f1 100644 --- a/comms/dht/src/network_discovery/test.rs +++ b/comms/dht/src/network_discovery/test.rs @@ -28,12 +28,12 @@ use crate::{ test_utils::{build_peer_manager, make_node_identity}, DhtConfig, }; -use futures::StreamExt; use std::{iter, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityStatus, peer_manager::{Peer, PeerFeatures}, protocol::rpc::{mock::MockRpcServer, NamedProtocolService}, + runtime, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -97,7 +97,7 @@ mod state_machine { ) } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn it_fetches_peers() { const NUM_PEERS: usize = 3; @@ -139,7 +139,7 @@ mod state_machine { mock.get_peers.set_response(Ok(peers)).await; discovery_actor.spawn(); - let event = event_rx.next().await.unwrap().unwrap(); + let event = event_rx.recv().await.unwrap(); unpack_enum!(DhtEvent::NetworkDiscoveryPeersAdded(info) = &*event); assert!(info.has_new_neighbours()); assert_eq!(info.num_new_neighbours, NUM_PEERS); @@ -149,11 +149,11 @@ mod state_machine { assert_eq!(info.sync_peers, vec![peer_node_identity.node_id().clone()]); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_shuts_down() { let (discovery, _, _, _, _, mut shutdown) = setup(Default::default(), make_node_identity(), vec![]).await; - shutdown.trigger().unwrap(); + shutdown.trigger(); tokio::time::timeout(Duration::from_secs(5), discovery.run()) .await .unwrap(); @@ -200,7 +200,7 @@ mod discovery_ready { (node_identity, peer_manager, connectivity_mock, ready, context) } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_begins_aggressive_discovery() { let (_, pm, _, mut ready, _) = setup(Default::default()); let peers = build_many_node_identities(1, PeerFeatures::COMMUNICATION_NODE); @@ -212,14 +212,14 @@ mod discovery_ready { assert!(params.num_peers_to_request.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_no_sync_peers() { let (_, _, _, mut ready, _) = setup(Default::default()); let state_event = ready.next_event().await; unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_num_rounds_reached() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, @@ -240,7 +240,7 @@ mod discovery_ready { unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_transitions_to_on_connect() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, diff --git a/comms/dht/src/network_discovery/waiting.rs b/comms/dht/src/network_discovery/waiting.rs index 73e8929ac5..f61dfc6b24 100644 --- a/comms/dht/src/network_discovery/waiting.rs +++ b/comms/dht/src/network_discovery/waiting.rs @@ -46,7 +46,7 @@ impl Waiting { target: LOG_TARGET, "Network discovery is IDLING for {:.0?}", self.duration ); - time::delay_for(self.duration).await; + time::sleep(self.duration).await; debug!(target: LOG_TARGET, "Network discovery resuming"); StateEvent::Ready } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 0aa9fab611..5f123d0867 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -39,7 +39,6 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use digest::Digest; use futures::{ - channel::oneshot, future, future::BoxFuture, stream::{self, StreamExt}, @@ -60,6 +59,7 @@ use tari_crypto::{ tari_utilities::{message_format::MessageFormat, ByteArray}, }; use tari_utilities::hex::Hex; +use tokio::sync::oneshot; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::outbound::broadcast_middleware"; @@ -255,7 +255,7 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError> match self.select_peers(broadcast_strategy.clone()).await { Ok(mut peers) => { - if reply_tx.is_canceled() { + if reply_tx.is_closed() { return Err(DhtOutboundError::ReplyChannelCanceled); } @@ -525,19 +525,19 @@ mod test { DhtDiscoveryMockState, }, }; - use futures::channel::oneshot; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, + runtime, types::CommsPublicKey, }; use tari_crypto::keys::PublicKey; use tari_test_utils::unpack_enum; - use tokio::task; + use tokio::{sync::oneshot, task}; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_flood() { let pk = CommsPublicKey::default(); let example_peer = Peer::new( @@ -601,7 +601,7 @@ mod test { assert!(requests.iter().any(|msg| msg.destination_node_id == other_peer.node_id)); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_not_found() { // Test for issue https://github.com/tari-project/tari/issues/959 @@ -645,7 +645,7 @@ mod test { assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_dht_discovery() { let node_identity = NodeIdentity::random( &mut OsRng, diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 3f93dab043..090aea4965 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -20,16 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::outbound::message::SendFailure; -use futures::channel::mpsc::SendError; +use crate::outbound::{message::SendFailure, DhtOutboundRequest}; use tari_comms::message::MessageError; use tari_crypto::{signatures::SchnorrSignatureError, tari_utilities::message_format::MessageFormatError}; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtOutboundError { - #[error("SendError: {0}")] - SendError(#[from] SendError), + #[error("`Failed to send: {0}")] + SendError(#[from] SendError<DhtOutboundRequest>), #[error("MessageSerializationError: {0}")] MessageSerializationError(#[from] MessageError), #[error("MessageFormatError: {0}")] diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 52356ec364..bb782dc2e5 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -25,7 +25,6 @@ use crate::{ outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, }; use bytes::Bytes; -use futures::channel::oneshot; use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ message::{MessageTag, MessagingReplyTx}, @@ -34,6 +33,7 @@ use tari_comms::{ }; use tari_utilities::hex::Hex; use thiserror::Error; +use tokio::sync::oneshot; /// Determines if an outbound message should be Encrypted and, if so, for which public key #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/comms/dht/src/outbound/message_send_state.rs b/comms/dht/src/outbound/message_send_state.rs index 1576e87c70..b3ca43fcb2 100644 --- a/comms/dht/src/outbound/message_send_state.rs +++ b/comms/dht/src/outbound/message_send_state.rs @@ -250,9 +250,9 @@ impl Index<usize> for MessageSendStates { #[cfg(test)] mod test { use super::*; - use futures::channel::oneshot; use std::iter::repeat_with; - use tari_comms::message::MessagingReplyTx; + use tari_comms::{message::MessagingReplyTx, runtime}; + use tokio::sync::oneshot; fn create_send_state() -> (MessageSendState, MessagingReplyTx) { let (reply_tx, reply_rx) = oneshot::channel(); @@ -269,7 +269,7 @@ mod test { assert!(!states.is_empty()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn wait_single() { let (state, mut reply_tx) = create_send_state(); let states = MessageSendStates::from(vec![state]); @@ -284,7 +284,7 @@ mod test { assert!(!states.wait_single().await); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_percentage_success() { let states = repeat_with(|| create_send_state()).take(10).collect::<Vec<_>>(); @@ -300,7 +300,7 @@ mod test { assert_eq!(failed.len(), 4); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_n_timeout() { let states = repeat_with(|| create_send_state()).take(10).collect::<Vec<_>>(); @@ -329,7 +329,7 @@ mod test { assert_eq!(failed.len(), 6); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_all() { let states = repeat_with(|| create_send_state()).take(10).collect::<Vec<_>>(); diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index f5c3f30665..66e2b7258e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -31,11 +31,6 @@ use crate::{ }, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - StreamExt, -}; use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -45,7 +40,10 @@ use tari_comms::{ message::{MessageTag, MessagingReplyTx}, protocol::messaging::SendFailReason, }; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "mock::outbound_requester"; @@ -54,7 +52,7 @@ const LOG_TARGET: &str = "mock::outbound_requester"; /// Each time a request is expected, handle_next should be called. pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, OutboundServiceMock) { let (tx, rx) = mpsc::channel(size); - (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx.fuse())) + (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx)) } #[derive(Clone, Default)] @@ -149,12 +147,12 @@ impl OutboundServiceMockState { } pub struct OutboundServiceMock { - receiver: Fuse<mpsc::Receiver<DhtOutboundRequest>>, + receiver: mpsc::Receiver<DhtOutboundRequest>, mock_state: OutboundServiceMockState, } impl OutboundServiceMock { - pub fn new(receiver: Fuse<mpsc::Receiver<DhtOutboundRequest>>) -> Self { + pub fn new(receiver: mpsc::Receiver<DhtOutboundRequest>) -> Self { Self { receiver, mock_state: OutboundServiceMockState::new(), @@ -166,7 +164,7 @@ impl OutboundServiceMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { let behaviour = self.mock_state.get_behaviour(); @@ -192,7 +190,7 @@ impl OutboundServiceMock { ResponseType::QueuedSuccessDelay(delay) => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); - delay_for(delay).await; + sleep(delay).await; inner_reply_tx.reply_success(); }, resp => { diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index c1957a98b0..f8536c9d9c 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -32,12 +32,9 @@ use crate::{ MessageSendStates, }, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use log::*; use tari_comms::{message::MessageExt, peer_manager::NodeId, types::CommsPublicKey, wrap_in_envelope_body}; +use tokio::sync::{mpsc, oneshot}; const LOG_TARGET: &str = "comms::dht::requests::outbound"; diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 97c7c0df58..195d2d3d39 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -137,9 +137,9 @@ mod test { use super::*; use crate::test_utils::{assert_send_static_service, create_outbound_message, service_spy}; use prost::Message; - use tari_comms::peer_manager::NodeId; + use tari_comms::{peer_manager::NodeId, runtime}; - #[tokio_macros::test_basic] + #[runtime::test] async fn serialize() { let spy = service_spy(); let mut serialize = SerializeLayer.layer(spy.to_service::<PipelineError>()); diff --git a/comms/dht/src/rpc/service.rs b/comms/dht/src/rpc/service.rs index e762ed4d79..84aef6e5ff 100644 --- a/comms/dht/src/rpc/service.rs +++ b/comms/dht/src/rpc/service.rs @@ -24,16 +24,16 @@ use crate::{ proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}, rpc::DhtRpcService, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc}; use tari_comms::{ peer_manager::{NodeId, Peer, PeerFeatures, PeerQuery}, protocol::rpc::{Request, RpcError, RpcStatus, Streaming}, + utils, PeerManager, }; use tari_utilities::ByteArray; -use tokio::task; +use tokio::{sync::mpsc, task}; const LOG_TARGET: &str = "comms::dht::rpc"; @@ -56,17 +56,15 @@ impl DhtRpcServiceImpl { // A maximum buffer size of 10 is selected arbitrarily and is to allow the producer/consumer some room to // buffer. - let (mut tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); + let (tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); task::spawn(async move { let iter = peers .into_iter() .map(|peer| GetPeersResponse { peer: Some(peer.into()), }) - .map(Ok) .map(Ok); - let mut stream = stream::iter(iter); - let _ = tx.send_all(&mut stream).await; + let _ = utils::mpsc::send_all(&tx, iter).await; }); Streaming::new(rx) diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index 764d49ba7f..cd70d65a0f 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -26,13 +26,16 @@ use crate::{ test_utils::build_peer_manager, }; use futures::StreamExt; -use std::{convert::TryInto, sync::Arc}; +use std::{convert::TryInto, sync::Arc, time::Duration}; use tari_comms::{ - peer_manager::{node_id::NodeDistance, PeerFeatures}, + peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, protocol::rpc::{mock::RpcRequestMock, RpcStatusCode}, + runtime, test_utils::node_identity::{build_node_identity, ordered_node_identities_by_distance}, PeerManager, }; +use tari_test_utils::collect_recv; +use tari_utilities::ByteArray; fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc<PeerManager>) { let peer_manager = build_peer_manager(); @@ -45,10 +48,8 @@ fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc<PeerManager>) { // Unit tests for get_closer_peers request mod get_closer_peers { use super::*; - use tari_comms::peer_manager::{NodeId, Peer}; - use tari_utilities::ByteArray; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -66,7 +67,7 @@ mod get_closer_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_closest_peers() { let (service, mock, peer_manager) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -83,7 +84,7 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::<Vec<_>>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 10); let peers = results @@ -101,7 +102,7 @@ mod get_closer_peers { } } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); @@ -123,7 +124,7 @@ mod get_closer_peers { assert_eq!(results.len(), 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_skips_excluded_peers() { let (service, mock, peer_manager) = setup(); @@ -142,12 +143,12 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::<Vec<_>>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); let mut peers = results.into_iter().map(Result::unwrap).map(|r| r.peer.unwrap()); assert!(peers.all(|p| p.public_key != excluded_peer.public_key().as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_errors_if_maximum_n_exceeded() { let (service, mock, _) = setup(); let req = GetCloserPeersRequest { @@ -165,9 +166,10 @@ mod get_closer_peers { mod get_peers { use super::*; use crate::proto::rpc::GetPeersRequest; + use std::time::Duration; use tari_comms::{peer_manager::Peer, test_utils::node_identity::build_many_node_identities}; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -183,7 +185,7 @@ mod get_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_all_peers() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -200,7 +202,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::<Vec<_>>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 5); let peers = results @@ -214,7 +216,7 @@ mod get_peers { assert_eq!(peers.iter().filter(|p| p.features.is_node()).count(), 3); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_excludes_clients() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -231,7 +233,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::<Vec<_>>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 3); let peers = results @@ -244,7 +246,7 @@ mod get_peers { assert!(peers.iter().all(|p| p.features.is_node())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs index 856a94315a..ee99f8b560 100644 --- a/comms/dht/src/storage/connection.rs +++ b/comms/dht/src/storage/connection.rs @@ -123,16 +123,17 @@ impl DbConnection { mod test { use super::*; use diesel::{expression::sql_literal::sql, sql_types::Integer, RunQueryDsl}; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn connect_and_migrate() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); let output = conn.migrate().await.unwrap(); assert!(output.starts_with("Running migration")); } - #[tokio_macros::test_basic] + #[runtime::test] async fn memory_connections() { let id = random::string(8); let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 173d00e0ef..58ee06eb9c 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -255,9 +255,10 @@ impl StoreAndForwardDatabase { #[cfg(test)] mod test { use super::*; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -277,7 +278,7 @@ mod test { assert_eq!(messages[1].body_hash, msg2.body_hash); } - #[tokio_macros::test_basic] + #[runtime::test] async fn remove_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -304,7 +305,7 @@ mod test { assert_eq!(messages[0].id, msg2_id); } - #[tokio_macros::test_basic] + #[runtime::test] async fn truncate_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 95ce5e2500..c856d11af3 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -262,14 +262,13 @@ mod test { outbound::mock::create_outbound_service_mock, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; - use futures::{channel::mpsc, executor::block_on}; - use tari_comms::wrap_in_envelope_body; - use tokio::runtime::Runtime; + use tari_comms::{runtime, runtime::task, wrap_in_envelope_body}; + use tokio::sync::mpsc; - #[test] - fn decryption_succeeded() { + #[runtime::test] + async fn decryption_succeeded() { let spy = service_spy(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (oms_tx, _) = mpsc::channel(1); let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::<PipelineError>()); @@ -280,18 +279,16 @@ mod test { Some(node_identity.public_key().clone()), inbound_msg, ); - block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); - assert!(oms_rx.try_next().is_err()); } - #[test] - fn decryption_failed() { - let mut rt = Runtime::new().unwrap(); + #[runtime::test] + async fn decryption_failed() { let spy = service_spy(); let (oms_requester, oms_mock) = create_outbound_service_mock(1); let oms_mock_state = oms_mock.get_state(); - rt.spawn(oms_mock.run()); + task::spawn(oms_mock.run()); let mut service = ForwardLayer::new(oms_requester, true).layer(spy.to_service::<PipelineError>()); @@ -304,7 +301,7 @@ mod test { ); let header = inbound_msg.dht_header.clone(); let msg = DecryptedDhtMessage::failed(inbound_msg); - rt.block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); assert_eq!(oms_mock_state.call_count(), 1); diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 50b6ab7839..16e2760a1e 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -27,9 +27,9 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::channel::mpsc; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; +use tokio::sync::mpsc; use tower::layer::Layer; pub struct MessageHandlerLayer { diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 578fc1dcbc..641950e4f1 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -28,12 +28,13 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::{channel::mpsc, future::BoxFuture, task::Context}; +use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, pipeline::PipelineError, }; +use tokio::sync::mpsc; use tower::Service; #[derive(Clone)] diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index f3ba852118..efda074ddb 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -41,7 +41,7 @@ use crate::{ }; use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; -use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; +use futures::{future, stream, StreamExt}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; @@ -53,6 +53,7 @@ use tari_comms::{ utils::signature, }; use tari_utilities::{convert::try_convert_all, ByteArray}; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::storeforward::handler"; @@ -567,14 +568,13 @@ mod test { }, }; use chrono::{Duration as OldDuration, Utc}; - use futures::channel::mpsc; use prost::Message; use std::time::Duration; - use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex; - use tari_test_utils::collect_stream; + use tari_test_utils::collect_recv; use tari_utilities::hex::Hex; - use tokio::{runtime::Handle, task, time::delay_for}; + use tokio::{runtime::Handle, sync::mpsc, task, time::sleep}; // TODO: unit tests for static functions (check_signature, etc) @@ -602,7 +602,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn request_stored_messages() { let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); @@ -662,7 +662,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); @@ -724,7 +724,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); let call = oms_mock_state.pop_call().unwrap(); @@ -750,7 +750,7 @@ mod test { assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); @@ -845,7 +845,7 @@ mod test { assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - let signals = collect_stream!( + let signals = collect_recv!( saf_response_signal_receiver, take = 1, timeout = Duration::from_secs(20) diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 5d06d85d56..7c8af0b731 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -36,12 +36,6 @@ use crate::{ DhtRequester, }; use chrono::{DateTime, NaiveDateTime, Utc}; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ @@ -51,7 +45,11 @@ use tari_comms::{ PeerManager, }; use tari_shutdown::ShutdownSignal; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::storeforward::actor"; /// The interval to initiate a database cleanup. @@ -167,13 +165,13 @@ pub struct StoreAndForwardService { dht_requester: DhtRequester, database: StoreAndForwardDatabase, peer_manager: Arc<PeerManager>, - connection_events: Fuse<ConnectivityEventRx>, + connection_events: ConnectivityEventRx, outbound_requester: OutboundMessageRequester, - request_rx: Fuse<mpsc::Receiver<StoreAndForwardRequest>>, - shutdown_signal: Option<ShutdownSignal>, + request_rx: mpsc::Receiver<StoreAndForwardRequest>, + shutdown_signal: ShutdownSignal, num_received_saf_responses: Option<usize>, num_online_peers: Option<usize>, - saf_response_signal_rx: Fuse<mpsc::Receiver<()>>, + saf_response_signal_rx: mpsc::Receiver<()>, event_publisher: DhtEventSender, } @@ -196,13 +194,13 @@ impl StoreAndForwardService { database: StoreAndForwardDatabase::new(conn), peer_manager, dht_requester, - request_rx: request_rx.fuse(), - connection_events: connectivity.get_event_subscription().fuse(), + request_rx, + connection_events: connectivity.get_event_subscription(), outbound_requester, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, num_received_saf_responses: Some(0), num_online_peers: None, - saf_response_signal_rx: saf_response_signal_rx.fuse(), + saf_response_signal_rx, event_publisher, } } @@ -213,20 +211,15 @@ impl StoreAndForwardService { } async fn run(mut self) { - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("StoreAndForwardActor initialized without shutdown_signal"); - - let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - event = self.connection_events.select_next_some() => { + event = self.connection_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -234,20 +227,20 @@ impl StoreAndForwardService { } }, - _ = cleanup_ticker.select_next_some() => { + _ = cleanup_ticker.tick() => { if let Err(err) = self.cleanup().await { error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); } }, - _ = self.saf_response_signal_rx.select_next_some() => { + Some(_) = self.saf_response_signal_rx.recv() => { if let Some(n) = self.num_received_saf_responses { self.num_received_saf_responses = Some(n + 1); self.check_saf_response_threshold(); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); break; } diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 4393f36518..5366912e54 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -447,11 +447,11 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{runtime, wrap_in_envelope_body}; use tari_test_utils::async_assert_eventually; use tari_utilities::hex::Hex; - #[tokio_macros::test_basic] + #[runtime::test] async fn cleartext_message_no_origin() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -471,7 +471,7 @@ mod test { assert_eq!(messages.len(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_succeeded_no_store() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -499,7 +499,7 @@ mod test { assert_eq!(mock_state.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_should_store() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); @@ -538,7 +538,7 @@ mod test { assert!(duration.num_seconds() <= 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_banned_peer() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index ccc53c5a1e..3ac1306a85 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -25,7 +25,6 @@ use crate::{ actor::{DhtRequest, DhtRequester}, storage::DhtMetadataKey, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -35,11 +34,11 @@ use std::{ }, }; use tari_comms::peer_manager::Peer; -use tokio::task; +use tokio::{sync::mpsc, task}; pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { let (tx, rx) = mpsc::channel(buf_size); - (DhtRequester::new(tx), DhtActorMock::new(rx.fuse())) + (DhtRequester::new(tx), DhtActorMock::new(rx)) } #[derive(Default, Debug, Clone)] @@ -80,12 +79,12 @@ impl DhtMockState { } pub struct DhtActorMock { - receiver: Fuse<mpsc::Receiver<DhtRequest>>, + receiver: mpsc::Receiver<DhtRequest>, state: DhtMockState, } impl DhtActorMock { - pub fn new(receiver: Fuse<mpsc::Receiver<DhtRequest>>) -> Self { + pub fn new(receiver: mpsc::Receiver<DhtRequest>) -> Self { Self { receiver, state: DhtMockState::default(), @@ -101,7 +100,7 @@ impl DhtActorMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index 70575e2ae0..fbdf8e8284 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -24,7 +24,6 @@ use crate::{ discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester}, test_utils::make_peer, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use std::{ sync::{ @@ -35,15 +34,13 @@ use std::{ time::Duration, }; use tari_comms::peer_manager::Peer; +use tokio::sync::mpsc; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_dht_discovery_mock(buf_size: usize, timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { let (tx, rx) = mpsc::channel(buf_size); - ( - DhtDiscoveryRequester::new(tx, timeout), - DhtDiscoveryMock::new(rx.fuse()), - ) + (DhtDiscoveryRequester::new(tx, timeout), DhtDiscoveryMock::new(rx)) } #[derive(Debug, Clone)] @@ -75,12 +72,12 @@ impl DhtDiscoveryMockState { } pub struct DhtDiscoveryMock { - receiver: Fuse<mpsc::Receiver<DhtDiscoveryRequest>>, + receiver: mpsc::Receiver<DhtDiscoveryRequest>, state: DhtDiscoveryMockState, } impl DhtDiscoveryMock { - pub fn new(receiver: Fuse<mpsc::Receiver<DhtDiscoveryRequest>>) -> Self { + pub fn new(receiver: mpsc::Receiver<DhtDiscoveryRequest>) -> Self { Self { receiver, state: DhtDiscoveryMockState::new(), @@ -92,7 +89,7 @@ impl DhtDiscoveryMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 0dd464c43a..72c0861a6d 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -23,7 +23,6 @@ use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; use chrono::Utc; use digest::Digest; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::sync::{ @@ -32,14 +31,17 @@ use std::sync::{ }; use tari_comms::types::Challenge; use tari_utilities::hex; -use tokio::{runtime, sync::RwLock}; +use tokio::{ + runtime, + sync::{mpsc, RwLock}, +}; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { let (tx, rx) = mpsc::channel(10); - let mock = StoreAndForwardMock::new(rx.fuse()); + let mock = StoreAndForwardMock::new(rx); let state = mock.get_shared_state(); runtime::Handle::current().spawn(mock.run()); (StoreAndForwardRequester::new(tx), state) @@ -90,12 +92,12 @@ impl StoreAndForwardMockState { } pub struct StoreAndForwardMock { - receiver: Fuse<mpsc::Receiver<StoreAndForwardRequest>>, + receiver: mpsc::Receiver<StoreAndForwardRequest>, state: StoreAndForwardMockState, } impl StoreAndForwardMock { - pub fn new(receiver: Fuse<mpsc::Receiver<StoreAndForwardRequest>>) -> Self { + pub fn new(receiver: mpsc::Receiver<StoreAndForwardRequest>) -> Self { Self { receiver, state: StoreAndForwardMockState::new(), @@ -107,7 +109,7 @@ impl StoreAndForwardMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index a5aed09970..5647bfefd7 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -54,13 +53,16 @@ use tari_storage::{ }; use tari_test_utils::{ async_assert_eventually, - collect_stream, + collect_try_recv, paths::create_temporary_data_path, random, streams, unpack_enum, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc}, + time, +}; use tower::ServiceBuilder; struct TestNode { @@ -81,11 +83,11 @@ impl TestNode { } pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option<DecryptedDhtMessage> { - time::timeout(timeout, self.inbound_messages.next()).await.ok()? + time::timeout(timeout, self.inbound_messages.recv()).await.ok()? } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -205,7 +207,7 @@ async fn setup_comms_dht( (comms, dht, event_tx) } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_join_propagation() { // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other @@ -262,7 +264,7 @@ async fn dht_join_propagation() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_discover_propagation() { // Create 4 nodes where A knows B, B knows A and C, C knows B and D, and D knows C @@ -318,7 +320,7 @@ async fn dht_discover_propagation() { assert!(node_D_peer_manager.exists(node_A.node_identity().public_key()).await); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_store_forward() { let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -370,7 +372,7 @@ async fn dht_store_forward() { .unwrap(); // Wait for node B to receive 2 propagation messages - collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); + collect_try_recv!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); let mut node_C = make_node_with_node_identity(node_C_node_identity, Some(node_B.to_peer())).await; let mut node_C_dht_events = node_C.dht.subscribe_dht_events(); @@ -389,8 +391,8 @@ async fn dht_store_forward() { .await .unwrap(); // Wait for node C to and receive a response from the SAF request - let event = collect_stream!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &*event.get(0).unwrap().as_ref()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -418,15 +420,15 @@ async fn dht_store_forward() { assert!(msgs.is_empty()); // Check that Node C emitted the StoreAndForwardMessagesReceived event when it went Online - let event = collect_stream!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &*event.get(0).unwrap().as_ref()); node_A.shutdown().await; node_B.shutdown().await; node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_dedup() { // Node D knows no one @@ -515,28 +517,28 @@ async fn dht_propagate_dedup() { node_D.shutdown().await; // Check the message flow BEFORE deduping - let received = filter_received(collect_stream!(node_A_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_A_messaging, timeout = Duration::from_secs(20))); // Expected race condition: If A->(B|C)->(C|B) before A->(C|B) then (C|B)->A if !received.is_empty() { assert_eq!(count_messages_received(&received, &[&node_B_id, &node_C_id]), 1); } - let received = filter_received(collect_stream!(node_B_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_B_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); // Expected race condition: If A->B->C before A->C then C->B does not happen assert!((1..=2).contains(&recv_count)); - let received = filter_received(collect_stream!(node_C_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); assert_eq!(recv_count, 2); assert_eq!(count_messages_received(&received, &[&node_D_id]), 0); - let received = filter_received(collect_stream!(node_D_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_D_messaging, timeout = Duration::from_secs(20))); assert_eq!(received.len(), 1); assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() { let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; @@ -613,9 +615,9 @@ async fn dht_propagate_message_contents_not_malleable_ban() { let node_B_node_id = node_B.node_identity().node_id().clone(); // Node C should ban node B - let banned_node_id = streams::assert_in_stream( + let banned_node_id = streams::assert_in_broadcast( &mut connectivity_events, - |r| match &*r.unwrap() { + |r| match &*r { ConnectivityEvent::PeerBanned(node_id) => Some(node_id.clone()), _ => None, }, @@ -629,12 +631,9 @@ async fn dht_propagate_message_contents_not_malleable_ban() { node_C.shutdown().await; } -fn filter_received( - events: Vec<Result<Arc<MessagingEvent>, tokio::sync::broadcast::RecvError>>, -) -> Vec<Arc<MessagingEvent>> { +fn filter_received(events: Vec<Arc<MessagingEvent>>) -> Vec<Arc<MessagingEvent>> { events .into_iter() - .map(Result::unwrap) .filter(|e| match &**e { MessagingEvent::MessageReceived(_, _) => true, _ => unreachable!(), diff --git a/comms/examples/stress/error.rs b/comms/examples/stress/error.rs index 5cb9be1cb3..e87aae514e 100644 --- a/comms/examples/stress/error.rs +++ b/comms/examples/stress/error.rs @@ -19,10 +19,10 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -use futures::channel::{mpsc::SendError, oneshot}; use std::io; use tari_comms::{ connectivity::ConnectivityError, + message::OutboundMessage, peer_manager::PeerManagerError, tor, CommsBuilderError, @@ -30,7 +30,11 @@ use tari_comms::{ }; use tari_crypto::tari_utilities::message_format::MessageFormatError; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc::error::SendError, oneshot}, + task, + time, +}; #[derive(Debug, Error)] pub enum Error { @@ -48,12 +52,12 @@ pub enum Error { ConnectivityError(#[from] ConnectivityError), #[error("Message format error: {0}")] MessageFormatError(#[from] MessageFormatError), - #[error("Failed to send message")] - SendError(#[from] SendError), + #[error("Failed to send message: {0}")] + SendError(#[from] SendError<OutboundMessage>), #[error("JoinError: {0}")] JoinError(#[from] task::JoinError), #[error("Example did not exit cleanly: `{0}`")] - WaitTimeout(#[from] time::Elapsed), + WaitTimeout(#[from] time::error::Elapsed), #[error("IO error: {0}")] Io(#[from] io::Error), #[error("User quit")] @@ -63,5 +67,5 @@ pub enum Error { #[error("Unexpected EoF")] UnexpectedEof, #[error("Internal reply canceled")] - ReplyCanceled(#[from] oneshot::Canceled), + ReplyCanceled(#[from] oneshot::error::RecvError), } diff --git a/comms/examples/stress/node.rs b/comms/examples/stress/node.rs index 45ad0919ab..d060d18071 100644 --- a/comms/examples/stress/node.rs +++ b/comms/examples/stress/node.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::Error, STRESS_PROTOCOL_NAME, TOR_CONTROL_PORT_ADDR, TOR_SOCKS_ADDR}; -use futures::channel::mpsc; use rand::rngs::OsRng; use std::{convert, net::Ipv4Addr, path::Path, sync::Arc, time::Duration}; use tari_comms::{ @@ -43,7 +42,7 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; pub async fn create( node_identity: Option<Arc<NodeIdentity>>, diff --git a/comms/examples/stress/service.rs b/comms/examples/stress/service.rs index 3e262cc38b..45e2bc0fd3 100644 --- a/comms/examples/stress/service.rs +++ b/comms/examples/stress/service.rs @@ -23,15 +23,7 @@ use super::error::Error; use crate::stress::{MAX_FRAME_SIZE, STRESS_PROTOCOL_NAME}; use bytes::{Buf, Bytes, BytesMut}; -use futures::{ - channel::{mpsc, oneshot}, - stream, - stream::Fuse, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::{stream, SinkExt, StreamExt}; use rand::{rngs::OsRng, RngCore}; use std::{ iter::repeat_with, @@ -43,12 +35,19 @@ use tari_comms::{ message::{InboundMessage, OutboundMessage}, peer_manager::{NodeId, Peer}, protocol::{ProtocolEvent, ProtocolNotification}, + utils, CommsNode, PeerConnection, Substream, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, task::JoinHandle, time}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot, RwLock}, + task, + task::JoinHandle, + time, +}; pub fn start_service( comms_node: CommsNode, @@ -70,6 +69,7 @@ pub fn start_service( (task::spawn(service.start()), request_tx) } +#[derive(Debug)] pub enum StressTestServiceRequest { BeginProtocol(Peer, StressProtocol, oneshot::Sender<Result<(), Error>>), Shutdown, @@ -135,9 +135,9 @@ impl StressProtocol { } struct StressTestService { - request_rx: Fuse<mpsc::Receiver<StressTestServiceRequest>>, + request_rx: mpsc::Receiver<StressTestServiceRequest>, comms_node: CommsNode, - protocol_notif: Fuse<mpsc::Receiver<ProtocolNotification<Substream>>>, + protocol_notif: mpsc::Receiver<ProtocolNotification<Substream>>, shutdown: bool, inbound_rx: Arc<RwLock<mpsc::Receiver<InboundMessage>>>, @@ -153,9 +153,9 @@ impl StressTestService { outbound_tx: mpsc::Sender<OutboundMessage>, ) -> Self { Self { - request_rx: request_rx.fuse(), + request_rx, comms_node, - protocol_notif: protocol_notif.fuse(), + protocol_notif, shutdown: false, inbound_rx: Arc::new(RwLock::new(inbound_rx)), outbound_tx, @@ -163,23 +163,21 @@ impl StressTestService { } async fn start(mut self) -> Result<(), Error> { - let mut events = self.comms_node.subscribe_connectivity_events().fuse(); + let mut events = self.comms_node.subscribe_connectivity_events(); loop { - futures::select! { - event = events.select_next_some() => { - if let Ok(event) = event { - println!("{}", event); - } + tokio::select! { + Ok(event) = events.recv() => { + println!("{}", event); }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { if let Err(err) = self.handle_request(request).await { println!("Error: {}", err); } }, - notif = self.protocol_notif.select_next_some() => { + Some(notif) = self.protocol_notif.recv() => { self.handle_protocol_notification(notif).await; }, } @@ -431,7 +429,7 @@ async fn messaging_flood( peer: NodeId, protocol: StressProtocol, inbound_rx: Arc<RwLock<mpsc::Receiver<InboundMessage>>>, - mut outbound_tx: mpsc::Sender<OutboundMessage>, + outbound_tx: mpsc::Sender<OutboundMessage>, ) -> Result<(), Error> { let start = Instant::now(); let mut counter = 1u32; @@ -441,18 +439,15 @@ async fn messaging_flood( protocol.num_messages * protocol.message_size / 1024 / 1024 ); let outbound_task = task::spawn(async move { - let mut iter = stream::iter( - repeat_with(|| { - counter += 1; - - println!("Send MSG {}", counter); - OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) - }) - .take(protocol.num_messages as usize) - .map(Ok), - ); - outbound_tx.send_all(&mut iter).await?; - time::delay_for(Duration::from_secs(5)).await; + let iter = repeat_with(|| { + counter += 1; + + println!("Send MSG {}", counter); + OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) + }) + .take(protocol.num_messages as usize); + utils::mpsc::send_all(&outbound_tx, iter).await?; + time::sleep(Duration::from_secs(5)).await; outbound_tx .send(OutboundMessage::new(peer.clone(), Bytes::from_static(&[0u8; 4]))) .await?; @@ -462,7 +457,7 @@ async fn messaging_flood( let inbound_task = task::spawn(async move { let mut inbound_rx = inbound_rx.write().await; let mut msgs = vec![]; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { let msg_id = decode_msg(msg.body); println!("GOT MSG {}", msg_id); if msgs.len() == protocol.num_messages as usize { @@ -497,6 +492,6 @@ fn generate_message(n: u32, size: usize) -> Bytes { fn decode_msg<T: prost::bytes::Buf>(msg: T) -> u32 { let mut buf = [0u8; 4]; - msg.bytes().copy_to_slice(&mut buf); + msg.chunk().copy_to_slice(&mut buf); u32::from_be_bytes(buf) } diff --git a/comms/examples/stress_test.rs b/comms/examples/stress_test.rs index 71c185aa3e..3a0c04f020 100644 --- a/comms/examples/stress_test.rs +++ b/comms/examples/stress_test.rs @@ -24,13 +24,13 @@ mod stress; use stress::{error::Error, prompt::user_prompt}; use crate::stress::{node, prompt::parse_from_short_str, service, service::StressTestServiceRequest}; -use futures::{channel::oneshot, future, future::Either, SinkExt}; +use futures::{future, future::Either}; use std::{env, net::Ipv4Addr, path::Path, process, sync::Arc, time::Duration}; use tari_crypto::tari_utilities::message_format::MessageFormat; use tempfile::Builder; -use tokio::time; +use tokio::{sync::oneshot, time}; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); match run().await { @@ -99,7 +99,7 @@ async fn run() -> Result<(), Error> { } println!("Stress test service started!"); - let (handle, mut requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); + let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); let mut last_peer = peer.as_ref().and_then(parse_from_short_str); diff --git a/comms/examples/tor.rs b/comms/examples/tor.rs index 734ef7718a..9186c69d43 100644 --- a/comms/examples/tor.rs +++ b/comms/examples/tor.rs @@ -1,7 +1,6 @@ use anyhow::anyhow; use bytes::Bytes; use chrono::Utc; -use futures::{channel::mpsc, SinkExt, StreamExt}; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{collections::HashMap, convert::identity, env, net::SocketAddr, path::Path, process, sync::Arc}; use tari_comms::{ @@ -21,7 +20,10 @@ use tari_storage::{ LMDBWrapper, }; use tempfile::Builder; -use tokio::{runtime, sync::broadcast}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, +}; // Tor example for tari_comms. // @@ -29,7 +31,7 @@ use tokio::{runtime, sync::broadcast}; type Error = anyhow::Error; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); if let Err(err) = run().await { @@ -56,7 +58,7 @@ async fn run() -> Result<(), Error> { println!("Starting comms nodes...",); let temp_dir1 = Builder::new().prefix("tor-example1").tempdir().unwrap(); - let (comms_node1, inbound_rx1, mut outbound_tx1) = setup_node_with_tor( + let (comms_node1, inbound_rx1, outbound_tx1) = setup_node_with_tor( control_port_addr.clone(), temp_dir1.as_ref(), (9098u16, "127.0.0.1:0".parse::<SocketAddr>().unwrap()), @@ -208,11 +210,11 @@ async fn setup_node_with_tor<P: Into<tor::PortMapping>>( async fn start_ping_ponger( dest_node_id: NodeId, mut inbound_rx: mpsc::Receiver<InboundMessage>, - mut outbound_tx: mpsc::Sender<OutboundMessage>, + outbound_tx: mpsc::Sender<OutboundMessage>, ) -> Result<usize, Error> { let mut inflight_pings = HashMap::new(); let mut counter = 0; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { counter += 1; let msg_str = String::from_utf8_lossy(&msg.body); diff --git a/comms/rpc_macros/Cargo.toml b/comms/rpc_macros/Cargo.toml index 3680ed81f9..5bc185bc90 100644 --- a/comms/rpc_macros/Cargo.toml +++ b/comms/rpc_macros/Cargo.toml @@ -13,16 +13,16 @@ edition = "2018" proc-macro = true [dependencies] +tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} + proc-macro2 = "1.0.24" quote = "1.0.7" syn = {version = "1.0.38", features = ["fold"]} -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} [dev-dependencies] tari_test_utils = {version="^0.9", path="../../infrastructure/test_utils"} futures = "0.3.5" -prost = "0.6.1" -tokio = "0.2.22" -tokio-macros = "0.2.5" +prost = "0.8.0" +tokio = {version = "1", features = ["macros"]} tower-service = "0.3.0" diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index f3e8cbffd1..5f44066f19 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -215,8 +215,8 @@ impl RpcCodeGenerator { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } }; diff --git a/comms/rpc_macros/tests/macro.rs b/comms/rpc_macros/tests/macro.rs index f3f05b4481..9dec9ff5ff 100644 --- a/comms/rpc_macros/tests/macro.rs +++ b/comms/rpc_macros/tests/macro.rs @@ -101,7 +101,7 @@ fn it_sets_the_protocol_name() { assert_eq!(TestClient::PROTOCOL_NAME, b"/test/protocol/123"); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_the_correct_type() { let mut server = TestServer::new(TestService::default()); let resp = server @@ -112,7 +112,7 @@ async fn it_returns_the_correct_type() { assert_eq!(u32::decode(v).unwrap(), 12); } -#[tokio_macros::test] +#[tokio::test] async fn it_correctly_maps_the_method_nums() { let service = TestService::default(); let spy = service.state.clone(); @@ -135,7 +135,7 @@ async fn it_correctly_maps_the_method_nums() { assert_eq!(*spy.read().await.get("unit").unwrap(), 1); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_error_for_invalid_method_nums() { let service = TestService::default(); let mut server = TestServer::new(service); @@ -147,7 +147,7 @@ async fn it_returns_an_error_for_invalid_method_nums() { unpack_enum!(RpcStatusCode::UnsupportedMethod = err.status_code()); } -#[tokio_macros::test] +#[tokio::test] async fn it_generates_client_calls() { let (sock_client, sock_server) = MemorySocket::new_pair(); let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024))); diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index f82b4a9f84..e9f0a82d5c 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -143,7 +143,9 @@ impl BoundedExecutor { F: Future + Send + 'static, F::Output: Send + 'static, { - let permit = self.semaphore.clone().acquire_owned().await; + // SAFETY: acquire_owned only fails if the semaphore is closed (i.e self.semaphore.close() is called) - this + // never happens in this implementation + let permit = self.semaphore.clone().acquire_owned().await.expect("semaphore closed"); self.do_spawn(permit, future) } @@ -227,9 +229,9 @@ mod test { }, time::Duration, }; - use tokio::time::delay_for; + use tokio::time::sleep; - #[runtime::test_basic] + #[runtime::test] async fn spawn() { let flag = Arc::new(AtomicBool::new(false)); let flag_cloned = flag.clone(); @@ -238,7 +240,7 @@ mod test { // Spawn 1 let task1_fut = executor .spawn(async move { - delay_for(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)).await; flag_cloned.store(true, Ordering::SeqCst); }) .await; diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index 24c856f5ba..abd71e8952 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -46,11 +46,13 @@ use crate::{ CommsBuilder, Substream, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use log::*; use std::{iter, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::node"; diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index f5acc151a9..5da0d51793 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -51,10 +51,9 @@ use crate::{ tor, types::CommsDatabase, }; -use futures::channel::mpsc; use std::{fs::File, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; /// The `CommsBuilder` provides a simple builder API for getting Tari comms p2p messaging up and running. pub struct CommsBuilder { diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index d1ae9a0f9a..d4fe8cca97 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -42,19 +42,16 @@ use crate::{ CommsNode, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::stream::FuturesUnordered; use std::{collections::HashSet, convert::identity, hash::Hash, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::HashmapDatabase; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, task}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{broadcast, mpsc, oneshot}, + task, +}; async fn spawn_node( protocols: Protocols<Substream>, @@ -109,7 +106,7 @@ async fn spawn_node( (comms_node, inbound_rx, outbound_tx, messaging_events_sender) } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_custom_protocols() { static TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test"); static ANOTHER_TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test-again"); @@ -161,9 +158,9 @@ async fn peer_to_peer_custom_protocols() { // Check that both nodes get the PeerConnected event. We subscribe after the nodes are initialized // so we miss those events. - let next_event = conn_man_events2.next().await.unwrap().unwrap(); + let next_event = conn_man_events2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn2) = &*next_event); - let next_event = conn_man_events1.next().await.unwrap().unwrap(); + let next_event = conn_man_events1.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn) = &*next_event); // Let's speak both our test protocols @@ -176,7 +173,7 @@ async fn peer_to_peer_custom_protocols() { negotiated_substream2.stream.write_all(ANOTHER_TEST_MSG).await.unwrap(); // Read TEST_PROTOCOL message to node 2 from node 1 - let negotiation = test_protocol_rx2.next().await.unwrap(); + let negotiation = test_protocol_rx2.recv().await.unwrap(); assert_eq!(negotiation.protocol, TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -185,7 +182,7 @@ async fn peer_to_peer_custom_protocols() { assert_eq!(buf, TEST_MSG); // Read ANOTHER_TEST_PROTOCOL message to node 1 from node 2 - let negotiation = another_test_protocol_rx1.next().await.unwrap(); + let negotiation = another_test_protocol_rx1.recv().await.unwrap(); assert_eq!(negotiation.protocol, ANOTHER_TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -193,18 +190,18 @@ async fn peer_to_peer_custom_protocols() { substream.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, ANOTHER_TEST_MSG); - shutdown.trigger().unwrap(); + shutdown.trigger(); comms_node1.wait_until_shutdown().await; comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging() { const NUM_MSGS: usize = 100; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, messaging_events2) = + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, messaging_events2) = spawn_node(Protocols::new(), shutdown.to_signal()).await; let mut messaging_events2 = messaging_events2.subscribe(); @@ -238,14 +235,14 @@ async fn peer_to_peer_messaging() { outbound_tx1.send(outbound_msg).await.unwrap(); } - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); let send_results = collect_stream!(replies, take = NUM_MSGS, timeout = Duration::from_secs(10)); send_results.into_iter().for_each(|r| { r.unwrap().unwrap(); }); - let events = collect_stream!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - events.into_iter().map(Result::unwrap).for_each(|m| { + let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + events.into_iter().for_each(|m| { unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); }); @@ -258,7 +255,7 @@ async fn peer_to_peer_messaging() { outbound_tx2.send(outbound_msg).await.unwrap(); } - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); // Check that we got all the messages let check_messages = |msgs: Vec<InboundMessage>| { @@ -279,13 +276,13 @@ async fn peer_to_peer_messaging() { comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging_simultaneous() { const NUM_MSGS: usize = 10; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; log::info!( "Peer1 = `{}`, Peer2 = `{}`", @@ -350,8 +347,8 @@ async fn peer_to_peer_messaging_simultaneous() { handle2.await.unwrap(); // Tasks are finished, let's see if all the messages made it though - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert!(has_unique_elements(messages1_to_2.into_iter().map(|m| m.body))); assert!(has_unique_elements(messages2_to_1.into_iter().map(|m| m.body))); diff --git a/comms/src/common/rate_limit.rs b/comms/src/common/rate_limit.rs index 1705397d8e..5cbcb964b6 100644 --- a/comms/src/common/rate_limit.rs +++ b/comms/src/common/rate_limit.rs @@ -26,7 +26,7 @@ // This is slightly changed from the libra rate limiter implementation -use futures::{stream::Fuse, FutureExt, Stream, StreamExt}; +use futures::FutureExt; use pin_project::pin_project; use std::{ future::Future, @@ -36,10 +36,11 @@ use std::{ time::Duration, }; use tokio::{ - sync::{OwnedSemaphorePermit, Semaphore}, + sync::{AcquireError, OwnedSemaphorePermit, Semaphore}, time, time::Interval, }; +use tokio_stream::Stream; pub trait RateLimit: Stream { /// Consumes the stream and returns a rate-limited stream that only polls the underlying stream @@ -60,12 +61,12 @@ pub struct RateLimiter<T> { stream: T, /// An interval stream that "restocks" the permits #[pin] - interval: Fuse<Interval>, + interval: Interval, /// The maximum permits to issue capacity: usize, /// A semaphore that holds the permits permits: Arc<Semaphore>, - permit_future: Option<Pin<Box<dyn Future<Output = OwnedSemaphorePermit> + Send>>>, + permit_future: Option<Pin<Box<dyn Future<Output = Result<OwnedSemaphorePermit, AcquireError>> + Send>>>, permit_acquired: bool, } @@ -75,7 +76,7 @@ impl<T: Stream> RateLimiter<T> { stream, capacity, - interval: time::interval(restock_interval).fuse(), + interval: time::interval(restock_interval), // `interval` starts immediately, so we can start with zero permits permits: Arc::new(Semaphore::new(0)), permit_future: None, @@ -89,7 +90,7 @@ impl<T: Stream> Stream for RateLimiter<T> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { // "Restock" permits once interval is ready - if let Poll::Ready(Some(_)) = self.as_mut().project().interval.poll_next(cx) { + if let Poll::Ready(_) = self.as_mut().project().interval.poll_tick(cx) { self.permits .add_permits(self.capacity - self.permits.available_permits()); } @@ -103,6 +104,8 @@ impl<T: Stream> Stream for RateLimiter<T> { } // Wait until a permit is acquired + // `unwrap()` is safe because acquire_owned only panics if the semaphore has closed, but we never close it + // for the lifetime of this instance let permit = futures::ready!(self .as_mut() .project() @@ -110,7 +113,8 @@ impl<T: Stream> Stream for RateLimiter<T> { .as_mut() .unwrap() .as_mut() - .poll(cx)); + .poll(cx)) + .unwrap(); // Don't release the permit on drop, `interval` will restock permits permit.forget(); let this = self.as_mut().project(); @@ -130,45 +134,54 @@ impl<T: Stream> Stream for RateLimiter<T> { mod test { use super::*; use crate::runtime; - use futures::{future::Either, stream}; + use futures::{stream, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn rate_limit() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } assert_eq!(count, 10); } - #[runtime::test_basic] + #[runtime::test] async fn rate_limit_restock() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } // Test that at least 1 restock happens. diff --git a/comms/src/compat.rs b/comms/src/compat.rs index 254876f14d..67b53c7c91 100644 --- a/comms/src/compat.rs +++ b/comms/src/compat.rs @@ -27,8 +27,9 @@ use std::{ io, pin::Pin, - task::{self, Poll}, + task::{self, Context, Poll}, }; +use tokio::io::ReadBuf; /// `IoCompat` provides a compatibility shim between the `AsyncRead`/`AsyncWrite` traits provided by /// the `futures` library and those provided by the `tokio` library since they are different and @@ -47,16 +48,16 @@ impl<T> IoCompat<T> { impl<T> tokio::io::AsyncRead for IoCompat<T> where T: futures::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { - futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<usize>> { + futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf.filled_mut()) } } impl<T> futures::io::AsyncRead for IoCompat<T> where T: tokio::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { - tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll<io::Result<()>> { + tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, &mut ReadBuf::new(buf)) } } diff --git a/comms/src/connection_manager/dial_state.rs b/comms/src/connection_manager/dial_state.rs index 0b07378747..07bbd1e631 100644 --- a/comms/src/connection_manager/dial_state.rs +++ b/comms/src/connection_manager/dial_state.rs @@ -24,8 +24,8 @@ use crate::{ connection_manager::{error::ConnectionManagerError, peer_connection::PeerConnection}, peer_manager::Peer, }; -use futures::channel::oneshot; use tari_shutdown::ShutdownSignal; +use tokio::sync::oneshot; /// The state of the dial request pub struct DialState { diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index b8fffb3029..f2b86bae38 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -39,22 +39,22 @@ use crate::{ types::CommsPublicKey, }; use futures::{ - channel::{mpsc, oneshot}, future, future::{BoxFuture, Either, FusedFuture}, pin_mut, - stream::{Fuse, FuturesUnordered}, - AsyncRead, - AsyncWrite, - AsyncWriteExt, + stream::FuturesUnordered, FutureExt, - SinkExt, - StreamExt, }; use log::*; use std::{collections::HashMap, sync::Arc, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; -use tokio::{task::JoinHandle, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + sync::{mpsc, oneshot}, + task::JoinHandle, + time, +}; +use tokio_stream::StreamExt; const LOG_TARGET: &str = "comms::connection_manager::dialer"; @@ -78,7 +78,7 @@ pub struct Dialer<TTransport, TBackoff> { transport: TTransport, noise_config: NoiseConfig, backoff: Arc<TBackoff>, - request_rx: Fuse<mpsc::Receiver<DialerRequest>>, + request_rx: mpsc::Receiver<DialerRequest>, cancel_signals: HashMap<NodeId, Shutdown>, conn_man_notifier: mpsc::Sender<ConnectionManagerEvent>, shutdown: Option<ShutdownSignal>, @@ -111,7 +111,7 @@ where transport, noise_config, backoff: Arc::new(backoff), - request_rx: request_rx.fuse(), + request_rx, cancel_signals: Default::default(), conn_man_notifier, shutdown: Some(shutdown), @@ -138,16 +138,20 @@ where .expect("Establisher initialized without a shutdown"); debug!(target: LOG_TARGET, "Connection dialer started"); loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(&mut pending_dials, request), - (dial_state, dial_result) = pending_dials.select_next_some() => { - self.handle_dial_result(dial_state, dial_result).await; - } - _ = shutdown => { + tokio::select! { + // Biased ordering is used because we already have the futures polled here in a fair order, and so wish to + // forgo the minor cost of the random ordering + biased; + + _ = &mut shutdown => { info!(target: LOG_TARGET, "Connection dialer shutting down because the shutdown signal was received"); self.cancel_all_dials(); break; } + Some((dial_state, dial_result)) = pending_dials.next() => { + self.handle_dial_result(dial_state, dial_result).await; + } + Some(request) = self.request_rx.recv() => self.handle_request(&mut pending_dials, request), } } } @@ -178,12 +182,7 @@ where self.cancel_signals.len() ); self.cancel_signals.drain().for_each(|(_, mut signal)| { - log_if_error_fmt!( - level: warn, - target: LOG_TARGET, - signal.trigger(), - "Shutdown trigger failed", - ); + signal.trigger(); }) } @@ -347,7 +346,6 @@ where cancel_signal: ShutdownSignal, ) -> Result<PeerConnection, ConnectionManagerError> { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; - let mut muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) .await .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; @@ -442,9 +440,9 @@ where current_state.peer.node_id.short_str(), backoff_duration.as_secs() ); - let mut delay = time::delay_for(backoff_duration).fuse(); - let mut cancel_signal = current_state.get_cancel_signal(); - futures::select! { + let delay = time::sleep(backoff_duration).fuse(); + let cancel_signal = current_state.get_cancel_signal(); + tokio::select! { _ = delay => { debug!(target: LOG_TARGET, "[Attempt {}] Connecting to peer '{}'", current_state.num_attempts(), current_state.peer.node_id.short_str()); match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_byte).await { @@ -538,18 +536,13 @@ where // Try the next address continue; }, - Either::Right((cancel_result, _)) => { + // Canceled + Either::Right(_) => { debug!( target: LOG_TARGET, "Dial for peer '{}' cancelled", dial_state.peer.node_id.short_str() ); - log_if_error!( - level: warn, - target: LOG_TARGET, - cancel_result, - "Cancel channel error during dial: {}", - ); Err(ConnectionManagerError::DialCancelled) }, } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index f7bbbaa564..5645a57e62 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ + connection_manager::PeerConnectionRequest, noise, peer_manager::PeerManagerError, protocol::{IdentityProtocolError, ProtocolError}, }; -use futures::channel::mpsc; use thiserror::Error; -use tokio::{time, time::Elapsed}; +use tokio::{sync::mpsc, time::error::Elapsed}; #[derive(Debug, Error, Clone)] pub enum ConnectionManagerError { @@ -108,14 +108,14 @@ pub enum PeerConnectionError { #[error("Internal oneshot reply channel was unexpectedly cancelled")] InternalReplyCancelled, #[error("Failed to send internal request: {0}")] - InternalRequestSendFailed(#[from] mpsc::SendError), + InternalRequestSendFailed(#[from] mpsc::error::SendError<PeerConnectionRequest>), #[error("Protocol error: {0}")] ProtocolError(#[from] ProtocolError), #[error("Protocol negotiation timeout")] ProtocolNegotiationTimeout, } -impl From<time::Elapsed> for PeerConnectionError { +impl From<Elapsed> for PeerConnectionError { fn from(_: Elapsed) -> Self { PeerConnectionError::ProtocolNegotiationTimeout } diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 5984ffcca2..16e8042425 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -32,7 +32,6 @@ use crate::{ bounded_executor::BoundedExecutor, connection_manager::{ liveness::LivenessSession, - types::OneshotTrigger, wire_mode::{WireMode, LIVENESS_WIRE_MODE}, }, multiaddr::Multiaddr, @@ -46,17 +45,7 @@ use crate::{ utils::multiaddr::multiaddr_to_socketaddr, PeerManager, }; -use futures::{ - channel::mpsc, - future, - AsyncRead, - AsyncReadExt, - AsyncWrite, - AsyncWriteExt, - FutureExt, - SinkExt, - StreamExt, -}; +use futures::{future, FutureExt}; use log::*; use std::{ convert::TryInto, @@ -69,8 +58,13 @@ use std::{ time::Duration, }; use tari_crypto::tari_utilities::hex::Hex; -use tari_shutdown::ShutdownSignal; -use tokio::time; +use tari_shutdown::{oneshot_trigger, oneshot_trigger::OneshotTrigger, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; const LOG_TARGET: &str = "comms::connection_manager::listener"; @@ -117,7 +111,7 @@ where bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, - on_listening: OneshotTrigger::new(), + on_listening: oneshot_trigger::channel(), } } @@ -127,7 +121,7 @@ where // 'static lifetime as well as to flatten the oneshot result for ergonomics pub fn on_listening(&self) -> impl Future<Output = Result<Multiaddr, ConnectionManagerError>> + 'static { let signal = self.on_listening.to_signal(); - signal.map(|r| r.map_err(|_| ConnectionManagerError::ListenerOneshotCancelled)?) + signal.map(|r| r.ok_or_else(|| ConnectionManagerError::ListenerOneshotCancelled)?) } /// Set the supported protocols of this node to send to peers during the peer identity exchange @@ -146,31 +140,30 @@ where let mut shutdown_signal = self.shutdown_signal.clone(); match self.bind().await { - Ok((inbound, address)) => { + Ok((mut inbound, address)) => { info!(target: LOG_TARGET, "Listening for peer connections on '{}'", address); - self.on_listening.trigger(Ok(address)); - - let inbound = inbound.fuse(); - futures::pin_mut!(inbound); + self.on_listening.broadcast(Ok(address)); loop { - futures::select! { - inbound_result = inbound.select_next_some() => { + tokio::select! { + biased; + + _ = &mut shutdown_signal => { + info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); + break; + }, + Some(inbound_result) = inbound.next() => { if let Some((socket, peer_addr)) = log_if_error!(target: LOG_TARGET, inbound_result, "Inbound connection failed because '{error}'",) { self.spawn_listen_task(socket, peer_addr).await; } }, - _ = shutdown_signal => { - info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); - break; - }, } } }, Err(err) => { warn!(target: LOG_TARGET, "PeerListener was unable to start because '{}'", err); - self.on_listening.trigger(Err(err)); + self.on_listening.broadcast(Err(err)); }, } } @@ -237,7 +230,7 @@ where async fn spawn_listen_task(&self, mut socket: TTransport::Output, peer_addr: Multiaddr) { let node_identity = self.node_identity.clone(); let peer_manager = self.peer_manager.clone(); - let mut conn_man_notifier = self.conn_man_notifier.clone(); + let conn_man_notifier = self.conn_man_notifier.clone(); let noise_config = self.noise_config.clone(); let config = self.config.clone(); let our_supported_protocols = self.our_supported_protocols.clone(); @@ -316,7 +309,7 @@ where "No liveness sessions available or permitted for peer address '{}'", peer_addr ); - let _ = socket.close().await; + let _ = socket.shutdown().await; } }, Err(err) => { diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 3c06889307..75ee2db13f 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -20,15 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite, Future, StreamExt}; +use futures::StreamExt; +use std::future::Future; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; /// Max line length accepted by the liveness session. const MAX_LINE_LENGTH: usize = 50; pub struct LivenessSession<TSocket> { - framed: Framed<IoCompat<TSocket>, LinesCodec>, + framed: Framed<TSocket, LinesCodec>, } impl<TSocket> LivenessSession<TSocket> @@ -36,7 +37,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin { pub fn new(socket: TSocket) -> Self { Self { - framed: Framed::new(IoCompat::new(socket), LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), + framed: Framed::new(socket, LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), } } @@ -52,13 +53,14 @@ mod test { use crate::{memsocket::MemorySocket, runtime}; use futures::SinkExt; use tokio::{time, time::Duration}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); let join_handle = runtime::current().spawn(liveness.run()); - let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); + let mut outbound = Framed::new(outbound, LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() } diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index d1019f33d4..287de0ec2b 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -36,20 +36,17 @@ use crate::{ transports::{TcpTransport, Transport}, PeerManager, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - AsyncRead, - AsyncWrite, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{fmt, sync::Arc}; use tari_shutdown::{Shutdown, ShutdownSignal}; use time::Duration; -use tokio::{sync::broadcast, task, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::connection_manager::manager"; @@ -155,8 +152,8 @@ impl ListenerInfo { } pub struct ConnectionManager<TTransport, TBackoff> { - request_rx: Fuse<mpsc::Receiver<ConnectionManagerRequest>>, - internal_event_rx: Fuse<mpsc::Receiver<ConnectionManagerEvent>>, + request_rx: mpsc::Receiver<ConnectionManagerRequest>, + internal_event_rx: mpsc::Receiver<ConnectionManagerEvent>, dialer_tx: mpsc::Sender<DialerRequest>, dialer: Option<Dialer<TTransport, TBackoff>>, listener: Option<PeerListener<TTransport>>, @@ -229,10 +226,10 @@ where Self { shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + request_rx, peer_manager, protocols: Protocols::new(), - internal_event_rx: internal_event_rx.fuse(), + internal_event_rx, dialer_tx, dialer: Some(dialer), listener: Some(listener), @@ -263,7 +260,7 @@ where .take() .expect("ConnectionManager initialized without a shutdown"); - // Runs the listeners, waiting for a + // Runs the listeners. Sockets are bound and ready once this resolves match self.run_listeners().await { Ok(info) => { self.listener_info = Some(info); @@ -290,16 +287,16 @@ where .join(", ") ); loop { - futures::select! { - event = self.internal_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_event_rx.recv() => { self.handle_event(event).await; }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - _ = shutdown => { + _ = &mut shutdown => { info!(target: LOG_TARGET, "ConnectionManager is shutting down because it received the shutdown signal"); break; } diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 260befaeee..9d0e0d8c61 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -44,20 +44,21 @@ use crate::{ protocol::{ProtocolId, ProtocolNegotiation}, runtime, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{ fmt, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; +use tokio_stream::StreamExt; const LOG_TARGET: &str = "comms::connection_manager::peer_connection"; @@ -128,7 +129,7 @@ pub struct PeerConnection { peer_node_id: NodeId, peer_features: PeerFeatures, request_tx: mpsc::Sender<PeerConnectionRequest>, - address: Multiaddr, + address: Arc<Multiaddr>, direction: ConnectionDirection, started_at: Instant, substream_counter: SubstreamCounter, @@ -149,7 +150,7 @@ impl PeerConnection { request_tx, peer_node_id, peer_features, - address, + address: Arc::new(address), direction, started_at: Instant::now(), substream_counter, @@ -291,9 +292,9 @@ impl PartialEq for PeerConnection { struct PeerConnectionActor { id: ConnectionId, peer_node_id: NodeId, - request_rx: Fuse<mpsc::Receiver<PeerConnectionRequest>>, + request_rx: mpsc::Receiver<PeerConnectionRequest>, direction: ConnectionDirection, - incoming_substreams: Fuse<IncomingSubstreams>, + incoming_substreams: IncomingSubstreams, control: Control, event_notifier: mpsc::Sender<ConnectionManagerEvent>, our_supported_protocols: Vec<ProtocolId>, @@ -317,8 +318,8 @@ impl PeerConnectionActor { peer_node_id, direction, control: connection.get_yamux_control(), - incoming_substreams: connection.incoming().fuse(), - request_rx: request_rx.fuse(), + incoming_substreams: connection.incoming(), + request_rx, event_notifier, our_supported_protocols, their_supported_protocols, @@ -327,8 +328,8 @@ impl PeerConnectionActor { pub async fn run(mut self) { loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(request).await, + tokio::select! { + Some(request) = self.request_rx.recv() => self.handle_request(request).await, maybe_substream = self.incoming_substreams.next() => { match maybe_substream { @@ -352,7 +353,7 @@ impl PeerConnectionActor { } } } - self.request_rx.get_mut().close(); + self.request_rx.close(); } async fn handle_request(&mut self, request: PeerConnectionRequest) { diff --git a/comms/src/connection_manager/requester.rs b/comms/src/connection_manager/requester.rs index 3b86a88bc2..f6771a06db 100644 --- a/comms/src/connection_manager/requester.rs +++ b/comms/src/connection_manager/requester.rs @@ -25,12 +25,8 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, oneshot}; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index 586c0d5ec4..25e715bf50 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -36,20 +36,17 @@ use crate::{ test_utils::{node_identity::build_node_identity, test_node::build_peer_manager}, transports::MemoryTransport, }; -use futures::{ - channel::{mpsc, oneshot}, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; use multiaddr::Protocol; use std::{error::Error, time::Duration}; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::time::timeout; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot}, + time::timeout, +}; -#[runtime::test_basic] +#[runtime::test] async fn listen() -> Result<(), Box<dyn Error>> { let (event_tx, _) = mpsc::channel(1); let mut shutdown = Shutdown::new(); @@ -61,7 +58,7 @@ async fn listen() -> Result<(), Box<dyn Error>> { "/memory/0".parse()?, MemoryTransport, noise_config.clone(), - event_tx.clone(), + event_tx, peer_manager, node_identity, shutdown.to_signal(), @@ -72,12 +69,12 @@ async fn listen() -> Result<(), Box<dyn Error>> { unpack_enum!(Protocol::Memory(port) = bind_addr.pop().unwrap()); assert!(port > 0); - shutdown.trigger().unwrap(); + shutdown.trigger(); Ok(()) } -#[runtime::test_basic] +#[runtime::test] async fn smoke() { let rt_handle = runtime::current(); // This test sets up Dialer and Listener components, uses the Dialer to dial the Listener, @@ -108,7 +105,7 @@ async fn smoke() { let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -148,11 +145,11 @@ async fn smoke() { } // Read PeerConnected events - we don't know which connection is which - unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.next().await.unwrap()); - unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.recv().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.recv().await.unwrap()); // Next event should be a NewInboundSubstream has been received - let listen_event = event_rx.next().await.unwrap(); + let listen_event = event_rx.recv().await.unwrap(); { unpack_enum!(ConnectionManagerEvent::NewInboundSubstream(node_id, proto, in_stream) = listen_event); assert_eq!(&*node_id, node_identity2.node_id()); @@ -165,7 +162,7 @@ async fn smoke() { conn1.disconnect().await.unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); let peer2 = peer_manager1.find_by_node_id(node_identity2.node_id()).await.unwrap(); let peer1 = peer_manager2.find_by_node_id(node_identity1.node_id()).await.unwrap(); @@ -176,7 +173,7 @@ async fn smoke() { timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn banned() { let rt_handle = runtime::current(); let (event_tx, mut event_rx) = mpsc::channel(10); @@ -209,7 +206,7 @@ async fn banned() { peer_manager1.add_peer(peer).await.unwrap(); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -241,10 +238,10 @@ async fn banned() { let err = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::IdentityProtocolError(_err) = err); - unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.recv().await.unwrap()); unpack_enum!(ConnectionManagerError::PeerBanned = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index bca876eff4..910280b4cf 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -41,19 +41,17 @@ use crate::{ }, transports::{MemoryTransport, TcpTransport}, }; -use futures::{ - channel::{mpsc, oneshot}, - future, - AsyncReadExt, - AsyncWriteExt, - StreamExt, -}; +use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{runtime::Handle, sync::broadcast}; +use tari_test_utils::{collect_try_recv, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + runtime::Handle, + sync::{broadcast, mpsc, oneshot}, +}; -#[runtime::test_basic] +#[runtime::test] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); @@ -83,10 +81,10 @@ async fn connect_to_nonexistent_peer() { unpack_enum!(ConnectionManagerError::PeerManagerError(err) = err); unpack_enum!(PeerManagerError::PeerNotFoundError = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -159,7 +157,7 @@ async fn dial_success() { assert_eq!(peer2.supported_protocols, [&IDENTITY_PROTOCOL, &TEST_PROTO]); assert_eq!(peer2.user_agent, "node2"); - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn_in) = &*event); assert_eq!(conn_in.peer_node_id(), node_identity1.node_id()); @@ -179,7 +177,7 @@ async fn dial_success() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx2.next().await.unwrap(); + let protocol_in = proto_rx2.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -189,7 +187,7 @@ async fn dial_success() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success_aux_tcp_listener() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -271,7 +269,7 @@ async fn dial_success_aux_tcp_listener() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx1.next().await.unwrap(); + let protocol_in = proto_rx1.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -281,7 +279,7 @@ async fn dial_success_aux_tcp_listener() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn simultaneous_dial_events() { let mut shutdown = Shutdown::new(); @@ -360,29 +358,22 @@ async fn simultaneous_dial_events() { _ => panic!("unexpected simultaneous dial result"), } - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); assert!(count_string_occurrences(&[event], &["PeerConnected", "PeerInboundConnectFailed"]) >= 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); drop(conn_man2); - let _events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::<Vec<_>>(); - - let _events2 = collect_stream!(subscription2, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::<Vec<_>>(); + let _events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); + let _events2 = collect_try_recv!(subscription2, timeout = Duration::from_secs(5)); // TODO: Investigate why two PeerDisconnected events are sometimes received // assert!(count_string_occurrences(&events1, &["PeerDisconnected"]) >= 1); // assert!(count_string_occurrences(&events2, &["PeerDisconnected"]) >= 1); } -#[tokio_macros::test_basic] +#[runtime::test] async fn dial_cancelled() { let mut shutdown = Shutdown::new(); @@ -429,13 +420,10 @@ async fn dial_cancelled() { let err = dial_result.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::DialCancelled = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); - let events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::<Vec<_>>(); + let events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); assert_eq!(events1.len(), 1); unpack_enum!(ConnectionManagerEvent::PeerConnectFailed(node_id, err) = &*events1[0]); diff --git a/comms/src/connection_manager/types.rs b/comms/src/connection_manager/types.rs index ddb8f6de8f..c92b2b717a 100644 --- a/comms/src/connection_manager/types.rs +++ b/comms/src/connection_manager/types.rs @@ -20,11 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::oneshot, - future::{Fuse, Shared}, - FutureExt, -}; use std::fmt; /// Direction of the connection relative to this node @@ -47,29 +42,3 @@ impl fmt::Display for ConnectionDirection { write!(f, "{:?}", self) } } - -pub type OneshotSignal<T> = Shared<Fuse<oneshot::Receiver<T>>>; -pub struct OneshotTrigger<T>(Option<oneshot::Sender<T>>, OneshotSignal<T>); - -impl<T: Clone> OneshotTrigger<T> { - pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self(Some(tx), rx.fuse().shared()) - } - - pub fn to_signal(&self) -> OneshotSignal<T> { - self.1.clone() - } - - pub fn trigger(&mut self, item: T) { - if let Some(tx) = self.0.take() { - let _ = tx.send(item); - } - } -} - -impl<T: Clone> Default for OneshotTrigger<T> { - fn default() -> Self { - Self::new() - } -} diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index b692c8c5af..67331d7c80 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -34,6 +34,7 @@ use crate::{ ConnectionManagerEvent, ConnectionManagerRequester, }, + connectivity::ConnectivityEventTx, peer_manager::NodeId, runtime::task, utils::datetime::format_duration, @@ -41,7 +42,6 @@ use crate::{ PeerConnection, PeerManager, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use nom::lib::std::collections::hash_map::Entry; use std::{ @@ -52,7 +52,7 @@ use std::{ time::{Duration, Instant}, }; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle, time}; +use tokio::{sync::mpsc, task::JoinHandle, time}; const LOG_TARGET: &str = "comms::connectivity::manager"; @@ -71,7 +71,7 @@ const LOG_TARGET: &str = "comms::connectivity::manager"; pub struct ConnectivityManager { pub config: ConnectivityConfig, pub request_rx: mpsc::Receiver<ConnectivityRequest>, - pub event_tx: broadcast::Sender<Arc<ConnectivityEvent>>, + pub event_tx: ConnectivityEventTx, pub connection_manager: ConnectionManagerRequester, pub peer_manager: Arc<PeerManager>, pub node_identity: Arc<NodeIdentity>, @@ -83,7 +83,7 @@ impl ConnectivityManager { ConnectivityManagerActor { config: self.config, status: ConnectivityStatus::Initializing, - request_rx: self.request_rx.fuse(), + request_rx: self.request_rx, connection_manager: self.connection_manager, peer_manager: self.peer_manager.clone(), event_tx: self.event_tx, @@ -139,12 +139,12 @@ impl fmt::Display for ConnectivityStatus { pub struct ConnectivityManagerActor { config: ConnectivityConfig, status: ConnectivityStatus, - request_rx: Fuse<mpsc::Receiver<ConnectivityRequest>>, + request_rx: mpsc::Receiver<ConnectivityRequest>, connection_manager: ConnectionManagerRequester, node_identity: Arc<NodeIdentity>, shutdown_signal: Option<ShutdownSignal>, peer_manager: Arc<PeerManager>, - event_tx: broadcast::Sender<Arc<ConnectivityEvent>>, + event_tx: ConnectivityEventTx, connection_stats: HashMap<NodeId, PeerConnectionStats>, managed_peers: Vec<NodeId>, @@ -163,7 +163,7 @@ impl ConnectivityManagerActor { .take() .expect("ConnectivityManager initialized without a shutdown_signal"); - let mut connection_manager_events = self.connection_manager.get_event_subscription().fuse(); + let mut connection_manager_events = self.connection_manager.get_event_subscription(); let interval = self.config.connection_pool_refresh_interval; let mut ticker = time::interval_at( @@ -172,18 +172,17 @@ impl ConnectivityManagerActor { .expect("connection_pool_refresh_interval cause overflow") .into(), interval, - ) - .fuse(); + ); self.publish_event(ConnectivityEvent::ConnectivityStateInitialized); loop { - futures::select! { - req = self.request_rx.select_next_some() => { + tokio::select! { + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; }, - event = connection_manager_events.select_next_some() => { + event = connection_manager_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connection_manager_event(&event).await { error!(target:LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -191,13 +190,13 @@ impl ConnectivityManagerActor { } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_connection_pool().await { error!(target: LOG_TARGET, "Error when refreshing connection pools: {:?}", err); } }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "ConnectivityManager is shutting down because it received the shutdown signal"); self.disconnect_all().await; break; @@ -807,7 +806,7 @@ impl ConnectivityManagerActor { fn publish_event(&mut self, event: ConnectivityEvent) { // A send operation can only fail if there are no subscribers, so it is safe to ignore the error - let _ = self.event_tx.send(Arc::new(event)); + let _ = self.event_tx.send(event); } async fn ban_peer( @@ -847,7 +846,7 @@ impl ConnectivityManagerActor { fn delayed_close(conn: PeerConnection, delay: Duration) { task::spawn(async move { - time::delay_for(delay).await; + time::sleep(delay).await; debug!( target: LOG_TARGET, "Closing connection from peer `{}` after delay", diff --git a/comms/src/connectivity/requester.rs b/comms/src/connectivity/requester.rs index b092b80c4d..073661fa22 100644 --- a/comms/src/connectivity/requester.rs +++ b/comms/src/connectivity/requester.rs @@ -31,22 +31,20 @@ use crate::{ peer_manager::NodeId, PeerConnection, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, - StreamExt, -}; use log::*; use std::{ fmt, - sync::Arc, time::{Duration, Instant}, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, broadcast::error::RecvError, mpsc, oneshot}, + time, +}; + const LOG_TARGET: &str = "comms::connectivity::requester"; -pub type ConnectivityEventRx = broadcast::Receiver<Arc<ConnectivityEvent>>; -pub type ConnectivityEventTx = broadcast::Sender<Arc<ConnectivityEvent>>; +pub type ConnectivityEventRx = broadcast::Receiver<ConnectivityEvent>; +pub type ConnectivityEventTx = broadcast::Sender<ConnectivityEvent>; #[derive(Debug, Clone)] pub enum ConnectivityEvent { @@ -254,24 +252,23 @@ impl ConnectivityRequester { let mut last_known_peer_count = status.num_connected_nodes(); loop { debug!(target: LOG_TARGET, "Waiting for connectivity event"); - let recv_result = time::timeout(remaining, connectivity_events.next()) + let recv_result = time::timeout(remaining, connectivity_events.recv()) .await - .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))? - .ok_or(ConnectivityError::ConnectivityEventStreamClosed)?; + .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; remaining = timeout .checked_sub(start.elapsed()) .ok_or(ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; match recv_result { - Ok(event) => match &*event { + Ok(event) => match event { ConnectivityEvent::ConnectivityStateOnline(_) => { info!(target: LOG_TARGET, "Connectivity is ONLINE."); break Ok(()); }, ConnectivityEvent::ConnectivityStateDegraded(n) => { warn!(target: LOG_TARGET, "Connectivity is DEGRADED ({} peer(s))", n); - last_known_peer_count = *n; + last_known_peer_count = n; }, ConnectivityEvent::ConnectivityStateOffline => { warn!( @@ -287,14 +284,14 @@ impl ConnectivityRequester { ); }, }, - Err(broadcast::RecvError::Closed) => { + Err(RecvError::Closed) => { error!( target: LOG_TARGET, "Connectivity event stream closed unexpectedly. System may be shutting down." ); break Err(ConnectivityError::ConnectivityEventStreamClosed); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagging behind on {} connectivity event(s)", n); // We lagged, so could have missed the state change. Check it explicitly. let status = self.get_connectivity_status().await?; diff --git a/comms/src/connectivity/selection.rs b/comms/src/connectivity/selection.rs index 931b78163d..c300d2d353 100644 --- a/comms/src/connectivity/selection.rs +++ b/comms/src/connectivity/selection.rs @@ -135,8 +135,8 @@ mod test { peer_manager::node_id::NodeDistance, test_utils::{mocks::create_dummy_peer_connection, node_id, node_identity::build_node_identity}, }; - use futures::channel::mpsc; use std::iter::repeat_with; + use tokio::sync::mpsc; fn create_pool_with_connections(n: usize) -> (ConnectionPool, Vec<mpsc::Receiver<PeerConnectionRequest>>) { let mut pool = ConnectionPool::new(); diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index a4fec1e896..948d083e94 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -28,6 +28,7 @@ use super::{ }; use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerEvent}, + connectivity::ConnectivityEventRx, peer_manager::{Peer, PeerFeatures}, runtime, runtime::task, @@ -39,18 +40,18 @@ use crate::{ NodeIdentity, PeerManager, }; -use futures::{channel::mpsc, future}; +use futures::future; use std::{sync::Arc, time::Duration}; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, streams, unpack_enum}; -use tokio::sync::broadcast; +use tari_test_utils::{collect_try_recv, streams, unpack_enum}; +use tokio::sync::{broadcast, mpsc}; #[allow(clippy::type_complexity)] fn setup_connectivity_manager( config: ConnectivityConfig, ) -> ( ConnectivityRequester, - broadcast::Receiver<Arc<ConnectivityEvent>>, + ConnectivityEventRx, Arc<NodeIdentity>, Arc<PeerManager>, ConnectionManagerMockState, @@ -100,7 +101,7 @@ async fn add_test_peers(peer_manager: &PeerManager, n: usize) -> Vec<Peer> { peers } -#[runtime::test_basic] +#[runtime::test] async fn connecting_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -117,15 +118,15 @@ async fn connecting_peers() { .map(|(_, _, conn, _)| conn) .collect::<Vec<_>>(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // All connections succeeded for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } - let _events = collect_stream!(event_stream, take = 11, timeout = Duration::from_secs(10)); + let _events = collect_try_recv!(event_stream, take = 11, timeout = Duration::from_secs(10)); let connection_states = connectivity.get_all_connection_states().await.unwrap(); assert_eq!(connection_states.len(), 10); @@ -135,7 +136,7 @@ async fn connecting_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn add_many_managed_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -156,8 +157,8 @@ async fn add_many_managed_peers() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // First 5 succeeded for conn in &connections { @@ -172,10 +173,10 @@ async fn add_many_managed_peers() { )); } - let events = collect_stream!(event_stream, take = 9, timeout = Duration::from_secs(10)); + let events = collect_try_recv!(event_stream, take = 9, timeout = Duration::from_secs(10)); let n = events .iter() - .find_map(|event| match &**event.as_ref().unwrap() { + .find_map(|event| match event { ConnectivityEvent::ConnectivityStateOnline(n) => Some(n), ConnectivityEvent::ConnectivityStateDegraded(_) => None, ConnectivityEvent::PeerConnected(_) => None, @@ -205,7 +206,7 @@ async fn add_many_managed_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn online_then_offline() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -244,8 +245,8 @@ async fn online_then_offline() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); for conn in connections.iter().skip(1) { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); @@ -269,9 +270,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateDegraded(2) => Some(()), _ => None, }, @@ -289,9 +290,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateOffline => Some(()), _ => None, }, @@ -303,20 +304,20 @@ async fn online_then_offline() { assert!(is_offline); } -#[runtime::test_basic] +#[runtime::test] async fn ban_peer() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); let peer = add_test_peers(&peer_manager, 1).await.pop().unwrap(); let (conn, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); - let mut events = collect_stream!(event_stream, take = 2, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = &*events.remove(0).unwrap()); - unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 2, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = events.remove(0)); + unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = events.remove(0)); let conn = connectivity.get_connection(peer.node_id.clone()).await.unwrap(); assert!(conn.is_some()); @@ -329,13 +330,12 @@ async fn ban_peer() { // We can always expect a single PeerBanned because we do not publish a disconnected event from the connection // manager In a real system, peer disconnect and peer banned events may happen in any order and should always be // completely fine. - let event = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)) + let event = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)) .pop() - .unwrap() .unwrap(); - unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = &*event); - assert_eq!(node_id, &peer.node_id); + unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = event); + assert_eq!(node_id, peer.node_id); let peer = peer_manager.find_by_node_id(&peer.node_id).await.unwrap(); assert!(peer.is_banned()); @@ -344,7 +344,7 @@ async fn ban_peer() { assert!(conn.is_none()); } -#[runtime::test_basic] +#[runtime::test] async fn peer_selection() { let config = ConnectivityConfig { min_connectivity: 1.0, @@ -370,15 +370,15 @@ async fn peer_selection() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // 10 connections for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } // Wait for all peers to be connected (i.e. for the connection manager events to be received) - let mut _events = collect_stream!(event_stream, take = 12, timeout = Duration::from_secs(10)); + let mut _events = collect_try_recv!(event_stream, take = 12, timeout = Duration::from_secs(10)); let conns = connectivity .select_connections(ConnectivitySelection::random_nodes(10, vec![connections[0] diff --git a/comms/src/framing.rs b/comms/src/framing.rs index 1e6b67691e..06ccc00c30 100644 --- a/comms/src/framing.rs +++ b/comms/src/framing.rs @@ -20,17 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; /// Tari comms canonical framing -pub type CanonicalFraming<T> = Framed<IoCompat<T>, LengthDelimitedCodec>; +pub type CanonicalFraming<T> = Framed<T, LengthDelimitedCodec>; pub fn canonical<T>(stream: T, max_frame_len: usize) -> CanonicalFraming<T> where T: AsyncRead + AsyncWrite + Unpin { Framed::new( - IoCompat::new(stream), + stream, LengthDelimitedCodec::builder() .max_frame_length(max_frame_len) .new_codec(), diff --git a/comms/src/lib.rs b/comms/src/lib.rs index 4f3bb178d2..2949d4fe4f 100644 --- a/comms/src/lib.rs +++ b/comms/src/lib.rs @@ -40,13 +40,12 @@ pub use multiplexing::Substream; mod noise; mod proto; -mod runtime; pub mod backoff; pub mod bounded_executor; -pub mod compat; pub mod memsocket; pub mod protocol; +pub mod runtime; #[macro_use] pub mod message; pub mod net_address; diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index 3d102bf8dc..ed77fc6146 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -26,17 +26,21 @@ use bytes::{Buf, Bytes}; use futures::{ channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, - io::{AsyncRead, AsyncWrite, Error, ErrorKind, Result}, ready, stream::{FusedStream, Stream}, task::{Context, Poll}, }; use std::{ + cmp, collections::{hash_map::Entry, HashMap}, num::NonZeroU16, pin::Pin, sync::Mutex, }; +use tokio::{ + io, + io::{AsyncRead, AsyncWrite, ErrorKind, ReadBuf}, +}; lazy_static! { static ref SWITCHBOARD: Mutex<SwitchBoard> = Mutex::new(SwitchBoard(HashMap::default(), 1)); @@ -114,6 +118,7 @@ pub fn release_memsocket_port(port: NonZeroU16) { /// use std::io::Result; /// /// use tari_comms::memsocket::{MemoryListener, MemorySocket}; +/// use tokio::io::*; /// use futures::prelude::*; /// /// async fn write_stormlight(mut stream: MemorySocket) -> Result<()> { @@ -170,7 +175,7 @@ impl MemoryListener { /// ``` /// /// [`local_addr`]: #method.local_addr - pub fn bind(port: u16) -> Result<Self> { + pub fn bind(port: u16) -> io::Result<Self> { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Get the port we should bind to. If 0 was given, use a random port @@ -262,11 +267,11 @@ impl MemoryListener { Incoming { inner: self } } - fn poll_accept(&mut self, context: &mut Context) -> Poll<Result<MemorySocket>> { + fn poll_accept(&mut self, context: &mut Context) -> Poll<io::Result<MemorySocket>> { match Pin::new(&mut self.incoming).poll_next(context) { Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)), Poll::Ready(None) => { - let err = Error::new(ErrorKind::Other, "MemoryListener unknown error"); + let err = io::Error::new(ErrorKind::Other, "MemoryListener unknown error"); Poll::Ready(Err(err)) }, Poll::Pending => Poll::Pending, @@ -283,7 +288,7 @@ pub struct Incoming<'a> { } impl<'a> Stream for Incoming<'a> { - type Item = Result<MemorySocket>; + type Item = io::Result<MemorySocket>; fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Option<Self::Item>> { let socket = ready!(self.inner.poll_accept(context)?); @@ -302,6 +307,7 @@ impl<'a> Stream for Incoming<'a> { /// /// ```rust, no_run /// use futures::prelude::*; +/// use tokio::io::*; /// use tari_comms::memsocket::MemorySocket; /// /// # async fn run() -> ::std::io::Result<()> { @@ -371,7 +377,7 @@ impl MemorySocket { /// let socket = MemorySocket::connect(16)?; /// # Ok(())} /// ``` - pub fn connect(port: u16) -> Result<MemorySocket> { + pub fn connect(port: u16) -> io::Result<MemorySocket> { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Find port to connect to @@ -399,13 +405,13 @@ impl MemorySocket { impl AsyncRead for MemorySocket { /// Attempt to read from the `AsyncRead` into `buf`. - fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut [u8]) -> Poll<Result<usize>> { + fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { if self.incoming.is_terminated() { if self.seen_eof { return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); } else { self.seen_eof = true; - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } } @@ -413,22 +419,23 @@ impl AsyncRead for MemorySocket { loop { // If we're already filled up the buffer then we can return - if bytes_read == buf.len() { - return Poll::Ready(Ok(bytes_read)); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } match self.current_buffer { // We have data to copy to buf Some(ref mut current_buffer) if !current_buffer.is_empty() => { - let bytes_to_read = ::std::cmp::min(buf.len() - bytes_read, current_buffer.len()); - debug_assert!(bytes_to_read > 0); - - buf[bytes_read..(bytes_read + bytes_to_read)] - .copy_from_slice(current_buffer.slice(0..bytes_to_read).as_ref()); + let bytes_to_read = cmp::min(buf.remaining(), current_buffer.len()); + if bytes_to_read > 0 { + buf.initialize_unfilled_to(bytes_to_read) + .copy_from_slice(¤t_buffer.slice(..bytes_to_read)); + buf.advance(bytes_to_read); - current_buffer.advance(bytes_to_read); + current_buffer.advance(bytes_to_read); - bytes_read += bytes_to_read; + bytes_read += bytes_to_read; + } }, // Either we've exhausted our current buffer or don't have one @@ -438,13 +445,13 @@ impl AsyncRead for MemorySocket { Poll::Pending => { // If we've read anything up to this point return the bytes read if bytes_read > 0 { - return Poll::Ready(Ok(bytes_read)); + return Poll::Ready(Ok(())); } else { return Poll::Pending; } }, Poll::Ready(Some(buf)) => Some(buf), - Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)), + Poll::Ready(None) => return Poll::Ready(Ok(())), } }; }, @@ -455,14 +462,14 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll<Result<usize>> { + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { let len = buf.len(); match self.outgoing.poll_ready(context) { Poll::Ready(Ok(())) => { if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -471,7 +478,7 @@ impl AsyncWrite for MemorySocket { }, Poll::Ready(Err(e)) => { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -484,12 +491,12 @@ impl AsyncWrite for MemorySocket { } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> { + fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> { Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> { + fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -499,15 +506,12 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use futures::{ - executor::block_on, - io::{AsyncReadExt, AsyncWriteExt}, - stream::StreamExt, - }; - use std::io::Result; + use crate::runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; #[test] - fn listener_bind() -> Result<()> { + fn listener_bind() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let listener = MemoryListener::bind(port)?; assert_eq!(listener.local_addr(), port); @@ -515,172 +519,187 @@ mod test { Ok(()) } - #[test] - fn simple_connect() -> Result<()> { + #[runtime::test] + async fn simple_connect() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let mut listener = MemoryListener::bind(port)?; let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); Ok(()) } - #[test] - fn listen_on_port_zero() -> Result<()> { + #[runtime::test] + async fn listen_on_port_zero() -> io::Result<()> { let mut listener = MemoryListener::bind(0)?; let listener_addr = listener.local_addr(); let mut dialer = MemorySocket::connect(listener_addr)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(listener_socket.write_all(b"bar"))?; - block_on(listener_socket.flush())?; + listener_socket.write_all(b"bar").await?; + listener_socket.flush().await?; let mut buf = [0; 3]; - block_on(dialer.read_exact(&mut buf))?; + dialer.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let mut listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + #[runtime::test] + async fn listener_correctly_frees_port_on_drop() { + async fn connect_on_port(port: u16) { + let mut listener = MemoryListener::bind(port).unwrap(); + let mut dialer = MemorySocket::connect(port).unwrap(); + let mut listener_socket = listener.incoming().next().await.unwrap().unwrap(); - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await.unwrap(); + dialer.flush().await.unwrap(); let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + let n = listener_socket.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, 3); assert_eq!(&buf, b"foo"); - - Ok(()) } let port = acquire_next_memsocket_port().into(); - connect_on_port(port)?; - connect_on_port(port)?; - - Ok(()) + connect_on_port(port).await; + connect_on_port(port).await; } - #[test] - fn simple_write_read() -> Result<()> { + #[runtime::test] + async fn simple_write_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"hello world"))?; - block_on(a.flush())?; + a.write_all(b"hello world").await?; + a.flush().await?; drop(a); let mut v = Vec::new(); - block_on(b.read_to_end(&mut v))?; + b.read_to_end(&mut v).await?; assert_eq!(v, b"hello world"); Ok(()) } - #[test] - fn partial_read() -> Result<()> { + #[runtime::test] + async fn partial_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; let mut buf = [0; 3]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn partial_read_write_both_sides() -> Result<()> { + #[runtime::test] + async fn partial_read_write_both_sides() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; - block_on(b.write_all(b"stormlight"))?; - block_on(b.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; + b.write_all(b"stormlight").await?; + b.flush().await?; let mut buf_a = [0; 5]; let mut buf_b = [0; 3]; - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"storm"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"foo"); - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"light"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"bar"); Ok(()) } - #[test] - fn many_small_writes() -> Result<()> { + #[runtime::test] + async fn many_small_writes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"words"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"of"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"radiance"))?; - block_on(a.flush())?; + a.write_all(b"words").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"of").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"radiance").await?; + a.flush().await?; drop(a); let mut buf = [0; 17]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"words of radiance"); Ok(()) } - #[test] - fn read_zero_bytes() -> Result<()> { + #[runtime::test] + async fn large_writes() -> io::Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + let large_data = vec![123u8; 1024]; + a.write_all(&large_data).await?; + a.flush().await?; + drop(a); + + let mut buf = Vec::new(); + b.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 1024); + + Ok(()) + } + + #[runtime::test] + async fn read_zero_bytes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 12]; - block_on(b.read_exact(&mut buf[0..0]))?; + b.read_exact(&mut buf[0..0]).await?; assert_eq!(buf, [0; 12]); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"way of kings"); Ok(()) } - #[test] - fn read_bytes_with_large_buffer() -> Result<()> { + #[runtime::test] + async fn read_bytes_with_large_buffer() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 20]; - let bytes_read = block_on(b.read(&mut buf))?; + let bytes_read = b.read(&mut buf).await?; assert_eq!(bytes_read, 12); assert_eq!(&buf[0..12], b"way of kings"); diff --git a/comms/src/message/outbound.rs b/comms/src/message/outbound.rs index 25f0899143..a08c9604ce 100644 --- a/comms/src/message/outbound.rs +++ b/comms/src/message/outbound.rs @@ -22,11 +22,11 @@ use crate::{message::MessageTag, peer_manager::NodeId, protocol::messaging::SendFailReason}; use bytes::Bytes; -use futures::channel::oneshot; use std::{ fmt, fmt::{Error, Formatter}, }; +use tokio::sync::oneshot; pub type MessagingReplyResult = Result<(), SendFailReason>; pub type MessagingReplyRx = oneshot::Receiver<MessagingReplyResult>; diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 4c78b38174..f7275cc7d0 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -21,18 +21,15 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{connection_manager::ConnectionDirection, runtime}; -use futures::{ - channel::mpsc, - io::{AsyncRead, AsyncWrite}, - stream::FusedStream, - task::Context, - SinkExt, - Stream, - StreamExt, -}; +use futures::{task::Context, Stream}; use log::*; use std::{future::Future, io, pin::Pin, sync::Arc, task::Poll}; use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use yamux::Mode; type IncomingRx = mpsc::Receiver<yamux::Stream>; @@ -73,7 +70,7 @@ impl Yamux { config.set_receive_window(RECEIVE_WINDOW); let substream_counter = SubstreamCounter::new(); - let connection = yamux::Connection::new(socket, config, mode); + let connection = yamux::Connection::new(socket.compat(), config, mode); let control = Control::new(connection.control(), substream_counter.clone()); let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); @@ -91,12 +88,11 @@ impl Yamux { counter: SubstreamCounter, ) -> IncomingSubstreams where - TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, + TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static, { let shutdown = Shutdown::new(); let (incoming_tx, incoming_rx) = mpsc::channel(10); - let stream = yamux::into_stream(connection).boxed(); - let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); + let incoming = IncomingWorker::new(connection, incoming_tx, shutdown.to_signal()); runtime::task::spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } @@ -125,10 +121,6 @@ impl Yamux { pub(crate) fn substream_counter(&self) -> SubstreamCounter { self.substream_counter.clone() } - - pub fn is_terminated(&self) -> bool { - self.incoming.is_terminated() - } } #[derive(Clone)] @@ -149,7 +141,7 @@ impl Control { pub async fn open_stream(&mut self) -> Result<Substream, ConnectionError> { let stream = self.inner.open_stream().await?; Ok(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), }) } @@ -188,19 +180,13 @@ impl IncomingSubstreams { } } -impl FusedStream for IncomingSubstreams { - fn is_terminated(&self) -> bool { - self.inner.is_terminated() - } -} - impl Stream for IncomingSubstreams { type Item = Substream; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match futures::ready!(Pin::new(&mut self.inner).poll_next(cx)) { + match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(stream) => Poll::Ready(Some(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), })), None => Poll::Ready(None), @@ -216,17 +202,17 @@ impl Drop for IncomingSubstreams { #[derive(Debug)] pub struct Substream { - stream: yamux::Stream, + stream: Compat<yamux::Stream>, counter_guard: CounterGuard, } -impl AsyncRead for Substream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { +impl tokio::io::AsyncRead for Substream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { Pin::new(&mut self.stream).poll_read(cx, buf) } } -impl AsyncWrite for Substream { +impl tokio::io::AsyncWrite for Substream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { Pin::new(&mut self.stream).poll_write(cx, buf) } @@ -235,54 +221,73 @@ impl AsyncWrite for Substream { Pin::new(&mut self.stream).poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - Pin::new(&mut self.stream).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.stream).poll_shutdown(cx) } } -struct IncomingWorker<S> { - inner: S, +struct IncomingWorker<TSocket> { + connection: yamux::Connection<TSocket>, sender: mpsc::Sender<yamux::Stream>, shutdown_signal: ShutdownSignal, } -impl<S> IncomingWorker<S> -where S: Stream<Item = Result<yamux::Stream, yamux::ConnectionError>> + Unpin +impl<TSocket> IncomingWorker<TSocket> +where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */ { - pub fn new(stream: S, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { + pub fn new(connection: yamux::Connection<TSocket>, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { Self { - inner: stream, + connection, sender, shutdown_signal, } } pub async fn run(mut self) { - let mut mux_stream = self.inner.take_until(&mut self.shutdown_signal); - while let Some(result) = mux_stream.next().await { - match result { - Ok(stream) => { - if self.sender.send(stream).await.is_err() { - debug!( - target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" - ); - break; + loop { + tokio::select! { + biased; + + _ = &mut self.shutdown_signal => { + let mut control = self.connection.control(); + if let Err(err) = control.close().await { + error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); + } + break + } + + result = self.connection.next_stream() => { + match result { + Ok(Some(stream)) => { + if self.sender.send(stream).await.is_err() { + debug!( + target: LOG_TARGET, + "Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed" + ); + break; + } + }, + Ok(None) =>{ + debug!( + target: LOG_TARGET, + "Incoming peer substream completed. IncomingWorker exiting" + ); + break; + } + Err(err) => { + debug!( + target: LOG_TARGET, + "Incoming peer substream task received an error because '{}'", err + ); + break; + }, } - }, - Err(err) => { - debug!( - target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err - ); - break; - }, + } } } debug!(target: LOG_TARGET, "Incoming peer substream task is shutting down"); - self.sender.close_channel(); } } @@ -317,15 +322,12 @@ mod test { runtime, runtime::task, }; - use futures::{ - future, - io::{AsyncReadExt, AsyncWriteExt}, - StreamExt, - }; use std::{io, time::Duration}; use tari_test_utils::collect_stream; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn open_substream() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; @@ -340,7 +342,7 @@ mod test { substream.write_all(msg).await.unwrap(); substream.flush().await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); }); let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) @@ -352,13 +354,16 @@ mod test { .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))?; let mut buf = Vec::new(); - let _ = future::select(substream.read_to_end(&mut buf), listener.next()).await; + tokio::select! { + _ = substream.read_to_end(&mut buf) => {}, + _ = listener.next() => {}, + }; assert_eq!(buf, msg); Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn substream_count() { const NUM_SUBSTREAMS: usize = 10; let (dialer, listener) = MemorySocket::new_pair(); @@ -392,7 +397,7 @@ mod test { assert_eq!(listener.substream_count(), 0); } - #[runtime::test_basic] + #[runtime::test] async fn close() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; @@ -421,7 +426,7 @@ mod test { assert_eq!(buf, msg); // Close the substream and then try to write to it - substream.close().await?; + substream.shutdown().await?; let result = substream.write_all(b"ignored message").await; match result { @@ -432,7 +437,7 @@ mod test { Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn send_big_message() -> io::Result<()> { #[allow(non_upper_case_globals)] static MiB: usize = 1 << 20; @@ -453,7 +458,7 @@ mod test { let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); assert_eq!(buf.len(), MSG_LEN); assert_eq!(buf, vec![0xAAu8; MSG_LEN]); @@ -472,7 +477,7 @@ mod test { let msg = vec![0xAAu8; MSG_LEN]; substream.write_all(msg.as_slice()).await?; - substream.close().await?; + substream.shutdown().await?; drop(substream); assert_eq!(incoming.substream_count(), 0); diff --git a/comms/src/noise/config.rs b/comms/src/noise/config.rs index 7776ade335..946ddfe53a 100644 --- a/comms/src/noise/config.rs +++ b/comms/src/noise/config.rs @@ -31,11 +31,11 @@ use crate::{ }, peer_manager::NodeIdentity, }; -use futures::{AsyncRead, AsyncWrite}; use log::*; use snow::{self, params::NoiseParams}; use std::sync::Arc; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncWrite}; const LOG_TARGET: &str = "comms::noise"; pub(super) const NOISE_IX_PARAMETER: &str = "Noise_IX_25519_ChaChaPoly_BLAKE2b"; @@ -95,10 +95,15 @@ impl NoiseConfig { #[cfg(test)] mod test { use super::*; - use crate::{memsocket::MemorySocket, peer_manager::PeerFeatures, test_utils::node_identity::build_node_identity}; - use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt}; + use crate::{ + memsocket::MemorySocket, + peer_manager::PeerFeatures, + runtime, + test_utils::node_identity::build_node_identity, + }; + use futures::{future, FutureExt}; use snow::params::{BaseChoice, CipherChoice, DHChoice, HandshakePattern, HashChoice}; - use tokio::runtime::Runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; fn check_noise_params(config: &NoiseConfig) { assert_eq!(config.parameters.hash, HashChoice::Blake2b); @@ -117,39 +122,35 @@ mod test { assert_eq!(config.node_identity.public_key(), node_identity.public_key()); } - #[test] - fn upgrade_socket() { - let mut rt = Runtime::new().unwrap(); - + #[runtime::test] + async fn upgrade_socket() { let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config1 = NoiseConfig::new(node_identity1.clone()); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config2 = NoiseConfig::new(node_identity2.clone()); - rt.block_on(async move { - let (in_socket, out_socket) = MemorySocket::new_pair(); - let (mut socket_in, mut socket_out) = future::join( - config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), - config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), - ) - .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) - .await; - - let in_pubkey = socket_in.get_remote_public_key().unwrap(); - let out_pubkey = socket_out.get_remote_public_key().unwrap(); - - assert_eq!(&in_pubkey, node_identity2.public_key()); - assert_eq!(&out_pubkey, node_identity1.public_key()); - - let sample = b"Children of time"; - socket_in.write_all(sample).await.unwrap(); - socket_in.flush().await.unwrap(); - socket_in.close().await.unwrap(); - - let mut read_buf = Vec::with_capacity(16); - socket_out.read_to_end(&mut read_buf).await.unwrap(); - assert_eq!(read_buf, sample); - }); + let (in_socket, out_socket) = MemorySocket::new_pair(); + let (mut socket_in, mut socket_out) = future::join( + config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), + config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), + ) + .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) + .await; + + let in_pubkey = socket_in.get_remote_public_key().unwrap(); + let out_pubkey = socket_out.get_remote_public_key().unwrap(); + + assert_eq!(&in_pubkey, node_identity2.public_key()); + assert_eq!(&out_pubkey, node_identity1.public_key()); + + let sample = b"Children of time"; + socket_in.write_all(sample).await.unwrap(); + socket_in.flush().await.unwrap(); + socket_in.shutdown().await.unwrap(); + + let mut read_buf = Vec::with_capacity(16); + socket_out.read_to_end(&mut read_buf).await.unwrap(); + assert_eq!(read_buf, sample); } } diff --git a/comms/src/noise/socket.rs b/comms/src/noise/socket.rs index eaf02a60b0..33d35f89ca 100644 --- a/comms/src/noise/socket.rs +++ b/comms/src/noise/socket.rs @@ -26,27 +26,27 @@ //! Noise Socket +use crate::types::CommsPublicKey; use futures::ready; use log::*; use snow::{error::StateProblem, HandshakeState, TransportState}; use std::{ + cmp, convert::TryInto, io, pin::Pin, task::{Context, Poll}, }; -// use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::types::CommsPublicKey; -use futures::{io::Error, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; const LOG_TARGET: &str = "comms::noise::socket"; -const MAX_PAYLOAD_LENGTH: usize = u16::max_value() as usize; // 65535 +const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize; // 65535 // The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because // encrypted messages include a tag along with the payload. -const MAX_WRITE_BUFFER_LENGTH: usize = u16::max_value() as usize - 16; // 65519 +const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16; // 65519 /// Collection of buffers used for buffering data during the various read/write states of a /// NoiseSocket @@ -223,7 +223,12 @@ where TSocket: AsyncRead, { loop { - let n = ready!(socket.as_mut().poll_read(&mut context, &mut buf[*offset..]))?; + let mut read_buf = ReadBuf::new(&mut buf[*offset..]); + let prev_rem = read_buf.remaining(); + ready!(socket.as_mut().poll_read(&mut context, &mut read_buf))?; + let n = prev_rem + .checked_sub(read_buf.remaining()) + .expect("buffer underflow: prev_rem < read_buf.remaining()"); trace!( target: LOG_TARGET, "poll_read_exact: read {}/{} bytes", @@ -320,7 +325,7 @@ where TSocket: AsyncRead + Unpin decrypted_len, ref mut offset, } => { - let bytes_to_copy = ::std::cmp::min(decrypted_len as usize - *offset, buf.len()); + let bytes_to_copy = cmp::min(decrypted_len as usize - *offset, buf.len()); buf[..bytes_to_copy] .copy_from_slice(&self.buffers.read_decrypted[*offset..(*offset + bytes_to_copy)]); trace!( @@ -351,8 +356,11 @@ where TSocket: AsyncRead + Unpin impl<TSocket> AsyncRead for NoiseSocket<TSocket> where TSocket: AsyncRead + Unpin { - fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { - self.get_mut().poll_read(context, buf) + fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { + let slice = buf.initialize_unfilled(); + let n = futures::ready!(self.get_mut().poll_read(context, slice))?; + buf.advance(n); + Poll::Ready(Ok(())) } } @@ -501,8 +509,8 @@ where TSocket: AsyncWrite + Unpin self.get_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { - Pin::new(&mut self.socket).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.socket).poll_shutdown(cx) } } @@ -531,7 +539,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin target: LOG_TARGET, "Noise handshake failed because '{:?}'. Closing socket.", err ); - self.socket.close().await?; + self.socket.shutdown().await?; Err(err) }, } @@ -644,7 +652,6 @@ mod test { use futures::future::join; use snow::{params::NoiseParams, Builder, Error, Keypair}; use std::io; - use tokio::runtime::Runtime; async fn build_test_connection( ) -> Result<((Keypair, Handshake<MemorySocket>), (Keypair, Handshake<MemorySocket>)), Error> { @@ -707,7 +714,7 @@ mod test { dialer_socket.write_all(b" ").await?; dialer_socket.write_all(b"archive").await?; dialer_socket.flush().await?; - dialer_socket.close().await?; + dialer_socket.shutdown().await?; let mut buf = Vec::new(); listener_socket.read_to_end(&mut buf).await?; @@ -745,51 +752,60 @@ mod test { Ok(()) } - #[test] - fn u16_max_writes() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn u16_max_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH + 1]; + a.write_all(&buf_send).await?; + a.flush().await?; - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await?; - assert_eq!(&buf_receive[..], &buf_send[..]); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH + 1]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - Ok(()) - }) + Ok(()) } - #[test] - fn unexpected_eof() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn larger_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH * 2 + 1024]; + a.write_all(&buf_send).await?; + a.flush().await?; - a.socket.close().await.unwrap(); - drop(a); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH * 2 + 1024]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await.unwrap(); - assert_eq!(&buf_receive[..], &buf_send[..]); + Ok(()) + } + + #[runtime::test] + async fn unexpected_eof() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let err = b.read_exact(&mut buf_receive).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - Ok(()) - }) + let buf_send = [1; MAX_PAYLOAD_LENGTH]; + a.write_all(&buf_send).await?; + a.flush().await?; + + a.socket.shutdown().await.unwrap(); + drop(a); + + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; + b.read_exact(&mut buf_receive).await.unwrap(); + assert_eq!(&buf_receive[..], &buf_send[..]); + + let err = b.read_exact(&mut buf_receive).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) } } diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index 3828c82ab5..d7df51a2ea 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -337,7 +337,7 @@ mod test { peer } - #[runtime::test_basic] + #[runtime::test] async fn get_broadcast_identities() { // Create peer manager with random peers let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); @@ -446,7 +446,7 @@ mod test { assert_ne!(identities1, identities2); } - #[runtime::test_basic] + #[runtime::test] async fn calc_region_threshold() { let n = 5; // Create peer manager with random peers @@ -514,7 +514,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn closest_peers() { let n = 5; // Create peer manager with random peers @@ -548,7 +548,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn add_or_update_online_peer() { let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); let mut peer = create_test_peer(false, PeerFeatures::COMMUNICATION_NODE); diff --git a/comms/src/pipeline/builder.rs b/comms/src/pipeline/builder.rs index 40a38d10a3..9ae90fabec 100644 --- a/comms/src/pipeline/builder.rs +++ b/comms/src/pipeline/builder.rs @@ -24,8 +24,8 @@ use crate::{ message::{InboundMessage, OutboundMessage}, pipeline::SinkService, }; -use futures::channel::mpsc; use thiserror::Error; +use tokio::sync::mpsc; use tower::Service; const DEFAULT_MAX_CONCURRENT_TASKS: usize = 50; @@ -99,9 +99,7 @@ where TOutSvc: Service<TOutReq> + Clone + Send + 'static, TInSvc: Service<InboundMessage> + Clone + Send + 'static, { - fn build_outbound( - &mut self, - ) -> Result<OutboundPipelineConfig<mpsc::Receiver<TOutReq>, TOutSvc>, PipelineBuilderError> { + fn build_outbound(&mut self) -> Result<OutboundPipelineConfig<TOutReq, TOutSvc>, PipelineBuilderError> { let (out_sender, out_receiver) = mpsc::channel(self.outbound_buffer_size); let in_receiver = self @@ -137,9 +135,9 @@ where } } -pub struct OutboundPipelineConfig<TInStream, TPipeline> { +pub struct OutboundPipelineConfig<TInItem, TPipeline> { /// Messages read from this stream are passed to the pipeline - pub in_receiver: TInStream, + pub in_receiver: mpsc::Receiver<TInItem>, /// Receiver of `OutboundMessage`s coming from the pipeline pub out_receiver: mpsc::Receiver<OutboundMessage>, /// The pipeline (`tower::Service`) to run for each in_stream message @@ -149,7 +147,7 @@ pub struct OutboundPipelineConfig<TInStream, TPipeline> { pub struct Config<TInSvc, TOutSvc, TOutReq> { pub max_concurrent_inbound_tasks: usize, pub inbound: TInSvc, - pub outbound: OutboundPipelineConfig<mpsc::Receiver<TOutReq>, TOutSvc>, + pub outbound: OutboundPipelineConfig<TOutReq, TOutSvc>, } #[derive(Debug, Error)] diff --git a/comms/src/pipeline/inbound.rs b/comms/src/pipeline/inbound.rs index 0b2116bc37..1f135640a7 100644 --- a/comms/src/pipeline/inbound.rs +++ b/comms/src/pipeline/inbound.rs @@ -21,10 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::bounded_executor::BoundedExecutor; -use futures::{future::FusedFuture, stream::FusedStream, Stream, StreamExt}; +use futures::future::FusedFuture; use log::*; use std::fmt::Display; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::inbound"; @@ -33,22 +34,26 @@ const LOG_TARGET: &str = "comms::pipeline::inbound"; /// The difference between this can ServiceExt::call_all is /// that ServicePipeline doesn't keep the result of the service /// call and that it spawns a task for each incoming item. -pub struct Inbound<TSvc, TStream> { +pub struct Inbound<TSvc, TMsg> { executor: BoundedExecutor, service: TSvc, - stream: TStream, + stream: mpsc::Receiver<TMsg>, shutdown_signal: ShutdownSignal, } -impl<TSvc, TStream> Inbound<TSvc, TStream> +impl<TSvc, TMsg> Inbound<TSvc, TMsg> where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TSvc: Service<TStream::Item> + Clone + Send + 'static, + TMsg: Send + 'static, + TSvc: Service<TMsg> + Clone + Send + 'static, TSvc::Error: Display + Send, TSvc::Future: Send, { - pub fn new(executor: BoundedExecutor, stream: TStream, service: TSvc, shutdown_signal: ShutdownSignal) -> Self { + pub fn new( + executor: BoundedExecutor, + stream: mpsc::Receiver<TMsg>, + service: TSvc, + shutdown_signal: ShutdownSignal, + ) -> Self { Self { executor, service, @@ -59,7 +64,7 @@ where } pub async fn run(mut self) { - while let Some(item) = self.stream.next().await { + while let Some(item) = self.stream.recv().await { // Check if the shutdown signal has been triggered. // If there are messages in the stream, drop them. Otherwise the stream is empty, // it will return None and the while loop will end. @@ -100,21 +105,25 @@ where mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, future, stream}; + use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; - use tari_test_utils::collect_stream; - use tokio::{runtime::Handle, time}; + use tari_test_utils::collect_recv; + use tokio::{sync::mpsc, time}; use tower::service_fn; - #[runtime::test_basic] + #[runtime::test] async fn run() { let items = vec![1, 2, 3, 4, 5, 6]; - let stream = stream::iter(items.clone()).fuse(); + let (tx, mut stream) = mpsc::channel(items.len()); + for i in items.clone() { + tx.send(i).await.unwrap(); + } + stream.close(); - let (mut out_tx, mut out_rx) = mpsc::channel(items.len()); + let (out_tx, mut out_rx) = mpsc::channel(items.len()); - let executor = Handle::current(); + let executor = runtime::current(); let shutdown = Shutdown::new(); let pipeline = Inbound::new( BoundedExecutor::new(executor.clone(), 1), @@ -125,9 +134,10 @@ mod test { }), shutdown.to_signal(), ); + let spawned_task = executor.spawn(pipeline.run()); - let received = collect_stream!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); + let received = collect_recv!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); assert!(received.iter().all(|i| items.contains(i))); // Check that this task ends because the stream has closed diff --git a/comms/src/pipeline/mod.rs b/comms/src/pipeline/mod.rs index 1039374c65..4ea2da9c53 100644 --- a/comms/src/pipeline/mod.rs +++ b/comms/src/pipeline/mod.rs @@ -44,7 +44,7 @@ pub(crate) use inbound::Inbound; mod outbound; pub(crate) use outbound::Outbound; -mod translate_sink; -pub use translate_sink::TranslateSink; +// mod translate_sink; +// pub use translate_sink::TranslateSink; pub type PipelineError = anyhow::Error; diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index c860166ad0..54facdb92b 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -25,34 +25,33 @@ use crate::{ pipeline::builder::OutboundPipelineConfig, protocol::messaging::MessagingRequest, }; -use futures::{channel::mpsc, future, future::Either, stream::FusedStream, SinkExt, Stream, StreamExt}; +use futures::future::Either; use log::*; use std::fmt::Display; -use tokio::runtime; +use tokio::{runtime, sync::mpsc}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::outbound"; -pub struct Outbound<TPipeline, TStream> { +pub struct Outbound<TPipeline, TItem> { /// Executor used to spawn a pipeline for each received item on the stream executor: runtime::Handle, /// Outbound pipeline configuration containing the pipeline and it's in and out streams - config: OutboundPipelineConfig<TStream, TPipeline>, + config: OutboundPipelineConfig<TItem, TPipeline>, /// Request sender for Messaging messaging_request_tx: mpsc::Sender<MessagingRequest>, } -impl<TPipeline, TStream> Outbound<TPipeline, TStream> +impl<TPipeline, TItem> Outbound<TPipeline, TItem> where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TPipeline: Service<TStream::Item, Response = ()> + Clone + Send + 'static, + TItem: Send + 'static, + TPipeline: Service<TItem, Response = ()> + Clone + Send + 'static, TPipeline::Error: Display + Send, TPipeline::Future: Send, { pub fn new( executor: runtime::Handle, - config: OutboundPipelineConfig<TStream, TPipeline>, + config: OutboundPipelineConfig<TItem, TPipeline>, messaging_request_tx: mpsc::Sender<MessagingRequest>, ) -> Self { Self { @@ -64,10 +63,13 @@ where pub async fn run(mut self) { loop { - let either = future::select(self.config.in_receiver.next(), self.config.out_receiver.next()).await; + let either = tokio::select! { + next = self.config.in_receiver.recv() => Either::Left(next), + next = self.config.out_receiver.recv() => Either::Right(next) + }; match either { // Pipeline IN received a message. Spawn a new task for the pipeline - Either::Left((Some(msg), _)) => { + Either::Left(Some(msg)) => { let pipeline = self.config.pipeline.clone(); self.executor.spawn(async move { if let Err(err) = pipeline.oneshot(msg).await { @@ -76,7 +78,7 @@ where }); }, // Pipeline IN channel closed - Either::Left((None, _)) => { + Either::Left(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the in channel closed" @@ -84,7 +86,7 @@ where break; }, // Pipeline OUT received a message - Either::Right((Some(out_msg), _)) => { + Either::Right(Some(out_msg)) => { if self.messaging_request_tx.is_closed() { // MessagingRequest channel closed break; @@ -92,7 +94,7 @@ where self.send_messaging_request(out_msg).await; }, // Pipeline OUT channel closed - Either::Right((None, _)) => { + Either::Right(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the out channel closed" @@ -117,19 +119,22 @@ where #[cfg(test)] mod test { use super::*; - use crate::{pipeline::SinkService, runtime}; + use crate::{pipeline::SinkService, runtime, utils}; use bytes::Bytes; - use futures::stream; use std::time::Duration; - use tari_test_utils::{collect_stream, unpack_enum}; + use tari_test_utils::{collect_recv, unpack_enum}; use tokio::{runtime::Handle, time}; - #[runtime::test_basic] + #[runtime::test] async fn run() { const NUM_ITEMS: usize = 10; - let items = - (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))); - let stream = stream::iter(items).fuse(); + let (tx, in_receiver) = mpsc::channel(NUM_ITEMS); + utils::mpsc::send_all( + &tx, + (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))), + ) + .await + .unwrap(); let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); let (msg_tx, mut msg_rx) = mpsc::channel(NUM_ITEMS); let executor = Handle::current(); @@ -137,7 +142,7 @@ mod test { let pipeline = Outbound::new( executor.clone(), OutboundPipelineConfig { - in_receiver: stream, + in_receiver, out_receiver: out_rx, pipeline: SinkService::new(out_tx), }, @@ -146,7 +151,8 @@ mod test { let spawned_task = executor.spawn(pipeline.run()); - let requests = collect_stream!(msg_rx, take = NUM_ITEMS, timeout = Duration::from_millis(5)); + msg_rx.close(); + let requests = collect_recv!(msg_rx, timeout = Duration::from_millis(5)); for req in requests { unpack_enum!(MessagingRequest::SendMessage(_o) = req); } diff --git a/comms/src/pipeline/sink.rs b/comms/src/pipeline/sink.rs index a455aaf320..bb6f5c270e 100644 --- a/comms/src/pipeline/sink.rs +++ b/comms/src/pipeline/sink.rs @@ -21,8 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::PipelineError; -use futures::{future::BoxFuture, task::Context, FutureExt, Sink, SinkExt}; -use std::{pin::Pin, task::Poll}; +use futures::{future::BoxFuture, task::Context, FutureExt}; +use std::task::Poll; use tower::Service; /// A service which forwards and messages it gets to the given Sink @@ -35,22 +35,44 @@ impl<TSink> SinkService<TSink> { } } -impl<T, TSink> Service<T> for SinkService<TSink> -where - T: Send + 'static, - TSink: Sink<T> + Unpin + Clone + Send + 'static, - TSink::Error: Into<PipelineError> + Send + 'static, +// impl<T, TSink> Service<T> for SinkService<TSink> +// where +// T: Send + 'static, +// TSink: Sink<T> + Unpin + Clone + Send + 'static, +// TSink::Error: Into<PipelineError> + Send + 'static, +// { +// type Error = PipelineError; +// type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; +// type Response = (); +// +// fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { +// Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) +// } +// +// fn call(&mut self, item: T) -> Self::Future { +// let mut sink = self.0.clone(); +// async move { sink.send(item).await.map_err(Into::into) }.boxed() +// } +// } + +impl<T> Service<T> for SinkService<tokio::sync::mpsc::Sender<T>> +where T: Send + 'static { type Error = PipelineError; type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) } fn call(&mut self, item: T) -> Self::Future { - let mut sink = self.0.clone(); - async move { sink.send(item).await.map_err(Into::into) }.boxed() + let sink = self.0.clone(); + async move { + sink.send(item) + .await + .map_err(|_| anyhow::anyhow!("sink closed in sink service")) + } + .boxed() } } diff --git a/comms/src/pipeline/translate_sink.rs b/comms/src/pipeline/translate_sink.rs index 6a2bcad56a..606c038299 100644 --- a/comms/src/pipeline/translate_sink.rs +++ b/comms/src/pipeline/translate_sink.rs @@ -93,9 +93,10 @@ where F: FnMut(I) -> Option<O> mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, SinkExt, StreamExt}; + use futures::{SinkExt, StreamExt}; + use tokio::sync::mpsc; - #[runtime::test_basic] + #[runtime::test] async fn check_translates() { let (tx, mut rx) = mpsc::channel(1); diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index df2900e1af..6bcfca1adf 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -20,19 +20,21 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - compat::IoCompat, connection_manager::ConnectionDirection, message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, protocol::{NodeNetworkInfo, ProtocolError, ProtocolId, ProtocolNegotiation}, }; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use log::*; use prost::Message; use std::{io, time::Duration}; use thiserror::Error; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; pub static IDENTITY_PROTOCOL: ProtocolId = ProtocolId::from_static(b"t/identity/1.0"); @@ -77,7 +79,7 @@ where debug_assert_eq!(proto, IDENTITY_PROTOCOL); // Create length-delimited frame codec - let framed = Framed::new(IoCompat::new(socket), LengthDelimitedCodec::new()); + let framed = Framed::new(socket, LengthDelimitedCodec::new()); let (mut sink, mut stream) = framed.split(); let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect(); @@ -134,8 +136,8 @@ pub enum IdentityProtocolError { ProtocolVersionMismatch, } -impl From<time::Elapsed> for IdentityProtocolError { - fn from(_: time::Elapsed) -> Self { +impl From<time::error::Elapsed> for IdentityProtocolError { + fn from(_: time::error::Elapsed) -> Self { IdentityProtocolError::Timeout } } @@ -170,7 +172,7 @@ mod test { }; use futures::{future, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn identity_exchange() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); @@ -219,7 +221,7 @@ mod test { assert_eq!(identity2.addresses, vec![node_identity2.public_address().to_vec()]); } - #[runtime::test_basic] + #[runtime::test] async fn fail_cases() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index 6f078cec23..91d0e786ba 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -20,10 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{connection_manager::PeerConnectionError, peer_manager::PeerManagerError, protocol::ProtocolError}; -use futures::channel::mpsc; +use crate::{ + connection_manager::PeerConnectionError, + message::OutboundMessage, + peer_manager::PeerManagerError, + protocol::ProtocolError, +}; use std::io; use thiserror::Error; +use tokio::sync::mpsc; #[derive(Debug, Error)] pub enum InboundMessagingError { @@ -46,7 +51,7 @@ pub enum MessagingProtocolError { #[error("IO Error: {0}")] Io(#[from] io::Error), #[error("Sender error: {0}")] - SenderError(#[from] mpsc::SendError), + SenderError(#[from] mpsc::error::SendError<OutboundMessage>), #[error("Stream closed due to inactivity")] Inactivity, } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 241a152a5b..f216ddd04e 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -34,8 +34,8 @@ use crate::{ runtime, runtime::task, }; -use futures::channel::mpsc; use std::fmt; +use tokio::sync::mpsc; use tower::Service; /// Buffer size for inbound messages from _all_ peers. This should be large enough to buffer quite a few incoming diff --git a/comms/src/protocol/messaging/forward.rs b/comms/src/protocol/messaging/forward.rs new file mode 100644 index 0000000000..ce035e8fb8 --- /dev/null +++ b/comms/src/protocol/messaging/forward.rs @@ -0,0 +1,110 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Copied from futures rs + +use futures::{ + future::{FusedFuture, Future}, + ready, + stream::{Fuse, StreamExt}, + task::{Context, Poll}, + Sink, + Stream, + TryStream, +}; +use pin_project::pin_project; +use std::pin::Pin; + +/// Future for the [`forward`](super::StreamExt::forward) method. +#[pin_project(project = ForwardProj)] +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Forward<St, Si, Item> { + #[pin] + sink: Option<Si>, + #[pin] + stream: Fuse<St>, + buffered_item: Option<Item>, +} + +impl<St, Si, Item> Forward<St, Si, Item> +where St: TryStream +{ + pub(crate) fn new(stream: St, sink: Si) -> Self { + Self { + sink: Some(sink), + stream: stream.fuse(), + buffered_item: None, + } + } +} + +impl<St, Si, Item, E> FusedFuture for Forward<St, Si, Item> +where + Si: Sink<Item, Error = E>, + St: Stream<Item = Result<Item, E>>, +{ + fn is_terminated(&self) -> bool { + self.sink.is_none() + } +} + +impl<St, Si, Item, E> Future for Forward<St, Si, Item> +where + Si: Sink<Item, Error = E>, + St: Stream<Item = Result<Item, E>>, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let ForwardProj { + mut sink, + mut stream, + buffered_item, + } = self.project(); + let mut si = sink.as_mut().as_pin_mut().expect("polled `Forward` after completion"); + + loop { + // If we've got an item buffered already, we need to write it to the + // sink before we can do anything else + if buffered_item.is_some() { + ready!(si.as_mut().poll_ready(cx))?; + si.as_mut().start_send(buffered_item.take().unwrap())?; + } + + match stream.as_mut().poll_next(cx)? { + Poll::Ready(Some(item)) => { + *buffered_item = Some(item); + }, + Poll::Ready(None) => { + ready!(si.poll_close(cx))?; + sink.set(None); + return Poll::Ready(Ok(())); + }, + Poll::Pending => { + ready!(si.poll_flush(cx))?; + return Poll::Pending; + }, + } + } + } +} diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs index 643b07fc45..aa592fe8c8 100644 --- a/comms/src/protocol/messaging/inbound.rs +++ b/comms/src/protocol/messaging/inbound.rs @@ -26,10 +26,13 @@ use crate::{ peer_manager::NodeId, protocol::messaging::{MessagingEvent, MessagingProtocol}, }; -use futures::{channel::mpsc, future::Either, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt}; use log::*; use std::{sync::Arc, time::Duration}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; @@ -61,7 +64,7 @@ impl InboundMessaging { } } - pub async fn run<S>(mut self, socket: S) + pub async fn run<S>(self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { let peer = &self.peer; debug!( @@ -70,48 +73,40 @@ impl InboundMessaging { peer.short_str() ); - let (mut sink, stream) = MessagingProtocol::framed(socket).split(); - - if let Err(err) = sink.close().await { - debug!( - target: LOG_TARGET, - "Error closing sink half for peer `{}`: {}", - peer.short_str(), - err - ); - } - let stream = stream.rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); + let stream = + MessagingProtocol::framed(socket).rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); - let mut stream = match self.inactivity_timeout { - Some(timeout) => Either::Left(tokio::stream::StreamExt::timeout(stream, timeout)), + let stream = match self.inactivity_timeout { + Some(timeout) => Either::Left(tokio_stream::StreamExt::timeout(stream, timeout)), None => Either::Right(stream.map(Ok)), }; + tokio::pin!(stream); while let Some(result) = stream.next().await { match result { Ok(Ok(raw_msg)) => { - let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.clone().freeze()); + let msg_len = raw_msg.len(); + let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); debug!( target: LOG_TARGET, "Received message {} from peer '{}' ({} bytes)", inbound_msg.tag, peer.short_str(), - raw_msg.len() + msg_len ); let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { + let tag = err.0.tag; warn!( target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", + "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", + tag, peer.short_str(), - err ); - if err.is_disconnected() { - break; - } + break; } let _ = self.messaging_events_tx.send(Arc::new(event)); diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 88fca6af05..1fa347b3a2 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -27,9 +27,9 @@ mod extension; pub use extension::MessagingProtocolExtension; mod error; +mod forward; mod inbound; mod outbound; - mod protocol; pub use protocol::{ MessagingEvent, diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 377f7a2d8a..8ab1298c88 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -29,13 +29,13 @@ use crate::{ peer_manager::NodeId, protocol::messaging::protocol::MESSAGING_PROTOCOL, }; -use futures::{channel::mpsc, future::Either, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt, TryStreamExt}; use log::*; use std::{ io, time::{Duration, Instant}, }; -use tokio::stream as tokio_stream; +use tokio::sync::mpsc as tokiompsc; const LOG_TARGET: &str = "comms::protocol::messaging::outbound"; /// The number of times to retry sending a failed message before publishing a SendMessageFailed event. @@ -45,8 +45,8 @@ const MAX_SEND_RETRIES: usize = 1; pub struct OutboundMessaging { connectivity: ConnectivityRequester, - request_rx: mpsc::UnboundedReceiver<OutboundMessage>, - messaging_events_tx: mpsc::Sender<MessagingEvent>, + request_rx: tokiompsc::UnboundedReceiver<OutboundMessage>, + messaging_events_tx: tokiompsc::Sender<MessagingEvent>, peer_node_id: NodeId, inactivity_timeout: Option<Duration>, } @@ -54,8 +54,8 @@ pub struct OutboundMessaging { impl OutboundMessaging { pub fn new( connectivity: ConnectivityRequester, - messaging_events_tx: mpsc::Sender<MessagingEvent>, - request_rx: mpsc::UnboundedReceiver<OutboundMessage>, + messaging_events_tx: tokiompsc::Sender<MessagingEvent>, + request_rx: tokiompsc::UnboundedReceiver<OutboundMessage>, peer_node_id: NodeId, inactivity_timeout: Option<Duration>, ) -> Self { @@ -75,7 +75,7 @@ impl OutboundMessaging { self.peer_node_id.short_str() ); let peer_node_id = self.peer_node_id.clone(); - let mut messaging_events_tx = self.messaging_events_tx.clone(); + let messaging_events_tx = self.messaging_events_tx.clone(); match self.run_inner().await { Ok(_) => { debug!( @@ -211,7 +211,7 @@ impl OutboundMessaging { ); let substream = substream.stream; - let (sink, _) = MessagingProtocol::framed(substream).split(); + let framed = MessagingProtocol::framed(substream); let Self { request_rx, @@ -219,30 +219,30 @@ impl OutboundMessaging { .. } = self; + // Convert unbounded channel to a stream + let stream = futures::stream::unfold(request_rx, |mut rx| async move { + let v = rx.recv().await; + v.map(|v| (v, rx)) + }); + let stream = match inactivity_timeout { Some(timeout) => { - let s = tokio_stream::StreamExt::timeout(request_rx, timeout).map(|r| match r { - Ok(s) => Ok(s), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - MessagingProtocolError::Inactivity, - )), - }); + let s = tokio_stream::StreamExt::timeout(stream, timeout) + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, MessagingProtocolError::Inactivity)); Either::Left(s) }, - None => Either::Right(request_rx.map(Ok)), + None => Either::Right(stream.map(Ok)), }; - stream - .map(|msg| { - msg.map(|mut out_msg| { - trace!(target: LOG_TARGET, "Message buffered for sending {}", out_msg); - out_msg.reply_success(); - out_msg.body - }) + let stream = stream.map(|msg| { + msg.map(|mut out_msg| { + trace!(target: LOG_TARGET, "Message buffered for sending {}", out_msg); + out_msg.reply_success(); + out_msg.body }) - .forward(sink) - .await?; + }); + + super::forward::Forward::new(stream, framed).await?; debug!( target: LOG_TARGET, @@ -256,7 +256,7 @@ impl OutboundMessaging { // Close the request channel so that we can read all the remaining messages and flush them // to a failed event self.request_rx.close(); - while let Some(mut out_msg) = self.request_rx.next().await { + while let Some(mut out_msg) = self.request_rx.recv().await { out_msg.reply_fail(reason); let _ = self .messaging_events_tx diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index 6e53e80e6c..39bd8d6242 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -22,7 +22,6 @@ use super::error::MessagingProtocolError; use crate::{ - compat::IoCompat, connectivity::{ConnectivityEvent, ConnectivityRequester}, framing, message::{InboundMessage, MessageTag, OutboundMessage}, @@ -36,7 +35,6 @@ use crate::{ runtime::task, }; use bytes::Bytes; -use futures::{channel::mpsc, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{ collections::{hash_map::Entry, HashMap}, @@ -46,7 +44,10 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; const LOG_TARGET: &str = "comms::protocol::messaging"; @@ -106,13 +107,13 @@ impl fmt::Display for MessagingEvent { pub struct MessagingProtocol { config: MessagingConfig, connectivity: ConnectivityRequester, - proto_notification: Fuse<mpsc::Receiver<ProtocolNotification<Substream>>>, + proto_notification: mpsc::Receiver<ProtocolNotification<Substream>>, active_queues: HashMap<NodeId, mpsc::UnboundedSender<OutboundMessage>>, - request_rx: Fuse<mpsc::Receiver<MessagingRequest>>, + request_rx: mpsc::Receiver<MessagingRequest>, messaging_events_tx: MessagingEventSender, inbound_message_tx: mpsc::Sender<InboundMessage>, internal_messaging_event_tx: mpsc::Sender<MessagingEvent>, - internal_messaging_event_rx: Fuse<mpsc::Receiver<MessagingEvent>>, + internal_messaging_event_rx: mpsc::Receiver<MessagingEvent>, shutdown_signal: ShutdownSignal, complete_trigger: Shutdown, } @@ -133,11 +134,11 @@ impl MessagingProtocol { Self { config, connectivity, - proto_notification: proto_notification.fuse(), - request_rx: request_rx.fuse(), + proto_notification, + request_rx, active_queues: Default::default(), messaging_events_tx, - internal_messaging_event_rx: internal_messaging_event_rx.fuse(), + internal_messaging_event_rx, internal_messaging_event_tx, inbound_message_tx, shutdown_signal, @@ -151,15 +152,15 @@ impl MessagingProtocol { pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut connectivity_events = self.connectivity.get_event_subscription(); loop { - futures::select! { - event = self.internal_messaging_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_messaging_event_rx.recv() => { self.handle_internal_messaging_event(event).await; }, - req = self.request_rx.select_next_some() => { + Some(req) = self.request_rx.recv() => { if let Err(err) = self.handle_request(req).await { error!( target: LOG_TARGET, @@ -169,17 +170,17 @@ impl MessagingProtocol { } }, - event = connectivity_events.select_next_some() => { + event = connectivity_events.recv() => { if let Ok(event) = event { self.handle_connectivity_event(&event); } } - notification = self.proto_notification.select_next_some() => { + Some(notification) = self.proto_notification.recv() => { self.handle_protocol_notification(notification).await; }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "MessagingProtocol is shutting down because the shutdown signal was triggered"); break; } @@ -188,7 +189,7 @@ impl MessagingProtocol { } #[inline] - pub fn framed<TSubstream>(socket: TSubstream) -> Framed<IoCompat<TSubstream>, LengthDelimitedCodec> + pub fn framed<TSubstream>(socket: TSubstream) -> Framed<TSubstream, LengthDelimitedCodec> where TSubstream: AsyncRead + AsyncWrite + Unpin { framing::canonical(socket, MAX_FRAME_LENGTH) } @@ -198,11 +199,9 @@ impl MessagingProtocol { #[allow(clippy::single_match)] match event { PeerConnectionWillClose(node_id, _) => { - // If the peer connection will close, cut off the pipe to send further messages. - // Any messages in the channel will be sent (hopefully) before the connection is disconnected. - if let Some(sender) = self.active_queues.remove(node_id) { - sender.close_channel(); - } + // If the peer connection will close, cut off the pipe to send further messages by dropping the sender. + // Any messages in the channel may be sent before the connection is disconnected. + let _ = self.active_queues.remove(node_id); }, _ => {}, } @@ -272,7 +271,7 @@ impl MessagingProtocol { debug!(target: LOG_TARGET, "Sending message {}", out_msg); let tag = out_msg.tag; - match sender.send(out_msg).await { + match sender.send(out_msg) { Ok(_) => { debug!(target: LOG_TARGET, "Message ({}) dispatched to outbound handler", tag,); Ok(()) @@ -293,7 +292,7 @@ impl MessagingProtocol { peer_node_id: NodeId, inactivity_timeout: Option<Duration>, ) -> mpsc::UnboundedSender<OutboundMessage> { - let (msg_tx, msg_rx) = mpsc::unbounded(); + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); let outbound_messaging = OutboundMessaging::new(connectivity, events_tx, msg_rx, peer_node_id, inactivity_timeout); task::spawn(outbound_messaging.run()); diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index e954af31f2..5d066eb4c2 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -49,18 +49,16 @@ use crate::{ types::{CommsDatabase, CommsPublicKey}, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - SinkExt, - StreamExt, -}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::rngs::OsRng; use std::{io, sync::Arc, time::Duration}; use tari_crypto::keys::PublicKey; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, time}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1"); @@ -110,9 +108,9 @@ async fn spawn_messaging_protocol() -> ( ) } -#[runtime::test_basic] +#[runtime::test] async fn new_inbound_substream_handling() { - let (peer_manager, _, _, mut proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let expected_node_id = node_id::random(); @@ -148,7 +146,7 @@ async fn new_inbound_substream_handling() { framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); - let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.next()) + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() .unwrap(); @@ -156,19 +154,18 @@ async fn new_inbound_substream_handling() { assert_eq!(in_msg.body, TEST_MSG1); let expected_tag = in_msg.tag; - let event = time::timeout(Duration::from_secs(5), events_rx.next()) + let event = time::timeout(Duration::from_secs(5), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_request() { - let (_, node_identity, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, node_identity, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -192,9 +189,9 @@ async fn send_message_request() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_dial_failed() { - let (_, _, conn_manager_mock, _, mut request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); let out_msg = OutboundMessage::new(node_id, TEST_MSG1.clone()); @@ -202,7 +199,7 @@ async fn send_message_dial_failed() { // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); - let event = event_tx.next().await.unwrap().unwrap(); + let event = event_tx.recv().await.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); assert_eq!(out_msg.tag, expected_out_msg_tag); @@ -212,7 +209,7 @@ async fn send_message_dial_failed() { assert!(calls.iter().all(|evt| evt.starts_with("DialPeer"))); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_substream_bulk_failure() { const NUM_MSGS: usize = 10; let (_, node_identity, conn_manager_mock, _, mut request_tx, _, mut events_rx, _shutdown) = @@ -258,19 +255,18 @@ async fn send_message_substream_bulk_failure() { } // Check that the outbound handler closed - let event = time::timeout(Duration::from_secs(10), events_rx.next()) + let event = time::timeout(Duration::from_secs(10), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); assert_eq!(node_id, peer_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests() { const NUM_MSGS: usize = 100; - let (_, _, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -315,10 +311,10 @@ async fn many_concurrent_send_message_requests() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests_that_fail() { const NUM_MSGS: usize = 100; - let (_, _, _, _, mut request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, _, _, request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let node_id2 = node_id::random(); @@ -339,10 +335,9 @@ async fn many_concurrent_send_message_requests_that_fail() { } // Check that we got message success events - let events = collect_stream!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let events = collect_recv!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert_eq!(events.len(), NUM_MSGS); for event in events { - let event = event.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); // Assert that each tag is emitted only once @@ -357,7 +352,7 @@ async fn many_concurrent_send_message_requests_that_fail() { assert_eq!(msg_tags.len(), 0); } -#[runtime::test_basic] +#[runtime::test] async fn inactivity_timeout() { let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); let (inbound_msg_tx, mut inbound_msg_rx) = mpsc::channel(5); @@ -381,13 +376,13 @@ async fn inactivity_timeout() { let mut framed = MessagingProtocol::framed(socket_out); for _ in 0..5u8 { framed.send(Bytes::from_static(b"some message")).await.unwrap(); - time::delay_for(Duration::from_millis(1)).await; + time::sleep(Duration::from_millis(1)).await; } - time::delay_for(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(10)).await; let err = framed.send(Bytes::from_static(b"another message")).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - let _ = collect_stream!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); + let _ = collect_recv!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); } diff --git a/comms/src/protocol/negotiation.rs b/comms/src/protocol/negotiation.rs index 5178d790e6..326910dc99 100644 --- a/comms/src/protocol/negotiation.rs +++ b/comms/src/protocol/negotiation.rs @@ -23,9 +23,9 @@ use super::{ProtocolError, ProtocolId}; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use log::*; use std::convert::TryInto; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const LOG_TARGET: &str = "comms::connection_manager::protocol"; @@ -204,7 +204,7 @@ mod test { use futures::future; use tari_test_utils::unpack_enum; - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -229,7 +229,7 @@ mod test { assert_eq!(out_proto.unwrap(), ProtocolId::from_static(b"A")); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -254,7 +254,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolOutboundNegotiationFailed(_s) = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_max_rounds() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -279,7 +279,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -300,7 +300,7 @@ mod test { out_proto.unwrap(); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 936ef15f34..14d253196e 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -32,8 +32,8 @@ use crate::{ }, Substream, }; -use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +use tokio::sync::mpsc; pub type ProtocolNotificationTx<TSubstream> = mpsc::Sender<ProtocolNotification<TSubstream>>; pub type ProtocolNotificationRx<TSubstream> = mpsc::Receiver<ProtocolNotification<TSubstream>>; @@ -143,7 +143,6 @@ impl ProtocolExtension for Protocols<Substream> { mod test { use super::*; use crate::runtime; - use futures::StreamExt; use tari_test_utils::unpack_enum; #[test] @@ -160,7 +159,7 @@ mod test { assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } - #[runtime::test_basic] + #[runtime::test] async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; @@ -172,12 +171,12 @@ mod test { .await .unwrap(); - let notification = rx.next().await.unwrap(); + let notification = rx.recv().await.unwrap(); unpack_enum!(ProtocolEvent::NewInboundSubstream(peer_id, _s) = notification.event); assert_eq!(peer_id, NodeId::new()); } - #[runtime::test_basic] + #[runtime::test] async fn notify_fail_not_registered() { let mut protocols = Protocols::<()>::new(); diff --git a/comms/src/protocol/rpc/body.rs b/comms/src/protocol/rpc/body.rs index 6079508729..e563d6483e 100644 --- a/comms/src/protocol/rpc/body.rs +++ b/comms/src/protocol/rpc/body.rs @@ -27,7 +27,6 @@ use crate::{ }; use bytes::BytesMut; use futures::{ - channel::mpsc, ready, stream::BoxStream, task::{Context, Poll}, @@ -37,6 +36,7 @@ use futures::{ use pin_project::pin_project; use prost::bytes::Buf; use std::{fmt, marker::PhantomData, pin::Pin}; +use tokio::sync::mpsc; pub trait IntoBody { fn into_body(self) -> Body; @@ -205,8 +205,8 @@ impl Buf for BodyBytes { self.0.as_ref().map(Buf::remaining).unwrap_or(0) } - fn bytes(&self) -> &[u8] { - self.0.as_ref().map(Buf::bytes).unwrap_or(&[]) + fn chunk(&self) -> &[u8] { + self.0.as_ref().map(Buf::chunk).unwrap_or(&[]) } fn advance(&mut self, cnt: usize) { @@ -227,7 +227,7 @@ impl<T> Streaming<T> { } pub fn empty() -> Self { - let (_, rx) = mpsc::channel(0); + let (_, rx) = mpsc::channel(1); Self { inner: rx } } @@ -240,7 +240,7 @@ impl<T: prost::Message> Stream for Streaming<T> { type Item = Result<Bytes, RpcStatus>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(result) => { let result = result.map(|msg| msg.to_encoded_bytes().into()); Poll::Ready(Some(result)) @@ -275,7 +275,7 @@ impl<T: prost::Message + Default + Unpin> Stream for ClientStreaming<T> { type Item = Result<T, RpcStatus>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(Ok(resp)) => { // The streaming protocol dictates that an empty finish flag MUST be sent to indicate a terminated // stream. This empty response need not be emitted to downsteam consumers. @@ -298,7 +298,7 @@ mod test { use futures::{stream, StreamExt}; use prost::Message; - #[runtime::test_basic] + #[runtime::test] async fn single_body() { let mut body = Body::single(123u32.to_encoded_bytes()); let bytes = body.next().await.unwrap().unwrap(); @@ -306,7 +306,7 @@ mod test { assert_eq!(u32::decode(bytes).unwrap(), 123u32); } - #[runtime::test_basic] + #[runtime::test] async fn streaming_body() { let body = Body::streaming(stream::repeat(Bytes::new()).map(Ok).take(10)); let body = body.collect::<Vec<_>>().await; diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 79cd5540a7..0b61b0c451 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -41,11 +41,8 @@ use crate::{ }; use bytes::Bytes; use futures::{ - channel::{mpsc, oneshot}, future::{BoxFuture, Either}, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, SinkExt, StreamExt, @@ -58,9 +55,15 @@ use std::{ fmt, future::Future, marker::PhantomData, + sync::Arc, time::{Duration, Instant}, }; -use tokio::time; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot, Mutex}, + time, +}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::rpc::client"; @@ -81,9 +84,11 @@ impl RpcClient { TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (request_tx, request_rx) = mpsc::channel(1); - let connector = ClientConnector::new(request_tx); + let shutdown = Shutdown::new(); + let shutdown_signal = shutdown.to_signal(); + let connector = ClientConnector::new(request_tx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); - task::spawn(RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name).run()); + task::spawn(RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name, shutdown_signal).run()); ready_rx .await .expect("ready_rx oneshot is never dropped without a reply")?; @@ -101,7 +106,7 @@ impl RpcClient { let request = BaseRequest::new(method.into(), req_bytes.into()); let mut resp = self.call_inner(request).await?; - let resp = resp.next().await.ok_or(RpcError::ServerClosedRequest)??; + let resp = resp.recv().await.ok_or(RpcError::ServerClosedRequest)??; let resp = R::decode(resp.into_message())?; Ok(resp) @@ -123,8 +128,8 @@ impl RpcClient { } /// Close the RPC session. Any subsequent calls will error. - pub fn close(&mut self) { - self.connector.close() + pub async fn close(&mut self) { + self.connector.close().await; } pub fn is_connected(&self) -> bool { @@ -260,15 +265,20 @@ impl Default for RpcClientConfig { #[derive(Clone)] pub struct ClientConnector { inner: mpsc::Sender<ClientRequest>, + shutdown: Arc<Mutex<Shutdown>>, } impl ClientConnector { - pub(self) fn new(sender: mpsc::Sender<ClientRequest>) -> Self { - Self { inner: sender } + pub(self) fn new(sender: mpsc::Sender<ClientRequest>, shutdown: Shutdown) -> Self { + Self { + inner: sender, + shutdown: Arc::new(Mutex::new(shutdown)), + } } - pub fn close(&mut self) { - self.inner.close_channel(); + pub async fn close(&mut self) { + let mut lock = self.shutdown.lock().await; + lock.trigger(); } pub async fn get_last_request_latency(&mut self) -> Result<Option<Duration>, RpcError> { @@ -308,13 +318,13 @@ impl Service<BaseRequest<Bytes>> for ClientConnector { type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; type Response = mpsc::Receiver<Result<Response<Bytes>, RpcStatus>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - self.inner.poll_ready_unpin(cx).map_err(|_| RpcError::ClientClosed) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) } fn call(&mut self, request: BaseRequest<Bytes>) -> Self::Future { let (reply, reply_rx) = oneshot::channel(); - let mut inner = self.inner.clone(); + let inner = self.inner.clone(); async move { inner .send(ClientRequest::SendRequest { request, reply }) @@ -337,6 +347,7 @@ pub struct RpcClientWorker<TSubstream> { ready_tx: Option<oneshot::Sender<Result<(), RpcError>>>, last_request_latency: Option<Duration>, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, } impl<TSubstream> RpcClientWorker<TSubstream> @@ -348,6 +359,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send framed: CanonicalFraming<TSubstream>, ready_tx: oneshot::Sender<Result<(), RpcError>>, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, ) -> Self { Self { config, @@ -357,6 +369,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ready_tx: Some(ready_tx), last_request_latency: None, protocol_id, + shutdown_signal, } } @@ -395,26 +408,26 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send }, } - while let Some(req) = self.request_rx.next().await { - use ClientRequest::*; - match req { - SendRequest { request, reply } => { - if let Err(err) = self.do_request_response(request, reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; - } - }, - GetLastRequestLatency(reply) => { - let _ = reply.send(self.last_request_latency); - }, - SendPing(reply) => { - if let Err(err) = self.do_ping_pong(reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; + loop { + tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + break; + } + req = self.request_rx.recv() => { + match req { + Some(req) => { + if let Err(err) = self.handle_request(req).await { + error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + break; + } + } + None => break, } - }, + } } } + if let Err(err) = self.framed.close().await { debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); } @@ -426,6 +439,22 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ); } + async fn handle_request(&mut self, req: ClientRequest) -> Result<(), RpcError> { + use ClientRequest::*; + match req { + SendRequest { request, reply } => { + self.do_request_response(request, reply).await?; + }, + GetLastRequestLatency(reply) => { + let _ = reply.send(self.last_request_latency); + }, + SendPing(reply) => { + self.do_ping_pong(reply).await?; + }, + } + Ok(()) + } + async fn do_ping_pong(&mut self, reply: oneshot::Sender<Result<Duration, RpcStatus>>) -> Result<(), RpcError> { let ack = proto::rpc::RpcRequest { flags: RpcMessageFlags::ACK.bits() as u32, @@ -492,10 +521,10 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let start = Instant::now(); self.framed.send(req.to_encoded_bytes().into()).await?; - let (mut response_tx, response_rx) = mpsc::channel(10); - if reply.send(response_rx).is_err() { + let (response_tx, response_rx) = mpsc::channel(10); + if let Err(mut rx) = reply.send(response_rx) { debug!(target: LOG_TARGET, "Client request was cancelled."); - response_tx.close_channel(); + rx.close(); } loop { @@ -522,8 +551,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send method, start.elapsed() ); - let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; - response_tx.close_channel(); + if !response_tx.is_closed() { + let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; + } + break; + }, + Err(RpcError::ClientClosed) => { + debug!( + target: LOG_TARGET, + "Request {} (method={}) was closed after {:.0?} (read_reply)", + request_id, + method, + start.elapsed() + ); + self.request_rx.close(); break; }, Err(err) => return Err(err), @@ -546,7 +587,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let _ = response_tx.send(Ok(resp)).await; } if is_finished { - response_tx.close_channel(); break; } }, @@ -555,7 +595,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if !response_tx.is_closed() { let _ = response_tx.send(Err(err)).await; } - response_tx.close_channel(); break; }, Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) | @@ -580,7 +619,15 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send None => Either::Right(self.framed.next().map(Ok)), }; - match next_msg_fut.await { + let result = tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + return Err(RpcError::ClientClosed); + } + result = next_msg_fut => result, + }; + + match result { Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(RpcError::ServerClosedRequest), diff --git a/comms/src/protocol/rpc/client_pool.rs b/comms/src/protocol/rpc/client_pool.rs index 6829b41265..7cf99ed419 100644 --- a/comms/src/protocol/rpc/client_pool.rs +++ b/comms/src/protocol/rpc/client_pool.rs @@ -61,6 +61,11 @@ where T: RpcPoolClient + From<RpcClient> + NamedProtocolService + Clone let mut pool = self.pool.lock().await; pool.get_least_used_or_connect().await } + + pub async fn is_connected(&self) -> bool { + let pool = self.pool.lock().await; + pool.is_connected() + } } #[derive(Clone)] @@ -111,6 +116,10 @@ where T: RpcPoolClient + From<RpcClient> + NamedProtocolService + Clone } } + pub fn is_connected(&self) -> bool { + self.connection.is_connected() + } + pub(super) fn refresh_num_active_connections(&mut self) -> usize { self.prune(); self.clients.len() diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index 7c4dca1ae9..f73b65d74a 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -22,11 +22,14 @@ use crate::{framing::CanonicalFraming, message::MessageExt, proto, protocol::rpc::error::HandshakeRejectReason}; use bytes::BytesMut; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use log::*; use prost::{DecodeError, Message}; use std::{io, time::Duration}; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; const LOG_TARGET: &str = "comms::rpc::handshake"; @@ -138,7 +141,7 @@ where T: AsyncRead + AsyncWrite + Unpin } } - async fn recv_next_frame(&mut self) -> Result<Option<Result<BytesMut, io::Error>>, time::Elapsed> { + async fn recv_next_frame(&mut self) -> Result<Option<Result<BytesMut, io::Error>>, time::error::Elapsed> { match self.timeout { Some(timeout) => time::timeout(timeout, self.framed.next()).await, None => Ok(self.framed.next().await), diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index d4e91fa8e4..2244979adf 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -80,6 +80,7 @@ pub mod __macro_reexports { }, Bytes, }; - pub use futures::{future, future::BoxFuture, AsyncRead, AsyncWrite}; + pub use futures::{future, future::BoxFuture}; + pub use tokio::io::{AsyncRead, AsyncWrite}; pub use tower::Service; } diff --git a/comms/src/protocol/rpc/server/error.rs b/comms/src/protocol/rpc/server/error.rs index 5078c6c588..6972cec60b 100644 --- a/comms/src/protocol/rpc/server/error.rs +++ b/comms/src/protocol/rpc/server/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::protocol::rpc::handshake::RpcHandshakeError; -use futures::channel::oneshot; use prost::DecodeError; use std::io; +use tokio::sync::oneshot; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -41,8 +41,8 @@ pub enum RpcServerError { ProtocolServiceNotFound(String), } -impl From<oneshot::Canceled> for RpcServerError { - fn from(_: oneshot::Canceled) -> Self { +impl From<oneshot::error::RecvError> for RpcServerError { + fn from(_: oneshot::error::RecvError) -> Self { RpcServerError::RequestCanceled } } diff --git a/comms/src/protocol/rpc/server/handle.rs b/comms/src/protocol/rpc/server/handle.rs index 89bf8dd3b9..972d91429d 100644 --- a/comms/src/protocol/rpc/server/handle.rs +++ b/comms/src/protocol/rpc/server/handle.rs @@ -21,10 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::RpcServerError; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug)] pub enum RpcServerRequest { diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 19741a0a1a..69659ba03b 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -42,6 +42,7 @@ use crate::{ ProtocolNotificationTx, }, test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, + utils, NodeIdentity, PeerConnection, PeerManager, @@ -49,7 +50,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{channel::mpsc, future::BoxFuture, stream, SinkExt}; +use futures::future::BoxFuture; use std::{ collections::HashMap, future, @@ -57,7 +58,7 @@ use std::{ task::{Context, Poll}, }; use tokio::{ - sync::{Mutex, RwLock}, + sync::{mpsc, Mutex, RwLock}, task, }; use tower::Service; @@ -139,9 +140,13 @@ pub trait RpcMock { { method_state.requests.write().await.push(request.into_message()); let resp = method_state.response.read().await.clone()?; - let (mut tx, rx) = mpsc::channel(resp.len()); - let mut resp = stream::iter(resp.into_iter().map(Ok).map(Ok)); - tx.send_all(&mut resp).await.unwrap(); + let (tx, rx) = mpsc::channel(resp.len()); + match utils::mpsc::send_all(&tx, resp.into_iter().map(Ok)).await { + Ok(_) => {}, + // This is done because tokio mpsc channels give the item back to you in the error, and our item doesn't + // impl Debug, so we can't use unwrap, expect etc + Err(_) => panic!("send error"), + } Ok(Streaming::new(rx)) } } @@ -234,7 +239,7 @@ where let peer_node_id = peer.node_id.clone(); let (_, our_conn_mock, peer_conn, _) = create_peer_connection_mock_pair(peer, self.our_node.to_peer()).await; - let mut protocol_tx = self.protocol_tx.clone(); + let protocol_tx = self.protocol_tx.clone(); task::spawn(async move { while let Some(substream) = our_conn_mock.next_incoming_substream().await { let proto_notif = ProtocolNotification::new( diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index 7615b1497e..56f4575649 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -53,7 +53,7 @@ use crate::{ protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::SinkExt; use log::*; use prost::Message; use std::{ @@ -61,7 +61,12 @@ use std::{ future::Future, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; @@ -198,7 +203,7 @@ pub(super) struct PeerRpcServer<TSvc, TSubstream, TCommsProvider> { service: TSvc, protocol_notifications: Option<ProtocolNotificationRx<TSubstream>>, comms_provider: TCommsProvider, - request_rx: Option<mpsc::Receiver<RpcServerRequest>>, + request_rx: mpsc::Receiver<RpcServerRequest>, } impl<TSvc, TSubstream, TCommsProvider> PeerRpcServer<TSvc, TSubstream, TCommsProvider> @@ -233,7 +238,7 @@ where service, protocol_notifications: Some(protocol_notifications), comms_provider, - request_rx: Some(request_rx), + request_rx, } } @@ -243,24 +248,19 @@ where .take() .expect("PeerRpcServer initialized without protocol_notifications"); - let mut requests = self - .request_rx - .take() - .expect("PeerRpcServer initialized without request_rx"); - loop { - futures::select! { - maybe_notif = protocol_notifs.next() => { - match maybe_notif { - Some(notif) => self.handle_protocol_notification(notif).await?, - // No more protocol notifications to come, so we're done - None => break, - } - } - - req = requests.select_next_some() => { + tokio::select! { + maybe_notif = protocol_notifs.recv() => { + match maybe_notif { + Some(notif) => self.handle_protocol_notification(notif).await?, + // No more protocol notifications to come, so we're done + None => break, + } + } + + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; - }, + }, } } diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 9d03c6535d..1d40988075 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -44,14 +44,15 @@ use crate::{ Bytes, }; use futures::{ - channel::mpsc, future::BoxFuture, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, }; use std::sync::Arc; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; use tower::Service; use tower_make::MakeService; @@ -329,7 +330,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn find_route() { let server = RpcServer::new(); let mut router = Router::new(server, HelloService).add_service(GoodbyeService); diff --git a/comms/src/protocol/rpc/test/client_pool.rs b/comms/src/protocol/rpc/test/client_pool.rs index e1eb957d5f..d95e1d22b6 100644 --- a/comms/src/protocol/rpc/test/client_pool.rs +++ b/comms/src/protocol/rpc/test/client_pool.rs @@ -39,13 +39,13 @@ use crate::{ runtime::task, test_utils::mocks::{new_peer_connection_mock_pair, PeerConnectionMockState}, }; -use futures::{channel::mpsc, SinkExt}; use tari_shutdown::Shutdown; use tari_test_utils::{async_assert_eventually, unpack_enum}; +use tokio::sync::mpsc; async fn setup(num_concurrent_sessions: usize) -> (PeerConnection, PeerConnectionMockState, Shutdown) { let (conn1, conn1_state, conn2, conn2_state) = new_peer_connection_mock_pair().await; - let (mut notif_tx, notif_rx) = mpsc::channel(1); + let (notif_tx, notif_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); @@ -148,15 +148,15 @@ mod lazy_pool { async fn it_prunes_disconnected_sessions() { let (conn, mock_state, _shutdown) = setup(2).await; let mut pool = LazyPool::<GreetingClient>::new(conn, 2, Default::default()); - let mut conn1 = pool.get_least_used_or_connect().await.unwrap(); + let mut client1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); - let _conn2 = pool.get_least_used_or_connect().await.unwrap(); + let _client2 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 2); - conn1.close(); - drop(conn1); + client1.close().await; + drop(client1); async_assert_eventually!(mock_state.num_open_substreams(), expect = 1); assert_eq!(pool.refresh_num_active_connections(), 1); - let _conn3 = pool.get_least_used_or_connect().await.unwrap(); + let _client3 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(pool.refresh_num_active_connections(), 2); assert_eq!(mock_state.num_open_substreams(), 2); } diff --git a/comms/src/protocol/rpc/test/comms_integration.rs b/comms/src/protocol/rpc/test/comms_integration.rs index 9d23088f07..f43f921081 100644 --- a/comms/src/protocol/rpc/test/comms_integration.rs +++ b/comms/src/protocol/rpc/test/comms_integration.rs @@ -37,7 +37,7 @@ use crate::{ use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn run_service() { let node_identity1 = build_node_identity(Default::default()); let rpc_service = MockRpcService::new(); diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index 0e190473dd..445099e7c3 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -26,12 +26,16 @@ use crate::{ rpc::{NamedProtocolService, Request, Response, RpcError, RpcServerError, RpcStatus, Streaming}, ProtocolId, }, + utils, }; use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, time}; +use tokio::{ + sync::{mpsc, RwLock}, + task, + time, +}; #[async_trait] // #[tari_rpc(protocol_name = "/tari/greeting/1.0", server_struct = GreetingServer, client_struct = GreetingClient)] @@ -91,20 +95,11 @@ impl GreetingRpc for GreetingService { } async fn get_greetings(&self, request: Request<u32>) -> Result<Streaming<String>, RpcStatus> { - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); let num = *request.message(); let greetings = self.greetings[..num as usize].to_vec(); task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - match tx.send_all(&mut stream).await { - Ok(_) => {}, - Err(_err) => { - // Log error - }, - } + let _ = utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await; }); Ok(Streaming::new(rx)) @@ -118,7 +113,7 @@ impl GreetingRpc for GreetingService { } async fn streaming_error2(&self, _: Request<()>) -> Result<Streaming<String>, RpcStatus> { - let (mut tx, rx) = mpsc::channel(2); + let (tx, rx) = mpsc::channel(2); tx.send(Ok("This is ok".to_string())).await.unwrap(); tx.send(Err(RpcStatus::bad_request("This is a problem"))).await.unwrap(); @@ -151,7 +146,7 @@ impl SlowGreetingService { impl GreetingRpc for SlowGreetingService { async fn say_hello(&self, _: Request<SayHelloRequest>) -> Result<Response<SayHelloResponse>, RpcStatus> { let delay = *self.delay.read().await; - time::delay_for(delay).await; + time::sleep(delay).await; Ok(Response::new(SayHelloResponse { greeting: "took a while to load".to_string(), })) @@ -376,8 +371,8 @@ impl GreetingClient { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } } diff --git a/comms/src/protocol/rpc/test/handshake.rs b/comms/src/protocol/rpc/test/handshake.rs index cdd79746f2..9a21628012 100644 --- a/comms/src/protocol/rpc/test/handshake.rs +++ b/comms/src/protocol/rpc/test/handshake.rs @@ -33,7 +33,7 @@ use crate::{ }; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn it_performs_the_handshake() { let (client, server) = MemorySocket::new_pair(); @@ -51,7 +51,7 @@ async fn it_performs_the_handshake() { assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); } -#[runtime::test_basic] +#[runtime::test] async fn it_rejects_the_handshake() { let (client, server) = MemorySocket::new_pair(); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index a762ac4c9c..3149e794e3 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -52,12 +52,15 @@ use crate::{ test_utils::node_identity::build_node_identity, NodeIdentity, }; -use futures::{channel::mpsc, future, future::Either, SinkExt, StreamExt}; +use futures::{future, future::Either, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; pub(super) async fn setup_service<T: GreetingRpc>( service_impl: T, @@ -85,7 +88,7 @@ pub(super) async fn setup_service<T: GreetingRpc>( futures::pin_mut!(fut); match future::select(shutdown_signal, fut).await { - Either::Left((r, _)) => r.unwrap(), + Either::Left(_) => {}, Either::Right((r, _)) => r.unwrap(), } } @@ -97,7 +100,7 @@ pub(super) async fn setup<T: GreetingRpc>( service_impl: T, num_concurrent_sessions: usize, ) -> (MemorySocket, task::JoinHandle<()>, Arc<NodeIdentity>, Shutdown) { - let (mut notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; + let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (inbound, outbound) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -114,7 +117,7 @@ pub(super) async fn setup<T: GreetingRpc>( (outbound, server_hnd, node_identity, shutdown) } -#[runtime::test_basic] +#[runtime::test] async fn request_response_errors_and_streaming() { let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; @@ -171,7 +174,7 @@ async fn request_response_errors_and_streaming() { let pk_hex = client.get_public_key_hex().await.unwrap(); assert_eq!(pk_hex, node_identity.public_key().to_hex()); - client.close(); + client.close().await; let err = client .say_hello(SayHelloRequest { @@ -181,13 +184,20 @@ async fn request_response_errors_and_streaming() { .await .unwrap_err(); - unpack_enum!(RpcError::ClientClosed = err); + match err { + // Because of the race between closing the request stream and sending on that stream in the above call + // We can either get "this client was closed" or "the request you made was cancelled". + // If we delay some small time, we'll always get the former (but arbitrary delays cause flakiness and should be + // avoided) + RpcError::ClientClosed | RpcError::RequestCancelled => {}, + err => panic!("Unexpected error {:?}", err), + } - shutdown.trigger().unwrap(); + shutdown.trigger(); server_hnd.await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn concurrent_requests() { let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; @@ -227,7 +237,7 @@ async fn concurrent_requests() { assert_eq!(spawned2.await.unwrap(), GreetingService::DEFAULT_GREETINGS[..5]); } -#[runtime::test_basic] +#[runtime::test] async fn response_too_big() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -248,7 +258,7 @@ async fn response_too_big() { let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn ping_latency() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -261,11 +271,11 @@ async fn ping_latency() { assert!(latency.as_secs() < 5); } -#[runtime::test_basic] +#[runtime::test] async fn server_shutdown_before_connect() { let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; let framed = framing::canonical(socket, 1024); - shutdown.trigger().unwrap(); + shutdown.trigger(); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( @@ -274,7 +284,7 @@ async fn server_shutdown_before_connect() { )); } -#[runtime::test_basic] +#[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; @@ -298,9 +308,9 @@ async fn timeout() { assert_eq!(resp.greeting, "took a while to load"); } -#[runtime::test_basic] +#[runtime::test] async fn unknown_protocol() { - let (mut notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; + let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; let (inbound, socket) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -324,7 +334,7 @@ async fn unknown_protocol() { )); } -#[runtime::test_basic] +#[runtime::test] async fn rejected_no_sessions_available() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; let framed = framing::canonical(socket, 1024); diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index a6f7e615f7..48752c05d0 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -25,10 +25,7 @@ use tokio::runtime; // Re-export pub use tokio::{runtime::Handle, task}; -#[cfg(test)] -pub use tokio_macros::test; -#[cfg(test)] -pub use tokio_macros::test_basic; +pub use tokio::test; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] diff --git a/comms/src/socks/client.rs b/comms/src/socks/client.rs index b9d7dc255d..4e1383054e 100644 --- a/comms/src/socks/client.rs +++ b/comms/src/socks/client.rs @@ -23,7 +23,6 @@ // Acknowledgement to @sticnarf for tokio-socks on which this code is based use super::error::SocksError; use data_encoding::BASE32; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use multiaddr::{Multiaddr, Protocol}; use std::{ borrow::Cow, @@ -31,6 +30,7 @@ use std::{ fmt::Formatter, net::{Ipv4Addr, Ipv6Addr}, }; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub type Result<T> = std::result::Result<T, SocksError>; diff --git a/comms/src/test_utils/mocks/connection_manager.rs b/comms/src/test_utils/mocks/connection_manager.rs index 1637074ff7..28d95a3e55 100644 --- a/comms/src/test_utils/mocks/connection_manager.rs +++ b/comms/src/test_utils/mocks/connection_manager.rs @@ -31,7 +31,6 @@ use crate::{ peer_manager::NodeId, runtime::task, }; -use futures::{channel::mpsc, lock::Mutex, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -39,14 +38,14 @@ use std::{ Arc, }, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, Mutex}; pub fn create_connection_manager_mock() -> (ConnectionManagerRequester, ConnectionManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectionManagerRequester::new(tx, event_tx.clone()), - ConnectionManagerMock::new(rx.fuse(), event_tx), + ConnectionManagerMock::new(rx, event_tx), ) } @@ -97,13 +96,13 @@ impl ConnectionManagerMockState { } pub struct ConnectionManagerMock { - receiver: Fuse<mpsc::Receiver<ConnectionManagerRequest>>, + receiver: mpsc::Receiver<ConnectionManagerRequest>, state: ConnectionManagerMockState, } impl ConnectionManagerMock { pub fn new( - receiver: Fuse<mpsc::Receiver<ConnectionManagerRequest>>, + receiver: mpsc::Receiver<ConnectionManagerRequest>, event_tx: broadcast::Sender<Arc<ConnectionManagerEvent>>, ) -> Self { Self { @@ -121,7 +120,7 @@ impl ConnectionManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/connectivity_manager.rs b/comms/src/test_utils/mocks/connectivity_manager.rs index c3479006d2..ddb1439aab 100644 --- a/comms/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/src/test_utils/mocks/connectivity_manager.rs @@ -22,32 +22,36 @@ use crate::{ connection_manager::{ConnectionManagerError, PeerConnection}, - connectivity::{ConnectivityEvent, ConnectivityRequest, ConnectivityRequester, ConnectivityStatus}, + connectivity::{ + ConnectivityEvent, + ConnectivityEventTx, + ConnectivityRequest, + ConnectivityRequester, + ConnectivityStatus, + }, peer_manager::NodeId, runtime::task, }; -use futures::{ - channel::{mpsc, oneshot}, - lock::Mutex, - stream::Fuse, - StreamExt, -}; +use futures::lock::Mutex; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; pub fn create_connectivity_mock() -> (ConnectivityRequester, ConnectivityManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectivityRequester::new(tx, event_tx.clone()), - ConnectivityManagerMock::new(rx.fuse(), event_tx), + ConnectivityManagerMock::new(rx, event_tx), ) } #[derive(Debug, Clone)] pub struct ConnectivityManagerMockState { inner: Arc<Mutex<State>>, - event_tx: broadcast::Sender<Arc<ConnectivityEvent>>, + event_tx: ConnectivityEventTx, } #[derive(Debug, Default)] @@ -61,7 +65,7 @@ struct State { } impl ConnectivityManagerMockState { - pub fn new(event_tx: broadcast::Sender<Arc<ConnectivityEvent>>) -> Self { + pub fn new(event_tx: ConnectivityEventTx) -> Self { Self { event_tx, inner: Default::default(), @@ -132,7 +136,7 @@ impl ConnectivityManagerMockState { count, self.call_count().await ); - time::delay_for(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(100)).await; } } @@ -156,9 +160,8 @@ impl ConnectivityManagerMockState { .await } - #[allow(dead_code)] pub fn publish_event(&self, event: ConnectivityEvent) { - self.event_tx.send(Arc::new(event)).unwrap(); + self.event_tx.send(event).unwrap(); } pub(self) async fn with_state<F, R>(&self, f: F) -> R @@ -169,15 +172,12 @@ impl ConnectivityManagerMockState { } pub struct ConnectivityManagerMock { - receiver: Fuse<mpsc::Receiver<ConnectivityRequest>>, + receiver: mpsc::Receiver<ConnectivityRequest>, state: ConnectivityManagerMockState, } impl ConnectivityManagerMock { - pub fn new( - receiver: Fuse<mpsc::Receiver<ConnectivityRequest>>, - event_tx: broadcast::Sender<Arc<ConnectivityEvent>>, - ) -> Self { + pub fn new(receiver: mpsc::Receiver<ConnectivityRequest>, event_tx: ConnectivityEventTx) -> Self { Self { receiver, state: ConnectivityManagerMockState::new(event_tx), @@ -195,7 +195,7 @@ impl ConnectivityManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index 00756cda1c..cc3d54165f 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -34,15 +34,18 @@ use crate::{ peer_manager::{NodeId, Peer, PeerFeatures}, test_utils::{node_identity::build_node_identity, transport}, }; -use futures::{channel::mpsc, lock::Mutex, StreamExt}; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use tokio::runtime::Handle; +use tokio::{ + runtime::Handle, + sync::{mpsc, Mutex}, +}; +use tokio_stream::StreamExt; pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::Receiver<PeerConnectionRequest>) { - let (tx, rx) = mpsc::channel(0); + let (tx, rx) = mpsc::channel(1); ( PeerConnection::new( 1, @@ -114,7 +117,7 @@ pub async fn new_peer_connection_mock_pair() -> ( create_peer_connection_mock_pair(peer1, peer2).await } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct PeerConnectionMockState { call_count: Arc<AtomicUsize>, mux_control: Arc<Mutex<multiplexing::Control>>, @@ -181,7 +184,7 @@ impl PeerConnectionMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index a150a765f8..3e5d7229a0 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -29,12 +29,14 @@ use crate::{ protocol::Protocols, transports::Transport, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_shutdown::ShutdownSignal; use tari_storage::HashmapDatabase; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; #[derive(Clone, Debug)] pub struct TestNodeConfig { diff --git a/comms/src/tor/control_client/client.rs b/comms/src/tor/control_client/client.rs index 5d14c770b7..573bce60c0 100644 --- a/comms/src/tor/control_client/client.rs +++ b/comms/src/tor/control_client/client.rs @@ -34,10 +34,13 @@ use crate::{ tor::control_client::{event::TorControlEvent, monitor::spawn_monitor}, transports::{TcpTransport, Transport}, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{borrow::Cow, fmt, fmt::Display, num::NonZeroU16}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; +use tokio_stream::wrappers::BroadcastStream; /// Client for the Tor control port. /// @@ -80,8 +83,8 @@ impl TorControlPortClient { &self.event_tx } - pub fn get_event_stream(&self) -> broadcast::Receiver<TorControlEvent> { - self.event_tx.subscribe() + pub fn get_event_stream(&self) -> BroadcastStream<TorControlEvent> { + BroadcastStream::new(self.event_tx.subscribe()) } /// Authenticate with the tor control port @@ -232,8 +235,7 @@ impl TorControlPortClient { } async fn receive_line(&mut self) -> Result<ResponseLine, TorClientError> { - let line = self.output_stream.next().await.ok_or(TorClientError::UnexpectedEof)?; - + let line = self.output_stream.recv().await.ok_or(TorClientError::UnexpectedEof)?; Ok(line) } } @@ -273,9 +275,11 @@ mod test { runtime, tor::control_client::{test_server, test_server::canned_responses, types::PrivateKey}, }; - use futures::{future, AsyncWriteExt}; + use futures::future; use std::net::SocketAddr; use tari_test_utils::unpack_enum; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; async fn setup_test() -> (TorControlPortClient, test_server::State) { let (_, mock_state, socket) = test_server::spawn().await; @@ -298,7 +302,7 @@ mod test { let _out_sock = result_out.unwrap(); let (mut in_sock, _) = result_in.unwrap().unwrap(); in_sock.write(b"test123").await.unwrap(); - in_sock.close().await.unwrap(); + in_sock.shutdown().await.unwrap(); } #[runtime::test] diff --git a/comms/src/tor/control_client/monitor.rs b/comms/src/tor/control_client/monitor.rs index a5191b466d..72eb6b88ef 100644 --- a/comms/src/tor/control_client/monitor.rs +++ b/comms/src/tor/control_client/monitor.rs @@ -21,11 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{event::TorControlEvent, parsers, response::ResponseLine, LOG_TARGET}; -use crate::{compat::IoCompat, runtime::task}; -use futures::{channel::mpsc, future, future::Either, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use crate::runtime::task; +use futures::{future::Either, SinkExt, Stream, StreamExt}; use log::*; use std::fmt; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LinesCodec}; pub fn spawn_monitor<TSocket>( @@ -36,16 +39,19 @@ pub fn spawn_monitor<TSocket>( where TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut responses_tx, responses_rx) = mpsc::channel(100); + let (responses_tx, responses_rx) = mpsc::channel(100); task::spawn(async move { - let framed = Framed::new(IoCompat::new(socket), LinesCodec::new()); + let framed = Framed::new(socket, LinesCodec::new()); let (mut sink, mut stream) = framed.split(); loop { - let either = future::select(cmd_rx.next(), stream.next()).await; + let either = tokio::select! { + next = cmd_rx.recv() => Either::Left(next), + next = stream.next() => Either::Right(next), + }; match either { // Received a command to send to the control server - Either::Left((Some(line), _)) => { + Either::Left(Some(line)) => { trace!(target: LOG_TARGET, "Writing command of length '{}'", line.len()); if let Err(err) = sink.send(line).await { error!( @@ -56,7 +62,7 @@ where } }, // Command stream ended - Either::Left((None, _)) => { + Either::Left(None) => { debug!( target: LOG_TARGET, "Tor control server command receiver closed. Monitor is exiting." @@ -65,7 +71,7 @@ where }, // Received a line from the control server - Either::Right((Some(Ok(line)), _)) => { + Either::Right(Some(Ok(line))) => { trace!(target: LOG_TARGET, "Read line of length '{}'", line.len()); match parsers::response_line(&line) { Ok(mut line) => { @@ -95,7 +101,7 @@ where }, // Error receiving a line from the control server - Either::Right((Some(Err(err)), _)) => { + Either::Right(Some(Err(err))) => { error!( target: LOG_TARGET, "Line framing error when reading from tor control server: '{:?}'. Monitor is exiting.", err @@ -103,7 +109,7 @@ where break; }, // The control server disconnected - Either::Right((None, _)) => { + Either::Right(None) => { cmd_rx.close(); debug!( target: LOG_TARGET, diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 5a5e1b3b7c..1741cfc721 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -20,13 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - compat::IoCompat, - memsocket::MemorySocket, - multiaddr::Multiaddr, - runtime, - test_utils::transport::build_connected_sockets, -}; +use crate::{memsocket::MemorySocket, multiaddr::Multiaddr, runtime, test_utils::transport::build_connected_sockets}; use futures::{lock::Mutex, stream, SinkExt, StreamExt}; use std::sync::Arc; use tokio_util::codec::{Framed, LinesCodec}; @@ -82,7 +76,7 @@ impl TorControlPortTestServer { } pub async fn run(self) { - let mut framed = Framed::new(IoCompat::new(self.socket), LinesCodec::new()); + let mut framed = Framed::new(self.socket, LinesCodec::new()); let state = self.state; while let Some(msg) = framed.next().await { state.request_lines.lock().await.push(msg.unwrap()); diff --git a/comms/src/tor/hidden_service/controller.rs b/comms/src/tor/hidden_service/controller.rs index e19818df58..74b89808fb 100644 --- a/comms/src/tor/hidden_service/controller.rs +++ b/comms/src/tor/hidden_service/controller.rs @@ -214,7 +214,7 @@ impl HiddenServiceController { "Failed to reestablish connection with tor control server because '{:?}'", err ); warn!(target: LOG_TARGET, "Will attempt again in 5 seconds..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; }, Either::Right(_) => { diff --git a/comms/src/transports/memory.rs b/comms/src/transports/memory.rs index 074e728274..8e3a1d7a91 100644 --- a/comms/src/transports/memory.rs +++ b/comms/src/transports/memory.rs @@ -129,7 +129,8 @@ impl Stream for Listener { mod test { use super::*; use crate::runtime; - use futures::{future::join, stream::StreamExt, AsyncReadExt, AsyncWriteExt}; + use futures::{future::join, stream::StreamExt}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[runtime::test] async fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { diff --git a/comms/src/transports/mod.rs b/comms/src/transports/mod.rs index 0c04e50a72..75b16db975 100644 --- a/comms/src/transports/mod.rs +++ b/comms/src/transports/mod.rs @@ -24,8 +24,8 @@ // Copyright (c) The Libra Core Contributors // SPDX-License-Identifier: Apache-2.0 -use futures::Stream; use multiaddr::Multiaddr; +use tokio_stream::Stream; mod dns; mod helpers; @@ -37,7 +37,7 @@ mod socks; pub use socks::{SocksConfig, SocksTransport}; mod tcp; -pub use tcp::{TcpSocket, TcpTransport}; +pub use tcp::TcpTransport; mod tcp_with_tor; pub use tcp_with_tor::TcpWithTorTransport; diff --git a/comms/src/transports/socks.rs b/comms/src/transports/socks.rs index 2027f34561..7dc87ef0e4 100644 --- a/comms/src/transports/socks.rs +++ b/comms/src/transports/socks.rs @@ -24,12 +24,13 @@ use crate::{ multiaddr::Multiaddr, socks, socks::Socks5Client, - transports::{dns::SystemDnsResolver, tcp::TcpTransport, TcpSocket, Transport}, + transports::{dns::SystemDnsResolver, tcp::TcpTransport, Transport}, }; -use std::{io, time::Duration}; +use std::io; +use tokio::net::TcpStream; -/// SO_KEEPALIVE setting for the SOCKS TCP connection -const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); +// /// SO_KEEPALIVE setting for the SOCKS TCP connection +// const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); #[derive(Clone, Debug)] pub struct SocksConfig { @@ -57,7 +58,7 @@ impl SocksTransport { pub fn create_socks_tcp_transport() -> TcpTransport { let mut tcp_transport = TcpTransport::new(); tcp_transport.set_nodelay(true); - tcp_transport.set_keepalive(Some(SOCKS_SO_KEEPALIVE)); + // .set_keepalive(Some(SOCKS_SO_KEEPALIVE)) tcp_transport.set_dns_resolver(SystemDnsResolver); tcp_transport } @@ -66,7 +67,7 @@ impl SocksTransport { tcp: TcpTransport, socks_config: SocksConfig, dest_addr: Multiaddr, - ) -> io::Result<TcpSocket> { + ) -> io::Result<TcpStream> { // Create a new connection to the SOCKS proxy let socks_conn = tcp.dial(socks_config.proxy_address).await?; let mut client = Socks5Client::new(socks_conn); diff --git a/comms/src/transports/tcp.rs b/comms/src/transports/tcp.rs index 8112cca203..6b47e7c357 100644 --- a/comms/src/transports/tcp.rs +++ b/comms/src/transports/tcp.rs @@ -25,44 +25,42 @@ use crate::{ transports::dns::{DnsResolverRef, SystemDnsResolver}, utils::multiaddr::socketaddr_to_multiaddr, }; -use futures::{io::Error, ready, AsyncRead, AsyncWrite, Future, FutureExt, Stream}; +use futures::{ready, FutureExt}; use multiaddr::Multiaddr; use std::{ + future::Future, io, pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}, - net::{TcpListener, TcpStream}, }; +use tokio::net::{TcpListener, TcpStream}; +use tokio_stream::Stream; /// Transport implementation for TCP #[derive(Clone)] pub struct TcpTransport { - recv_buffer_size: Option<usize>, - send_buffer_size: Option<usize>, + // recv_buffer_size: Option<usize>, + // send_buffer_size: Option<usize>, ttl: Option<u32>, - #[allow(clippy::option_option)] - keepalive: Option<Option<Duration>>, + // #[allow(clippy::option_option)] + // keepalive: Option<Option<Duration>>, nodelay: Option<bool>, dns_resolver: DnsResolverRef, } impl TcpTransport { - #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] - setter_mut!(set_recv_buffer_size, recv_buffer_size, Option<usize>); - - #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] - setter_mut!(set_send_buffer_size, send_buffer_size, Option<usize>); + // #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] + // setter_mut!(set_recv_buffer_size, recv_buffer_size, Option<usize>); + // + // #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] + // setter_mut!(set_send_buffer_size, send_buffer_size, Option<usize>); #[doc("Sets `IP_TTL` i.e. the TTL of packets sent from this socket.")] setter_mut!(set_ttl, ttl, Option<u32>); - #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] - setter_mut!(set_keepalive, keepalive, Option<Option<Duration>>); + // #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] + // setter_mut!(set_keepalive, keepalive, Option<Option<Duration>>); #[doc("Sets `TCP_NODELAY` i.e disable Nagle's algorithm if set to true.")] setter_mut!(set_nodelay, nodelay, Option<bool>); @@ -81,9 +79,10 @@ impl TcpTransport { /// Apply socket options to `TcpStream`. fn configure(&self, socket: &TcpStream) -> io::Result<()> { - if let Some(keepalive) = self.keepalive { - socket.set_keepalive(keepalive)?; - } + // https://github.com/rust-lang/rust/issues/69774 + // if let Some(keepalive) = self.keepalive { + // socket.set_keepalive(keepalive)?; + // } if let Some(ttl) = self.ttl { socket.set_ttl(ttl)?; @@ -93,13 +92,13 @@ impl TcpTransport { socket.set_nodelay(nodelay)?; } - if let Some(recv_buffer_size) = self.recv_buffer_size { - socket.set_recv_buffer_size(recv_buffer_size)?; - } - - if let Some(send_buffer_size) = self.send_buffer_size { - socket.set_send_buffer_size(send_buffer_size)?; - } + // if let Some(recv_buffer_size) = self.recv_buffer_size { + // socket.set_recv_buffer_size(recv_buffer_size)?; + // } + // + // if let Some(send_buffer_size) = self.send_buffer_size { + // socket.set_send_buffer_size(send_buffer_size)?; + // } Ok(()) } @@ -108,10 +107,10 @@ impl TcpTransport { impl Default for TcpTransport { fn default() -> Self { Self { - recv_buffer_size: None, - send_buffer_size: None, + // recv_buffer_size: None, + // send_buffer_size: None, ttl: None, - keepalive: None, + // keepalive: None, nodelay: None, dns_resolver: Arc::new(SystemDnsResolver), } @@ -122,7 +121,7 @@ impl Default for TcpTransport { impl Transport for TcpTransport { type Error = io::Error; type Listener = TcpInbound; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { let socket_addr = self @@ -161,12 +160,12 @@ impl<F> TcpOutbound<F> { impl<F> Future for TcpOutbound<F> where F: Future<Output = io::Result<TcpStream>> + Unpin { - type Output = io::Result<TcpSocket>; + type Output = io::Result<TcpStream>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let socket = ready!(Pin::new(&mut self.future).poll(cx))?; - self.config.configure(&socket)?; - Poll::Ready(Ok(TcpSocket::new(socket))) + let stream = ready!(Pin::new(&mut self.future).poll(cx))?; + self.config.configure(&stream)?; + Poll::Ready(Ok(stream)) } } @@ -184,52 +183,14 @@ impl TcpInbound { } impl Stream for TcpInbound { - type Item = io::Result<(TcpSocket, Multiaddr)>; + type Item = io::Result<(TcpStream, Multiaddr)>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let (socket, addr) = ready!(self.listener.poll_accept(cx))?; // Configure each socket self.config.configure(&socket)?; let peer_addr = socketaddr_to_multiaddr(&addr); - Poll::Ready(Some(Ok((TcpSocket::new(socket), peer_addr)))) - } -} - -/// TcpSocket is a wrapper struct for tokio `TcpStream` and implements -/// `futures-rs` AsyncRead/Write -pub struct TcpSocket { - inner: TcpStream, -} - -impl TcpSocket { - pub fn new(stream: TcpStream) -> Self { - Self { inner: stream } - } -} - -impl AsyncWrite for TcpSocket { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -impl AsyncRead for TcpSocket { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize, Error>> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl From<TcpStream> for TcpSocket { - fn from(stream: TcpStream) -> Self { - Self { inner: stream } + Poll::Ready(Some(Ok((socket, peer_addr)))) } } @@ -240,16 +201,15 @@ mod test { #[test] fn configure() { let mut tcp = TcpTransport::new(); - tcp.set_send_buffer_size(123) - .set_recv_buffer_size(456) - .set_nodelay(true) - .set_ttl(789) - .set_keepalive(Some(Duration::from_millis(100))); - - assert_eq!(tcp.send_buffer_size, Some(123)); - assert_eq!(tcp.recv_buffer_size, Some(456)); + // tcp.set_send_buffer_size(123) + // .set_recv_buffer_size(456) + tcp.set_nodelay(true).set_ttl(789); + // .set_keepalive(Some(Duration::from_millis(100))); + + // assert_eq!(tcp.send_buffer_size, Some(123)); + // assert_eq!(tcp.recv_buffer_size, Some(456)); assert_eq!(tcp.nodelay, Some(true)); assert_eq!(tcp.ttl, Some(789)); - assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); + // assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); } } diff --git a/comms/src/transports/tcp_with_tor.rs b/comms/src/transports/tcp_with_tor.rs index de54e17bb5..be800d9bfb 100644 --- a/comms/src/transports/tcp_with_tor.rs +++ b/comms/src/transports/tcp_with_tor.rs @@ -21,16 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::Transport; -use crate::transports::{ - dns::TorDnsResolver, - helpers::is_onion_address, - SocksConfig, - SocksTransport, - TcpSocket, - TcpTransport, -}; +use crate::transports::{dns::TorDnsResolver, helpers::is_onion_address, SocksConfig, SocksTransport, TcpTransport}; use multiaddr::Multiaddr; use std::io; +use tokio::net::TcpStream; /// Transport implementation for TCP with Tor support #[derive(Clone, Default)] @@ -69,7 +63,7 @@ impl TcpWithTorTransport { impl Transport for TcpWithTorTransport { type Error = io::Error; type Listener = <TcpTransport as Transport>::Listener; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { self.tcp_transport.listen(addr).await diff --git a/comms/src/utils/mod.rs b/comms/src/utils/mod.rs index 697a758c1f..e543e644bb 100644 --- a/comms/src/utils/mod.rs +++ b/comms/src/utils/mod.rs @@ -22,5 +22,6 @@ pub mod cidr; pub mod datetime; +pub mod mpsc; pub mod multiaddr; pub mod signature; diff --git a/comms/src/utils/mpsc.rs b/comms/src/utils/mpsc.rs new file mode 100644 index 0000000000..8ded39967f --- /dev/null +++ b/comms/src/utils/mpsc.rs @@ -0,0 +1,33 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use tokio::sync::mpsc; + +pub async fn send_all<T, I: IntoIterator<Item = T>>( + sender: &mpsc::Sender<T>, + iter: I, +) -> Result<(), mpsc::error::SendError<T>> { + for item in iter { + sender.send(item).await?; + } + Ok(()) +} diff --git a/comms/tests/greeting_service.rs b/comms/tests/greeting_service.rs index c13ae842e4..e70b6e16ff 100644 --- a/comms/tests/greeting_service.rs +++ b/comms/tests/greeting_service.rs @@ -23,14 +23,14 @@ #![cfg(feature = "rpc")] use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{cmp, time::Duration}; use tari_comms::{ async_trait, protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, }; use tari_comms_rpc_macros::tari_rpc; -use tokio::{task, time}; +use tokio::{sync::mpsc, task, time}; #[tari_rpc(protocol_name = b"t/greeting/1", server_struct = GreetingServer, client_struct = GreetingClient)] pub trait GreetingRpc: Send + Sync + 'static { @@ -85,15 +85,9 @@ impl GreetingRpc for GreetingService { async fn get_greetings(&self, request: Request<u32>) -> Result<Streaming<String>, RpcStatus> { let num = *request.message(); - let (mut tx, rx) = mpsc::channel(num as usize); + let (tx, rx) = mpsc::channel(num as usize); let greetings = self.greetings[..cmp::min(num as usize + 1, self.greetings.len())].to_vec(); - task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - tx.send_all(&mut stream).await.unwrap(); - }); + task::spawn(async move { utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await }); Ok(Streaming::new(rx)) } @@ -113,7 +107,7 @@ impl GreetingRpc for GreetingService { item_size, num_items, } = request.into_message(); - let (mut tx, rx) = mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let t = std::time::Instant::now(); task::spawn(async move { let item = iter::repeat(0u8).take(item_size as usize).collect::<Vec<_>>(); @@ -136,7 +130,7 @@ impl GreetingRpc for GreetingService { } async fn slow_response(&self, request: Request<u64>) -> Result<Response<()>, RpcStatus> { - time::delay_for(Duration::from_secs(request.into_message())).await; + time::sleep(Duration::from_secs(request.into_message())).await; Ok(Response::new(())) } } diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs index 0376a7dddc..933e158398 100644 --- a/comms/tests/rpc_stress.rs +++ b/comms/tests/rpc_stress.rs @@ -155,6 +155,7 @@ async fn run_stress_test(test_params: Params) { } future::join_all(tasks).await.into_iter().for_each(Result::unwrap); + log::info!("Stress test took {:.2?}", time.elapsed()); } @@ -259,7 +260,7 @@ async fn high_contention_high_concurrency() { .await; } -#[tokio_macros::test] +#[tokio::test] async fn run() { // let _ = env_logger::try_init(); log_timing("quick", quick()).await; diff --git a/comms/tests/substream_stress.rs b/comms/tests/substream_stress.rs index cbae6e8e52..a72eff03f3 100644 --- a/comms/tests/substream_stress.rs +++ b/comms/tests/substream_stress.rs @@ -23,7 +23,7 @@ mod helpers; use helpers::create_comms; -use futures::{channel::mpsc, future, SinkExt, StreamExt}; +use futures::{future, SinkExt, StreamExt}; use std::time::Duration; use tari_comms::{ framing, @@ -35,7 +35,7 @@ use tari_comms::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::unpack_enum; -use tokio::{task, time::Instant}; +use tokio::{sync::mpsc, task, time::Instant}; const PROTOCOL_NAME: &[u8] = b"test/dummy/protocol"; @@ -79,7 +79,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s task::spawn({ let sample = sample.clone(); async move { - while let Some(event) = notif_rx.next().await { + while let Some(event) = notif_rx.recv().await { unpack_enum!(ProtocolEvent::NewInboundSubstream(_n, remote_substream) = event.event); let mut remote_substream = framing::canonical(remote_substream, frame_size); @@ -150,7 +150,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s println!("avg t = {}ms", avg); } -#[tokio_macros::test] +#[tokio::test] async fn many_at_frame_limit() { const NUM_SUBSTREAMS: usize = 20; const NUM_ITERATIONS_PER_STREAM: usize = 100; diff --git a/infrastructure/shutdown/Cargo.toml b/infrastructure/shutdown/Cargo.toml index 1102ec037a..176bace6e2 100644 --- a/infrastructure/shutdown/Cargo.toml +++ b/infrastructure/shutdown/Cargo.toml @@ -12,7 +12,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "^0.3.1" +futures = "^0.3" [dev-dependencies] -tokio = {version="^0.2", features=["rt-core"]} +tokio = {version="1", default-features = false, features = ["rt", "macros"]} diff --git a/infrastructure/shutdown/src/lib.rs b/infrastructure/shutdown/src/lib.rs index f0054bd843..e79f45842f 100644 --- a/infrastructure/shutdown/src/lib.rs +++ b/infrastructure/shutdown/src/lib.rs @@ -26,14 +26,16 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use futures::{ - channel::{oneshot, oneshot::Canceled}, - future::{Fuse, FusedFuture, Shared}, + +pub mod oneshot_trigger; + +use crate::oneshot_trigger::OneshotSignal; +use futures::future::FusedFuture; +use std::{ + future::Future, + pin::Pin, task::{Context, Poll}, - Future, - FutureExt, }; -use std::pin::Pin; /// Trigger for shutdowns. /// @@ -42,71 +44,63 @@ use std::pin::Pin; /// /// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as /// long as required by the application. -pub struct Shutdown { - trigger: Option<oneshot::Sender<()>>, - signal: ShutdownSignal, - on_triggered: Option<Box<dyn FnOnce() + Send + Sync>>, -} - +pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>); impl Shutdown { - /// Create a new Shutdown pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self { - trigger: Some(tx), - signal: rx.fuse().shared(), - on_triggered: None, - } + Self(oneshot_trigger::OneshotTrigger::new()) } - /// Set the on_triggered callback - pub fn on_triggered<F>(&mut self, on_trigger: F) -> &mut Self - where F: FnOnce() + Send + Sync + 'static { - self.on_triggered = Some(Box::new(on_trigger)); - self + pub fn trigger(&mut self) { + self.0.broadcast(()); } - /// Convert this into a ShutdownSignal without consuming the - /// struct. - pub fn to_signal(&self) -> ShutdownSignal { - self.signal.clone() + pub fn is_triggered(&self) -> bool { + self.0.is_used() } - /// Trigger any listening signals - pub fn trigger(&mut self) -> Result<(), ShutdownError> { - match self.trigger.take() { - Some(trigger) => { - trigger.send(()).map_err(|_| ShutdownError)?; + pub fn to_signal(&self) -> ShutdownSignal { + self.0.to_signal().into() + } +} - if let Some(on_triggered) = self.on_triggered.take() { - on_triggered(); - } +/// Receiver end of a shutdown signal. Once received the consumer should shut down. +#[derive(Debug, Clone)] +pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>); - Ok(()) - }, - None => Ok(()), - } +impl ShutdownSignal { + pub fn is_triggered(&self) -> bool { + self.0.is_terminated() } - pub fn is_triggered(&self) -> bool { - self.trigger.is_none() + /// Wait for the shutdown signal to trigger. + pub fn wait(&mut self) -> &mut Self { + self } } -impl Drop for Shutdown { - fn drop(&mut self) { - let _ = self.trigger(); +impl Future for ShutdownSignal { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.0).poll(cx) { + // Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } } } -impl Default for Shutdown { - fn default() -> Self { - Self::new() +impl FusedFuture for ShutdownSignal { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } -/// Receiver end of a shutdown signal. Once received the consumer should shut down. -pub type ShutdownSignal = Shared<Fuse<oneshot::Receiver<()>>>; +impl From<oneshot_trigger::OneshotSignal<()>> for ShutdownSignal { + fn from(inner: OneshotSignal<()>) -> Self { + Self(inner) + } +} #[derive(Debug, Clone, Default)] pub struct OptionalShutdownSignal(Option<ShutdownSignal>); @@ -137,11 +131,11 @@ impl OptionalShutdownSignal { } impl Future for OptionalShutdownSignal { - type Output = Result<(), Canceled>; + type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { match self.0.as_mut() { - Some(inner) => inner.poll_unpin(cx), + Some(inner) => Pin::new(inner).poll(cx), None => Poll::Pending, } } @@ -165,73 +159,50 @@ impl FusedFuture for OptionalShutdownSignal { } } -#[derive(Debug)] -pub struct ShutdownError; - #[cfg(test)] mod test { use super::*; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; - use tokio::runtime::Runtime; - - #[test] - fn trigger() { - let rt = Runtime::new().unwrap(); + use tokio::task; + + #[tokio::test] + async fn trigger() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); assert!(!shutdown.is_triggered()); - rt.spawn(async move { - signal.await.unwrap(); + let fut = task::spawn(async move { + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + assert!(shutdown.is_triggered()); // Shutdown::trigger is idempotent - shutdown.trigger().unwrap(); + shutdown.trigger(); assert!(shutdown.is_triggered()); + fut.await.unwrap(); } - #[test] - fn signal_clone() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn signal_clone() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + fut.await.unwrap(); } - #[test] - fn drop_trigger() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn drop_trigger() { let shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); drop(shutdown); - } - - #[test] - fn on_trigger() { - let rt = Runtime::new().unwrap(); - let spy = Arc::new(AtomicBool::new(false)); - let spy_clone = Arc::clone(&spy); - let mut shutdown = Shutdown::new(); - shutdown.on_triggered(move || { - spy_clone.store(true, Ordering::SeqCst); - }); - let signal = shutdown.to_signal(); - rt.spawn(async move { - signal.await.unwrap(); - }); - shutdown.trigger().unwrap(); - assert!(spy.load(Ordering::SeqCst)); + fut.await.unwrap(); } } diff --git a/infrastructure/shutdown/src/oneshot_trigger.rs b/infrastructure/shutdown/src/oneshot_trigger.rs new file mode 100644 index 0000000000..7d2ee5b46a --- /dev/null +++ b/infrastructure/shutdown/src/oneshot_trigger.rs @@ -0,0 +1,106 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{ + channel::{oneshot, oneshot::Receiver}, + future::{FusedFuture, Shared}, + FutureExt, +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +pub fn channel<T: Clone>() -> OneshotTrigger<T> { + OneshotTrigger::new() +} + +pub struct OneshotTrigger<T> { + sender: Option<oneshot::Sender<T>>, + signal: OneshotSignal<T>, +} + +impl<T: Clone> OneshotTrigger<T> { + pub fn new() -> Self { + let (tx, rx) = oneshot::channel(); + Self { + sender: Some(tx), + signal: rx.shared().into(), + } + } + + pub fn to_signal(&self) -> OneshotSignal<T> { + self.signal.clone() + } + + pub fn broadcast(&mut self, item: T) { + if let Some(tx) = self.sender.take() { + let _ = tx.send(item); + } + } + + pub fn is_used(&self) -> bool { + self.sender.is_none() + } +} + +impl<T: Clone> Default for OneshotTrigger<T> { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OneshotSignal<T> { + inner: Shared<oneshot::Receiver<T>>, +} + +impl<T: Clone> From<Shared<oneshot::Receiver<T>>> for OneshotSignal<T> { + fn from(inner: Shared<Receiver<T>>) -> Self { + Self { inner } + } +} + +impl<T: Clone> Future for OneshotSignal<T> { + type Output = Option<T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + if self.inner.is_terminated() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(Some(v)), + // Channel canceled + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl<T: Clone> FusedFuture for OneshotSignal<T> { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} diff --git a/infrastructure/storage/Cargo.toml b/infrastructure/storage/Cargo.toml index 5db74c5097..3915bde787 100644 --- a/infrastructure/storage/Cargo.toml +++ b/infrastructure/storage/Cargo.toml @@ -13,13 +13,13 @@ edition = "2018" bincode = "1.1" log = "0.4.0" lmdb-zero = "0.4.4" -thiserror = "1.0.20" +thiserror = "1.0.26" rmp = "0.8.7" rmp-serde = "0.13.7" serde = "1.0.80" serde_derive = "1.0.80" tari_utilities = "^0.3" -bytes = "0.4.12" +bytes = "0.5" [dev-dependencies] rand = "0.8" diff --git a/infrastructure/test_utils/Cargo.toml b/infrastructure/test_utils/Cargo.toml index bee18392c9..c2ea538ffb 100644 --- a/infrastructure/test_utils/Cargo.toml +++ b/infrastructure/test_utils/Cargo.toml @@ -10,9 +10,10 @@ license = "BSD-3-Clause" [dependencies] tari_shutdown = {version="*", path="../shutdown"} + futures-test = { version = "^0.3.1" } futures = {version= "^0.3.1"} rand = "0.8" -tokio = {version= "0.2.10", features=["rt-threaded", "time", "io-driver"]} +tokio = {version= "1.10", features=["rt-multi-thread", "time"]} lazy_static = "1.3.0" tempfile = "3.1.0" diff --git a/infrastructure/test_utils/src/futures/async_assert_eventually.rs b/infrastructure/test_utils/src/futures/async_assert_eventually.rs index cddfb29b37..f9a1f3ee9d 100644 --- a/infrastructure/test_utils/src/futures/async_assert_eventually.rs +++ b/infrastructure/test_utils/src/futures/async_assert_eventually.rs @@ -46,7 +46,7 @@ macro_rules! async_assert_eventually { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; value = $check_expr; } }}; @@ -82,7 +82,7 @@ macro_rules! async_assert { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; } }}; ($check_expr:expr$(,)?) => {{ diff --git a/infrastructure/test_utils/src/runtime.rs b/infrastructure/test_utils/src/runtime.rs index 7b8c5faa57..22ee5962a6 100644 --- a/infrastructure/test_utils/src/runtime.rs +++ b/infrastructure/test_utils/src/runtime.rs @@ -26,12 +26,8 @@ use tari_shutdown::Shutdown; use tokio::{runtime, runtime::Runtime, task, task::JoinError}; pub fn create_runtime() -> Runtime { - tokio::runtime::Builder::new() - .threaded_scheduler() - .enable_io() - .enable_time() - .max_threads(8) - .core_threads(4) + tokio::runtime::Builder::new_multi_thread() + .enable_all() .build() .expect("Could not create runtime") } @@ -49,7 +45,7 @@ where F: Future<Output = ()> + Send + 'static { /// Create a runtime and report if it panics. If there are tasks still running after the panic, this /// will carry on running forever. -// #[deprecated(note = "use tokio_macros::test instead")] +// #[deprecated(note = "use tokio::test instead")] pub fn test_async<F>(f: F) where F: FnOnce(&mut TestRuntime) { let mut rt = TestRuntime::from(create_runtime()); diff --git a/infrastructure/test_utils/src/streams/mod.rs b/infrastructure/test_utils/src/streams/mod.rs index a70e588f7e..c3f31d6b42 100644 --- a/infrastructure/test_utils/src/streams/mod.rs +++ b/infrastructure/test_utils/src/streams/mod.rs @@ -20,8 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{Stream, StreamExt}; -use std::{collections::HashMap, hash::Hash, time::Duration}; +use futures::{stream, Stream, StreamExt}; +use std::{borrow::BorrowMut, collections::HashMap, hash::Hash, time::Duration}; +use tokio::sync::{broadcast, mpsc}; #[allow(dead_code)] #[allow(clippy::mutable_key_type)] // Note: Clippy Breaks with Interior Mutability Error @@ -54,7 +55,6 @@ where #[macro_export] macro_rules! collect_stream { ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::{Stream, StreamExt}; use tokio::time; // Evaluate $stream once, NOT in the loop 🐛🚨 @@ -62,14 +62,17 @@ macro_rules! collect_stream { let mut items = Vec::new(); loop { - if let Some(item) = time::timeout($timeout, stream.next()).await.expect( - format!( - "Timeout before stream could collect {} item(s). Got {} item(s).", - $take, - items.len() + if let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next(stream)) + .await + .expect( + format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + ) + .as_str(), ) - .as_str(), - ) { + { items.push(item); if items.len() == $take { break items; @@ -80,11 +83,86 @@ macro_rules! collect_stream { } }}; ($stream:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::StreamExt; use tokio::time; + let mut stream = &mut $stream; let mut items = Vec::new(); - while let Some(item) = time::timeout($timeout, $stream.next()) + while let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next($stream)) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + + let mut items = Vec::new(); + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv ended early", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, stream.recv()) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_try_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + + let mut items = Vec::new(); + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv returned unexpected result", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Ok(item) = time::timeout($timeout, stream.recv()) .await .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) { @@ -139,3 +217,56 @@ where } } } + +pub async fn assert_in_mpsc<T, P, R>(rx: &mut mpsc::Receiver<T>, mut predicate: P, timeout: Duration) -> R +where P: FnMut(T) -> Option<R> { + loop { + if let Some(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the mpsc stream ended"); + } + } +} + +pub async fn assert_in_broadcast<T, P, R>(rx: &mut broadcast::Receiver<T>, mut predicate: P, timeout: Duration) -> R +where + P: FnMut(T) -> Option<R>, + T: Clone, +{ + loop { + if let Ok(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the broadcast channel ended"); + } + } +} + +pub fn convert_mpsc_to_stream<'a, T>(rx: &'a mut mpsc::Receiver<T>) -> impl Stream<Item = T> + 'a { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_unbounded_mpsc_to_stream<'a, T>(rx: &'a mut mpsc::UnboundedReceiver<T>) -> impl Stream<Item = T> + 'a { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_broadcast_to_stream<'a, T, S>(rx: S) -> impl Stream<Item = Result<T, broadcast::error::RecvError>> + 'a +where + T: Clone + Send + 'static, + S: BorrowMut<broadcast::Receiver<T>> + 'a, +{ + stream::unfold(rx, |mut rx| async move { + Some(rx.borrow_mut().recv().await).map(|t| (t, rx)) + }) +}