diff --git a/.circleci/config.yml b/.circleci/config.yml index 0478f5aaae..f665baf02c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -129,10 +129,10 @@ commands: when: always - run: name: Run ffi cucumber scenarios - command: cd integration_tests && mkdir -p cucumber_output && node_modules/.bin/cucumber-js --tags "not @long-running and not @broken and not @flaky and @wallet-ffi" --format json:cucumber_output/tests-ffi.cucumber --exit + command: cd integration_tests && mkdir -p cucumber_output && node_modules/.bin/cucumber-js --tags "not @long-running and not @broken and not @flaky and @wallet-ffi" --format json:cucumber_output/tests_ffi.cucumber --exit - run: name: Generate report (ffi) - command: cd integration_tests && touch cucumber_output/tests-ffi.cucumber && node ./generate_report.js cucumber_output/tests-ffi.cucumber temp/reports/cucumber_ffi_report.html + command: cd integration_tests && node ./generate_report.js "cucumber_output/tests_ffi.cucumber" "temp/reports/cucumber_ffi_report.html" when: always # - run: # name: Run flaky/broken cucumber scenarios (Always pass) diff --git a/Cargo.lock b/Cargo.lock index f3a00b0139..b93e14637e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e61f2b7f93d2c7d2b08263acaa4a363b3e276806c68af6134c44f523bf1aacd" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -28,40 +19,29 @@ dependencies = [ [[package]] name = "aead" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e3e798aa0c8239776f54415bc06f3d74b1850f3f830b45c35cfc80556973f70" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" dependencies = [ "generic-array", ] -[[package]] -name = "aes" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2bc6d3f370b5666245ff421e231cba4353df936e26986d2918e61a8fd6aef6" -dependencies = [ - "aes-soft 0.5.0", - "aesni 0.8.0", - "block-cipher", -] - [[package]] name = "aes" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "884391ef1066acaa41e766ba8f596341b96e93ce34f9a43e7d24bf0a0eaf0561" dependencies = [ - "aes-soft 0.6.4", - "aesni 0.10.0", + "aes-soft", + "aesni", "cipher 0.2.5", ] [[package]] name = "aes" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495ee669413bfbe9e8cace80f4d3d78e6d8c8d99579f97fb93bde351b185f2d4" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -85,29 +65,18 @@ dependencies = [ [[package]] name = "aes-gcm" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a930fd487faaa92a30afa92cc9dd1526a5cff67124abbbb1c617ce070f4dcf" +checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" dependencies = [ - "aead 0.4.2", - "aes 0.7.4", + "aead 0.4.3", + "aes 0.7.5", "cipher 0.3.0", "ctr 0.8.0", - "ghash 0.4.3", + "ghash 0.4.4", "subtle", ] -[[package]] -name = "aes-soft" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63dd91889c49327ad7ef3b500fd1109dbd3c509a03db0d4a9ce413b79f575cb6" -dependencies = [ - "block-cipher", - "byteorder", - "opaque-debug", -] - [[package]] name = "aes-soft" version = "0.6.4" @@ -118,16 +87,6 @@ dependencies = [ "opaque-debug", ] -[[package]] -name = "aesni" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6fe808308bb07d393e2ea47780043ec47683fcf19cf5efc8ca51c50cc8c68a" -dependencies = [ - "block-cipher", - "opaque-debug", -] - [[package]] name = "aesni" version = "0.10.0" @@ -206,9 +165,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" dependencies = [ "async-stream-impl", "futures-core", @@ -216,9 +175,9 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -259,21 +218,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" -[[package]] -name = "backtrace" -version = "0.3.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a905d892734eea339e896738c14b9afce22b5318f64b951e70bf3844419b01" -dependencies = [ - "addr2line", - "cc", - "cfg-if 1.0.0", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base58-monero" version = "0.3.0" @@ -302,12 +246,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "base64" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" - [[package]] name = "base64" version = "0.12.3" @@ -337,7 +275,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -360,7 +298,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "which", + "which 3.1.1", ] [[package]] @@ -389,9 +327,9 @@ checksum = "3e54f7b7a46d7b183eb41e2d82965261fa8a1597c68b50aced268ee1fc70272d" [[package]] name = "blake2" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a5720225ef5daecf08657f23791354e1685a8c91a4c60c7f3d3b2892f978f4" +checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" dependencies = [ "crypto-mac", "digest", @@ -435,12 +373,12 @@ checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" [[package]] name = "blowfish" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f06850ba969bc59388b2cc0a4f186fc6d9d37208863b15b84ae3866ac90ac06" +checksum = "32fa6a061124e37baba002e496d203e23ba3d7b73750be82dbfbc92913048a5b" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -459,7 +397,7 @@ dependencies = [ "lazy_static 1.4.0", "memchr", "regex-automata", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -496,30 +434,20 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" -[[package]] -name = "bytes" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" -dependencies = [ - "byteorder", - "iovec", -] - [[package]] name = "bytes" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" -dependencies = [ - "serde 1.0.129", -] [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +dependencies = [ + "serde 1.0.130", +] [[package]] name = "c_linked_list" @@ -550,12 +478,12 @@ dependencies = [ [[package]] name = "cast5" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ed1e6b53a3de8bafcce4b88867893c234e57f91686a4726d8e803771f0b55b" +checksum = "1285caf81ea1f1ece6b24414c521e625ad0ec94d880625c20f2e65d8d3f78823" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -571,7 +499,7 @@ dependencies = [ "log 0.4.14", "proc-macro2 1.0.28", "quote 1.0.9", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "syn 1.0.75", "tempfile", @@ -595,11 +523,11 @@ dependencies = [ [[package]] name = "cfb-mode" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fa76b7293f89734378d27057d169dc68077ad34b21dbcabf1c0a646a9462592" +checksum = "1d6975e91054798d325f85f50115056d7deccf6817fe7f947c438ee45b119632" dependencies = [ - "stream-cipher", + "cipher 0.2.5", ] [[package]] @@ -616,9 +544,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chacha20" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8756167ea0aca10e066cdbe7813bd71d2f24e69b0bc7b50509590cef2ce0b9" +checksum = "f08493fa7707effc63254c66c6ea908675912493cd67952eda23c09fae2610b1" dependencies = [ "cfg-if 1.0.0", "cipher 0.3.0", @@ -628,11 +556,11 @@ dependencies = [ [[package]] name = "chacha20poly1305" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "175a11316f33592cf2b71416ee65283730b5b7849813c4891d02a12906ed9acc" +checksum = "b6547abe025f4027edacd9edaa357aded014eecec42a5070d9b885c3c334aba2" dependencies = [ - "aead 0.4.2", + "aead 0.4.3", "chacha20", "cipher 0.3.0", "poly1305", @@ -654,7 +582,7 @@ dependencies = [ "libc", "num-integer", "num-traits 0.2.14", - "serde 1.0.129", + "serde 1.0.130", "time", "winapi 0.3.9", ] @@ -677,7 +605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6316c62053228eddd526a5e6deb6344c80bf2bc1e9786e7f90b3083e73197c1" dependencies = [ "bitstring", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -763,7 +691,7 @@ dependencies = [ "lazy_static 1.4.0", "nom 4.2.3", "rust-ini", - "serde 1.0.129", + "serde 1.0.130", "serde-hjson", "serde_json", "toml 0.4.10", @@ -788,9 +716,9 @@ checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" [[package]] name = "cpufeatures" -version = "0.1.5" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" dependencies = [ "libc", ] @@ -827,7 +755,7 @@ dependencies = [ "clap", "criterion-plot", "csv", - "itertools", + "itertools 0.8.2", "lazy_static 1.4.0", "libc", "num-traits 0.2.14", @@ -836,7 +764,7 @@ dependencies = [ "rand_xoshiro", "rayon", "rayon-core", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "tinytemplate", @@ -851,7 +779,7 @@ checksum = "76f9212ddf2f4a9eb2d401635190600656a1f88a932ef53d06e7fa4c7e02fb8e" dependencies = [ "byteorder", "cast", - "itertools", + "itertools 0.8.2", ] [[package]] @@ -1014,7 +942,7 @@ dependencies = [ "csv-core", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -1068,7 +996,7 @@ dependencies = [ "byteorder", "digest", "rand_core 0.5.1", - "serde 1.0.129", + "serde 1.0.130", "subtle", "zeroize", ] @@ -1083,7 +1011,7 @@ dependencies = [ "digest", "packed_simd_2", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "subtle-ng", "zeroize", ] @@ -1178,12 +1106,12 @@ dependencies = [ [[package]] name = "des" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e084b5048dec677e6c9f27d7abc551dde7d127cf4127fea82323c98a30d7fa0d" +checksum = "b24e7c748888aa2fa8bce21d8c64a52efc810663285315ac7476f7197a982fae" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -1279,7 +1207,7 @@ dependencies = [ "curve25519-dalek", "ed25519", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "sha2", "zeroize", ] @@ -1599,19 +1527,6 @@ dependencies = [ "pin-utils", ] -[[package]] -name = "futures-test-preview" -version = "0.3.0-alpha.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0813d833a213f893d1f07ccd9d49d3de100181d146053b2097b8b934d7945eb" -dependencies = [ - "futures-core-preview", - "futures-executor-preview", - "futures-io-preview", - "futures-util-preview", - "pin-utils", -] - [[package]] name = "futures-timer" version = "0.3.0" @@ -1731,20 +1646,14 @@ dependencies = [ [[package]] name = "ghash" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b442c439366184de619215247d24e908912b175e824a530253845ac4c251a5c1" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" dependencies = [ "opaque-debug", - "polyval 0.5.2", + "polyval 0.5.3", ] -[[package]] -name = "gimli" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0a01e0497841a3b2db4f8afa483cce65f7e96a3498bd6c541734792aeac8fe7" - [[package]] name = "git2" version = "0.8.0" @@ -1792,7 +1701,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7f3675cfef6a30c8031cf9e6493ebdc3bb3272a3fea3923c4210d1830e6a472" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -1847,7 +1756,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -1868,7 +1777,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] @@ -1955,7 +1864,7 @@ version = "0.14.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13f67199e765030fa08fe0bd581af683f0d5bc04ea09c2b1102012c5fb90e7fd" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -1973,6 +1882,18 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.12", + "pin-project-lite 0.2.7", + "tokio 1.10.1", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.4.3" @@ -1992,7 +1913,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "hyper 0.14.12", "native-tls", "tokio 1.10.1", @@ -2037,7 +1958,7 @@ dependencies = [ "byteorder", "color_quant", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits 0.2.14", ] @@ -2090,6 +2011,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -2112,7 +2042,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "436f3455a8a4e9c7b14de9f1206198ee5d0bdc2db1b560339d2141093d7dd389" dependencies = [ "hyper 0.10.16", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", ] @@ -2162,9 +2092,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.100" +version = "0.2.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1fa8cddc8fbbee11227ef194b5317ed014b8acbf15139bd716a18ad3fe99ec5" +checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libgit2-sys" @@ -2289,9 +2219,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" dependencies = [ "scopeguard", ] @@ -2312,7 +2242,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -2335,7 +2265,7 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.5.3", "serde_derive", "serde_yaml", @@ -2359,9 +2289,9 @@ dependencies = [ "libc", "log 0.4.14", "log-mdc", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde-value 0.7.0", "serde_json", "serde_yaml", @@ -2512,17 +2442,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -2555,21 +2474,39 @@ dependencies = [ "fixed-hash", "hex", "hex-literal", - "serde 1.0.129", + "serde 1.0.130", "serde-big-array", "thiserror", "tiny-keccak", ] +[[package]] +name = "multiaddr" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48ee4ea82141951ac6379f964f71b20876d43712bea8faf6dd1a375e08a46499" +dependencies = [ + "arrayref", + "bs58", + "byteorder", + "data-encoding", + "multihash", + "percent-encoding 2.1.0", + "serde 1.0.130", + "static_assertions", + "unsigned-varint", + "url 2.2.2", +] + [[package]] name = "multihash" -version = "0.13.2" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac63698b887d2d929306ea48b63760431ff8a24fac40ddb22f9c7f49fb7cab" +checksum = "752a61cd890ff691b4411423d23816d5866dd5621e4d1c5687a53b94b5a979d8" dependencies = [ "generic-array", "multihash-derive", - "unsigned-varint 0.5.1", + "unsigned-varint", ] [[package]] @@ -2629,9 +2566,12 @@ checksum = "d36047f46c69ef97b60e7b069a26ce9a15cd8a7852eddb6991ea94a83ba36a78" [[package]] name = "nibble_vec" -version = "0.0.4" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d77f3db4bce033f4d04db08079b2ef1c3d02b44e86f25d08886fafa7756ffa" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] [[package]] name = "nix" @@ -2693,10 +2633,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" dependencies = [ "num-bigint 0.3.2", - "num-complex", + "num-complex 0.3.1", "num-integer", "num-iter", - "num-rational", + "num-rational 0.3.2", + "num-traits 0.2.14", +] + +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint 0.4.1", + "num-complex 0.4.0", + "num-integer", + "num-iter", + "num-rational 0.4.0", "num-traits 0.2.14", ] @@ -2722,6 +2676,17 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-bigint" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76e97c412795abf6c24ba30055a8f20642ea57ca12875220b854cfa501bf1e48" +dependencies = [ + "autocfg 1.0.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-bigint-dig" version = "0.6.1" @@ -2736,7 +2701,7 @@ dependencies = [ "num-iter", "num-traits 0.2.14", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "smallvec", "zeroize", ] @@ -2750,6 +2715,15 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits 0.2.14", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -2804,6 +2778,18 @@ dependencies = [ "num-traits 0.2.14", ] +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg 1.0.1", + "num-bigint 0.4.1", + "num-integer", + "num-traits 0.2.14", +] + [[package]] name = "num-traits" version = "0.1.43" @@ -2832,15 +2818,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee2766204889d09937d00bfbb7fec56bb2a199e2ade963cab19185d8a6104c7c" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.8.0" @@ -2967,24 +2944,6 @@ dependencies = [ "libm 0.1.4", ] -[[package]] -name = "parity-multiaddr" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bfda2e46fc5e14122649e2645645a81ee5844e0fb2e727ef560cc71a8b2d801" -dependencies = [ - "arrayref", - "bs58", - "byteorder", - "data-encoding", - "multihash", - "percent-encoding 2.1.0", - "serde 1.0.129", - "static_assertions", - "unsigned-varint 0.6.0", - "url 2.2.2", -] - [[package]] name = "parking_lot" version = "0.10.2" @@ -2997,13 +2956,13 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", - "lock_api 0.4.4", - "parking_lot_core 0.8.3", + "lock_api 0.4.5", + "parking_lot_core 0.8.5", ] [[package]] @@ -3022,9 +2981,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" +checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" dependencies = [ "cfg-if 1.0.0", "instant", @@ -3090,11 +3049,11 @@ dependencies = [ [[package]] name = "pgp" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501f8c2834bc16a23ae40932b9f924c6c5fc1d7cd1cc3536a532f37e81f603ed" +checksum = "856124b4d0a95badd3e1ad353edd7157fc6c6995767b78ef62848f3b296405ff" dependencies = [ - "aes 0.5.0", + "aes 0.6.0", "base64 0.12.3", "bitfield", "block-modes", @@ -3105,6 +3064,7 @@ dependencies = [ "cast5", "cfb-mode", "chrono", + "cipher 0.2.5", "circular", "clear_on_drop", "crc24", @@ -3203,9 +3163,9 @@ checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "poly1305" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fcffab1f78ebbdf4b93b68c1ffebc24037eedf271edaca795732b24e5e4e349" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" dependencies = [ "cpufeatures", "opaque-debug", @@ -3225,9 +3185,9 @@ dependencies = [ [[package]] name = "polyval" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6ba6a405ef63530d6cb12802014b22f9c5751bd17cdcddbe9e46d5c8ae83287" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" dependencies = [ "cfg-if 1.0.0", "cpufeatures", @@ -3307,40 +3267,40 @@ dependencies = [ [[package]] name = "prost" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce49aefe0a6144a45de32927c77bd2859a5f7677b55f220ae5b744e87389c212" +checksum = "de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost-derive", ] [[package]] name = "prost-build" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b10678c913ecbd69350e8535c3aef91a8676c0773fc1d7b95cdd196d7f2f26" +checksum = "355f634b43cdd80724ee7848f95770e7e70eefa6dcf14fea676216573b8fd603" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "heck", - "itertools", + "itertools 0.10.1", "log 0.4.14", "multimap", "petgraph", "prost", "prost-types", "tempfile", - "which", + "which 4.2.2", ] [[package]] name = "prost-derive" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537aa19b95acde10a12fec4301466386f757403de4cd4e5b4fa78fb5ecb18f72" +checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.1", "proc-macro2 1.0.28", "quote 1.0.9", "syn 1.0.75", @@ -3348,11 +3308,11 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1834f67c0697c001304b75be76f67add9c89742eda3a085ad8ee0bb38c3417aa" +checksum = "603bbd6394701d13f3f25aada59c7de9d35a6a5887cfc156181234a44002771b" dependencies = [ - "bytes 0.5.6", + "bytes 1.1.0", "prost", ] @@ -3398,9 +3358,9 @@ dependencies = [ [[package]] name = "radix_trie" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3681b28cd95acfb0560ea9441f82d6a4504fa3b15b97bd7b6e952131820e95" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" dependencies = [ "endian-type", "nibble_vec", @@ -3417,7 +3377,6 @@ dependencies = [ "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc 0.2.0", - "rand_pcg", ] [[package]] @@ -3517,15 +3476,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "rand_pcg" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "rand_xoshiro" version = "0.1.0" @@ -3666,8 +3616,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", - "serde_json", + "serde 1.0.130", "serde_urlencoded", "tokio 0.2.25", "tokio-tls", @@ -3685,7 +3634,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246e9f61b9bb77df069a947682be06e31ac43ea37862e244a69f177694ea6d22" dependencies = [ "base64 0.13.0", - "bytes 1.0.1", + "bytes 1.1.0", "encoding_rs", "futures-core", "futures-util", @@ -3701,7 +3650,7 @@ dependencies = [ "native-tls", "percent-encoding 2.1.0", "pin-project-lite 0.2.7", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "serde_urlencoded", "tokio 1.10.1", @@ -3757,7 +3706,7 @@ checksum = "011e1d58446e9fa3af7cdc1fb91295b10621d3ac4cb3a85cc86385ee9ca50cd3" dependencies = [ "byteorder", "rmp", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -3798,12 +3747,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e52c148ef37f8c375d49d5a73aa70713125b7f19095948a923f80afdeb22ec2" -[[package]] -name = "rustc-demangle" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -3836,11 +3779,11 @@ dependencies = [ [[package]] name = "rustls" -version = "0.17.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d4a31f5d68413404705d6982529b0e11a9aacd4839d1d6222ee3b8cb4015e1" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ - "base64 0.11.0", + "base64 0.13.0", "log 0.4.14", "ring", "sct", @@ -3931,22 +3874,23 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467" +checksum = "5b9bd29cdffb8875b04f71c51058f940cf4e390bbfd2ce669c4f22cd70b492a5" dependencies = [ "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", + "num 0.4.0", "security-framework-sys", ] [[package]] name = "security-framework-sys" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4effb91b4b8b6fb7732e670b6cee160278ff8e6bf485c7805d9e319d76e284" +checksum = "19133a286e494cc3311c165c4676ccb1fd47bed45b55f9d71fbd784ad4cea6f8" dependencies = [ "core-foundation-sys", "libc", @@ -3984,9 +3928,9 @@ checksum = "9dad3f759919b92c3068c696c15c3d17238234498bbdcc80f2c469606f948ac8" [[package]] name = "serde" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1f72836d2aa753853178eda473a3b9d8e4eefdaf20523b919677e6de489f8f1" +checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" dependencies = [ "serde_derive", ] @@ -3997,7 +3941,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_derive", ] @@ -4021,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a663f873dedc4eac1a559d4c6bc0d0b2c34dc5ac4702e105014b8281489e44f" dependencies = [ "ordered-float 1.1.1", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4031,14 +3975,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" dependencies = [ "ordered-float 2.7.0", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_derive" -version = "1.0.129" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e57ae87ad533d9a56427558b516d0adac283614e347abf85b0dc0cbbf0a249f3" +checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -4047,13 +3991,13 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "336b10da19a12ad094b59d870ebde26a45402e5b470add4b5fd03c5048a32127" +checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950" dependencies = [ "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -4085,26 +4029,26 @@ dependencies = [ "form_urlencoded", "itoa", "ryu", - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "serde_yaml" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" +checksum = "ad104641f3c958dab30eb3010e834c2622d1f3f4c530fef1dee20ad9485f3c09" dependencies = [ "dtoa", "indexmap", - "serde 1.0.129", + "serde 1.0.130", "yaml-rust", ] [[package]] name = "sha-1" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4115,9 +4059,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" +checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" dependencies = [ "block-buffer", "cfg-if 1.0.0", @@ -4208,7 +4152,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6142f7c25e94f6fd25a32c3348ec230df9109b463f59c8c7acc4bd34936babb7" dependencies = [ - "aes-gcm 0.9.3", + "aes-gcm 0.9.4", "blake2", "chacha20poly1305", "rand 0.8.4", @@ -4261,16 +4205,6 @@ dependencies = [ "futures 0.1.31", ] -[[package]] -name = "stream-cipher" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c80e15f898d8d8f25db24c253ea615cc14acf418ff307822995814e7d42cfa89" -dependencies = [ - "block-cipher", - "generic-array", -] - [[package]] name = "strsim" version = "0.8.0" @@ -4452,13 +4386,14 @@ dependencies = [ "strum", "strum_macros 0.19.4", "tari_common", + "tari_common_types", "tari_comms", "tari_core", "tari_crypto", "tari_p2p", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4482,6 +4417,7 @@ dependencies = [ "tari_app_grpc", "tari_app_utilities", "tari_common", + "tari_common_types", "tari_comms", "tari_comms_dht", "tari_core", @@ -4490,9 +4426,8 @@ dependencies = [ "tari_p2p", "tari_service_framework", "tari_shutdown", - "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4512,7 +4447,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rand_core 0.6.3", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "sha3", "subtle-ng", @@ -4530,12 +4465,12 @@ dependencies = [ "git2", "log 0.4.14", "log4rs 1.0.0", + "multiaddr", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", "path-clean", "prost-build", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "structopt", @@ -4552,11 +4487,13 @@ dependencies = [ name = "tari_common_types" version = "0.9.5" dependencies = [ + "digest", "futures 0.3.16", + "lazy_static 1.4.0", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "tari_crypto", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4567,7 +4504,7 @@ dependencies = [ "async-trait", "bitflags 1.3.2", "blake2", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "cidr", "clear_on_drop", @@ -4578,15 +4515,15 @@ dependencies = [ "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", + "multiaddr", "nom 5.1.2", "openssl", "opentelemetry", "opentelemetry-jaeger", - "parity-multiaddr", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "snow", @@ -4598,10 +4535,10 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-util 0.3.1", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.3.1", "tower-make", "tracing", "tracing-futures", @@ -4614,7 +4551,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bitflags 1.3.2", - "bytes 0.4.12", + "bytes 0.5.6", "chacha20", "chrono", "clap", @@ -4623,7 +4560,7 @@ dependencies = [ "digest", "env_logger 0.7.1", "futures 0.3.16", - "futures-test-preview", + "futures-test", "futures-util", "lazy_static 1.4.0", "libsqlite3-sys", @@ -4634,7 +4571,7 @@ dependencies = [ "prost", "prost-types", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_repr", "tari_common", @@ -4647,10 +4584,10 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tokio-test", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tokio-test 0.4.2", + "tower 0.3.1", "tower-test", "ttl_cache", ] @@ -4666,8 +4603,7 @@ dependencies = [ "syn 1.0.75", "tari_comms", "tari_test_utils", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tower-service", ] @@ -4693,6 +4629,7 @@ dependencies = [ "tari_app_grpc", "tari_app_utilities", "tari_common", + "tari_common_types", "tari_comms", "tari_comms_dht", "tari_core", @@ -4702,7 +4639,7 @@ dependencies = [ "tari_shutdown", "tari_wallet", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", "tracing", "tracing-opentelemetry", @@ -4719,7 +4656,7 @@ dependencies = [ "bincode", "bitflags 1.3.2", "blake2", - "bytes 0.4.12", + "bytes 0.5.6", "chrono", "config", "croaring", @@ -4728,17 +4665,18 @@ dependencies = [ "fs2", "futures 0.3.16", "hex", + "lazy_static 1.4.0", "lmdb-zero", "log 0.4.14", "monero", "newtype-ops", - "num", + "num 0.3.1", "num-format", "prost", "prost-types", "rand 0.8.4", "randomx-rs", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "strum_macros 0.17.1", @@ -4756,8 +4694,7 @@ dependencies = [ "tari_test_utils", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tracing", "tracing-attributes", "tracing-futures", @@ -4781,7 +4718,7 @@ dependencies = [ "merlin", "rand 0.8.4", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha2", "sha3", @@ -4806,7 +4743,7 @@ version = "0.9.5" dependencies = [ "digest", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "serde_json", "sha2", @@ -4820,7 +4757,7 @@ version = "0.9.5" dependencies = [ "anyhow", "bincode", - "bytes 0.5.6", + "bytes 1.1.0", "chrono", "config", "derive-error", @@ -4828,12 +4765,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.8.4", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -4843,8 +4780,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tracing", "tracing-futures", @@ -4868,7 +4804,7 @@ dependencies = [ "prost-types", "rand 0.8.4", "reqwest 0.11.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sha3", "tari_app_grpc", @@ -4878,7 +4814,7 @@ dependencies = [ "tari_crypto", "thiserror", "time", - "tokio 0.2.25", + "tokio 1.10.1", "tonic", ] @@ -4893,7 +4829,7 @@ dependencies = [ "digest", "log 0.4.14", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_crypto", "tari_infra_derive", @@ -4922,7 +4858,7 @@ dependencies = [ "rand 0.8.4", "reqwest 0.10.10", "semver 1.0.4", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "stream-cancel", "tari_common", @@ -4936,9 +4872,9 @@ dependencies = [ "tari_utilities", "tempfile", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tokio-stream", + "tower 0.3.1", "tower-service", "trust-dns-client", ] @@ -4955,9 +4891,8 @@ dependencies = [ "tari_shutdown", "tari_test_utils", "thiserror", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", "tower-service", ] @@ -4966,7 +4901,7 @@ name = "tari_shutdown" version = "0.9.5" dependencies = [ "futures 0.3.16", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -4974,14 +4909,14 @@ name = "tari_storage" version = "0.9.5" dependencies = [ "bincode", - "bytes 0.4.12", + "bytes 0.5.6", "env_logger 0.6.2", "lmdb-zero", "log 0.4.14", "rand 0.8.4", "rmp", "rmp-serde", - "serde 1.0.129", + "serde 1.0.130", "serde_derive", "tari_utilities", "thiserror", @@ -4993,7 +4928,7 @@ version = "0.0.1" dependencies = [ "hex", "libc", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_app_grpc", "tari_common", @@ -5017,12 +4952,12 @@ dependencies = [ "futures 0.3.16", "futures-test", "hex", - "hyper 0.13.10", + "hyper 0.14.12", "jsonrpc", "log 0.4.14", "rand 0.7.3", - "reqwest 0.10.10", - "serde 1.0.129", + "reqwest 0.11.4", + "serde 1.0.130", "serde_json", "structopt", "tari_app_grpc", @@ -5031,8 +4966,7 @@ dependencies = [ "tari_crypto", "tari_utilities", "thiserror", - "tokio 0.2.25", - "tokio-macros", + "tokio 1.10.1", "tonic", "tonic-build", "tracing", @@ -5051,7 +4985,7 @@ dependencies = [ "rand 0.8.4", "tari_shutdown", "tempfile", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5067,7 +5001,7 @@ dependencies = [ "clear_on_drop", "newtype-ops", "rand 0.7.3", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "thiserror", ] @@ -5087,14 +5021,13 @@ dependencies = [ "env_logger 0.7.1", "fs2", "futures 0.3.16", - "lazy_static 1.4.0", "libsqlite3-sys", "lmdb-zero", "log 0.4.14", "log4rs 1.0.0", "prost", "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "tari_common_types", "tari_comms", @@ -5110,9 +5043,8 @@ dependencies = [ "tempfile", "thiserror", "time", - "tokio 0.2.25", - "tokio-macros", - "tower", + "tokio 1.10.1", + "tower 0.3.1", ] [[package]] @@ -5140,7 +5072,7 @@ dependencies = [ "tari_wallet", "tempfile", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5171,12 +5103,13 @@ name = "test_faucet" version = "0.9.5" dependencies = [ "rand 0.8.4", - "serde 1.0.129", + "serde 1.0.130", "serde_json", + "tari_common_types", "tari_core", "tari_crypto", "tari_utilities", - "tokio 0.2.25", + "tokio 1.10.1", ] [[package]] @@ -5190,18 +5123,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "283d5230e63df9608ac7d9691adc1dfb6e701225436eb64d0b9a7f0a5a04f6ec" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "fa3884228611f5cd3608e2d409bf7dce832e4eb3135e3f11addbd7e41bd68e71" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5276,7 +5209,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "serde_json", ] @@ -5306,16 +5239,10 @@ dependencies = [ "futures-core", "iovec", "lazy_static 1.4.0", - "libc", "memchr", "mio 0.6.23", - "mio-uds", - "num_cpus", "pin-project-lite 0.1.12", - "signal-hook-registry", "slab", - "tokio-macros", - "winapi 0.3.9", ] [[package]] @@ -5325,20 +5252,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92036be488bb6594459f2e03b60e42df6f937fe6ca5c5ffdcb539c6b84dc40f5" dependencies = [ "autocfg 1.0.1", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", "mio 0.7.13", "num_cpus", + "once_cell", "pin-project-lite 0.2.7", + "signal-hook-registry", + "tokio-macros", "winapi 0.3.9", ] +[[package]] +name = "tokio-io-timeout" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90c49f106be240de154571dd31fbe48acb10ba6c6dd6f6517ad603abffa42de9" +dependencies = [ + "pin-project-lite 0.2.7", + "tokio 1.10.1", +] + [[package]] name = "tokio-macros" -version = "0.2.6" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" dependencies = [ "proc-macro2 1.0.28", "quote 1.0.9", @@ -5355,6 +5295,17 @@ dependencies = [ "tokio 1.10.1", ] +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio 1.10.1", + "webpki", +] + [[package]] name = "tokio-stream" version = "0.1.7" @@ -5364,6 +5315,7 @@ dependencies = [ "futures-core", "pin-project-lite 0.2.7", "tokio 1.10.1", + "tokio-util 0.6.7", ] [[package]] @@ -5377,6 +5329,19 @@ dependencies = [ "tokio 0.2.25", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes 1.1.0", + "futures-core", + "tokio 1.10.1", + "tokio-stream", +] + [[package]] name = "tokio-tls" version = "0.3.1" @@ -5407,8 +5372,9 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", + "futures-io", "futures-sink", "log 0.4.14", "pin-project-lite 0.2.7", @@ -5421,7 +5387,7 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] @@ -5430,34 +5396,35 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", ] [[package]] name = "tonic" -version = "0.2.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4afef9ce97ea39593992cf3fa00ff33b1ad5eb07665b31355df63a690e38c736" +checksum = "796c5e1cd49905e65dd8e700d4cb1dffcbfdb4fc9d017de08c1a537afd83627c" dependencies = [ "async-stream", "async-trait", - "base64 0.11.0", - "bytes 0.5.6", + "base64 0.13.0", + "bytes 1.1.0", "futures-core", "futures-util", + "h2 0.3.4", "http", - "http-body 0.3.1", - "hyper 0.13.10", + "http-body 0.4.3", + "hyper 0.14.12", + "hyper-timeout", "percent-encoding 2.1.0", - "pin-project 0.4.28", + "pin-project 1.0.8", "prost", "prost-derive", - "tokio 0.2.25", - "tokio-util 0.3.1", - "tower", - "tower-balance", - "tower-load", - "tower-make", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", + "tower 0.4.8", + "tower-layer", "tower-service", "tracing", "tracing-futures", @@ -5465,9 +5432,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.2.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d8d21cb568e802d77055ab7fcd43f0992206de5028de95c8d3a41118d32e8e" +checksum = "12b52d07035516c2b74337d2ac7746075e7dcae7643816c1b12c5ff8a7484c08" dependencies = [ "proc-macro2 1.0.28", "prost-build", @@ -5494,23 +5461,21 @@ dependencies = [ ] [[package]] -name = "tower-balance" -version = "0.3.0" +name = "tower" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a792277613b7052448851efcf98a2c433e6f1d01460832dc60bef676bc275d4c" +checksum = "f60422bc7fefa2f3ec70359b8ff1caff59d785877eb70595904605bcc412470f" dependencies = [ "futures-core", "futures-util", "indexmap", - "pin-project 0.4.28", - "rand 0.7.3", + "pin-project 1.0.8", + "rand 0.8.4", "slab", - "tokio 0.2.25", - "tower-discover", + "tokio 1.10.1", + "tokio-stream", + "tokio-util 0.6.7", "tower-layer", - "tower-load", - "tower-make", - "tower-ready-cache", "tower-service", "tracing", ] @@ -5592,21 +5557,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" dependencies = [ - "tokio 0.2.25", - "tower-service", -] - -[[package]] -name = "tower-ready-cache" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eabb6620e5481267e2ec832c780b31cad0c15dcb14ed825df5076b26b591e1f" -dependencies = [ - "futures-core", - "futures-util", - "indexmap", - "log 0.4.14", - "tokio 0.2.25", "tower-service", ] @@ -5638,7 +5588,7 @@ dependencies = [ "futures-util", "pin-project 0.4.28", "tokio 0.2.25", - "tokio-test", + "tokio-test 0.2.1", "tower-layer", "tower-service", ] @@ -5740,7 +5690,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb65ea441fbb84f9f6748fd496cf7f63ec9af5bca94dd86456978d055e8eb28b" dependencies = [ - "serde 1.0.129", + "serde 1.0.130", "tracing-core", ] @@ -5755,7 +5705,7 @@ dependencies = [ "lazy_static 1.4.0", "matchers", "regex", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "sharded-slab", "smallvec", @@ -5774,47 +5724,54 @@ checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" [[package]] name = "trust-dns-client" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e935ae5a26a2745fb5a6b95f0e206e1cfb7f00066892d2cf78a8fee87bc2e0c6" +checksum = "37532ce92c75c6174b1d51ed612e26c5fde66ef3f29aa10dbd84e7c5d9a0c27b" dependencies = [ "cfg-if 1.0.0", "chrono", "data-encoding", - "futures 0.3.16", + "futures-channel", + "futures-util", "lazy_static 1.4.0", "log 0.4.14", "radix_trie", - "rand 0.7.3", + "rand 0.8.4", "ring", "rustls", "thiserror", - "tokio 0.2.25", + "tokio 1.10.1", "trust-dns-proto", "webpki", ] [[package]] name = "trust-dns-proto" -version = "0.19.7" +version = "0.21.0-alpha.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cad71a0c0d68ab9941d2fb6e82f8fb2e86d9945b94e1661dd0aaea2b88215a9" +checksum = "4cd23117e93ea0e776abfd8a07c9e389d7ecd3377827858f21bd795ebdfefa36" dependencies = [ "async-trait", - "backtrace", "cfg-if 1.0.0", "data-encoding", "enum-as-inner", - "futures 0.3.16", + "futures-channel", + "futures-io", + "futures-util", "idna 0.2.3", + "ipnet", "lazy_static 1.4.0", "log 0.4.14", - "rand 0.7.3", + "rand 0.8.4", "ring", + "rustls", "smallvec", "thiserror", - "tokio 0.2.25", + "tinyvec", + "tokio 1.10.1", + "tokio-rustls", "url 2.2.2", + "webpki", ] [[package]] @@ -5856,12 +5813,12 @@ dependencies = [ [[package]] name = "twofish" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a30db256d7388f6e08efa0a8e9e62ee34dd1af59706c76c9e8c97c2a500f12" +checksum = "0028f5982f23ecc9a1bc3008ead4c664f843ed5d78acd3d213b99ff50c441bc2" dependencies = [ - "block-cipher", "byteorder", + "cipher 0.2.5", "opaque-debug", ] @@ -5988,15 +5945,9 @@ dependencies = [ [[package]] name = "unsigned-varint" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fdeedbf205afadfe39ae559b75c3240f24e257d0ca27e85f85cb82aa19ac35" - -[[package]] -name = "unsigned-varint" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35581ff83d4101e58b582e607120c7f5ffb17e632a980b1f38334d76b36908b2" +checksum = "5f8d425fafb8cd76bc3f22aace4af471d3156301d7508f2107e98fbeae10bc7f" [[package]] name = "untrusted" @@ -6097,7 +6048,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce9b1b516211d33767048e5d47fa2a381ed8b76fc48d2ce4aa39877f9f183e0" dependencies = [ "cfg-if 1.0.0", - "serde 1.0.129", + "serde 1.0.130", "serde_json", "wasm-bindgen-macro", ] @@ -6187,6 +6138,17 @@ dependencies = [ "libc", ] +[[package]] +name = "which" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea187a8ef279bc014ec368c27a920da2024d2a711109bfbe3440585d5cf27ad9" +dependencies = [ + "either", + "lazy_static 1.4.0", + "libc", +] + [[package]] name = "winapi" version = "0.2.8" @@ -6278,7 +6240,7 @@ dependencies = [ "futures 0.3.16", "log 0.4.14", "nohash-hasher", - "parking_lot 0.11.1", + "parking_lot 0.11.2", "rand 0.8.4", "static_assertions", ] diff --git a/Cargo.toml b/Cargo.toml index 2394f2eb9c..ac268c2585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,6 @@ members = [ "applications/tari_stratum_transcoder", "applications/tari_mining_node", ] +# +#[profile.release] +#debug = true diff --git a/applications/ffi_client/index.js b/applications/ffi_client/index.js index bf1a47d3c2..aa9473740b 100644 --- a/applications/ffi_client/index.js +++ b/applications/ffi_client/index.js @@ -1,6 +1,8 @@ // this is nasty // ¯\_(ツ)_/¯ +// TODO: Use implementation in cucumber tests instead (see helpers/ffi). + const lib = require("./lib"); const ref = require("ref-napi"); const ffi = require("ffi-napi"); diff --git a/applications/tari_app_grpc/Cargo.toml b/applications/tari_app_grpc/Cargo.toml index bec7e05720..8a40dba4ee 100644 --- a/applications/tari_app_grpc/Cargo.toml +++ b/applications/tari_app_grpc/Cargo.toml @@ -10,14 +10,17 @@ edition = "2018" [dependencies] tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_core = { path = "../../base_layer/core"} -tari_wallet = { path = "../../base_layer/wallet"} +tari_wallet = { path = "../../base_layer/wallet", optional = true} tari_crypto = "0.11.1" tari_comms = { path = "../../comms"} chrono = "0.4.6" -prost = "0.6" -prost-types = "0.6.1" -tonic = "0.2" +prost = "0.8" +prost-types = "0.8" +tonic = "0.5.2" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" + +[features] +wallet = ["tari_wallet"] \ No newline at end of file diff --git a/applications/tari_app_grpc/src/conversions/block_header.rs b/applications/tari_app_grpc/src/conversions/block_header.rs index 3b660f21e1..5bb37e87e3 100644 --- a/applications/tari_app_grpc/src/conversions/block_header.rs +++ b/applications/tari_app_grpc/src/conversions/block_header.rs @@ -25,7 +25,8 @@ use crate::{ tari_rpc as grpc, }; use std::convert::TryFrom; -use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork, transactions::types::BlindingFactor}; +use tari_common_types::types::BlindingFactor; +use tari_core::{blocks::BlockHeader, proof_of_work::ProofOfWork}; use tari_crypto::tari_utilities::{ByteArray, Hashable}; impl From for grpc::BlockHeader { diff --git a/applications/tari_app_grpc/src/conversions/com_signature.rs b/applications/tari_app_grpc/src/conversions/com_signature.rs index e10e48ffe8..1924e1c054 100644 --- a/applications/tari_app_grpc/src/conversions/com_signature.rs +++ b/applications/tari_app_grpc/src/conversions/com_signature.rs @@ -21,10 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; -use tari_core::transactions::types::{ComSignature, Commitment, PrivateKey}; +use tari_common_types::types::{ComSignature, Commitment, PrivateKey}; impl TryFrom for ComSignature { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/mod.rs b/applications/tari_app_grpc/src/conversions/mod.rs index e404d52d1a..f48a1b876d 100644 --- a/applications/tari_app_grpc/src/conversions/mod.rs +++ b/applications/tari_app_grpc/src/conversions/mod.rs @@ -59,7 +59,7 @@ pub use self::{ use crate::{tari_rpc as grpc, tari_rpc::BlockGroupRequest}; use prost_types::Timestamp; -use tari_crypto::tari_utilities::epoch_time::EpochTime; +use tari_core::crypto::tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `EpochTime` to a `prost::Timestamp` pub fn datetime_to_timestamp(datetime: EpochTime) -> Timestamp { diff --git a/applications/tari_app_grpc/src/conversions/new_block_template.rs b/applications/tari_app_grpc/src/conversions/new_block_template.rs index 7d87900cd1..15c41c499b 100644 --- a/applications/tari_app_grpc/src/conversions/new_block_template.rs +++ b/applications/tari_app_grpc/src/conversions/new_block_template.rs @@ -22,12 +22,12 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::BlindingFactor; use tari_core::{ blocks::{NewBlockHeaderTemplate, NewBlockTemplate}, + crypto::tari_utilities::ByteArray, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; -use tari_crypto::tari_utilities::ByteArray; impl From for grpc::NewBlockTemplate { fn from(block: NewBlockTemplate) -> Self { let header = grpc::NewBlockHeaderTemplate { diff --git a/applications/tari_app_grpc/src/conversions/peer.rs b/applications/tari_app_grpc/src/conversions/peer.rs index f04d3fd8dd..f3bd151d0d 100644 --- a/applications/tari_app_grpc/src/conversions/peer.rs +++ b/applications/tari_app_grpc/src/conversions/peer.rs @@ -22,7 +22,7 @@ use crate::{conversions::datetime_to_timestamp, tari_rpc as grpc}; use tari_comms::{connectivity::ConnectivityStatus, net_address::MutliaddrWithStats, peer_manager::Peer}; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; impl From for grpc::Peer { fn from(peer: Peer) -> Self { diff --git a/applications/tari_app_grpc/src/conversions/signature.rs b/applications/tari_app_grpc/src/conversions/signature.rs index d9883a338e..2f0fe10cd3 100644 --- a/applications/tari_app_grpc/src/conversions/signature.rs +++ b/applications/tari_app_grpc/src/conversions/signature.rs @@ -21,10 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::convert::TryFrom; -use tari_crypto::tari_utilities::ByteArray; +use tari_core::crypto::tari_utilities::ByteArray; use crate::tari_rpc as grpc; -use tari_core::transactions::types::{PrivateKey, PublicKey, Signature}; +use tari_common_types::types::{PrivateKey, PublicKey, Signature}; impl TryFrom for Signature { type Error = String; diff --git a/applications/tari_app_grpc/src/conversions/transaction.rs b/applications/tari_app_grpc/src/conversions/transaction.rs index 97dcaf8795..cb9e91f63a 100644 --- a/applications/tari_app_grpc/src/conversions/transaction.rs +++ b/applications/tari_app_grpc/src/conversions/transaction.rs @@ -22,9 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::transaction::Transaction; -use tari_crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}; -use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; +use tari_core::{ + crypto::{ristretto::RistrettoSecretKey, tari_utilities::ByteArray}, + transactions::transaction::Transaction, +}; impl From for grpc::Transaction { fn from(source: Transaction) -> Self { @@ -53,38 +54,44 @@ impl TryFrom for Transaction { } } -impl From for grpc::TransactionStatus { - fn from(status: models::TransactionStatus) -> Self { - use models::TransactionStatus::*; - match status { - Completed => grpc::TransactionStatus::Completed, - Broadcast => grpc::TransactionStatus::Broadcast, - MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, - MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, - Imported => grpc::TransactionStatus::Imported, - Pending => grpc::TransactionStatus::Pending, - Coinbase => grpc::TransactionStatus::Coinbase, +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{output_manager_service::TxId, transaction_service::storage::models}; + + impl From for grpc::TransactionStatus { + fn from(status: models::TransactionStatus) -> Self { + use models::TransactionStatus::*; + match status { + Completed => grpc::TransactionStatus::Completed, + Broadcast => grpc::TransactionStatus::Broadcast, + MinedUnconfirmed => grpc::TransactionStatus::MinedUnconfirmed, + MinedConfirmed => grpc::TransactionStatus::MinedConfirmed, + Imported => grpc::TransactionStatus::Imported, + Pending => grpc::TransactionStatus::Pending, + Coinbase => grpc::TransactionStatus::Coinbase, + } } } -} -impl From for grpc::TransactionDirection { - fn from(status: models::TransactionDirection) -> Self { - use models::TransactionDirection::*; - match status { - Unknown => grpc::TransactionDirection::Unknown, - Inbound => grpc::TransactionDirection::Inbound, - Outbound => grpc::TransactionDirection::Outbound, + impl From for grpc::TransactionDirection { + fn from(status: models::TransactionDirection) -> Self { + use models::TransactionDirection::*; + match status { + Unknown => grpc::TransactionDirection::Unknown, + Inbound => grpc::TransactionDirection::Inbound, + Outbound => grpc::TransactionDirection::Outbound, + } } } -} -impl grpc::TransactionInfo { - pub fn not_found(tx_id: TxId) -> Self { - Self { - tx_id, - status: grpc::TransactionStatus::NotFound as i32, - ..Default::default() + impl grpc::TransactionInfo { + pub fn not_found(tx_id: TxId) -> Self { + Self { + tx_id, + status: grpc::TransactionStatus::NotFound as i32, + ..Default::default() + } } } } diff --git a/applications/tari_app_grpc/src/conversions/transaction_input.rs b/applications/tari_app_grpc/src/conversions/transaction_input.rs index ed5793a2f2..48eebe04ad 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_input.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_input.rs @@ -22,10 +22,8 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - transaction::TransactionInput, - types::{Commitment, PublicKey}, -}; +use tari_common_types::types::{Commitment, PublicKey}; +use tari_core::transactions::transaction::TransactionInput; use tari_crypto::{ script::{ExecutionStack, TariScript}, tari_utilities::{ByteArray, Hashable}, diff --git a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs index e394a6bce5..7bf8664487 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_kernel.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_kernel.rs @@ -22,10 +22,10 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::Commitment; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{KernelFeatures, TransactionKernel}, - types::Commitment, }; use tari_crypto::tari_utilities::{ByteArray, Hashable}; diff --git a/applications/tari_app_grpc/src/conversions/transaction_output.rs b/applications/tari_app_grpc/src/conversions/transaction_output.rs index b9556b2940..7b783e3498 100644 --- a/applications/tari_app_grpc/src/conversions/transaction_output.rs +++ b/applications/tari_app_grpc/src/conversions/transaction_output.rs @@ -22,14 +22,13 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - bullet_rangeproofs::BulletRangeProof, - transaction::TransactionOutput, - types::{Commitment, PublicKey}, -}; -use tari_crypto::{ - script::TariScript, - tari_utilities::{ByteArray, Hashable}, +use tari_common_types::types::{BulletRangeProof, Commitment, PublicKey}; +use tari_core::{ + crypto::{ + script::TariScript, + tari_utilities::{ByteArray, Hashable}, + }, + transactions::transaction::TransactionOutput, }; impl TryFrom for TransactionOutput { diff --git a/applications/tari_app_grpc/src/conversions/unblinded_output.rs b/applications/tari_app_grpc/src/conversions/unblinded_output.rs index 94ac4c178d..bf9efa58bc 100644 --- a/applications/tari_app_grpc/src/conversions/unblinded_output.rs +++ b/applications/tari_app_grpc/src/conversions/unblinded_output.rs @@ -22,14 +22,13 @@ use crate::tari_rpc as grpc; use std::convert::{TryFrom, TryInto}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::UnblindedOutput, - types::{PrivateKey, PublicKey}, -}; -use tari_crypto::{ - script::{ExecutionStack, TariScript}, - tari_utilities::ByteArray, +use tari_common_types::types::{PrivateKey, PublicKey}; +use tari_core::{ + crypto::{ + script::{ExecutionStack, TariScript}, + tari_utilities::ByteArray, + }, + transactions::{tari_amount::MicroTari, transaction::UnblindedOutput}, }; impl From for grpc::UnblindedOutput { diff --git a/applications/tari_app_utilities/Cargo.toml b/applications/tari_app_utilities/Cargo.toml index 333af5f959..5fdae3d097 100644 --- a/applications/tari_app_utilities/Cargo.toml +++ b/applications/tari_app_utilities/Cargo.toml @@ -8,22 +8,23 @@ edition = "2018" tari_comms = { path = "../../comms"} tari_crypto = "0.11.1" tari_common = { path = "../../common" } +tari_common_types ={ path ="../../base_layer/common_types"} tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", optional = true } config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} qrcode = { version = "0.12" } dirs-next = "1.0.2" serde_json = "1.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version="^1.10", features = ["signal"] } structopt = { version = "0.3.13", default_features = false } strum = "^0.19" strum_macros = "^0.19" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" [dependencies.tari_core] path = "../../base_layer/core" @@ -33,3 +34,7 @@ features = ["transactions"] [build-dependencies] tari_common = { path = "../../common", features = ["build", "static-application-info"] } + +[features] +# TODO: This crate is supposed to hold common logic. Move code from this feature into the crate that is more specific to the wallet +wallet = ["tari_wallet"] diff --git a/applications/tari_app_utilities/src/identity_management.rs b/applications/tari_app_utilities/src/identity_management.rs index 40bfcf5cee..013a0e8fc8 100644 --- a/applications/tari_app_utilities/src/identity_management.rs +++ b/applications/tari_app_utilities/src/identity_management.rs @@ -25,8 +25,8 @@ use log::*; use rand::rngs::OsRng; use std::{clone::Clone, fs, path::Path, string::ToString, sync::Arc}; use tari_common::configuration::bootstrap::prompt; +use tari_common_types::types::PrivateKey; use tari_comms::{multiaddr::Multiaddr, peer_manager::PeerFeatures, NodeIdentity}; -use tari_core::transactions::types::PrivateKey; use tari_crypto::{ keys::SecretKey, tari_utilities::{hex::Hex, message_format::MessageFormat}, @@ -55,14 +55,19 @@ pub fn setup_node_identity>( if !create_id { let prompt = prompt("Node identity does not exist.\nWould you like to to create one (Y/n)?"); if !prompt { - let msg = format!( + error!( + target: LOG_TARGET, "Node identity information not found. {}. You can update the configuration file to point to a \ valid node identity file, or re-run the node with the --create-id flag to create a new \ identity.", e ); - error!(target: LOG_TARGET, "{}", msg); - return Err(ExitCodes::ConfigError(msg)); + return Err(ExitCodes::ConfigError(format!( + "Node identity information not found. {}. You can update the configuration file to point to a \ + valid node identity file, or re-run the node with the --create-id flag to create a new \ + identity.", + e + ))); }; } diff --git a/applications/tari_app_utilities/src/initialization.rs b/applications/tari_app_utilities/src/initialization.rs index ad66210437..2497788307 100644 --- a/applications/tari_app_utilities/src/initialization.rs +++ b/applications/tari_app_utilities/src/initialization.rs @@ -1,8 +1,13 @@ use crate::{consts, utilities::ExitCodes}; use config::Config; -use std::path::PathBuf; +use std::{path::PathBuf, str::FromStr}; use structopt::StructOpt; -use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DatabaseType, GlobalConfig}; +use tari_common::{ + configuration::{bootstrap::ApplicationType, Network}, + ConfigBootstrap, + DatabaseType, + GlobalConfig, +}; pub const LOG_TARGET: &str = "tari::application"; @@ -27,6 +32,28 @@ pub fn init_configuration( let mut global_config = GlobalConfig::convert_from(application_type, cfg.clone()) .map_err(|err| ExitCodes::ConfigError(err.to_string()))?; check_file_paths(&mut global_config, &bootstrap); + + if let Some(str) = bootstrap.network.clone() { + log::info!(target: LOG_TARGET, "Network selection requested"); + let network = Network::from_str(&str); + match network { + Ok(network) => { + log::info!( + target: LOG_TARGET, + "Network selection successful, current network is: {}", + network + ); + global_config.network = network; + }, + Err(_) => { + log::warn!( + target: LOG_TARGET, + "Network selection was invalid, continuing with default network." + ); + }, + } + } + Ok((bootstrap, global_config, cfg)) } diff --git a/applications/tari_app_utilities/src/utilities.rs b/applications/tari_app_utilities/src/utilities.rs index d963f212a7..23a1ebf9ab 100644 --- a/applications/tari_app_utilities/src/utilities.rs +++ b/applications/tari_app_utilities/src/utilities.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::identity_management::load_from_json; use futures::future::Either; use log::*; +use thiserror::Error; +use tokio::{runtime, runtime::Runtime}; + use tari_common::{CommsTransport, GlobalConfig, SocksAuthentication, TorControlAuthentication}; use tari_comms::{ connectivity::ConnectivityError, @@ -37,13 +39,9 @@ use tari_comms::{ }; use tari_core::tari_utilities::hex::Hex; use tari_p2p::transport::{TorConfig, TransportType}; -use tari_wallet::{ - error::{WalletError, WalletStorageError}, - output_manager_service::error::OutputManagerError, - util::emoji::EmojiId, -}; -use thiserror::Error; -use tokio::{runtime, runtime::Runtime}; + +use crate::identity_management::load_from_json; +use tari_common_types::emoji::EmojiId; pub const LOG_TARGET: &str = "tari::application"; @@ -107,20 +105,6 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - -impl From for ExitCodes { - fn from(err: OutputManagerError) -> Self { - error!(target: LOG_TARGET, "{}", err); - Self::WalletError(err.to_string()) - } -} - impl From for ExitCodes { fn from(err: ConnectivityError) -> Self { error!(target: LOG_TARGET, "{}", err); @@ -135,13 +119,36 @@ impl From for ExitCodes { } } -impl From for ExitCodes { - fn from(err: WalletStorageError) -> Self { - use WalletStorageError::*; - match err { - NoPasswordError => ExitCodes::NoPassword, - IncorrectPassword => ExitCodes::IncorrectPassword, - e => ExitCodes::WalletError(e.to_string()), +#[cfg(feature = "wallet")] +mod wallet { + use super::*; + use tari_wallet::{ + error::{WalletError, WalletStorageError}, + output_manager_service::error::OutputManagerError, + }; + + impl From for ExitCodes { + fn from(err: WalletError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: OutputManagerError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::WalletError(err.to_string()) + } + } + + impl From for ExitCodes { + fn from(err: WalletStorageError) -> Self { + use WalletStorageError::*; + match err { + NoPasswordError => ExitCodes::NoPassword, + IncorrectPassword => ExitCodes::IncorrectPassword, + e => ExitCodes::WalletError(e.to_string()), + } } } } @@ -259,26 +266,22 @@ pub fn convert_socks_authentication(auth: SocksAuthentication) -> socks::Authent /// ## Returns /// A result containing the runtime on success, string indicating the error on failure pub fn setup_runtime(config: &GlobalConfig) -> Result { - info!( - target: LOG_TARGET, - "Configuring the node to run on up to {} core threads and {} mining threads.", - config.max_threads.unwrap_or(512), - config.num_mining_threads - ); - - let mut builder = runtime::Builder::new(); + let mut builder = runtime::Builder::new_multi_thread(); - if let Some(max_threads) = config.max_threads { - // Ensure that there are always enough threads for mining. - // e.g if the user sets max_threads = 2, mining_threads = 5 then 7 threads are available in total - builder.max_threads(max_threads + config.num_mining_threads); - } if let Some(core_threads) = config.core_threads { - builder.core_threads(core_threads); + info!( + target: LOG_TARGET, + "Configuring the node to run on up to {} core threads.", + config + .core_threads + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "".to_string()), + ); + builder.worker_threads(core_threads); } builder - .threaded_scheduler() .enable_all() .build() .map_err(|e| format!("There was an error while building the node runtime. {}", e.to_string())) diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index db832f8a9a..b41db079ef 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -11,30 +11,30 @@ edition = "2018" tari_app_grpc = { path = "../tari_app_grpc" } tari_app_utilities = { path = "../tari_app_utilities" } tari_common = { path = "../../common" } -tari_comms = { path = "../../comms", features = ["rpc"]} -tari_comms_dht = { path = "../../comms/dht"} -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_comms = { path = "../../comms", features = ["rpc"] } +tari_common_types = {path = "../../base_layer/common_types"} +tari_comms_dht = { path = "../../comms/dht" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_mmr = { path = "../../base_layer/mmr" } tari_p2p = { path = "../../base_layer/p2p", features = ["auto-update"] } -tari_service_framework = { path = "../../base_layer/service_framework"} -tari_shutdown = { path = "../../infrastructure/shutdown"} -tari_wallet = { path = "../../base_layer/wallet" } +tari_service_framework = { path = "../../base_layer/service_framework" } +tari_shutdown = { path = "../../infrastructure/shutdown" } anyhow = "1.0.32" bincode = "1.3.1" chrono = "0.4" config = { version = "0.9.3" } -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"] } log = { version = "0.4.8", features = ["std"] } regex = "1" rustyline = "6.0" rustyline-derive = "0.3" -tokio = { version="0.2.10", features = ["signal"] } +tokio = { version = "^1.10", features = ["signal"] } strum = "^0.19" strum_macros = "0.18.0" -thiserror = "^1.0.20" -tonic = "0.2" +thiserror = "^1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" tracing-subscriber = "0.2.20" @@ -44,7 +44,7 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} [features] -avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_wallet/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] +avx2 = ["tari_core/avx2", "tari_crypto/avx2", "tari_p2p/avx2", "tari_comms/avx2", "tari_comms_dht/avx2"] safe = [] diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index cd10870e6d..7cb00ce527 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{cmp, fs, str::FromStr, sync::Arc, time::Duration}; + use anyhow::anyhow; use log::*; -use std::{cmp, fs, str::FromStr, sync::Arc, time::Duration}; + use tari_app_utilities::{consts, identity_management, utilities::create_transport_type}; use tari_common::{configuration::bootstrap::ApplicationType, GlobalConfig}; use tari_comms::{peer_manager::Peer, protocol::rpc::RpcServer, NodeIdentity, UnspawnedCommsNode}; @@ -47,7 +49,7 @@ use tari_core::{ MempoolServiceInitializer, MempoolSyncInitializer, }, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; use tari_p2p::{ auto_update::{AutoUpdateConfig, SoftwareUpdaterService}, @@ -59,7 +61,6 @@ use tari_p2p::{ }; use tari_service_framework::{ServiceHandles, StackBuilder}; use tari_shutdown::ShutdownSignal; -use tokio::runtime; const LOG_TARGET: &str = "c::bn::initialization"; /// The minimum buffer size for the base node pubsub_connector channel @@ -84,8 +85,7 @@ where B: BlockchainBackend + 'static fs::create_dir_all(&config.peer_db_path)?; let buf_size = cmp::max(BASE_NODE_BUFFER_MIN_SIZE, config.buffer_size_base_node); - let (publisher, peer_message_subscriptions) = - pubsub_connector(runtime::Handle::current(), buf_size, config.buffer_rate_limit_base_node); + let (publisher, peer_message_subscriptions) = pubsub_connector(buf_size, config.buffer_rate_limit_base_node); let peer_message_subscriptions = Arc::new(peer_message_subscriptions); let node_config = BaseNodeServiceConfig { diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index dc36f64f59..ee374b339e 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::bootstrap::BaseNodeBootstrapper; -use log::*; use std::sync::Arc; + +use log::*; +use tokio::sync::watch; + use tari_common::{configuration::Network, DatabaseType, GlobalConfig}; use tari_comms::{peer_manager::NodeIdentity, protocol::rpc::RpcServerHandle, CommsNode}; use tari_comms_dht::Dht; @@ -32,7 +34,7 @@ use tari_core::{ consensus::ConsensusManager, mempool::{service::LocalMempoolService, Mempool, MempoolConfig}, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, header_validator::HeaderValidator, @@ -48,7 +50,8 @@ use tari_core::{ use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; use tari_service_framework::ServiceHandles; use tari_shutdown::ShutdownSignal; -use tokio::sync::watch; + +use crate::bootstrap::BaseNodeBootstrapper; const LOG_TARGET: &str = "c::bn::initialization"; @@ -71,10 +74,8 @@ impl BaseNodeContext { pub async fn run(self) { info!(target: LOG_TARGET, "Tari base node has STARTED"); - if let Err(e) = self.state_machine().shutdown_signal().await { - warn!(target: LOG_TARGET, "Error shutting down Base Node State Machine: {}", e); - } - info!(target: LOG_TARGET, "Initiating communications stack shutdown"); + self.state_machine().shutdown_signal().wait().await; + info!(target: LOG_TARGET, "Waiting for communications stack shutdown"); self.base_node_comms.wait_until_shutdown().await; info!(target: LOG_TARGET, "Communications stack has shutdown"); @@ -222,7 +223,11 @@ async fn build_node_context( let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories.clone()), + OrphanBlockValidator::new( + rules.clone(), + config.base_node_bypass_range_proof_verification, + factories.clone(), + ), ); let db_config = BlockchainDatabaseConfig { orphan_storage_capacity: config.orphan_storage_capacity, @@ -238,7 +243,10 @@ async fn build_node_context( cleanup_orphans_at_startup, )?; let mempool_validator = MempoolValidator::new(vec![ - Box::new(TxInternalConsistencyValidator::new(factories.clone())), + Box::new(TxInternalConsistencyValidator::new( + factories.clone(), + config.base_node_bypass_range_proof_verification, + )), Box::new(TxInputAndMaturityValidator::new(blockchain_db.clone())), Box::new(TxConsensusValidator::new(blockchain_db.clone())), ]); diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index 114cee7ad2..d626c4fa0a 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -34,6 +34,10 @@ use std::{ }; use tari_app_utilities::consts; use tari_common::GlobalConfig; +use tari_common_types::{ + emoji::EmojiId, + types::{Commitment, HashOutput, Signature}, +}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer, PeerFeatures, PeerManager, PeerManagerError, PeerQuery}, @@ -53,11 +57,9 @@ use tari_core::{ mempool::service::LocalMempoolService, proof_of_work::PowAlgorithm, tari_utilities::{hex::Hex, message_format::MessageFormat}, - transactions::types::{Commitment, HashOutput, Signature}, }; use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::Hashable}; use tari_p2p::auto_update::SoftwareUpdaterHandle; -use tari_wallet::util::emoji::EmojiId; use tokio::{runtime, sync::watch}; pub enum StatusOutput { @@ -101,7 +103,7 @@ impl CommandHandler { } pub fn status(&self, output: StatusOutput) { - let mut state_info = self.state_machine_info.clone(); + let state_info = self.state_machine_info.clone(); let mut node = self.node_service.clone(); let mut mempool = self.mempool_service.clone(); let peer_manager = self.peer_manager.clone(); @@ -114,9 +116,9 @@ impl CommandHandler { let mut status_line = StatusLine::new(); let version = format!("v{}", consts::APP_VERSION_NUMBER); status_line.add_field("", version); - - let state = state_info.recv().await.unwrap(); - status_line.add_field("State", state.state_info.short_desc()); + let network = format!("{}", config.network); + status_line.add_field("", network); + status_line.add_field("State", state_info.borrow().state_info.short_desc()); let metadata = node.get_metadata().await.unwrap(); @@ -189,18 +191,8 @@ impl CommandHandler { /// Function to process the get-state-info command pub fn state_info(&self) { - let mut channel = self.state_machine_info.clone(); - self.executor.spawn(async move { - match channel.recv().await { - None => { - info!( - target: LOG_TARGET, - "Error communicating with state machine, channel could have been closed" - ); - }, - Some(data) => println!("Current state machine state:\n{}", data), - }; - }); + let watch = self.state_machine_info.clone(); + println!("Current state machine state:\n{}", *watch.borrow()); } /// Check for updates diff --git a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs index 00474a5cb4..e3a19ef185 100644 --- a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs +++ b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs @@ -26,6 +26,7 @@ use crate::{ helpers::{mean, median}, }, }; +use futures::{channel::mpsc, SinkExt}; use log::*; use std::{ cmp, @@ -36,11 +37,11 @@ use tari_app_grpc::{ tari_rpc::{CalcType, Sorting}, }; use tari_app_utilities::consts; +use tari_common_types::types::Signature; use tari_comms::{Bytes, CommsNode}; use tari_core::{ base_node::{ comms_interface::{Broadcast, CommsInterfaceError}, - state_machine_service::states::BlockSyncInfo, LocalNodeCommsInterface, StateMachineHandle, }, @@ -50,11 +51,11 @@ use tari_core::{ crypto::tari_utilities::{hex::Hex, ByteArray}, mempool::{service::LocalMempoolService, TxStorageResponse}, proof_of_work::PowAlgorithm, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use tari_crypto::tari_utilities::{message_format::MessageFormat, Hashable}; use tari_p2p::{auto_update::SoftwareUpdaterHandle, services::liveness::LivenessHandle}; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "tari::base_node::grpc"; @@ -995,32 +996,25 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer { ) -> Result, Status> { debug!(target: LOG_TARGET, "Incoming GRPC request for BN sync data"); - let mut channel = self.state_machine_handle.get_status_info_watch(); - - let mut sync_info: Option = None; - - if let Some(info) = channel.recv().await { - sync_info = info.state_info.get_block_sync_info(); - } - - let mut response = tari_rpc::SyncInfoResponse { - tip_height: 0, - local_height: 0, - peer_node_id: vec![], - }; - - if let Some(info) = sync_info { - let node_ids = info - .sync_peers - .iter() - .map(|x| x.to_string().as_bytes().to_vec()) - .collect(); - response = tari_rpc::SyncInfoResponse { - tip_height: info.tip_height, - local_height: info.local_height, - peer_node_id: node_ids, - }; - } + let response = self + .state_machine_handle + .get_status_info_watch() + .borrow() + .state_info + .get_block_sync_info() + .map(|info| { + let node_ids = info + .sync_peers + .iter() + .map(|x| x.to_string().as_bytes().to_vec()) + .collect(); + tari_rpc::SyncInfoResponse { + tip_height: info.tip_height, + local_height: info.local_height, + peer_node_id: node_ids, + } + }) + .unwrap_or_default(); debug!(target: LOG_TARGET, "Sending SyncData response to client"); Ok(Response::new(response)) diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index 8ef2c9b68a..cb9daf1904 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -96,7 +96,7 @@ mod status_line; mod utils; use crate::command_handler::{CommandHandler, StatusOutput}; -use futures::{future::Fuse, pin_mut, FutureExt}; +use futures::{pin_mut, FutureExt}; use log::*; use opentelemetry::{self, global, KeyValue}; use parser::Parser; @@ -119,7 +119,7 @@ use tari_shutdown::{Shutdown, ShutdownSignal}; use tokio::{ runtime, task, - time::{self, Delay}, + time::{self}, }; use tonic::transport::Server; use tracing_subscriber::{layer::SubscriberExt, Registry}; @@ -145,7 +145,7 @@ fn main_inner() -> Result<(), ExitCodes> { debug!(target: LOG_TARGET, "Using configuration: {:?}", node_config); // Set up the Tokio runtime - let mut rt = setup_runtime(&node_config).map_err(|e| { + let rt = setup_runtime(&node_config).map_err(|e| { error!(target: LOG_TARGET, "{}", e); ExitCodes::UnknownError })?; @@ -320,26 +320,28 @@ async fn read_command(mut rustyline: Editor) -> Result<(String, Editor

Fuse { +fn status_interval(start_time: Instant) -> time::Sleep { let duration = match start_time.elapsed().as_secs() { 0..=120 => Duration::from_secs(5), _ => Duration::from_secs(30), }; - time::delay_for(duration).fuse() + time::sleep(duration) } async fn status_loop(command_handler: Arc, shutdown: Shutdown) { let start_time = Instant::now(); let mut shutdown_signal = shutdown.to_signal(); loop { - let mut interval = status_interval(start_time); - futures::select! { + let interval = status_interval(start_time); + tokio::select! { + biased; + _ = shutdown_signal.wait() => { + break; + } + _ = interval => { command_handler.status(StatusOutput::Log); }, - _ = shutdown_signal => { - break; - } } } } @@ -368,9 +370,9 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { let start_time = Instant::now(); let mut software_update_notif = command_handler.get_software_updater().new_update_notifier().clone(); loop { - let mut interval = status_interval(start_time); - futures::select! { - res = read_command_fut => { + let interval = status_interval(start_time); + tokio::select! { + res = &mut read_command_fut => { match res { Ok((line, mut rustyline)) => { if let Some(p) = rustyline.helper_mut().as_deref_mut() { @@ -387,8 +389,8 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { } } }, - resp = software_update_notif.recv().fuse() => { - if let Some(Some(update)) = resp { + Ok(_) = software_update_notif.changed() => { + if let Some(ref update) = *software_update_notif.borrow() { println!( "Version {} of the {} is available: {} (sha: {})", update.version(), @@ -401,7 +403,7 @@ async fn cli_loop(parser: Parser, mut shutdown: Shutdown) { _ = interval => { command_handler.status(StatusOutput::Full); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { break; } } diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index f280019ee0..79fdf27efc 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -40,12 +40,8 @@ use tari_app_utilities::utilities::{ parse_emoji_id_or_public_key, parse_emoji_id_or_public_key_or_node_id, }; -use tari_core::{ - crypto::tari_utilities::hex::from_hex, - proof_of_work::PowAlgorithm, - tari_utilities::hex::Hex, - transactions::types::{Commitment, PrivateKey, PublicKey, Signature}, -}; +use tari_common_types::types::{Commitment, PrivateKey, PublicKey, Signature}; +use tari_core::{crypto::tari_utilities::hex::from_hex, proof_of_work::PowAlgorithm, tari_utilities::hex::Hex}; use tari_shutdown::Shutdown; /// Enum representing commands used by the basenode diff --git a/applications/tari_base_node/src/recovery.rs b/applications/tari_base_node/src/recovery.rs index ce62fe36a3..8e0e0786c7 100644 --- a/applications/tari_base_node/src/recovery.rs +++ b/applications/tari_base_node/src/recovery.rs @@ -21,14 +21,16 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use anyhow::anyhow; -use log::*; use std::{ fs, io::{self, Write}, path::Path, sync::Arc, }; + +use anyhow::anyhow; +use log::*; + use tari_app_utilities::utilities::ExitCodes; use tari_common::{configuration::Network, DatabaseType, GlobalConfig}; use tari_core::{ @@ -43,7 +45,7 @@ use tari_core::{ }, consensus::ConsensusManager, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, header_validator::HeaderValidator, @@ -98,7 +100,11 @@ pub async fn run_recovery(node_config: &GlobalConfig) -> Result<(), anyhow::Erro let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories.clone()), + OrphanBlockValidator::new( + rules.clone(), + node_config.base_node_bypass_range_proof_verification, + factories.clone(), + ), ); let db_config = BlockchainDatabaseConfig { orphan_storage_capacity: node_config.orphan_storage_capacity, @@ -173,12 +179,12 @@ async fn do_recovery( db.add_block(Arc::new(block)) .await .map_err(|e| anyhow!("Stopped recovery at height {}, reason: {}", counter, e))?; - counter += 1; - if counter > max_height { - info!(target: LOG_TARGET, "Done with recovery, chain height {}", counter - 1); + if counter >= max_height { + info!(target: LOG_TARGET, "Done with recovery, chain height {}", counter); break; } - print!("\x1B[{}D\x1B[K", (counter + 1).to_string().chars().count()); + print!("\x1B[{}D\x1B[K", counter.to_string().len()); + counter += 1; } Ok(()) } diff --git a/applications/tari_console_wallet/Cargo.toml b/applications/tari_console_wallet/Cargo.toml index 5163910d37..17bddfe336 100644 --- a/applications/tari_console_wallet/Cargo.toml +++ b/applications/tari_console_wallet/Cargo.toml @@ -5,21 +5,22 @@ authors = ["The Tari Development Community"] edition = "2018" [dependencies] -tari_wallet = { path = "../../base_layer/wallet" } +tari_wallet = { path = "../../base_layer/wallet", features=["bundled_sqlite"] } tari_crypto = "0.11.1" tari_common = { path = "../../common" } -tari_app_utilities = { path = "../tari_app_utilities"} +tari_app_utilities = { path = "../tari_app_utilities", features = ["wallet"]} tari_comms = { path = "../../comms"} tari_comms_dht = { path = "../../comms/dht"} +tari_common_types = {path = "../../base_layer/common_types"} tari_p2p = { path = "../../base_layer/p2p" } -tari_app_grpc = { path = "../tari_app_grpc" } +tari_app_grpc = { path = "../tari_app_grpc", features = ["wallet"] } tari_shutdown = { path = "../../infrastructure/shutdown" } tari_key_manager = { path = "../../base_layer/key_manager" } bitflags = "1.2.1" chrono = { version = "0.4.6", features = ["serde"]} chrono-english = "0.1" -futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} +futures = { version = "^0.3.16", default-features = false, features = ["alloc"]} crossterm = { version = "0.17"} rand = "0.8" unicode-width = "0.1" @@ -31,9 +32,9 @@ rpassword = "5.0" rustyline = "6.0" strum = "^0.19" strum_macros = "^0.19" -tokio = { version="0.2.10", features = ["signal"] } -thiserror = "1.0.20" -tonic = "0.2" +tokio = { version="^1.10", features = ["signal"] } +thiserror = "1.0.26" +tonic = "0.5.2" tracing = "0.1.26" tracing-opentelemetry = "0.15.0" @@ -43,7 +44,6 @@ tracing-subscriber = "0.2.20" opentelemetry = { version = "0.16", default-features = false, features = ["trace","rt-tokio"] } opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} - [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" diff --git a/applications/tari_console_wallet/src/automation/command_parser.rs b/applications/tari_console_wallet/src/automation/command_parser.rs index 9555ea7fba..90f0347918 100644 --- a/applications/tari_console_wallet/src/automation/command_parser.rs +++ b/applications/tari_console_wallet/src/automation/command_parser.rs @@ -32,7 +32,8 @@ use std::{ use tari_app_utilities::utilities::parse_emoji_id_or_public_key; use tari_comms::multiaddr::Multiaddr; -use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; +use tari_common_types::types::PublicKey; +use tari_core::transactions::tari_amount::MicroTari; #[derive(Debug)] pub struct ParsedCommand { @@ -348,7 +349,8 @@ mod test { }; use rand::rngs::OsRng; use std::str::FromStr; - use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; + use tari_common_types::types::PublicKey; + use tari_core::transactions::tari_amount::MicroTari; use tari_crypto::keys::PublicKey as PublicKeyTrait; #[test] diff --git a/applications/tari_console_wallet/src/automation/commands.rs b/applications/tari_console_wallet/src/automation/commands.rs index 608cc8a675..7bbdf8da44 100644 --- a/applications/tari_console_wallet/src/automation/commands.rs +++ b/applications/tari_console_wallet/src/automation/commands.rs @@ -21,12 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::error::CommandError; -use crate::{ - automation::command_parser::{ParsedArgument, ParsedCommand}, - utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, -}; -use chrono::{DateTime, Utc}; -use futures::{FutureExt, StreamExt}; use log::*; use std::{ fs::File, @@ -34,8 +28,18 @@ use std::{ str::FromStr, time::{Duration, Instant}, }; + +use chrono::{DateTime, Utc}; +use futures::FutureExt; use strum_macros::{Display, EnumIter, EnumString}; +use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; + +use crate::{ + automation::command_parser::{ParsedArgument, ParsedCommand}, + utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, +}; use tari_common::GlobalConfig; +use tari_common_types::{emoji::EmojiId, types::PublicKey}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityRequester}, multiaddr::Multiaddr, @@ -47,19 +51,16 @@ use tari_core::{ transactions::{ tari_amount::{uT, MicroTari, Tari}, transaction::UnblindedOutput, - types::PublicKey, }, }; -use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; use tari_wallet::{ output_manager_service::{handle::OutputManagerHandle, TxId}, transaction_service::handle::{TransactionEvent, TransactionServiceHandle}, - util::emoji::EmojiId, WalletSqlite, }; use tokio::{ - sync::mpsc, - time::{delay_for, timeout}, + sync::{broadcast, mpsc}, + time::{sleep, timeout}, }; pub const LOG_TARGET: &str = "wallet::automation::commands"; @@ -175,21 +176,22 @@ pub async fn coin_split( Ok(tx_id) } -async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result { - let mut connectivity = connectivity_requester.get_event_subscription().fuse(); +async fn wait_for_comms(connectivity_requester: &ConnectivityRequester) -> Result<(), CommandError> { + let mut connectivity = connectivity_requester.get_event_subscription(); print!("Waiting for connectivity... "); - let mut timeout = delay_for(Duration::from_secs(30)).fuse(); + let timeout = sleep(Duration::from_secs(30)); + tokio::pin!(timeout); + let mut timeout = timeout.fuse(); loop { - futures::select! { - result = connectivity.select_next_some() => { - if let Ok(msg) = result { - if let ConnectivityEvent::PeerConnected(_) = (*msg).clone() { - println!("✅"); - return Ok(true); - } + tokio::select! { + // Wait for the first base node connection + Ok(ConnectivityEvent::PeerConnected(conn)) = connectivity.recv() => { + if conn.peer_features().is_node() { + println!("✅"); + return Ok(()); } }, - () = timeout => { + () = &mut timeout => { println!("❌"); return Err(CommandError::Comms("Timed out".to_string())); } @@ -311,7 +313,7 @@ pub async fn make_it_rain( target: LOG_TARGET, "make-it-rain delaying for {:?} ms - scheduled to start at {}", delay_ms, start_time ); - delay_for(Duration::from_millis(delay_ms)).await; + sleep(Duration::from_millis(delay_ms)).await; let num_txs = (txps * duration as f64) as usize; let started_at = Utc::now(); @@ -352,10 +354,10 @@ pub async fn make_it_rain( let target_ms = (i as f64 / (txps / 1000.0)) as i64; if target_ms - actual_ms > 0 { // Maximum delay between Txs set to 120 s - delay_for(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; + sleep(Duration::from_millis((target_ms - actual_ms).min(120_000i64) as u64)).await; } let delayed_for = Instant::now(); - let mut sender_clone = sender.clone(); + let sender_clone = sender.clone(); tokio::task::spawn(async move { let spawn_start = Instant::now(); // Send transaction @@ -432,7 +434,7 @@ pub async fn monitor_transactions( tx_ids: Vec, wait_stage: TransactionStage, ) -> Vec { - let mut event_stream = transaction_service.get_event_stream_fused(); + let mut event_stream = transaction_service.get_event_stream(); let mut results = Vec::new(); debug!(target: LOG_TARGET, "monitor transactions wait_stage: {:?}", wait_stage); println!( @@ -442,104 +444,102 @@ pub async fn monitor_transactions( ); loop { - match event_stream.next().await { - Some(event_result) => match event_result { - Ok(event) => match &*event { - TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx direct send event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } - } - }, - TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx store and forward event for tx_id: {}, success: {}", *id, success - ); - if wait_stage == TransactionStage::DirectSendOrSaf { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::DirectSendOrSaf, - }); - if results.len() == tx_ids.len() { - break; - } + match event_stream.recv().await { + Ok(event) => match &*event { + TransactionEvent::TransactionDirectSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx direct send event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); - if wait_stage == TransactionStage::Negotiated { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Negotiated, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionStoreForwardSendResult(id, success) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx store and forward event for tx_id: {}, success: {}", *id, success + ); + if wait_stage == TransactionStage::DirectSendOrSaf { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::DirectSendOrSaf, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); - if wait_stage == TransactionStage::Broadcast { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Broadcast, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::ReceivedTransactionReply(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx reply event for tx_id: {}", *id); + if wait_stage == TransactionStage::Negotiated { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Negotiated, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { - debug!( - target: LOG_TARGET, - "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations - ); - if wait_stage == TransactionStage::MinedUnconfirmed { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::MinedUnconfirmed, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionBroadcast(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mempool broadcast event for tx_id: {}", *id); + if wait_stage == TransactionStage::Broadcast { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Broadcast, + }); + if results.len() == tx_ids.len() { + break; } - }, - TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { - debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); - if wait_stage == TransactionStage::Mined { - results.push(SentTransaction { - id: *id, - stage: TransactionStage::Mined, - }); - if results.len() == tx_ids.len() { - break; - } + } + }, + TransactionEvent::TransactionMinedUnconfirmed(id, confirmations) if tx_ids.contains(id) => { + debug!( + target: LOG_TARGET, + "tx mined unconfirmed event for tx_id: {}, confirmations: {}", *id, confirmations + ); + if wait_stage == TransactionStage::MinedUnconfirmed { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::MinedUnconfirmed, + }); + if results.len() == tx_ids.len() { + break; } - }, - _ => {}, + } }, - Err(e) => { - eprintln!("RecvError in monitor_transactions: {:?}", e); - break; + TransactionEvent::TransactionMined(id) if tx_ids.contains(id) => { + debug!(target: LOG_TARGET, "tx mined confirmed event for tx_id: {}", *id); + if wait_stage == TransactionStage::Mined { + results.push(SentTransaction { + id: *id, + stage: TransactionStage::Mined, + }); + if results.len() == tx_ids.len() { + break; + } + } }, + _ => {}, }, - None => { - warn!( + // All event senders have gone (i.e. we take it that the node is shutting down) + Err(broadcast::error::RecvError::Closed) => { + debug!( target: LOG_TARGET, - "`None` result in event in monitor_transactions loop" + "All Transaction event senders have gone. Exiting `monitor_transactions` loop." ); break; }, + Err(err) => { + warn!(target: LOG_TARGET, "monitor_transactions: {}", err); + }, } } @@ -578,7 +578,8 @@ pub async fn command_runner( }, DiscoverPeer => { if !online { - online = wait_for_comms(&connectivity_requester).await?; + wait_for_comms(&connectivity_requester).await?; + online = true; } discover_peer(dht_service.clone(), parsed.args).await? }, diff --git a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs index b7e53ae6a2..b0c7779431 100644 --- a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -1,4 +1,4 @@ -use futures::future; +use futures::{channel::mpsc, future, SinkExt}; use log::*; use std::convert::TryFrom; use tari_app_grpc::{ @@ -31,17 +31,18 @@ use tari_app_grpc::{ TransferResult, }, }; +use tari_common_types::types::Signature; use tari_comms::{types::CommsPublicKey, CommsNode}; use tari_core::{ tari_utilities::{hex::Hex, ByteArray}, - transactions::{tari_amount::MicroTari, transaction::UnblindedOutput, types::Signature}, + transactions::{tari_amount::MicroTari, transaction::UnblindedOutput}, }; use tari_wallet::{ output_manager_service::handle::OutputManagerHandle, transaction_service::{handle::TransactionServiceHandle, storage::models}, WalletSqlite, }; -use tokio::{sync::mpsc, task}; +use tokio::task; use tonic::{Request, Response, Status}; const LOG_TARGET: &str = "wallet::ui::grpc"; diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index b5c0c9d805..b877c6b729 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -20,23 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - utils::db::get_custom_base_node_peer_from_db, - wallet_modes::{PeerConfig, WalletMode}, -}; +use std::{fs, path::PathBuf, str::FromStr, sync::Arc}; + use log::*; use rpassword::prompt_password_stdout; use rustyline::Editor; -use std::{fs, path::PathBuf, str::FromStr, sync::Arc}; + use tari_app_utilities::utilities::{create_transport_type, ExitCodes}; use tari_common::{ConfigBootstrap, GlobalConfig}; +use tari_common_types::types::PrivateKey; use tari_comms::{ peer_manager::{Peer, PeerFeatures}, types::CommsSecretKey, NodeIdentity, }; use tari_comms_dht::{DbConnectionUrl, DhtConfig}; -use tari_core::transactions::types::{CryptoFactories, PrivateKey}; +use tari_core::transactions::CryptoFactories; use tari_p2p::{ initialization::CommsConfig, peer_seeds::SeedPeer, @@ -59,6 +58,11 @@ use tari_wallet::{ WalletSqlite, }; +use crate::{ + utils::db::get_custom_base_node_peer_from_db, + wallet_modes::{PeerConfig, WalletMode}, +}; + pub const LOG_TARGET: &str = "wallet::console_wallet::init"; /// The minimum buffer size for a tari application pubsub_connector channel const BASE_NODE_BUFFER_MIN_SIZE: usize = 30; @@ -128,9 +132,15 @@ pub async fn change_password( return Err(ExitCodes::InputError("Passwords don't match!".to_string())); } - wallet.remove_encryption().await?; + wallet + .remove_encryption() + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; - wallet.apply_encryption(passphrase).await?; + wallet + .apply_encryption(passphrase) + .await + .map_err(|e| ExitCodes::WalletError(e.to_string()))?; println!("Wallet password changed successfully."); diff --git a/applications/tari_console_wallet/src/main.rs b/applications/tari_console_wallet/src/main.rs index 4042eb27eb..23b917fb16 100644 --- a/applications/tari_console_wallet/src/main.rs +++ b/applications/tari_console_wallet/src/main.rs @@ -24,7 +24,7 @@ use recovery::prompt_private_key_from_seed_words; use std::{env, process}; use tari_app_utilities::{consts, initialization::init_configuration, utilities::ExitCodes}; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap}; -use tari_core::transactions::types::PrivateKey; +use tari_common_types::types::PrivateKey; use tari_shutdown::Shutdown; use tracing_subscriber::{layer::SubscriberExt, Registry}; use wallet_modes::{command_mode, grpc_mode, recovery_mode, script_mode, tui_mode, WalletMode}; @@ -58,8 +58,7 @@ fn main() { } fn main_inner() -> Result<(), ExitCodes> { - let mut runtime = tokio::runtime::Builder::new() - .threaded_scheduler() + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Failed to build a runtime!"); @@ -156,11 +155,8 @@ fn main_inner() -> Result<(), ExitCodes> { }; print!("\nShutting down wallet... "); - if shutdown.trigger().is_ok() { - runtime.block_on(wallet.wait_until_shutdown()); - } else { - error!(target: LOG_TARGET, "No listeners for the shutdown signal!"); - } + shutdown.trigger(); + runtime.block_on(wallet.wait_until_shutdown()); println!("Done."); result diff --git a/applications/tari_console_wallet/src/recovery.rs b/applications/tari_console_wallet/src/recovery.rs index 887995f0a5..297b84e983 100644 --- a/applications/tari_console_wallet/src/recovery.rs +++ b/applications/tari_console_wallet/src/recovery.rs @@ -21,11 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use chrono::offset::Local; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use rustyline::Editor; use tari_app_utilities::utilities::ExitCodes; -use tari_core::transactions::types::PrivateKey; +use tari_common_types::types::PrivateKey; use tari_key_manager::mnemonic::to_secretkey; use tari_shutdown::Shutdown; use tari_wallet::{ @@ -35,6 +35,7 @@ use tari_wallet::{ }; use crate::wallet_modes::PeerConfig; +use tokio::sync::broadcast; pub const LOG_TARGET: &str = "wallet::recovery"; @@ -97,13 +98,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi .with_retry_limit(10) .build_with_wallet(wallet, shutdown_signal); - let mut event_stream = recovery_task.get_event_receiver().fuse(); + let mut event_stream = recovery_task.get_event_receiver(); let recovery_join_handle = tokio::spawn(recovery_task.run()).fuse(); // Read recovery task events. The event stream will end once recovery has completed. - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { print!("Connecting to base node {}... ", peer); }, @@ -170,11 +171,13 @@ pub async fn wallet_recovery(wallet: &WalletSqlite, base_node_config: &PeerConfi info!(target: LOG_TARGET, "{}", stats); println!("{}", stats); }, - Err(e) => { - // Can occur if we read events too slowly (lagging/slow subscriber) + Err(e @ broadcast::error::RecvError::Lagged(_)) => { debug!(target: LOG_TARGET, "Error receiving Wallet recovery events: {}", e); continue; }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Ok(UtxoScannerEvent::ScanningFailed) => { error!(target: LOG_TARGET, "Wallet Recovery process failed and is exiting"); }, diff --git a/applications/tari_console_wallet/src/ui/components/base_node.rs b/applications/tari_console_wallet/src/ui/components/base_node.rs index d9a271e291..ade233b90c 100644 --- a/applications/tari_console_wallet/src/ui/components/base_node.rs +++ b/applications/tari_console_wallet/src/ui/components/base_node.rs @@ -42,9 +42,9 @@ impl BaseNode { impl Component for BaseNode { fn draw(&mut self, f: &mut Frame, area: Rect, app_state: &AppState) where B: Backend { - let base_node_state = app_state.get_base_node_state(); + let current_online_status = app_state.get_wallet_connectivity().get_connectivity_status(); - let chain_info = match base_node_state.online { + let chain_info = match current_online_status { OnlineStatus::Connecting => Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), @@ -56,7 +56,8 @@ impl Component for BaseNode { Span::styled("Offline", Style::default().fg(Color::Red)), ]), OnlineStatus::Online => { - if let Some(metadata) = base_node_state.clone().chain_metadata { + let base_node_state = app_state.get_base_node_state(); + if let Some(ref metadata) = base_node_state.chain_metadata { let tip = metadata.height_of_longest_chain(); let synced = base_node_state.is_synced.unwrap_or_default(); @@ -92,7 +93,7 @@ impl Component for BaseNode { Spans::from(vec![ Span::styled("Chain Tip:", Style::default().fg(Color::Magenta)), Span::raw(" "), - Span::styled("Error", Style::default().fg(Color::Red)), + Span::styled("Waiting for data...", Style::default().fg(Color::DarkGray)), ]) } }, diff --git a/applications/tari_console_wallet/src/ui/state/app_state.rs b/applications/tari_console_wallet/src/ui/state/app_state.rs index 6d9293ef47..b4a45df7b4 100644 --- a/applications/tari_console_wallet/src/ui/state/app_state.rs +++ b/applications/tari_console_wallet/src/ui/state/app_state.rs @@ -20,29 +20,23 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - notifier::Notifier, - ui::{ - state::{ - tasks::{send_one_sided_transaction_task, send_transaction_task}, - wallet_event_monitor::WalletEventMonitor, - }, - UiContact, - UiError, - }, - utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, - wallet_modes::PeerConfig, -}; -use bitflags::bitflags; -use futures::{stream::Fuse, StreamExt}; -use log::*; -use qrcode::{render::unicode, QrCode}; use std::{ collections::HashMap, sync::Arc, time::{Duration, Instant}, }; + +use bitflags::bitflags; +use log::*; +use qrcode::{render::unicode, QrCode}; +use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex}; +use tokio::{ + sync::{watch, RwLock}, + task, +}; + use tari_common::{configuration::Network, GlobalConfig}; +use tari_common_types::{emoji::EmojiId, types::PublicKey}; use tari_comms::{ connectivity::ConnectivityEventRx, multiaddr::Multiaddr, @@ -50,11 +44,7 @@ use tari_comms::{ types::CommsPublicKey, NodeIdentity, }; -use tari_core::transactions::{ - tari_amount::{uT, MicroTari}, - types::PublicKey, -}; -use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex}; +use tari_core::transactions::tari_amount::{uT, MicroTari}; use tari_shutdown::ShutdownSignal; use tari_wallet::{ base_node_service::{handle::BaseNodeEventReceiver, service::BaseNodeState}, @@ -66,12 +56,21 @@ use tari_wallet::{ storage::models::{CompletedTransaction, TransactionStatus}, }, types::ValidationRetryStrategy, - util::emoji::EmojiId, WalletSqlite, }; -use tokio::{ - sync::{watch, RwLock}, - task, + +use crate::{ + notifier::Notifier, + ui::{ + state::{ + tasks::{send_one_sided_transaction_task, send_transaction_task}, + wallet_event_monitor::WalletEventMonitor, + }, + UiContact, + UiError, + }, + utils::db::{CUSTOM_BASE_NODE_ADDRESS_KEY, CUSTOM_BASE_NODE_PUBLIC_KEY_KEY}, + wallet_modes::PeerConfig, }; const LOG_TARGET: &str = "wallet::console_wallet::app_state"; @@ -84,6 +83,7 @@ pub struct AppState { completed_tx_filter: TransactionFilter, node_config: GlobalConfig, config: AppStateConfig, + wallet_connectivity: WalletConnectivityHandle, } impl AppState { @@ -95,6 +95,7 @@ impl AppState { base_node_config: PeerConfig, node_config: GlobalConfig, ) -> Self { + let wallet_connectivity = wallet.wallet_connectivity.clone(); let inner = AppStateInner::new(node_identity, network, wallet, base_node_selected, base_node_config); let cached_data = inner.data.clone(); @@ -105,6 +106,7 @@ impl AppState { completed_tx_filter: TransactionFilter::ABANDONED_COINBASES, node_config, config: AppStateConfig::default(), + wallet_connectivity, } } @@ -352,6 +354,10 @@ impl AppState { &self.cached_data.base_node_state } + pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { + self.wallet_connectivity.clone() + } + pub fn get_selected_base_node(&self) -> &Peer { &self.cached_data.base_node_selected } @@ -641,24 +647,24 @@ impl AppStateInner { self.wallet.comms.shutdown_signal() } - pub fn get_transaction_service_event_stream(&self) -> Fuse { - self.wallet.transaction_service.get_event_stream_fused() + pub fn get_transaction_service_event_stream(&self) -> TransactionEventReceiver { + self.wallet.transaction_service.get_event_stream() } - pub fn get_output_manager_service_event_stream(&self) -> Fuse { - self.wallet.output_manager_service.get_event_stream_fused() + pub fn get_output_manager_service_event_stream(&self) -> OutputManagerEventReceiver { + self.wallet.output_manager_service.get_event_stream() } - pub fn get_connectivity_event_stream(&self) -> Fuse { - self.wallet.comms.connectivity().get_event_subscription().fuse() + pub fn get_connectivity_event_stream(&self) -> ConnectivityEventRx { + self.wallet.comms.connectivity().get_event_subscription() } pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle { self.wallet.wallet_connectivity.clone() } - pub fn get_base_node_event_stream(&self) -> Fuse { - self.wallet.base_node_service.clone().get_event_stream_fused() + pub fn get_base_node_event_stream(&self) -> BaseNodeEventReceiver { + self.wallet.base_node_service.get_event_stream() } pub async fn set_base_node_peer(&mut self, peer: Peer) -> Result<(), UiError> { diff --git a/applications/tari_console_wallet/src/ui/state/tasks.rs b/applications/tari_console_wallet/src/ui/state/tasks.rs index caf8073f56..85243660e6 100644 --- a/applications/tari_console_wallet/src/ui/state/tasks.rs +++ b/applications/tari_console_wallet/src/ui/state/tasks.rs @@ -21,11 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::ui::{state::UiTransactionSendStatus, UiError}; -use futures::StreamExt; use tari_comms::types::CommsPublicKey; use tari_core::transactions::tari_amount::MicroTari; use tari_wallet::transaction_service::handle::{TransactionEvent, TransactionServiceHandle}; -use tokio::sync::watch; +use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "wallet::console_wallet::tasks "; @@ -37,8 +36,8 @@ pub async fn send_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); let mut send_direct_received_result = (false, false); let mut send_saf_received_result = (false, false); match transaction_service_handle @@ -46,15 +45,15 @@ pub async fn send_transaction_task( .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => match &*event { TransactionEvent::TransactionDiscoveryInProgress(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::DiscoveryInProgress); + let _ = result_tx.send(UiTransactionSendStatus::DiscoveryInProgress); } }, TransactionEvent::TransactionDirectSendResult(tx_id, result) => { @@ -75,25 +74,28 @@ pub async fn send_transaction_task( }, TransactionEvent::TransactionCompletedImmediately(tx_id) => { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } }, _ => (), }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } if send_direct_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentDirect); + let _ = result_tx.send(UiTransactionSendStatus::SentDirect); } else if send_saf_received_result.1 { - let _ = result_tx.broadcast(UiTransactionSendStatus::SentViaSaf); + let _ = result_tx.send(UiTransactionSendStatus::SentViaSaf); } else { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "Transaction could not be sent".to_string(), )); } @@ -109,34 +111,37 @@ pub async fn send_one_sided_transaction_task( mut transaction_service_handle: TransactionServiceHandle, result_tx: watch::Sender, ) { - let _ = result_tx.broadcast(UiTransactionSendStatus::Initiated); - let mut event_stream = transaction_service_handle.get_event_stream_fused(); + let _ = result_tx.send(UiTransactionSendStatus::Initiated); + let mut event_stream = transaction_service_handle.get_event_stream(); match transaction_service_handle .send_one_sided_transaction(public_key, amount, fee_per_gram, message) .await { Err(e) => { - let _ = result_tx.broadcast(UiTransactionSendStatus::Error(UiError::from(e).to_string())); + let _ = result_tx.send(UiTransactionSendStatus::Error(UiError::from(e).to_string())); }, Ok(our_tx_id) => { - while let Some(event_result) = event_stream.next().await { - match event_result { + loop { + match event_stream.recv().await { Ok(event) => { if let TransactionEvent::TransactionCompletedImmediately(tx_id) = &*event { if our_tx_id == *tx_id { - let _ = result_tx.broadcast(UiTransactionSendStatus::TransactionComplete); + let _ = result_tx.send(UiTransactionSendStatus::TransactionComplete); return; } } }, - Err(e) => { + Err(e @ broadcast::error::RecvError::Lagged(_)) => { log::warn!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { break; }, } } - let _ = result_tx.broadcast(UiTransactionSendStatus::Error( + let _ = result_tx.send(UiTransactionSendStatus::Error( "One-sided transaction could not be sent".to_string(), )); }, diff --git a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs index 2e20999667..e7df30b653 100644 --- a/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs +++ b/applications/tari_console_wallet/src/ui/state/wallet_event_monitor.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{notifier::Notifier, ui::state::AppStateInner}; -use futures::stream::StreamExt; use log::*; use std::sync::Arc; use tari_comms::{connectivity::ConnectivityEvent, peer_manager::Peer}; @@ -30,7 +29,7 @@ use tari_wallet::{ output_manager_service::{handle::OutputManagerEvent, TxId}, transaction_service::handle::TransactionEvent, }; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, RwLock}; const LOG_TARGET: &str = "wallet::console_wallet::wallet_event_monitor"; @@ -55,14 +54,14 @@ impl WalletEventMonitor { let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream(); let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity(); - let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse(); + let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch(); let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream(); info!(target: LOG_TARGET, "Wallet Event Monitor starting"); loop { - futures::select! { - result = transaction_service_events.select_next_some() => { + tokio::select! { + result = transaction_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet transaction service event {:?}", msg); @@ -104,18 +103,21 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Transaction events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - status = connectivity_status.select_next_some() => { - trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status); + Ok(_) = connectivity_status.changed() => { + trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status changed"); self.trigger_peer_state_refresh().await; }, - result = connectivity_events.select_next_some() => { + result = connectivity_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg); - match &*msg { + match msg { ConnectivityEvent::PeerDisconnected(_) | ConnectivityEvent::ManagedPeerDisconnected(_) | ConnectivityEvent::PeerConnected(_) => { @@ -125,10 +127,13 @@ impl WalletEventMonitor { _ => (), } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on Connectivity event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Connectivity events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = base_node_events.select_next_some() => { + result = base_node_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Wallet Event Monitor received base node event {:?}", msg); @@ -141,10 +146,13 @@ impl WalletEventMonitor { } } }, - Err(_) => debug!(target: LOG_TARGET, "Lagging read on base node event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Base node Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } }, - result = output_manager_service_events.select_next_some() => { + result = output_manager_service_events.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -152,14 +160,13 @@ impl WalletEventMonitor { self.trigger_balance_refresh().await; } }, - Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(target: LOG_TARGET, "Missed {} from Output Manager Service events", n); + } + Err(broadcast::error::RecvError::Closed) => {} } - }, - complete => { - info!(target: LOG_TARGET, "Wallet Event Monitor is exiting because all tasks have completed"); - break; }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Wallet Event Monitor shutting down because the shutdown signal was received"); break; }, diff --git a/applications/tari_console_wallet/src/ui/ui_contact.rs b/applications/tari_console_wallet/src/ui/ui_contact.rs index 2d8be4a182..49f83e8284 100644 --- a/applications/tari_console_wallet/src/ui/ui_contact.rs +++ b/applications/tari_console_wallet/src/ui/ui_contact.rs @@ -1,4 +1,5 @@ -use tari_wallet::{contacts_service::storage::database::Contact, util::emoji::EmojiId}; +use tari_common_types::emoji::EmojiId; +use tari_wallet::contacts_service::storage::database::Contact; #[derive(Debug, Clone)] pub struct UiContact { diff --git a/applications/tari_console_wallet/src/utils/db.rs b/applications/tari_console_wallet/src/utils/db.rs index 9dbf43cfd2..aee50c4e40 100644 --- a/applications/tari_console_wallet/src/utils/db.rs +++ b/applications/tari_console_wallet/src/utils/db.rs @@ -21,11 +21,12 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use log::*; +use tari_common_types::types::PublicKey; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, }; -use tari_core::transactions::types::PublicKey; + use tari_crypto::tari_utilities::hex::Hex; use tari_wallet::WalletSqlite; diff --git a/applications/tari_console_wallet/src/wallet_modes.rs b/applications/tari_console_wallet/src/wallet_modes.rs index a205523a00..23b1ee6330 100644 --- a/applications/tari_console_wallet/src/wallet_modes.rs +++ b/applications/tari_console_wallet/src/wallet_modes.rs @@ -239,7 +239,10 @@ pub fn tui_mode(config: WalletModeConfig, mut wallet: WalletSqlite) -> Result<() info!(target: LOG_TARGET, "Starting app"); - handle.enter(|| ui::run(app))?; + { + let _enter = handle.enter(); + ui::run(app)?; + } info!( target: LOG_TARGET, diff --git a/applications/tari_merge_mining_proxy/Cargo.toml b/applications/tari_merge_mining_proxy/Cargo.toml index 626d0502b6..2cd31b8b3d 100644 --- a/applications/tari_merge_mining_proxy/Cargo.toml +++ b/applications/tari_merge_mining_proxy/Cargo.toml @@ -13,33 +13,32 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} -tari_app_utilities = { path = "../tari_app_utilities"} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } +tari_app_utilities = { path = "../tari_app_utilities" } tari_crypto = "0.11.1" tari_utilities = "^0.3" anyhow = "1.0.40" bincode = "1.3.1" -bytes = "0.5.6" +bytes = "1.1" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.8" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11.4", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" diff --git a/applications/tari_merge_mining_proxy/src/main.rs b/applications/tari_merge_mining_proxy/src/main.rs index 7e15777977..0df27490bb 100644 --- a/applications/tari_merge_mining_proxy/src/main.rs +++ b/applications/tari_merge_mining_proxy/src/main.rs @@ -46,7 +46,7 @@ use tari_app_utilities::initialization::init_configuration; use tari_common::configuration::bootstrap::ApplicationType; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), anyhow::Error> { let (_, config, _) = init_configuration(ApplicationType::MergeMiningProxy)?; diff --git a/applications/tari_mining_node/Cargo.toml b/applications/tari_mining_node/Cargo.toml index a04938c048..f893c0bc06 100644 --- a/applications/tari_mining_node/Cargo.toml +++ b/applications/tari_mining_node/Cargo.toml @@ -17,15 +17,15 @@ crossbeam = "0.8" futures = "0.3" log = { version = "0.4", features = ["std"] } num_cpus = "1.13" -prost-types = "0.6" +prost-types = "0.8" rand = "0.8" sha3 = "0.9" serde = { version = "1.0", default_features = false, features = ["derive"] } -tonic = { version = "0.2", features = ["transport"] } -tokio = { version = "0.2", default_features = false, features = ["rt-core"] } +tonic = { version = "0.5.2", features = ["transport"] } +tokio = { version = "1.10", default_features = false, features = ["rt-multi-thread"] } thiserror = "1.0" jsonrpc = "0.11.0" -reqwest = { version = "0.11", features = ["blocking", "json"] } +reqwest = { version = "0.11", features = [ "json"] } serde_json = "1.0.57" native-tls = "0.2" bufstream = "0.1" @@ -35,5 +35,5 @@ hex = "0.4.2" [dev-dependencies] tari_crypto = "0.11.1" -prost-types = "0.6.1" +prost-types = "0.8" chrono = "0.4" diff --git a/applications/tari_mining_node/src/main.rs b/applications/tari_mining_node/src/main.rs index b1f538c1e4..0a0bf31557 100644 --- a/applications/tari_mining_node/src/main.rs +++ b/applications/tari_mining_node/src/main.rs @@ -23,13 +23,6 @@ use config::MinerConfig; use futures::stream::StreamExt; use log::*; -use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; -use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; -use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; -use tari_core::blocks::BlockHeader; -use tokio::{runtime::Runtime, time::delay_for}; -use tonic::transport::Channel; -use utils::{coinbase_request, extract_outputs_and_kernels}; mod config; mod difficulty; @@ -53,10 +46,17 @@ use std::{ thread, time::Instant, }; +use tari_app_grpc::tari_rpc::{base_node_client::BaseNodeClient, wallet_client::WalletClient}; +use tari_app_utilities::{initialization::init_configuration, utilities::ExitCodes}; +use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, DefaultConfigLoader, GlobalConfig}; +use tari_core::blocks::BlockHeader; +use tokio::{runtime::Runtime, time::sleep}; +use tonic::transport::Channel; +use utils::{coinbase_request, extract_outputs_and_kernels}; /// Application entry point fn main() { - let mut rt = Runtime::new().expect("Failed to start tokio runtime"); + let rt = Runtime::new().expect("Failed to start tokio runtime"); match rt.block_on(main_inner()) { Ok(_) => std::process::exit(0), Err(exit_code) => { @@ -144,7 +144,7 @@ async fn main_inner() -> Result<(), ExitCodes> { error!("Connection error: {:?}", err); loop { debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; match connect(&config, &global).await { Ok((nc, wc)) => { node_conn = nc; @@ -168,7 +168,7 @@ async fn main_inner() -> Result<(), ExitCodes> { Err(err) => { error!("Error: {:?}", err); debug!("Holding for {:?}", config.wait_timeout()); - delay_for(config.wait_timeout()).await; + sleep(config.wait_timeout()).await; }, Ok(submitted) => { if submitted { diff --git a/applications/tari_stratum_transcoder/Cargo.toml b/applications/tari_stratum_transcoder/Cargo.toml index 29f95c82da..d29b91ead7 100644 --- a/applications/tari_stratum_transcoder/Cargo.toml +++ b/applications/tari_stratum_transcoder/Cargo.toml @@ -13,37 +13,37 @@ envlog = ["env_logger"] [dependencies] tari_app_grpc = { path = "../tari_app_grpc" } -tari_common = { path = "../../common" } -tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} +tari_common = { path = "../../common" } +tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"] } tari_crypto = "0.11.1" tari_utilities = "^0.3" + bincode = "1.3.1" -bytes = "0.5.6" +bytes = "0.5" chrono = "0.4.19" config = { version = "0.9.3" } derive-error = "0.0.4" env_logger = { version = "0.7.1", optional = true } futures = "0.3.5" hex = "0.4.2" -hyper = "0.13.7" +hyper = "0.14.12" jsonrpc = "0.11.0" log = { version = "0.4.8", features = ["std"] } rand = "0.7.2" -reqwest = {version = "0.10.8", features=["json"]} -serde = { version="1.0.106", features = ["derive"] } +reqwest = { version = "0.11", features = ["json"] } +serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0.57" structopt = { version = "0.3.13", default_features = false } -thiserror = "1.0.15" -tokio = "0.2.10" -tokio-macros = "0.2.5" -tonic = "0.2" +thiserror = "1.0.26" +tokio = { version = "^1.10", features = ["macros"] } +tonic = "0.5.2" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = "0.2" url = "2.1.1" [build-dependencies] -tonic-build = "0.2" +tonic-build = "0.5.2" [dev-dependencies] futures-test = "0.3.5" diff --git a/applications/tari_stratum_transcoder/src/main.rs b/applications/tari_stratum_transcoder/src/main.rs index d55d551b5b..f742c92d6d 100644 --- a/applications/tari_stratum_transcoder/src/main.rs +++ b/applications/tari_stratum_transcoder/src/main.rs @@ -41,7 +41,7 @@ use tari_app_grpc::tari_rpc as grpc; use tari_common::{configuration::bootstrap::ApplicationType, ConfigBootstrap, GlobalConfig}; use tokio::time::Duration; -#[tokio_macros::main] +#[tokio::main] async fn main() -> Result<(), StratumTranscoderProxyError> { let config = initialize()?; diff --git a/applications/test_faucet/Cargo.toml b/applications/test_faucet/Cargo.toml index 3ef3c8a4c1..f686ece098 100644 --- a/applications/test_faucet/Cargo.toml +++ b/applications/test_faucet/Cargo.toml @@ -7,11 +7,13 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tari_crypto = "0.11.1" tari_utilities = "^0.3" +tari_common_types ={path="../../base_layer/common_types"} + +rand = "0.8" serde = { version = "1.0.97", features = ["derive"] } serde_json = "1.0" -rand = "0.8" -tari_crypto = "0.11.1" [dependencies.tari_core] version = "^0.9" @@ -20,6 +22,6 @@ default-features = false features = ["transactions", "avx2"] [dependencies.tokio] -version = "^0.2.10" +version = "^1.10" default-features = false -features = ["fs", "blocking", "stream", "rt-threaded", "macros", "io-util", "sync"] +features = ["fs", "rt-multi-thread", "macros", "io-util", "sync"] diff --git a/applications/test_faucet/src/main.rs b/applications/test_faucet/src/main.rs index 0aadee4d5e..65f75a4b5c 100644 --- a/applications/test_faucet/src/main.rs +++ b/applications/test_faucet/src/main.rs @@ -5,20 +5,22 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use serde::Serialize; + use std::{fs::File, io::Write}; -use tari_core::{ - tari_utilities::hex::Hex, - transactions::{ - helpers, - tari_amount::{MicroTari, T}, - transaction::{KernelFeatures, OutputFeatures, TransactionKernel, TransactionOutput}, - types::{Commitment, CryptoFactories, PrivateKey}, - }, -}; + +use serde::Serialize; use tari_crypto::script; use tokio::{sync::mpsc, task}; +use tari_common_types::types::{Commitment, PrivateKey}; +use tari_core::transactions::{ + helpers, + tari_amount::{MicroTari, T}, + transaction::{KernelFeatures, OutputFeatures, TransactionKernel, TransactionOutput}, + CryptoFactories, +}; +use tari_crypto::tari_utilities::hex::Hex; + const NUM_KEYS: usize = 4000; #[derive(Serialize)] @@ -32,7 +34,7 @@ struct Key { /// UTXO generation is pretty slow (esp range proofs), so we'll use async threads to speed things up. /// We'll use blocking thread tasks to do the CPU intensive utxo generation, and then push the results /// through a channel where a file-writer is waiting to persist the results to disk. -#[tokio::main(core_threads = 2, max_threads = 10)] +#[tokio::main(worker_threads = 2)] async fn main() -> Result<(), Box> { let num_keys: usize = std::env::args() .skip(1) @@ -52,7 +54,7 @@ async fn main() -> Result<(), Box> { // Use Rust's awesome Iterator trait to produce a sequence of values and output features. for (value, feature) in values.take(num_keys).zip(features.take(num_keys)) { let fc = factories.clone(); - let mut txc = tx.clone(); + let txc = tx.clone(); // Notice the `spawn(.. spawn_blocking)` nested call here. If we don't do this, we're basically queuing up // blocking tasks, `await`ing them to finish, and then queueing up the next one. In effect we're running things // synchronously. diff --git a/base_layer/common_types/Cargo.toml b/base_layer/common_types/Cargo.toml index 2b85d4e9a9..b70379c97f 100644 --- a/base_layer/common_types/Cargo.toml +++ b/base_layer/common_types/Cargo.toml @@ -11,4 +11,6 @@ futures = {version = "^0.3.1", features = ["async-await"] } rand = "0.8" tari_crypto = "0.11.1" serde = { version = "1.0.106", features = ["derive"] } -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +tokio = { version="^1.10", features = [ "time", "sync"] } +lazy_static = "1.4.0" +digest = "0.9.0" \ No newline at end of file diff --git a/base_layer/wallet/src/util/emoji.rs b/base_layer/common_types/src/emoji.rs similarity index 97% rename from base_layer/wallet/src/util/emoji.rs rename to base_layer/common_types/src/emoji.rs index 18ecdc174c..6d5b42aea8 100644 --- a/base_layer/wallet/src/util/emoji.rs +++ b/base_layer/common_types/src/emoji.rs @@ -20,12 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::util::luhn::{checksum, is_valid}; +use crate::{ + luhn::{checksum, is_valid}, + types::PublicKey, +}; use std::{ collections::HashMap, fmt::{Display, Error, Formatter}, }; -use tari_core::transactions::types::PublicKey; use tari_crypto::tari_utilities::{ hex::{Hex, HexError}, ByteArray, @@ -70,7 +72,7 @@ lazy_static! { /// # Example /// /// ``` -/// use tari_wallet::util::emoji::EmojiId; +/// use tari_common_types::emoji::EmojiId; /// /// assert!(EmojiId::is_valid("🐎🍴🌷🌟💻🐖🐩🐾🌟🐬🎧🐌🏦🐳🐎🐝🐢🔋👕🎸👿🍒🐓🎉💔🌹🏆🐬💡🎳🚦🍹🎒")); /// let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); @@ -170,8 +172,7 @@ pub struct EmojiIdError; #[cfg(test)] mod test { - use crate::util::emoji::EmojiId; - use tari_core::transactions::types::PublicKey; + use crate::{emoji::EmojiId, types::PublicKey}; use tari_crypto::tari_utilities::hex::Hex; #[test] diff --git a/base_layer/common_types/src/lib.rs b/base_layer/common_types/src/lib.rs index 03d3d25a62..df01ae7302 100644 --- a/base_layer/common_types/src/lib.rs +++ b/base_layer/common_types/src/lib.rs @@ -21,5 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. pub mod chain_metadata; +pub mod emoji; +pub mod luhn; pub mod types; pub mod waiting_requests; + +#[macro_use] +extern crate lazy_static; diff --git a/base_layer/wallet/src/util/luhn.rs b/base_layer/common_types/src/luhn.rs similarity index 98% rename from base_layer/wallet/src/util/luhn.rs rename to base_layer/common_types/src/luhn.rs index 9a9996ef72..3225b42ebe 100644 --- a/base_layer/wallet/src/util/luhn.rs +++ b/base_layer/common_types/src/luhn.rs @@ -45,7 +45,7 @@ pub fn is_valid(arr: &[usize], dict_len: usize) -> bool { #[cfg(test)] mod test { - use crate::util::luhn::*; + use crate::luhn::{checksum, is_valid}; #[test] fn luhn_6() { diff --git a/base_layer/common_types/src/types.rs b/base_layer/common_types/src/types.rs deleted file mode 100644 index 99c2789cbc..0000000000 --- a/base_layer/common_types/src/types.rs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2020. The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -pub const BLOCK_HASH_LENGTH: usize = 32; -pub type BlockHash = Vec; diff --git a/base_layer/common_types/src/types/bullet_rangeproofs.rs b/base_layer/common_types/src/types/bullet_rangeproofs.rs new file mode 100644 index 0000000000..a62dcd1228 --- /dev/null +++ b/base_layer/common_types/src/types/bullet_rangeproofs.rs @@ -0,0 +1,110 @@ +// Copyright 2019 The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::types::HashDigest; +use digest::Digest; +use serde::{ + de::{self, Visitor}, + Deserialize, + Deserializer, + Serialize, + Serializer, +}; +use std::fmt; +use tari_crypto::tari_utilities::{hex::*, ByteArray, ByteArrayError, Hashable}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct BulletRangeProof(pub Vec); +/// Implement the hashing function for RangeProof for use in the MMR +impl Hashable for BulletRangeProof { + fn hash(&self) -> Vec { + HashDigest::new().chain(&self.0).finalize().to_vec() + } +} + +impl ByteArray for BulletRangeProof { + fn to_vec(&self) -> Vec { + self.0.clone() + } + + fn from_vec(v: &Vec) -> Result { + Ok(BulletRangeProof { 0: v.clone() }) + } + + fn from_bytes(bytes: &[u8]) -> Result { + Ok(BulletRangeProof { 0: bytes.to_vec() }) + } + + fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +impl From> for BulletRangeProof { + fn from(v: Vec) -> Self { + BulletRangeProof(v) + } +} + +impl fmt::Display for BulletRangeProof { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.to_hex()) + } +} + +impl Serialize for BulletRangeProof { + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + if serializer.is_human_readable() { + self.to_hex().serialize(serializer) + } else { + serializer.serialize_bytes(self.as_bytes()) + } + } +} + +impl<'de> Deserialize<'de> for BulletRangeProof { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + struct RangeProofVisitor; + + impl<'de> Visitor<'de> for RangeProofVisitor { + type Value = BulletRangeProof; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a bulletproof range proof in binary format") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where E: de::Error { + BulletRangeProof::from_bytes(v).map_err(E::custom) + } + } + + if deserializer.is_human_readable() { + let s = String::deserialize(deserializer)?; + BulletRangeProof::from_hex(&s).map_err(de::Error::custom) + } else { + deserializer.deserialize_bytes(RangeProofVisitor) + } + } +} diff --git a/base_layer/common_types/src/types/mod.rs b/base_layer/common_types/src/types/mod.rs new file mode 100644 index 0000000000..e379d2bbac --- /dev/null +++ b/base_layer/common_types/src/types/mod.rs @@ -0,0 +1,81 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use tari_crypto::{ + common::Blake256, + ristretto::{ + pedersen::{PedersenCommitment, PedersenCommitmentFactory}, + RistrettoComSig, + RistrettoPublicKey, + RistrettoSchnorr, + RistrettoSecretKey, + }, +}; + +use tari_crypto::ristretto::dalek_range_proof::DalekRangeProofService; + +mod bullet_rangeproofs; + +pub use bullet_rangeproofs::BulletRangeProof; + +pub const BLOCK_HASH_LENGTH: usize = 32; +pub type BlockHash = Vec; + +/// Define the explicit Signature implementation for the Tari base layer. A different signature scheme can be +/// employed by redefining this type. +pub type Signature = RistrettoSchnorr; +/// Define the explicit Commitment Signature implementation for the Tari base layer. +pub type ComSignature = RistrettoComSig; + +/// Define the explicit Commitment implementation for the Tari base layer. +pub type Commitment = PedersenCommitment; +pub type CommitmentFactory = PedersenCommitmentFactory; + +/// Define the explicit Public key implementation for the Tari base layer +pub type PublicKey = RistrettoPublicKey; + +/// Define the explicit Secret key implementation for the Tari base layer. +pub type PrivateKey = RistrettoSecretKey; +pub type BlindingFactor = RistrettoSecretKey; + +/// Define the hash function that will be used to produce a signature challenge +pub type SignatureHasher = Blake256; + +/// Specify the Hash function for general hashing +pub type HashDigest = Blake256; + +/// Specify the digest type for signature challenges +pub type Challenge = Blake256; + +/// The type of output that `Challenge` produces +pub type MessageHash = Vec; + +/// Define the data type that is used to store results of `HashDigest` +pub type HashOutput = Vec; + +pub const MAX_RANGE_PROOF_RANGE: usize = 64; // 2^64 + +/// Specify the range proof type +pub type RangeProofService = DalekRangeProofService; + +/// Specify the range proof +pub type RangeProof = BulletRangeProof; diff --git a/base_layer/common_types/src/waiting_requests.rs b/base_layer/common_types/src/waiting_requests.rs index a26119a5cb..67e6eed6ef 100644 --- a/base_layer/common_types/src/waiting_requests.rs +++ b/base_layer/common_types/src/waiting_requests.rs @@ -20,10 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::oneshot::Sender as OneshotSender; use rand::RngCore; use std::{collections::HashMap, sync::Arc, time::Instant}; -use tokio::sync::RwLock; +use tokio::sync::{oneshot::Sender as OneshotSender, RwLock}; pub type RequestKey = u64; diff --git a/base_layer/core/Cargo.toml b/base_layer/core/Cargo.toml index 1212b2a40c..25e8ceb197 100644 --- a/base_layer/core/Cargo.toml +++ b/base_layer/core/Cargo.toml @@ -35,27 +35,28 @@ bincode = "1.1.4" bitflags = "1.0.4" blake2 = "^0.9.0" sha3 = "0.9" -bytes = "0.4.12" +bytes = "0.5" chrono = { version = "0.4.6", features = ["serde"]} croaring = { version = "=0.4.5", optional = true } digest = "0.9.0" -futures = {version = "^0.3.1", features = ["async-await"] } +futures = {version = "^0.3.16", features = ["async-await"] } fs2 = "0.3.0" hex = "0.4.2" +lazy_static = "1.4.0" lmdb-zero = "0.4.4" log = "0.4" monero = { version = "^0.13.0", features= ["serde_support"], optional = true } newtype-ops = "0.1.4" num = "0.3" -prost = "0.6.1" -prost-types = "0.6.1" +prost = "0.8.0" +prost-types = "0.8.0" rand = "0.8" randomx-rs = { version = "0.5.0", optional = true } serde = { version = "1.0.106", features = ["derive"] } serde_json = "1.0" strum_macros = "0.17.1" -thiserror = "1.0.20" -tokio = { version="^0.2", features = ["blocking", "time", "sync"] } +thiserror = "1.0.26" +tokio = { version="^1.10", features = [ "time", "sync", "macros"] } ttl_cache = "0.5.1" uint = { version = "0.9", default-features = false } num-format = "0.4.0" @@ -70,7 +71,6 @@ tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } config = { version = "0.9.3" } env_logger = "0.7.0" tempfile = "3.1.0" -tokio-macros = "0.2.4" [build-dependencies] tari_common = { version = "^0.9", path="../../common", features = ["build"]} diff --git a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs index 1310f22702..2700dc9d01 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/initializer.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/initializer.rs @@ -20,10 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{service::ChainMetadataService, LOG_TARGET}; +use super::service::ChainMetadataService; use crate::base_node::{chain_metadata_service::handle::ChainMetadataHandle, comms_interface::LocalNodeCommsInterface}; -use futures::{future, pin_mut}; -use log::*; use tari_comms::connectivity::ConnectivityRequester; use tari_p2p::services::liveness::LivenessHandle; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; @@ -40,15 +38,12 @@ impl ServiceInitializer for ChainMetadataServiceInitializer { let handle = ChainMetadataHandle::new(publisher.clone()); context.register_handle(handle); - context.spawn_when_ready(|handles| async move { + context.spawn_until_shutdown(|handles| { let liveness = handles.expect_handle::(); let base_node = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - let service_run = ChainMetadataService::new(liveness, base_node, connectivity, publisher).run(); - pin_mut!(service_run); - future::select(service_run, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "ChainMetadataService has shut down"); + ChainMetadataService::new(liveness, base_node, connectivity, publisher).run() }); Ok(()) diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index 6c7da96719..61110bc147 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -29,7 +29,6 @@ use crate::{ chain_storage::BlockAddResult, proto::base_node as proto, }; -use futures::stream::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use prost::Message; @@ -75,9 +74,9 @@ impl ChainMetadataService { /// Run the service pub async fn run(mut self) { - let mut liveness_event_stream = self.liveness.get_event_stream().fuse(); - let mut block_event_stream = self.base_node.get_block_event_stream().fuse(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut liveness_event_stream = self.liveness.get_event_stream(); + let mut block_event_stream = self.base_node.get_block_event_stream(); + let mut connectivity_events = self.connectivity.get_event_subscription(); log_if_error!( target: LOG_TARGET, @@ -86,47 +85,36 @@ impl ChainMetadataService { ); loop { - futures::select! { - block_event = block_event_stream.select_next_some() => { - if let Ok(block_event) = block_event { - log_if_error!( - level: debug, - target: LOG_TARGET, - "Failed to handle block event because '{}'", - self.handle_block_event(&block_event).await - ); - } + tokio::select! { + Ok(block_event) = block_event_stream.recv() => { + log_if_error!( + level: debug, + target: LOG_TARGET, + "Failed to handle block event because '{}'", + self.handle_block_event(&block_event).await + ); }, - liveness_event = liveness_event_stream.select_next_some() => { - if let Ok(event) = liveness_event { - log_if_error!( - target: LOG_TARGET, - "Failed to handle liveness event because '{}'", - self.handle_liveness_event(&*event).await - ); - } + Ok(event) = liveness_event_stream.recv() => { + log_if_error!( + target: LOG_TARGET, + "Failed to handle liveness event because '{}'", + self.handle_liveness_event(&*event).await + ); }, - event = connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event); - } - } - - complete => { - info!(target: LOG_TARGET, "ChainStateSyncService is exiting because all tasks have completed"); - break; + Ok(event) = connectivity_events.recv() => { + self.handle_connectivity_event(event); } } } } - fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { use ConnectivityEvent::*; match event { PeerDisconnected(node_id) | ManagedPeerDisconnected(node_id) | PeerBanned(node_id) => { - if let Some(pos) = self.peer_chain_metadata.iter().position(|p| &p.node_id == node_id) { + if let Some(pos) = self.peer_chain_metadata.iter().position(|p| p.node_id == node_id) { debug!( target: LOG_TARGET, "Removing disconnected/banned peer `{}` from chain metadata list ", node_id @@ -164,7 +152,7 @@ impl ChainMetadataService { async fn handle_liveness_event(&mut self, event: &LivenessEvent) -> Result<(), ChainMetadataSyncError> { match event { - // Received a ping, check if our neighbour sent it and it contains ChainMetadata + // Received a ping, check if it contains ChainMetadata LivenessEvent::ReceivedPing(event) => { trace!( target: LOG_TARGET, @@ -172,6 +160,7 @@ impl ChainMetadataService { event.node_id ); self.collect_chain_state_from_ping(&event.node_id, &event.metadata)?; + self.send_chain_metadata_to_event_publisher().await?; }, // Received a pong, check if our neighbour sent it and it contains ChainMetadata LivenessEvent::ReceivedPong(event) => { @@ -181,11 +170,7 @@ impl ChainMetadataService { event.node_id ); self.collect_chain_state_from_pong(&event.node_id, &event.metadata)?; - - // All peers have responded in this round, send the chain metadata to the base node service - if self.peer_chain_metadata.len() >= self.peer_chain_metadata.capacity() { - self.flush_chain_metadata_to_event_publisher().await?; - } + self.send_chain_metadata_to_event_publisher().await?; }, // New ping round has begun LivenessEvent::PingRoundBroadcast(num_peers) => { @@ -193,11 +178,9 @@ impl ChainMetadataService { target: LOG_TARGET, "New chain metadata round sent to {} peer(s)", num_peers ); - // If we have chain metadata to send to the base node service, send them now - // because the next round of pings is happening. - self.flush_chain_metadata_to_event_publisher().await?; // Ensure that we're waiting for the correct amount of peers to respond // and have allocated space for their replies + self.resize_chainstate_buffer(*num_peers); }, } @@ -205,13 +188,13 @@ impl ChainMetadataService { Ok(()) } - async fn flush_chain_metadata_to_event_publisher(&mut self) -> Result<(), ChainMetadataSyncError> { - let chain_metadata = self.peer_chain_metadata.drain(..).collect::>(); - + async fn send_chain_metadata_to_event_publisher(&mut self) -> Result<(), ChainMetadataSyncError> { // send only fails if there are no subscribers. let _ = self .event_publisher - .send(Arc::new(ChainMetadataEvent::PeerChainMetadataReceived(chain_metadata))); + .send(Arc::new(ChainMetadataEvent::PeerChainMetadataReceived( + self.peer_chain_metadata.clone(), + ))); Ok(()) } @@ -289,7 +272,6 @@ impl ChainMetadataService { self.peer_chain_metadata .push(PeerChainMetadata::new(node_id.clone(), chain_metadata)); - Ok(()) } } @@ -298,6 +280,7 @@ impl ChainMetadataService { mod test { use super::*; use crate::base_node::comms_interface::{CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}; + use futures::StreamExt; use std::convert::TryInto; use tari_comms::test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, @@ -361,7 +344,7 @@ mod test { ) } - #[tokio_macros::test] + #[tokio::test] async fn update_liveness_chain_metadata() { let (mut service, liveness_mock_state, _, mut base_node_receiver) = setup(); @@ -370,11 +353,11 @@ mod test { let chain_metadata = proto_chain_metadata.clone().try_into().unwrap(); task::spawn(async move { - let base_node_req = base_node_receiver.select_next_some().await; - let (_req, reply_tx) = base_node_req.split(); - reply_tx - .send(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) - .unwrap(); + if let Some(base_node_req) = base_node_receiver.next().await { + base_node_req + .reply(Ok(NodeCommsResponse::ChainMetadata(chain_metadata))) + .unwrap(); + } }); service.update_liveness_chain_metadata().await.unwrap(); @@ -387,7 +370,7 @@ mod test { let chain_metadata = proto::ChainMetadata::decode(data.as_slice()).unwrap(); assert_eq!(chain_metadata.height_of_longest_chain, Some(123)); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_ok() { let (mut service, _, _, _) = setup(); @@ -416,7 +399,7 @@ mod test { ); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_banned_peer() { let (mut service, _, _, _) = setup(); @@ -442,7 +425,7 @@ mod test { .peer_chain_metadata .iter() .any(|p| &p.node_id == nodes[0].node_id())); - service.handle_connectivity_event(&ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); + service.handle_connectivity_event(ConnectivityEvent::PeerBanned(nodes[0].node_id().clone())); // Check that banned peer was removed assert!(service .peer_chain_metadata @@ -450,7 +433,7 @@ mod test { .all(|p| &p.node_id != nodes[0].node_id())); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_no_metadata() { let (mut service, _, _, _) = setup(); @@ -468,7 +451,7 @@ mod test { assert_eq!(service.peer_chain_metadata.len(), 0); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_liveness_event_bad_metadata() { let (mut service, _, _, _) = setup(); diff --git a/base_layer/core/src/base_node/comms_interface/comms_request.rs b/base_layer/core/src/base_node/comms_interface/comms_request.rs index 2eec332b58..eef287d8f1 100644 --- a/base_layer/core/src/base_node/comms_interface/comms_request.rs +++ b/base_layer/core/src/base_node/comms_interface/comms_request.rs @@ -20,14 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - blocks::NewBlockTemplate, - chain_storage::MmrTree, - proof_of_work::PowAlgorithm, - transactions::types::{Commitment, HashOutput, Signature}, -}; +use crate::{blocks::NewBlockTemplate, chain_storage::MmrTree, proof_of_work::PowAlgorithm}; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Error, Formatter}; +use tari_common_types::types::{Commitment, HashOutput, Signature}; use tari_crypto::tari_utilities::hex::Hex; /// A container for the parameters required for a FetchMmrState request. diff --git a/base_layer/core/src/base_node/comms_interface/comms_response.rs b/base_layer/core/src/base_node/comms_interface/comms_response.rs index e275dc9c5a..8f7ec1b9e5 100644 --- a/base_layer/core/src/base_node/comms_interface/comms_response.rs +++ b/base_layer/core/src/base_node/comms_interface/comms_response.rs @@ -24,14 +24,11 @@ use crate::{ blocks::{block_header::BlockHeader, Block, NewBlockTemplate}, chain_storage::HistoricalBlock, proof_of_work::Difficulty, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::HashOutput, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use serde::{Deserialize, Serialize}; use std::fmt::{self, Display, Formatter}; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::HashOutput}; /// API Response enum #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs index 21ebc49201..e0d2c1268a 100644 --- a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs +++ b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs @@ -34,7 +34,7 @@ use crate::{ consensus::{ConsensusConstants, ConsensusManager}, mempool::{async_mempool, Mempool}, proof_of_work::{Difficulty, PowAlgorithm}, - transactions::{transaction::TransactionKernel, types::HashOutput}, + transactions::transaction::TransactionKernel, }; use log::*; use std::{ @@ -42,7 +42,7 @@ use std::{ sync::Arc, }; use strum_macros::Display; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, HashOutput}; use tari_comms::peer_manager::NodeId; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; use tokio::sync::Semaphore; diff --git a/base_layer/core/src/base_node/comms_interface/local_interface.rs b/base_layer/core/src/base_node/comms_interface/local_interface.rs index a0f5bcf2c3..0a270a78e2 100644 --- a/base_layer/core/src/base_node/comms_interface/local_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/local_interface.rs @@ -31,10 +31,7 @@ use crate::{ blocks::{Block, BlockHeader, NewBlockTemplate}, chain_storage::HistoricalBlock, proof_of_work::PowAlgorithm, - transactions::{ - transaction::TransactionKernel, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::TransactionKernel, }; use std::sync::Arc; use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; @@ -47,6 +44,7 @@ use crate::{ base_node::comms_interface::comms_request::GetNewBlockTemplateRequest, transactions::transaction::TransactionOutput, }; +use tari_common_types::types::{Commitment, HashOutput, Signature}; /// The InboundNodeCommsInterface provides an interface to request information from the current local node by other /// internal services. diff --git a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs index 753fe802d8..433b5325de 100644 --- a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs @@ -24,13 +24,16 @@ use crate::{ base_node::comms_interface::{error::CommsInterfaceError, NodeCommsRequest, NodeCommsResponse}, blocks::{block_header::BlockHeader, NewBlock}, chain_storage::HistoricalBlock, - transactions::{transaction::TransactionOutput, types::HashOutput}, + transactions::transaction::TransactionOutput, }; -use futures::channel::mpsc::UnboundedSender; use log::*; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, HashOutput}, +}; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::bn::comms_interface::outbound_interface"; @@ -234,10 +237,8 @@ impl OutboundNodeCommsInterface { new_block: NewBlock, exclude_peers: Vec, ) -> Result<(), CommsInterfaceError> { - self.block_sender - .unbounded_send((new_block, exclude_peers)) - .map_err(|err| { - CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) - }) + self.block_sender.send((new_block, exclude_peers)).map_err(|err| { + CommsInterfaceError::InternalChannelError(format!("Failed to send on block_sender: {}", err)) + }) } } diff --git a/base_layer/core/src/base_node/proto/request.rs b/base_layer/core/src/base_node/proto/request.rs index 45195f20b8..bf766bffc2 100644 --- a/base_layer/core/src/base_node/proto/request.rs +++ b/base_layer/core/src/base_node/proto/request.rs @@ -32,9 +32,9 @@ use crate::{ HashOutputs, }, }, - transactions::types::{Commitment, HashOutput, Signature}, }; use std::convert::{From, TryFrom, TryInto}; +use tari_common_types::types::{Commitment, HashOutput, Signature}; use tari_crypto::tari_utilities::ByteArrayError; //---------------------------------- BaseNodeRequest --------------------------------------------// diff --git a/base_layer/core/src/base_node/proto/wallet_rpc.rs b/base_layer/core/src/base_node/proto/wallet_rpc.rs index 94f2f2d7f6..3183128b55 100644 --- a/base_layer/core/src/base_node/proto/wallet_rpc.rs +++ b/base_layer/core/src/base_node/proto/wallet_rpc.rs @@ -23,7 +23,6 @@ use crate::{ crypto::tari_utilities::ByteArrayError, proto::{base_node as proto, types}, - transactions::types::Signature, }; use serde::{Deserialize, Serialize}; @@ -31,7 +30,7 @@ use std::{ convert::TryFrom, fmt::{Display, Error, Formatter}, }; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, Signature}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TxSubmissionResponse { diff --git a/base_layer/core/src/base_node/rpc/service.rs b/base_layer/core/src/base_node/rpc/service.rs index c50600ea9c..dbd1b141e4 100644 --- a/base_layer/core/src/base_node/rpc/service.rs +++ b/base_layer/core/src/base_node/rpc/service.rs @@ -40,9 +40,10 @@ use crate::{ }, types::{Signature as SignatureProto, Transaction as TransactionProto}, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use std::convert::TryFrom; +use tari_common_types::types::Signature; use tari_comms::protocol::rpc::{Request, Response, RpcStatus}; const LOG_TARGET: &str = "c::base_node::rpc"; @@ -230,7 +231,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpc // Determine if we are synced let status_watch = state_machine.get_status_info_watch(); - let is_synced = match (*status_watch.borrow()).state_info { + let is_synced = match status_watch.borrow().state_info { StateInfo::Listening(li) => li.is_synced(), _ => false, }; diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index ae6be0b519..11132db138 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -33,7 +33,7 @@ use crate::{ proto as shared_protos, proto::base_node as proto, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -50,7 +50,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::service::initializer"; const SUBSCRIPTION_LABEL: &str = "Base Node"; @@ -151,7 +151,7 @@ where T: BlockchainBackend + 'static let inbound_block_stream = self.inbound_block_stream(); // Connect InboundNodeCommsInterface and OutboundNodeCommsInterface to BaseNodeService let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); - let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded(); + let (outbound_block_sender_service, outbound_block_stream) = mpsc::unbounded_channel(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (local_block_sender_service, local_block_stream) = reply_channel::unbounded(); let outbound_nci = diff --git a/base_layer/core/src/base_node/service/service.rs b/base_layer/core/src/base_node/service/service.rs index db96b6ed9e..1d66cbf1b1 100644 --- a/base_layer/core/src/base_node/service/service.rs +++ b/base_layer/core/src/base_node/service/service.rs @@ -38,16 +38,7 @@ use crate::{ proto as shared_protos, proto::{base_node as proto, base_node::base_node_service_request::Request}, }; -use futures::{ - channel::{ - mpsc::{channel, Receiver, Sender, UnboundedReceiver}, - oneshot::Sender as OneshotSender, - }, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -64,7 +55,14 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::reply_channel::RequestContext; -use tokio::task; +use tokio::{ + sync::{ + mpsc, + mpsc::{Receiver, Sender, UnboundedReceiver}, + oneshot::Sender as OneshotSender, + }, + task, +}; const LOG_TARGET: &str = "c::bn::base_node_service::service"; @@ -134,7 +132,7 @@ where B: BlockchainBackend + 'static config: BaseNodeServiceConfig, state_machine_handle: StateMachineHandle, ) -> Self { - let (timeout_sender, timeout_receiver) = channel(100); + let (timeout_sender, timeout_receiver) = mpsc::channel(100); Self { outbound_message_service, inbound_nch, @@ -162,7 +160,7 @@ where B: BlockchainBackend + 'static { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let outbound_block_stream = streams.outbound_block_stream.fuse(); + let outbound_block_stream = streams.outbound_block_stream; pin_mut!(outbound_block_stream); let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); @@ -177,53 +175,52 @@ where B: BlockchainBackend + 'static let timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Base Node Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Base Node Service initialized without timeout_receiver_stream"); pin_mut!(timeout_receiver_stream); loop { - futures::select! { + tokio::select! { // Outbound request messages from the OutboundNodeCommsInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound block messages from the OutboundNodeCommsInterface - (block, excluded_peers) = outbound_block_stream.select_next_some() => { + Some((block, excluded_peers)) = outbound_block_stream.recv() => { self.spawn_handle_outbound_block(block, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, // Incoming block messages from the Comms layer - block_msg = inbound_block_stream.select_next_some() => { + Some(block_msg) = inbound_block_stream.next() => { self.spawn_handle_incoming_block(block_msg).await; } // Incoming local request messages from the LocalNodeCommsInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Incoming local block messages from the LocalNodeCommsInterface e.g. miner and block sync - local_block_context = local_block_stream.select_next_some() => { + Some(local_block_context) = local_block_stream.next() => { self.spawn_handle_local_block(local_block_context); }, - complete => { - info!(target: LOG_TARGET, "Base Node service shutting down"); + else => { + info!(target: LOG_TARGET, "Base Node service shutting down because all streams ended"); break; } } @@ -646,9 +643,9 @@ async fn handle_request_timeout( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/base_node/state_machine_service/initializer.rs b/base_layer/core/src/base_node/state_machine_service/initializer.rs index a6d4c73a0c..c58d62000f 100644 --- a/base_layer/core/src/base_node/state_machine_service/initializer.rs +++ b/base_layer/core/src/base_node/state_machine_service/initializer.rs @@ -20,6 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use log::*; +use tokio::sync::{broadcast, watch}; + +use tari_comms::{connectivity::ConnectivityRequester, PeerManager}; +use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; + use crate::{ base_node::{ chain_metadata_service::ChainMetadataHandle, @@ -35,13 +43,8 @@ use crate::{ chain_storage::{async_db::AsyncBlockchainDb, BlockchainBackend}, consensus::ConsensusManager, proof_of_work::randomx_factory::RandomXFactory, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; -use log::*; -use std::sync::Arc; -use tari_comms::{connectivity::ConnectivityRequester, PeerManager}; -use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::sync::{broadcast, watch}; const LOG_TARGET: &str = "c::bn::state_machine_service::initializer"; @@ -98,7 +101,8 @@ where B: BlockchainBackend + 'static let connectivity = handles.expect_handle::(); let peer_manager = handles.expect_handle::>(); - let sync_validators = SyncValidators::full_consensus(rules.clone(), factories); + let sync_validators = + SyncValidators::full_consensus(rules.clone(), factories, config.bypass_range_proof_verification); let max_randomx_vms = config.max_randomx_vms; let node = BaseNodeStateMachine::new( diff --git a/base_layer/core/src/base_node/state_machine_service/state_machine.rs b/base_layer/core/src/base_node/state_machine_service/state_machine.rs index 3c8e33ee06..ec99ccc989 100644 --- a/base_layer/core/src/base_node/state_machine_service/state_machine.rs +++ b/base_layer/core/src/base_node/state_machine_service/state_machine.rs @@ -52,6 +52,7 @@ pub struct BaseNodeStateMachineConfig { pub pruning_horizon: u64, pub max_randomx_vms: usize, pub blocks_behind_before_considered_lagging: u64, + pub bypass_range_proof_verification: bool, } /// A Tari full node, aka Base Node. @@ -158,7 +159,7 @@ impl BaseNodeStateMachine { state_info: self.info.clone(), }; - if let Err(e) = self.status_event_sender.broadcast(status) { + if let Err(e) = self.status_event_sender.send(status) { debug!(target: LOG_TARGET, "Error broadcasting a StatusEvent update: {}", e); } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs index 689c9a8316..5c7710c371 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/block_sync.rs @@ -67,7 +67,7 @@ impl BlockSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSyncStarting, }); @@ -80,7 +80,7 @@ impl BlockSync { false.into(), )); - let _ = status_event_sender.broadcast(StatusInfo { + let _ = status_event_sender.send(StatusInfo { bootstrapped, state_info: StateInfo::BlockSync(BlockSyncInfo { tip_height: remote_tip_height, diff --git a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs index 4e864624bc..0ef1f6e99b 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/events_and_states.rs @@ -160,7 +160,7 @@ impl Display for BaseNodeState { #[derive(Debug, Clone, PartialEq)] pub enum StateInfo { StartUp, - HeaderSync(BlockSyncInfo), + HeaderSync(Option), HorizonSync(HorizonSyncInfo), BlockSyncStarting, BlockSync(BlockSyncInfo), @@ -169,15 +169,12 @@ pub enum StateInfo { impl StateInfo { pub fn short_desc(&self) -> String { + use StateInfo::*; match self { - Self::StartUp => "Starting up".to_string(), - Self::HeaderSync(info) => format!( - "Syncing headers: {}/{} ({:.0}%)", - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 - ), - Self::HorizonSync(info) => match info.status { + StartUp => "Starting up".to_string(), + HeaderSync(None) => "Starting header sync".to_string(), + HeaderSync(Some(info)) => format!("Syncing headers: {}", info.sync_progress_string()), + HorizonSync(info) => match info.status { HorizonSyncStatus::Starting => "Starting horizon sync".to_string(), HorizonSyncStatus::Kernels(current, total) => format!( "Syncing kernels: {}/{} ({:.0}%)", @@ -193,18 +190,16 @@ impl StateInfo { ), HorizonSyncStatus::Finalizing => "Finalizing horizon sync".to_string(), }, - Self::BlockSync(info) => format!( - "Syncing blocks with {}: {}/{} ({:.0}%) ", + BlockSync(info) => format!( + "Syncing blocks: ({}) {}", info.sync_peers .first() - .map(|s| s.short_str()) + .map(|n| n.short_str()) .unwrap_or_else(|| "".to_string()), - info.local_height, - info.tip_height, - info.local_height as f64 / info.tip_height as f64 * 100.0 + info.sync_progress_string() ), - Self::Listening(_) => "Listening".to_string(), - Self::BlockSyncStarting => "Starting block sync".to_string(), + Listening(_) => "Listening".to_string(), + BlockSyncStarting => "Starting block sync".to_string(), } } @@ -226,13 +221,15 @@ impl StateInfo { impl Display for StateInfo { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + use StateInfo::*; match self { - Self::StartUp => write!(f, "Node starting up"), - Self::HeaderSync(info) => write!(f, "Synchronizing block headers: {}", info), - Self::HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), - Self::BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), - Self::Listening(info) => write!(f, "Listening: {}", info), - Self::BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), + StartUp => write!(f, "Node starting up"), + HeaderSync(Some(info)) => write!(f, "Synchronizing block headers: {}", info), + HeaderSync(None) => write!(f, "Synchronizing block headers: Starting"), + HorizonSync(info) => write!(f, "Synchronizing horizon state: {}", info), + BlockSync(info) => write!(f, "Synchronizing blocks: {}", info), + Listening(info) => write!(f, "Listening: {}", info), + BlockSyncStarting => write!(f, "Synchronizing blocks: Starting"), } } } @@ -282,15 +279,24 @@ impl BlockSyncInfo { sync_peers, } } + + pub fn sync_progress_string(&self) -> String { + format!( + "{}/{} ({:.0}%)", + self.local_height, + self.tip_height, + (self.local_height as f64 / self.tip_height as f64 * 100.0) + ) + } } impl Display for BlockSyncInfo { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - fmt.write_str("Syncing from the following peers: \n")?; + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Syncing from the following peers:")?; for peer in &self.sync_peers { - fmt.write_str(&format!("{}\n", peer))?; + writeln!(f, "{}", peer)?; } - fmt.write_str(&format!("Syncing {}/{}\n", self.local_height, self.tip_height)) + writeln!(f, "Syncing {}", self.sync_progress_string()) } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs index 2acfd53206..68f663d71f 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/header_sync.rs @@ -74,14 +74,15 @@ impl HeaderSync { let status_event_sender = shared.status_event_sender.clone(); let bootstrapped = shared.is_bootstrapped(); - synchronizer.on_progress(move |current_height, remote_tip_height, sync_peers| { - let _ = status_event_sender.broadcast(StatusInfo { + synchronizer.on_progress(move |details, sync_peers| { + let details = details.map(|(current_height, remote_tip_height)| BlockSyncInfo { + tip_height: remote_tip_height, + local_height: current_height, + sync_peers: sync_peers.to_vec(), + }); + let _ = status_event_sender.send(StatusInfo { bootstrapped, - state_info: StateInfo::HeaderSync(BlockSyncInfo { - tip_height: remote_tip_height, - local_height: current_height, - sync_peers: sync_peers.to_vec(), - }), + state_info: StateInfo::HeaderSync(details), }); }); diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs index 2fe036b19e..c1b8412028 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync.rs @@ -25,26 +25,27 @@ // TODO: Move the horizon synchronizer to the `sync` module -mod config; - -pub use self::config::HorizonSyncConfig; - -mod error; +use log::*; pub use error::HorizonSyncError; - -mod horizon_state_synchronization; - use horizon_state_synchronization::HorizonStateSynchronization; +use tari_comms::PeerConnection; + +use crate::{base_node::BaseNodeStateMachine, chain_storage::BlockchainBackend, transactions::CryptoFactories}; use super::{ events_and_states::{HorizonSyncInfo, HorizonSyncStatus}, StateEvent, StateInfo, }; -use crate::{base_node::BaseNodeStateMachine, chain_storage::BlockchainBackend, transactions::types::CryptoFactories}; -use log::*; -use tari_comms::PeerConnection; + +pub use self::config::HorizonSyncConfig; + +mod config; + +mod error; + +mod horizon_state_synchronization; const LOG_TARGET: &str = "c::bn::state_machine_service::states::horizon_state_sync"; diff --git a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs index 8347bcf521..3511ade206 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/horizon_state_sync/horizon_state_synchronization.rs @@ -39,10 +39,7 @@ use crate::{ SyncUtxosRequest, SyncUtxosResponse, }, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{HashDigest, RangeProofService}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use futures::StreamExt; @@ -51,6 +48,7 @@ use std::{ convert::{TryFrom, TryInto}, sync::Arc, }; +use tari_common_types::types::{HashDigest, RangeProofService}; use tari_comms::PeerConnection; use tari_crypto::{ commitment::HomomorphicCommitment, diff --git a/base_layer/core/src/base_node/state_machine_service/states/listening.rs b/base_layer/core/src/base_node/state_machine_service/states/listening.rs index 0ea8157568..92349c269a 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/listening.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/listening.rs @@ -31,7 +31,6 @@ use crate::{ }, chain_storage::BlockchainBackend, }; -use futures::StreamExt; use log::*; use num_format::{Locale, ToFormattedString}; use serde::{Deserialize, Serialize}; @@ -118,7 +117,8 @@ impl Listening { info!(target: LOG_TARGET, "Listening for chain metadata updates"); shared.set_state_info(StateInfo::Listening(ListeningInfo::new(self.is_synced))); - while let Some(metadata_event) = shared.metadata_event_stream.next().await { + loop { + let metadata_event = shared.metadata_event_stream.recv().await; match metadata_event.as_ref().map(|v| v.deref()) { Ok(ChainMetadataEvent::PeerChainMetadataReceived(peer_metadata_list)) => { let mut peer_metadata_list = peer_metadata_list.clone(); @@ -199,16 +199,16 @@ impl Listening { if !self.is_synced { self.is_synced = true; + shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); debug!(target: LOG_TARGET, "Initial sync achieved"); } - shared.set_state_info(StateInfo::Listening(ListeningInfo::new(true))); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { debug!(target: LOG_TARGET, "Metadata event subscriber lagged by {} item(s)", n); }, - Err(broadcast::RecvError::Closed) => { - // This should never happen because the while loop exits when the stream ends + Err(broadcast::error::RecvError::Closed) => { debug!(target: LOG_TARGET, "Metadata event subscriber closed"); + break; }, } } diff --git a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs index 7ea2e7e2b0..aeaa5ab430 100644 --- a/base_layer/core/src/base_node/state_machine_service/states/waiting.rs +++ b/base_layer/core/src/base_node/state_machine_service/states/waiting.rs @@ -23,7 +23,7 @@ use crate::base_node::state_machine_service::states::{BlockSync, HeaderSync, HorizonStateSync, StateEvent}; use log::info; use std::time::Duration; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "c::bn::state_machine_service::states::waiting"; @@ -41,7 +41,7 @@ impl Waiting { "The base node has started a WAITING state for {} seconds", self.timeout.as_secs() ); - delay_for(self.timeout).await; + sleep(self.timeout).await; info!( target: LOG_TARGET, "The base node waiting state has completed. Resuming normal operations" diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index b7483ff70e..fdd1ab6d64 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -32,7 +32,6 @@ use crate::{ base_node::{FindChainSplitRequest, SyncHeadersRequest}, }, tari_utilities::{hex::Hex, Hashable}, - transactions::types::HashOutput, validation::ValidationError, }; use futures::{future, stream::FuturesUnordered, StreamExt}; @@ -42,6 +41,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use tari_common_types::types::HashOutput; use tari_comms::{ connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, peer_manager::NodeId, @@ -83,7 +83,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } pub fn on_progress(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.hooks.add_on_progress_header_hook(hook); } @@ -94,6 +94,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { pub async fn synchronize(&mut self) -> Result { debug!(target: LOG_TARGET, "Starting header sync.",); + self.hooks.call_on_progress_header_hooks(None, self.sync_peers); let sync_peers = self.select_sync_peers().await?; info!( target: LOG_TARGET, @@ -261,7 +262,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { let latency = client.get_last_request_latency().await?; debug!( target: LOG_TARGET, - "Initiating header sync with peer `{}` (latency = {}ms)", + "Initiating header sync with peer `{}` (sync latency = {}ms)", conn.peer_node_id(), latency.unwrap_or_default().as_millis() ); @@ -272,6 +273,10 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { // We're ahead of this peer, try another peer if possible SyncStatus::Ahead => Err(BlockHeaderSyncError::NotInSync), SyncStatus::Lagging(split_info) => { + self.hooks.call_on_progress_header_hooks( + Some((split_info.local_tip_header.height(), split_info.remote_tip_height)), + self.sync_peers, + ); self.synchronize_headers(&peer, &mut client, *split_info).await?; Ok(()) }, @@ -482,46 +487,53 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { ) -> Result<(), BlockHeaderSyncError> { const COMMIT_EVERY_N_HEADERS: usize = 1000; - // Peer returned no more than the max headers. This indicates that there are no further headers to request. - if self.header_validator.valid_headers().len() <= NUM_INITIAL_HEADERS_TO_REQUEST as usize { - debug!(target: LOG_TARGET, "No further headers to download"); - if !self.pending_chain_has_higher_pow(&split_info.local_tip_header)? { - return Err(BlockHeaderSyncError::WeakerChain); - } + let mut has_switched_to_new_chain = false; + let pending_len = self.header_validator.valid_headers().len(); + // Find the hash to start syncing the rest of the headers. + // The expectation cannot fail because there has been at least one valid header returned (checked in + // determine_sync_status) + let (start_header_height, start_header_hash) = self + .header_validator + .current_valid_chain_tip_header() + .map(|h| (h.height(), h.hash().clone())) + .expect("synchronize_headers: expected there to be a valid tip header but it was None"); + + // If we already have a stronger chain at this point, switch over to it. + // just in case we happen to be exactly NUM_INITIAL_HEADERS_TO_REQUEST headers behind. + let has_better_pow = self.pending_chain_has_higher_pow(&split_info.local_tip_header)?; + if has_better_pow { debug!( target: LOG_TARGET, "Remote chain from peer {} has higher PoW. Switching", peer ); - // PoW is higher, switching over to the new chain self.switch_to_pending_chain(&split_info).await?; + has_switched_to_new_chain = true; + } + + if pending_len < NUM_INITIAL_HEADERS_TO_REQUEST as usize { + // Peer returned less than the number of requested headers. This indicates that we have all the available + // headers. + debug!(target: LOG_TARGET, "No further headers to download"); + if !has_better_pow { + return Err(BlockHeaderSyncError::WeakerChain); + } return Ok(()); } - // Find the hash to start syncing the rest of the headers. - // The expectation cannot fail because the number of headers has been checked in determine_sync_status - let start_header = - self.header_validator.valid_headers().last().expect( - "synchronize_headers: expected there to be at least one valid pending header but there were none", - ); - debug!( target: LOG_TARGET, - "Download remaining headers starting from header #{} from peer `{}`", - start_header.height(), - peer + "Download remaining headers starting from header #{} from peer `{}`", start_header_height, peer ); let request = SyncHeadersRequest { - start_hash: start_header.hash().clone(), + start_hash: start_header_hash, // To the tip! count: 0, }; let mut header_stream = client.sync_headers(request).await?; - debug!(target: LOG_TARGET, "Reading headers from peer `{}`", peer); - - let mut has_switched_to_new_chain = false; + debug!(target: LOG_TARGET, "Reading headers from peer `{}`", peer,); while let Some(header) = header_stream.next().await { let header = BlockHeader::try_from(header?).map_err(BlockHeaderSyncError::ReceivedInvalidHeader)?; @@ -563,7 +575,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { } self.hooks - .call_on_progress_header_hooks(current_height, split_info.remote_tip_height, self.sync_peers); + .call_on_progress_header_hooks(Some((current_height, split_info.remote_tip_height)), self.sync_peers); } if !has_switched_to_new_chain { diff --git a/base_layer/core/src/base_node/sync/header_sync/validator.rs b/base_layer/core/src/base_node/sync/header_sync/validator.rs index aff52b80fd..3c11c9a2d5 100644 --- a/base_layer/core/src/base_node/sync/header_sync/validator.rs +++ b/base_layer/core/src/base_node/sync/header_sync/validator.rs @@ -35,7 +35,6 @@ use crate::{ consensus::ConsensusManager, proof_of_work::{randomx_factory::RandomXFactory, PowAlgorithm}, tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}, - transactions::types::HashOutput, validation::helpers::{ check_header_timestamp_greater_than_median, check_pow_data, @@ -45,6 +44,7 @@ use crate::{ }; use log::*; use std::cmp::Ordering; +use tari_common_types::types::HashOutput; const LOG_TARGET: &str = "c::bn::header_sync"; @@ -115,6 +115,10 @@ impl BlockHeaderSyncValidator { Ok(()) } + pub fn current_valid_chain_tip_header(&self) -> Option<&ChainHeader> { + self.valid_headers().last() + } + pub fn validate(&mut self, header: BlockHeader) -> Result<(), BlockHeaderSyncError> { let state = self.state(); let expected_height = state.current_height + 1; @@ -283,7 +287,7 @@ mod test { mod initialize_state { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_initializes_state_to_given_header() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(&tip.header().hash()).await.unwrap(); @@ -295,7 +299,7 @@ mod test { assert_eq!(state.current_height, 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_if_hash_does_not_exist() { let (mut validator, _) = setup(); let start_hash = vec![0; 32]; @@ -308,7 +312,7 @@ mod test { mod validate { use super::*; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_passes_if_headers_are_valid() { let (mut validator, _, tip) = setup_with_headers(1).await; validator.initialize_state(tip.hash()).await.unwrap(); @@ -322,7 +326,7 @@ mod test { assert_eq!(validator.valid_headers().len(), 2); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_fails_if_height_is_not_serial() { let (mut validator, _, tip) = setup_with_headers(2).await; validator.initialize_state(tip.hash()).await.unwrap(); diff --git a/base_layer/core/src/base_node/sync/hooks.rs b/base_layer/core/src/base_node/sync/hooks.rs index 71f0802926..d1e2628822 100644 --- a/base_layer/core/src/base_node/sync/hooks.rs +++ b/base_layer/core/src/base_node/sync/hooks.rs @@ -28,7 +28,7 @@ use tari_comms::peer_manager::NodeId; #[derive(Default)] pub(super) struct Hooks { - on_progress_header: Vec>, + on_progress_header: Vec, &[NodeId]) + Send + Sync>>, on_progress_block: Vec, u64, &[NodeId]) + Send + Sync>>, on_complete: Vec) + Send + Sync>>, on_rewind: Vec>) + Send + Sync>>, @@ -36,14 +36,14 @@ pub(super) struct Hooks { impl Hooks { pub fn add_on_progress_header_hook(&mut self, hook: H) - where H: FnMut(u64, u64, &[NodeId]) + Send + Sync + 'static { + where H: FnMut(Option<(u64, u64)>, &[NodeId]) + Send + Sync + 'static { self.on_progress_header.push(Box::new(hook)); } - pub fn call_on_progress_header_hooks(&mut self, height: u64, remote_tip_height: u64, sync_peers: &[NodeId]) { + pub fn call_on_progress_header_hooks(&mut self, height_vs_remote: Option<(u64, u64)>, sync_peers: &[NodeId]) { self.on_progress_header .iter_mut() - .for_each(|f| (*f)(height, remote_tip_height, sync_peers)); + .for_each(|f| (*f)(height_vs_remote, sync_peers)); } pub fn add_on_progress_block_hook(&mut self, hook: H) diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index 776c2b7e01..e9df7073a2 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -35,12 +35,14 @@ use crate::{ SyncUtxosResponse, }, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::cmp; -use tari_comms::protocol::rpc::{Request, Response, RpcStatus, Streaming}; +use tari_comms::{ + protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, +}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::task; +use tokio::{sync::mpsc, task}; use tracing::{instrument, span, Instrument, Level}; const LOG_TARGET: &str = "c::base_node::sync_rpc"; @@ -116,7 +118,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ // Number of blocks to load and push to the stream before loading the next batch const BATCH_SIZE: usize = 4; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let span = span!(Level::TRACE, "sync_rpc::block_sync::inner_worker"); task::spawn( @@ -138,19 +140,16 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(blocks) => { - let mut blocks = stream::iter( - blocks - .into_iter() - .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) - .map(|block| match block { - Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), - Err(err) => Err(err), - }) - .map(Ok), - ); + let blocks = blocks + .into_iter() + .map(|hb| hb.try_into_block().map_err(RpcStatus::log_internal_error(LOG_TARGET))) + .map(|block| match block { + Ok(b) => Ok(proto::base_node::BlockBodyResponse::from(b)), + Err(err) => Err(err), + }); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut blocks).await.is_err() { + if utils::mpsc::send_all(&tx, blocks).await.is_err() { break; } }, @@ -209,7 +208,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ chunk_size ); - let (mut tx, rx) = mpsc::channel(chunk_size); + let (tx, rx) = mpsc::channel(chunk_size); let span = span!(Level::TRACE, "sync_rpc::sync_headers::inner_worker"); task::spawn( async move { @@ -233,10 +232,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(headers) => { - let mut headers = - stream::iter(headers.into_iter().map(proto::core::BlockHeader::from).map(Ok).map(Ok)); + let headers = headers.into_iter().map(proto::core::BlockHeader::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut headers).await.is_err() { + if utils::mpsc::send_all(&tx, headers).await.is_err() { break; } }, @@ -354,7 +352,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ ) -> Result, RpcStatus> { let req = request.into_message(); const BATCH_SIZE: usize = 1000; - let (mut tx, rx) = mpsc::channel(BATCH_SIZE); + let (tx, rx) = mpsc::channel(BATCH_SIZE); let db = self.db(); task::spawn(async move { @@ -394,15 +392,9 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ break; }, Ok(kernels) => { - let mut kernels = stream::iter( - kernels - .into_iter() - .map(proto::types::TransactionKernel::from) - .map(Ok) - .map(Ok), - ); + let kernels = kernels.into_iter().map(proto::types::TransactionKernel::from).map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut kernels).await.is_err() { + if utils::mpsc::send_all(&tx, kernels).await.is_err() { break; } }, diff --git a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs index ef10b41c2f..8064aaf458 100644 --- a/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs +++ b/base_layer/core/src/base_node/sync/rpc/sync_utxos_task.rs @@ -25,11 +25,11 @@ use crate::{ proto, proto::base_node::{SyncUtxo, SyncUtxosRequest, SyncUtxosResponse}, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc, time::Instant}; -use tari_comms::protocol::rpc::RpcStatus; +use tari_comms::{protocol::rpc::RpcStatus, utils}; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tokio::sync::mpsc; const LOG_TARGET: &str = "c::base_node::sync_rpc::sync_utxo_task"; @@ -147,8 +147,7 @@ where B: BlockchainBackend + 'static utxos.len(), deleted_diff.cardinality(), ); - let mut utxos = stream::iter( - utxos + let utxos = utxos .into_iter() .enumerate() // Only include pruned UTXOs if include_pruned_utxos is true @@ -161,12 +160,10 @@ where B: BlockchainBackend + 'static mmr_index: start + i as u64, } }) - .map(Ok) - .map(Ok), - ); + .map(Ok); // Ensure task stops if the peer prematurely stops their RPC session - if tx.send_all(&mut utxos).await.is_err() { + if utils::mpsc::send_all(&tx, utxos).await.is_err() { break; } diff --git a/base_layer/core/src/base_node/sync/rpc/tests.rs b/base_layer/core/src/base_node/sync/rpc/tests.rs index 35611a9dda..61adc1aa3c 100644 --- a/base_layer/core/src/base_node/sync/rpc/tests.rs +++ b/base_layer/core/src/base_node/sync/rpc/tests.rs @@ -89,7 +89,7 @@ mod sync_blocks { use tari_comms::protocol::rpc::RpcStatusCode; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_not_found_if_unknown_hash() { let mut backend = create_mock_backend(); backend.expect_fetch().times(1).returning(|_| Ok(None)); @@ -103,7 +103,7 @@ mod sync_blocks { unpack_enum!(RpcStatusCode::NotFound = err.status_code()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_sends_an_empty_response() { let mut backend = create_mock_backend(); @@ -136,7 +136,7 @@ mod sync_blocks { assert!(streaming.next().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_streams_blocks_until_end() { let mut backend = create_mock_backend(); diff --git a/base_layer/core/src/base_node/sync/validators.rs b/base_layer/core/src/base_node/sync/validators.rs index d65af9a972..e5282cc604 100644 --- a/base_layer/core/src/base_node/sync/validators.rs +++ b/base_layer/core/src/base_node/sync/validators.rs @@ -20,10 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{fmt, sync::Arc}; + use crate::{ chain_storage::BlockchainBackend, consensus::ConsensusManager, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::{ block_validators::BlockValidator, CandidateBlockBodyValidation, @@ -31,7 +33,6 @@ use crate::{ FinalHorizonStateValidation, }, }; -use std::{fmt, sync::Arc}; #[derive(Clone)] pub struct SyncValidators { @@ -51,9 +52,13 @@ impl SyncValidators { } } - pub fn full_consensus(rules: ConsensusManager, factories: CryptoFactories) -> Self { + pub fn full_consensus( + rules: ConsensusManager, + factories: CryptoFactories, + bypass_range_proof_verification: bool, + ) -> Self { Self::new( - BlockValidator::new(rules.clone(), factories.clone()), + BlockValidator::new(rules.clone(), bypass_range_proof_verification, factories.clone()), ChainBalanceValidator::::new(rules, factories), ) } diff --git a/base_layer/core/src/blocks/block.rs b/base_layer/core/src/blocks/block.rs index 87124dbece..04cded3ada 100644 --- a/base_layer/core/src/blocks/block.rs +++ b/base_layer/core/src/blocks/block.rs @@ -23,6 +23,18 @@ // Portions of this file were originally copyrighted (c) 2018 The Grin Developers, issued under the Apache License, // Version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0. +use std::{ + fmt, + fmt::{Display, Formatter}, +}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::tari_utilities::Hashable; +use thiserror::Error; + +use tari_common_types::types::BlockHash; + use crate::{ blocks::BlockHeader, chain_storage::MmrTree, @@ -33,18 +45,9 @@ use crate::{ aggregated_body::AggregateBody, tari_amount::MicroTari, transaction::{Transaction, TransactionError, TransactionInput, TransactionKernel, TransactionOutput}, - types::CryptoFactories, + CryptoFactories, }, }; -use log::*; -use serde::{Deserialize, Serialize}; -use std::{ - fmt, - fmt::{Display, Formatter}, -}; -use tari_common_types::types::BlockHash; -use tari_crypto::tari_utilities::Hashable; -use thiserror::Error; #[derive(Clone, Debug, PartialEq, Error)] pub enum BlockValidationError { diff --git a/base_layer/core/src/blocks/block_header.rs b/base_layer/core/src/blocks/block_header.rs index 3346d0945d..fb7aad681e 100644 --- a/base_layer/core/src/blocks/block_header.rs +++ b/base_layer/core/src/blocks/block_header.rs @@ -39,11 +39,7 @@ #[cfg(feature = "base_node")] use crate::blocks::{BlockBuilder, NewBlockHeaderTemplate}; - -use crate::{ - proof_of_work::{PowAlgorithm, PowError, ProofOfWork}, - transactions::types::{BlindingFactor, HashDigest}, -}; +use crate::proof_of_work::{PowAlgorithm, PowError, ProofOfWork}; use chrono::{DateTime, Utc}; use digest::Digest; use serde::{ @@ -57,7 +53,7 @@ use std::{ fmt, fmt::{Display, Error, Formatter}, }; -use tari_common_types::types::{BlockHash, BLOCK_HASH_LENGTH}; +use tari_common_types::types::{BlindingFactor, BlockHash, HashDigest, BLOCK_HASH_LENGTH}; use tari_crypto::tari_utilities::{epoch_time::EpochTime, hex::Hex, ByteArray, Hashable}; use thiserror::Error; diff --git a/base_layer/core/src/blocks/genesis_block.rs b/base_layer/core/src/blocks/genesis_block.rs index a346bea095..861b3e8650 100644 --- a/base_layer/core/src/blocks/genesis_block.rs +++ b/base_layer/core/src/blocks/genesis_block.rs @@ -30,15 +30,13 @@ use crate::{ chain_storage::{BlockHeaderAccumulatedData, ChainBlock}, transactions::{ aggregated_body::AggregateBody, - bullet_rangeproofs::BulletRangeProof, tari_amount::MicroTari, transaction::{KernelFeatures, OutputFeatures, OutputFlags, TransactionKernel, TransactionOutput}, - types::{Commitment, PrivateKey, PublicKey, Signature}, }, }; use chrono::DateTime; use std::sync::Arc; -use tari_common_types::types::BLOCK_HASH_LENGTH; +use tari_common_types::types::{BulletRangeProof, Commitment, PrivateKey, PublicKey, Signature, BLOCK_HASH_LENGTH}; use tari_crypto::{ script::TariScript, tari_utilities::{hash::Hashable, hex::*}, @@ -369,10 +367,96 @@ pub fn get_ridcully_genesis_block_raw() -> Block { } } +pub fn get_igor_genesis_block() -> ChainBlock { + // lets get the block + let block = get_igor_genesis_block_raw(); + + let accumulated_data = BlockHeaderAccumulatedData { + hash: block.hash(), + total_kernel_offset: block.header.total_kernel_offset.clone(), + achieved_difficulty: 1.into(), + total_accumulated_difficulty: 1, + accumulated_monero_difficulty: 1.into(), + accumulated_sha_difficulty: 1.into(), + target_difficulty: 1.into(), + }; + ChainBlock::try_construct(Arc::new(block), accumulated_data).unwrap() +} + +#[allow(deprecated)] +pub fn get_igor_genesis_block_raw() -> Block { + let sig = Signature::new( + PublicKey::from_hex("f2139d1cdbcfa670bbb60d4d03d9d50b0a522e674b11280e8064f6dc30e84133").unwrap(), + PrivateKey::from_hex("3ff7522d9a744ebf99c7b6664c0e2c8c64d2a7b902a98b78964766f9f7f2b107").unwrap(), + ); + let mut body = AggregateBody::new( + vec![], + vec![TransactionOutput { + features: OutputFeatures { + flags: OutputFlags::COINBASE_OUTPUT, + maturity: 60, + }, + commitment: Commitment::from_hex( + "fadafb12de96d90042dcbf839985aadb7ae88baa3446d5c6a17937ef2b36783e", + ) + .unwrap(), + proof: BulletRangeProof::from_hex("845c947cbf23683f6ff6a56d0aa55fca14a618f7476d4e29348c5cbadf2bb062b8da701a0f058eb69c88492895c3f034db194f6d1b2d29ea83c1a68cbdd19a3f90ae080cfd0315bb20cd05a462c4e06e708b015da1d70c0f87e8c7413b579008e43a6c8dc1edb72b0b67612e897d251ec55798184ff35c80d18262e98034677b73f2dcc7ae25c9119900aadaf04a16068bf57b9e8b9bb694331750dc8acc6102b8961be183419dce2f96c48ced9892e4cdb091dcda0d6a0bb4ed94fc0c63ca065f25ce1e560504d49970bcaac007f33368f15ffa0dd3f56bf799b66fa684fe0fbeb882aee4a6fe05a3ca7c488a6ba22779a42f0f5d875175b8ebc517dd49df20b4f04f027b7d22b7c62cb93727f35c18a0b776d95fac4ff5405d6ed3dbb7613152178cecea4b712aa6e6701804ded71d94cf67de2e86ae401499b39de81b7344185c9eb3bd570ac6121143a690f118d9413abb894729b6b3e057f4771b2c2204285151a56695257992f2b0331f27066270718b37ab472c339d2560c1f6559f3c4ce31ec7f7e2acdbebb1715951d8177283a1ccc2f393ce292956de5db4afde419c0264d5cc4758e6e2c07b730ad43819f3761658d63794cc8071b30f9d7cd622bece4f086b0ca6a04fee888856084543a99848f06334acf48cace58e5ef8c85412017c400b4ec92481ba6d745915aef40531db73d1d84d07d7fce25737629e0fc4ee71e7d505bfd382e362cd1ac03a67c93b8f20cb4285ce240cf1e000d48332ba32e713d6cdf6266449a0a156241f7b1b36753f46f1ecb8b1836625508c5f31bc7ebc1d7cd634272be02cc109bf86983a0591bf00bacea1287233fc12324846398be07d44e8e14bd78cd548415f6de60b5a0c43a84ac29f6a8ac0b1b748dd07a8a4124625e1055b5f5b19da79c319b6e465ca5df0eb70cb4e3dc399891ce90b").unwrap(), + // For genesis block: A default script can never be spent, intentionally + script: TariScript::default(), + // Script offset never checked for coinbase, thus can use default + sender_offset_public_key: Default::default(), + // For genesis block: Metadata signature will never be checked + metadata_signature: Default::default(), + }], + vec![TransactionKernel { + features: KernelFeatures::COINBASE_KERNEL, + fee: MicroTari(0), + lock_height: 0, + excess: Commitment::from_hex( + "f472cc347a1006b7390f9c93b3c62fba334fd99f6c9c1daf9302646cd4781f61", + ) + .unwrap(), + excess_sig: sig, + }], + ); + body.sort(); + // set genesis timestamp + let genesis = DateTime::parse_from_rfc2822("27 Aug 2021 06:00:00 +0200").unwrap(); + let timestamp = genesis.timestamp() as u64; + Block { + header: BlockHeader { + version: 0, + height: 0, + prev_hash: vec![0; BLOCK_HASH_LENGTH], + timestamp: timestamp.into(), + output_mr: from_hex("dcc44f39b65e5e1e526887e7d56f7b85e2ea44bd29bc5bc195e6e015d19e1c06").unwrap(), + witness_mr: from_hex("e4d7dab49a66358379a901b9a36c10f070aa9d7bdc8ae752947b6fc4e55d255f").unwrap(), + output_mmr_size: 1, + kernel_mr: from_hex("589bc62ac5d9139f921c68b8075c32d8d130024acaf3196d1d6a89df601e2bcf").unwrap(), + kernel_mmr_size: 1, + input_mr: vec![0; BLOCK_HASH_LENGTH], + total_kernel_offset: PrivateKey::from_hex( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(), + total_script_offset: PrivateKey::from_hex( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(), + nonce: 0, + pow: ProofOfWork { + pow_algo: PowAlgorithm::Sha3, + pow_data: vec![], + }, + }, + body, + } +} + #[cfg(test)] mod test { use super::*; - use crate::transactions::types::CryptoFactories; + use crate::transactions::CryptoFactories; #[test] fn weatherwax_genesis_sanity_check() { diff --git a/base_layer/core/src/blocks/new_blockheader_template.rs b/base_layer/core/src/blocks/new_blockheader_template.rs index 543a22a287..7fc902fdf0 100644 --- a/base_layer/core/src/blocks/new_blockheader_template.rs +++ b/base_layer/core/src/blocks/new_blockheader_template.rs @@ -23,11 +23,10 @@ use crate::{ blocks::block_header::{hash_serializer, BlockHeader}, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Error, Formatter}; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlindingFactor, BlockHash}; use tari_crypto::tari_utilities::hex::Hex; /// The NewBlockHeaderTemplate is used for the construction of a new mineable block. It contains all the metadata for diff --git a/base_layer/core/src/chain_storage/accumulated_data.rs b/base_layer/core/src/chain_storage/accumulated_data.rs index 40f0bac01b..eb6d50e499 100644 --- a/base_layer/core/src/chain_storage/accumulated_data.rs +++ b/base_layer/core/src/chain_storage/accumulated_data.rs @@ -25,10 +25,7 @@ use crate::{ chain_storage::ChainStorageError, proof_of_work::{AchievedTargetDifficulty, Difficulty, PowAlgorithm}, tari_utilities::Hashable, - transactions::{ - aggregated_body::AggregateBody, - types::{BlindingFactor, Commitment, HashOutput}, - }, + transactions::aggregated_body::AggregateBody, }; use croaring::Bitmap; use log::*; @@ -47,6 +44,7 @@ use std::{ fmt::{Display, Formatter}, sync::Arc, }; +use tari_common_types::types::{BlindingFactor, Commitment, HashOutput}; use tari_crypto::tari_utilities::hex::Hex; use tari_mmr::{pruned_hashset::PrunedHashSet, ArrayLike}; diff --git a/base_layer/core/src/chain_storage/async_db.rs b/base_layer/core/src/chain_storage/async_db.rs index 14e4f0e5cd..ca1aea97c5 100644 --- a/base_layer/core/src/chain_storage/async_db.rs +++ b/base_layer/core/src/chain_storage/async_db.rs @@ -42,16 +42,16 @@ use crate::{ common::rolling_vec::RollingVec, proof_of_work::{PowAlgorithm, TargetDifficultyWindow}, tari_utilities::epoch_time::EpochTime, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{mem, ops::RangeBounds, sync::Arc, time::Instant}; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, Commitment, HashOutput, Signature}, +}; use tari_mmr::pruned_hashset::PrunedHashSet; const LOG_TARGET: &str = "c::bn::async_db"; diff --git a/base_layer/core/src/chain_storage/blockchain_backend.rs b/base_layer/core/src/chain_storage/blockchain_backend.rs index f5d3b6ad36..505d25dda7 100644 --- a/base_layer/core/src/chain_storage/blockchain_backend.rs +++ b/base_layer/core/src/chain_storage/blockchain_backend.rs @@ -14,13 +14,13 @@ use crate::{ HorizonData, MmrTree, }, - transactions::{ - transaction::{TransactionInput, TransactionKernel}, - types::{Commitment, HashOutput, Signature}, - }, + transactions::transaction::{TransactionInput, TransactionKernel}, }; use croaring::Bitmap; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{Commitment, HashOutput, Signature}, +}; use tari_mmr::Hash; /// Identify behaviour for Blockchain database backends. Implementations must support `Send` and `Sync` so that diff --git a/base_layer/core/src/chain_storage/blockchain_database.rs b/base_layer/core/src/chain_storage/blockchain_database.rs index 867e7877e1..f2bf132cf4 100644 --- a/base_layer/core/src/chain_storage/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/blockchain_database.rs @@ -46,10 +46,7 @@ use crate::{ consensus::{chain_strength_comparer::ChainStrengthComparer, ConsensusConstants, ConsensusManager}, proof_of_work::{monero_rx::MoneroPowData, PowAlgorithm, TargetDifficultyWindow}, tari_utilities::epoch_time::EpochTime, - transactions::{ - transaction::TransactionKernel, - types::{Commitment, HashDigest, HashOutput, Signature}, - }, + transactions::transaction::TransactionKernel, validation::{DifficultyCalculator, HeaderValidation, OrphanValidation, PostOrphanBodyValidation, ValidationError}, }; use croaring::Bitmap; @@ -64,7 +61,10 @@ use std::{ sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, time::Instant, }; -use tari_common_types::{chain_metadata::ChainMetadata, types::BlockHash}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{BlockHash, Commitment, HashDigest, HashOutput, Signature}, +}; use tari_crypto::tari_utilities::{hex::Hex, ByteArray, Hashable}; use tari_mmr::{MerkleMountainRange, MutableMmr}; use uint::static_assertions::_core::ops::RangeBounds; diff --git a/base_layer/core/src/chain_storage/db_transaction.rs b/base_layer/core/src/chain_storage/db_transaction.rs index dfe5947543..0d9736c39b 100644 --- a/base_layer/core/src/chain_storage/db_transaction.rs +++ b/base_layer/core/src/chain_storage/db_transaction.rs @@ -22,10 +22,7 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{error::ChainStorageError, ChainBlock, ChainHeader, MmrTree}, - transactions::{ - transaction::{TransactionKernel, TransactionOutput}, - types::{Commitment, HashOutput}, - }, + transactions::transaction::{TransactionKernel, TransactionOutput}, }; use croaring::Bitmap; use std::{ @@ -33,7 +30,7 @@ use std::{ fmt::{Display, Error, Formatter}, sync::Arc, }; -use tari_common_types::types::BlockHash; +use tari_common_types::types::{BlockHash, Commitment, HashOutput}; use tari_crypto::tari_utilities::{ hex::{to_hex, Hex}, Hashable, diff --git a/base_layer/core/src/chain_storage/historical_block.rs b/base_layer/core/src/chain_storage/historical_block.rs index 1188f7f27d..99fd45335f 100644 --- a/base_layer/core/src/chain_storage/historical_block.rs +++ b/base_layer/core/src/chain_storage/historical_block.rs @@ -23,10 +23,10 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockHeaderAccumulatedData, ChainBlock, ChainStorageError}, - transactions::types::HashOutput, }; use serde::{Deserialize, Serialize}; use std::{fmt, fmt::Display, sync::Arc}; +use tari_common_types::types::HashOutput; use tari_crypto::tari_utilities::hex::Hex; /// The representation of a historical block in the blockchain. It is essentially identical to a protocol-defined diff --git a/base_layer/core/src/chain_storage/horizon_data.rs b/base_layer/core/src/chain_storage/horizon_data.rs index 1e6f542142..6213d490f3 100644 --- a/base_layer/core/src/chain_storage/horizon_data.rs +++ b/base_layer/core/src/chain_storage/horizon_data.rs @@ -1,5 +1,3 @@ -use crate::transactions::types::Commitment; - // Copyright 2021. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -22,6 +20,7 @@ use crate::transactions::types::Commitment; // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use serde::{Deserialize, Serialize}; +use tari_common_types::types::Commitment; use tari_crypto::tari_utilities::ByteArray; #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs index 069765edbb..e49aee1fed 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs @@ -76,7 +76,6 @@ use crate::{ transactions::{ aggregated_body::AggregateBody, transaction::{TransactionInput, TransactionKernel, TransactionOutput}, - types::{Commitment, HashDigest, HashOutput, Signature}, }, }; use croaring::Bitmap; @@ -87,7 +86,7 @@ use serde::{Deserialize, Serialize}; use std::{convert::TryFrom, fmt, fs, fs::File, ops::Deref, path::Path, sync::Arc, time::Instant}; use tari_common_types::{ chain_metadata::ChainMetadata, - types::{BlockHash, BLOCK_HASH_LENGTH}, + types::{BlockHash, Commitment, HashDigest, HashOutput, Signature, BLOCK_HASH_LENGTH}, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex, ByteArray}; use tari_mmr::{pruned_hashset::PrunedHashSet, Hash, MerkleMountainRange, MutableMmr}; diff --git a/base_layer/core/src/chain_storage/lmdb_db/mod.rs b/base_layer/core/src/chain_storage/lmdb_db/mod.rs index 785b0363ee..f97c1c4878 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/mod.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/mod.rs @@ -24,12 +24,10 @@ mod lmdb; #[allow(clippy::module_inception)] mod lmdb_db; -use crate::transactions::{ - transaction::{TransactionInput, TransactionKernel, TransactionOutput}, - types::HashOutput, -}; +use crate::transactions::transaction::{TransactionInput, TransactionKernel, TransactionOutput}; pub use lmdb_db::{create_lmdb_database, create_recovery_lmdb_database, LMDBDatabase}; use serde::{Deserialize, Serialize}; +use tari_common_types::types::HashOutput; pub const LMDB_DB_METADATA: &str = "metadata"; pub const LMDB_DB_HEADERS: &str = "headers"; diff --git a/base_layer/core/src/chain_storage/pruned_output.rs b/base_layer/core/src/chain_storage/pruned_output.rs index 957c0e8c86..8c753f30a5 100644 --- a/base_layer/core/src/chain_storage/pruned_output.rs +++ b/base_layer/core/src/chain_storage/pruned_output.rs @@ -19,7 +19,8 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{transaction::TransactionOutput, types::HashOutput}; +use crate::transactions::transaction::TransactionOutput; +use tari_common_types::types::HashOutput; #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq)] diff --git a/base_layer/core/src/consensus/consensus_constants.rs b/base_layer/core/src/consensus/consensus_constants.rs index 52c0620e8c..85e5027479 100644 --- a/base_layer/core/src/consensus/consensus_constants.rs +++ b/base_layer/core/src/consensus/consensus_constants.rs @@ -356,6 +356,38 @@ impl ConsensusConstants { }] } + pub fn igor() -> Vec { + let mut algos = HashMap::new(); + // seting sha3/monero to 40/60 split + algos.insert(PowAlgorithm::Sha3, PowAlgorithmConstants { + max_target_time: 1800, + min_difficulty: 60_000_000.into(), + max_difficulty: u64::MAX.into(), + target_time: 300, + }); + algos.insert(PowAlgorithm::Monero, PowAlgorithmConstants { + max_target_time: 1200, + min_difficulty: 60_000.into(), + max_difficulty: u64::MAX.into(), + target_time: 200, + }); + vec![ConsensusConstants { + effective_from_height: 0, + coinbase_lock_height: 6, + blockchain_version: 1, + future_time_limit: 540, + difficulty_block_window: 90, + max_block_transaction_weight: 19500, + median_timestamp_count: 11, + emission_initial: 5_538_846_115 * uT, + emission_decay: &EMISSION_DECAY, + emission_tail: 100.into(), + max_randomx_seed_height: std::u64::MAX, + proof_of_work: algos, + faucet_value: (5000 * 4000) * T, + }] + } + pub fn mainnet() -> Vec { // Note these values are all placeholders for final values let difficulty_block_window = 90; diff --git a/base_layer/core/src/consensus/consensus_manager.rs b/base_layer/core/src/consensus/consensus_manager.rs index 0663d153ca..5ca1610e03 100644 --- a/base_layer/core/src/consensus/consensus_manager.rs +++ b/base_layer/core/src/consensus/consensus_manager.rs @@ -23,6 +23,7 @@ use crate::{ blocks::{ genesis_block::{ + get_igor_genesis_block, get_mainnet_genesis_block, get_ridcully_genesis_block, get_stibbons_genesis_block, @@ -82,6 +83,7 @@ impl ConsensusManager { .gen_block .clone() .unwrap_or_else(get_weatherwax_genesis_block), + Network::Igor => get_igor_genesis_block(), } } diff --git a/base_layer/core/src/consensus/network.rs b/base_layer/core/src/consensus/network.rs index 55e17e3c98..0e2e598a0a 100644 --- a/base_layer/core/src/consensus/network.rs +++ b/base_layer/core/src/consensus/network.rs @@ -36,6 +36,7 @@ impl NetworkConsensus { Stibbons => ConsensusConstants::stibbons(), Weatherwax => ConsensusConstants::weatherwax(), LocalNet => ConsensusConstants::localnet(), + Igor => ConsensusConstants::igor(), } } diff --git a/base_layer/core/src/lib.rs b/base_layer/core/src/lib.rs index 2e4bdc2f49..5a93b73576 100644 --- a/base_layer/core/src/lib.rs +++ b/base_layer/core/src/lib.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work +// Needed to make tokio::select! work #![recursion_limit = "512"] #![feature(shrink_to)] // #![cfg_attr(not(debug_assertions), deny(unused_variables))] diff --git a/base_layer/core/src/mempool/async_mempool.rs b/base_layer/core/src/mempool/async_mempool.rs index 04c4314c26..9a9b0a5f8a 100644 --- a/base_layer/core/src/mempool/async_mempool.rs +++ b/base_layer/core/src/mempool/async_mempool.rs @@ -23,9 +23,10 @@ use crate::{ blocks::Block, mempool::{error::MempoolError, Mempool, StateResponse, StatsResponse, TxStorageResponse}, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use std::sync::Arc; +use tari_common_types::types::Signature; macro_rules! make_async { ($fn:ident($($param1:ident:$ptype1:ty,$param2:ident:$ptype2:ty),+) -> $rtype:ty) => { diff --git a/base_layer/core/src/mempool/mempool.rs b/base_layer/core/src/mempool/mempool.rs index 97b3ceac15..865ca7b980 100644 --- a/base_layer/core/src/mempool/mempool.rs +++ b/base_layer/core/src/mempool/mempool.rs @@ -30,10 +30,11 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, validation::MempoolTransactionValidation, }; use std::sync::{Arc, RwLock}; +use tari_common_types::types::Signature; /// The Mempool consists of an Unconfirmed Transaction Pool, Pending Pool, Orphan Pool and Reorg Pool and is responsible /// for managing and maintaining all unconfirmed transactions have not yet been included in a block, and transactions diff --git a/base_layer/core/src/mempool/mempool_storage.rs b/base_layer/core/src/mempool/mempool_storage.rs index b5c0a800c7..d2ccb38dbd 100644 --- a/base_layer/core/src/mempool/mempool_storage.rs +++ b/base_layer/core/src/mempool/mempool_storage.rs @@ -31,11 +31,12 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, validation::{MempoolTransactionValidation, ValidationError}, }; use log::*; use std::sync::Arc; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; pub const LOG_TARGET: &str = "c::mp::mempool_storage"; diff --git a/base_layer/core/src/mempool/mod.rs b/base_layer/core/src/mempool/mod.rs index 1374d4f08e..afe8d5a69c 100644 --- a/base_layer/core/src/mempool/mod.rs +++ b/base_layer/core/src/mempool/mod.rs @@ -72,9 +72,10 @@ mod sync_protocol; #[cfg(feature = "base_node")] pub use sync_protocol::MempoolSyncInitializer; -use crate::transactions::{transaction::Transaction, types::Signature}; +use crate::transactions::transaction::Transaction; use core::fmt::{Display, Error, Formatter}; use serde::{Deserialize, Serialize}; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::hex::Hex; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] diff --git a/base_layer/core/src/mempool/priority/prioritized_transaction.rs b/base_layer/core/src/mempool/priority/prioritized_transaction.rs index 1080536d7b..cc82531461 100644 --- a/base_layer/core/src/mempool/priority/prioritized_transaction.rs +++ b/base_layer/core/src/mempool/priority/prioritized_transaction.rs @@ -20,11 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - mempool::priority::PriorityError, - transactions::{transaction::Transaction, types::HashOutput}, -}; +use crate::{mempool::priority::PriorityError, transactions::transaction::Transaction}; use std::sync::Arc; +use tari_common_types::types::HashOutput; use tari_crypto::tari_utilities::message_format::MessageFormat; /// Create a unique unspent transaction priority based on the transaction fee, maturity of the oldest input UTXO and the diff --git a/base_layer/core/src/mempool/proto/state_response.rs b/base_layer/core/src/mempool/proto/state_response.rs index 80ab03dd0a..8b3af21ac0 100644 --- a/base_layer/core/src/mempool/proto/state_response.rs +++ b/base_layer/core/src/mempool/proto/state_response.rs @@ -23,10 +23,8 @@ use crate::mempool::{proto::mempool::StateResponse as ProtoStateResponse, StateResponse}; use std::convert::{TryFrom, TryInto}; // use crate::transactions::proto::types::Signature as ProtoSignature; -use crate::{ - mempool::proto::mempool::Signature as ProtoSignature, - transactions::types::{PrivateKey, PublicKey, Signature}, -}; +use crate::mempool::proto::mempool::Signature as ProtoSignature; +use tari_common_types::types::{PrivateKey, PublicKey, Signature}; use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; //---------------------------------- Signature --------------------------------------------// diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs index 5626242474..5e0f12856e 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs @@ -26,7 +26,7 @@ use crate::{ consts::{MEMPOOL_REORG_POOL_CACHE_TTL, MEMPOOL_REORG_POOL_STORAGE_CAPACITY}, reorg_pool::{ReorgPoolError, ReorgPoolStorage}, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use serde::{Deserialize, Serialize}; use std::{ @@ -34,6 +34,7 @@ use std::{ time::Duration, }; use tari_common::configuration::seconds; +use tari_common_types::types::Signature; /// Configuration for the ReorgPool #[derive(Clone, Copy, Deserialize, Serialize)] diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs index 71fbb468ff..c178c33545 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs @@ -20,13 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - blocks::Block, - mempool::reorg_pool::reorg_pool::ReorgPoolConfig, - transactions::{transaction::Transaction, types::Signature}, -}; +use crate::{blocks::Block, mempool::reorg_pool::reorg_pool::ReorgPoolConfig, transactions::transaction::Transaction}; use log::*; use std::sync::Arc; +use tari_common_types::types::Signature; use tari_crypto::tari_utilities::hex::Hex; use ttl_cache::TtlCache; diff --git a/base_layer/core/src/mempool/rpc/test.rs b/base_layer/core/src/mempool/rpc/test.rs index 64ad84f122..a9cbb2ee49 100644 --- a/base_layer/core/src/mempool/rpc/test.rs +++ b/base_layer/core/src/mempool/rpc/test.rs @@ -43,7 +43,7 @@ mod get_stats { use super::*; use crate::mempool::{MempoolService, StatsResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_stats() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_stats = StatsResponse { @@ -66,7 +66,7 @@ mod get_state { use super::*; use crate::mempool::{MempoolService, StateResponse}; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_state() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected_state = StateResponse { @@ -94,7 +94,7 @@ mod get_tx_state_by_excess_sig { use tari_crypto::ristretto::{RistrettoPublicKey, RistrettoSecretKey}; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_the_storage_status() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -116,7 +116,7 @@ mod get_tx_state_by_excess_sig { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_signature() { let (service, _, req_mock, _tmpdir) = setup(); let status = service @@ -139,7 +139,7 @@ mod submit_transaction { use tari_crypto::ristretto::RistrettoSecretKey; use tari_test_utils::unpack_enum; - #[tokio_macros::test_basic] + #[tokio::test] async fn it_submits_transaction() { let (service, mempool, req_mock, _tmpdir) = setup(); let expected = TxStorageResponse::UnconfirmedPool; @@ -166,7 +166,7 @@ mod submit_transaction { assert_eq!(mempool.get_call_count(), 1); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_errors_on_invalid_transaction() { let (service, _, req_mock, _tmpdir) = setup(); let status = service diff --git a/base_layer/core/src/mempool/service/handle.rs b/base_layer/core/src/mempool/service/handle.rs index 662e411f15..6eebf2b958 100644 --- a/base_layer/core/src/mempool/service/handle.rs +++ b/base_layer/core/src/mempool/service/handle.rs @@ -28,8 +28,9 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; +use tari_common_types::types::Signature; use tari_service_framework::{reply_channel::TrySenderService, Service}; #[derive(Clone)] diff --git a/base_layer/core/src/mempool/service/inbound_handlers.rs b/base_layer/core/src/mempool/service/inbound_handlers.rs index a4f9aa1aee..7f3f90c44a 100644 --- a/base_layer/core/src/mempool/service/inbound_handlers.rs +++ b/base_layer/core/src/mempool/service/inbound_handlers.rs @@ -122,7 +122,7 @@ impl MempoolInboundHandlers { if tx_storage.is_stored() { debug!( target: LOG_TARGET, - "Mempool already has transaction: {}", kernel_excess_sig + "Mempool already has transaction: {}.", kernel_excess_sig ); return Ok(tx_storage); } diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index cb57d58e94..a295daf96a 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -37,7 +37,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{channel::mpsc, future, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use log::*; use std::{convert::TryFrom, sync::Arc}; use tari_comms_dht::Dht; @@ -54,7 +54,7 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; const LOG_TARGET: &str = "c::bn::mempool_service::initializer"; const SUBSCRIPTION_LABEL: &str = "Mempool"; @@ -148,7 +148,7 @@ impl ServiceInitializer for MempoolServiceInitializer { let mempool_handle = MempoolHandle::new(request_sender); context.register_handle(mempool_handle); - let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded(); + let (outbound_tx_sender, outbound_tx_stream) = mpsc::unbounded_channel(); let (outbound_request_sender_service, outbound_request_stream) = reply_channel::unbounded(); let (local_request_sender_service, local_request_stream) = reply_channel::unbounded(); let (mempool_state_event_publisher, _) = broadcast::channel(100); @@ -167,7 +167,7 @@ impl ServiceInitializer for MempoolServiceInitializer { context.register_handle(outbound_mp_interface); context.register_handle(local_mp_interface); - context.spawn_when_ready(move |handles| async move { + context.spawn_until_shutdown(move |handles| { let outbound_message_service = handles.expect_handle::().outbound_requester(); let state_machine = handles.expect_handle::(); let base_node = handles.expect_handle::(); @@ -182,11 +182,7 @@ impl ServiceInitializer for MempoolServiceInitializer { block_event_stream: base_node.get_block_event_stream(), request_receiver, }; - let service = - MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams); - futures::pin_mut!(service); - future::select(service, handles.get_shutdown_signal()).await; - info!(target: LOG_TARGET, "Mempool Service shutdown"); + MempoolService::new(outbound_message_service, inbound_handlers, config, state_machine).start(streams) }); Ok(()) diff --git a/base_layer/core/src/mempool/service/local_service.rs b/base_layer/core/src/mempool/service/local_service.rs index c58af9912d..05f3ac6779 100644 --- a/base_layer/core/src/mempool/service/local_service.rs +++ b/base_layer/core/src/mempool/service/local_service.rs @@ -28,8 +28,9 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; +use tari_common_types::types::Signature; use tari_service_framework::{reply_channel::SenderService, Service}; use tokio::sync::broadcast; @@ -146,7 +147,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); @@ -157,7 +158,7 @@ mod test { assert_eq!(stats, request_stats()); } - #[tokio_macros::test] + #[tokio::test] async fn mempool_stats_from_multiple() { let (event_publisher, _) = broadcast::channel(100); let (tx, rx) = unbounded(); diff --git a/base_layer/core/src/mempool/service/outbound_interface.rs b/base_layer/core/src/mempool/service/outbound_interface.rs index 87cda226f3..8606a8ce02 100644 --- a/base_layer/core/src/mempool/service/outbound_interface.rs +++ b/base_layer/core/src/mempool/service/outbound_interface.rs @@ -26,12 +26,13 @@ use crate::{ StatsResponse, TxStorageResponse, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; -use futures::channel::mpsc::UnboundedSender; use log::*; +use tari_common_types::types::Signature; use tari_comms::peer_manager::NodeId; use tari_service_framework::{reply_channel::SenderService, Service}; +use tokio::sync::mpsc::UnboundedSender; pub const LOG_TARGET: &str = "c::mp::service::outbound_interface"; @@ -71,15 +72,13 @@ impl OutboundMempoolServiceInterface { transaction: Transaction, exclude_peers: Vec, ) -> Result<(), MempoolServiceError> { - self.tx_sender - .unbounded_send((transaction, exclude_peers)) - .or_else(|e| { - { - error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); - Err(e) - } - .map_err(|_| MempoolServiceError::BroadcastFailed) - }) + self.tx_sender.send((transaction, exclude_peers)).or_else(|e| { + { + error!(target: LOG_TARGET, "Could not broadcast transaction. {:?}", e); + Err(e) + } + .map_err(|_| MempoolServiceError::BroadcastFailed) + }) } /// Check if the specified transaction is stored in the mempool of a remote base node. diff --git a/base_layer/core/src/mempool/service/request.rs b/base_layer/core/src/mempool/service/request.rs index 8437b84a18..a6d6910024 100644 --- a/base_layer/core/src/mempool/service/request.rs +++ b/base_layer/core/src/mempool/service/request.rs @@ -20,10 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{transaction::Transaction, types::Signature}; +use crate::transactions::transaction::Transaction; use core::fmt::{Display, Error, Formatter}; use serde::{Deserialize, Serialize}; -use tari_common_types::waiting_requests::RequestKey; +use tari_common_types::{types::Signature, waiting_requests::RequestKey}; use tari_crypto::tari_utilities::hex::Hex; /// API Request enum for Mempool requests. diff --git a/base_layer/core/src/mempool/service/service.rs b/base_layer/core/src/mempool/service/service.rs index b8ee487b9c..137dd9d0a0 100644 --- a/base_layer/core/src/mempool/service/service.rs +++ b/base_layer/core/src/mempool/service/service.rs @@ -38,13 +38,7 @@ use crate::{ proto, transactions::transaction::Transaction, }; -use futures::{ - channel::{mpsc, oneshot::Sender as OneshotSender}, - pin_mut, - stream::StreamExt, - SinkExt, - Stream, -}; +use futures::{pin_mut, stream::StreamExt, Stream}; use log::*; use rand::rngs::OsRng; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -58,7 +52,10 @@ use tari_comms_dht::{ use tari_crypto::tari_utilities::hex::Hex; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot::Sender as OneshotSender}, + task, +}; const LOG_TARGET: &str = "c::mempool::service::service"; @@ -118,7 +115,7 @@ impl MempoolService { { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); - let mut outbound_tx_stream = streams.outbound_tx_stream.fuse(); + let mut outbound_tx_stream = streams.outbound_tx_stream; let inbound_request_stream = streams.inbound_request_stream.fuse(); pin_mut!(inbound_request_stream); let inbound_response_stream = streams.inbound_response_stream.fuse(); @@ -127,70 +124,70 @@ impl MempoolService { pin_mut!(inbound_transaction_stream); let local_request_stream = streams.local_request_stream.fuse(); pin_mut!(local_request_stream); - let mut block_event_stream = streams.block_event_stream.fuse(); + let mut block_event_stream = streams.block_event_stream; let mut timeout_receiver_stream = self .timeout_receiver_stream .take() - .expect("Mempool Service initialized without timeout_receiver_stream") - .fuse(); + .expect("Mempool Service initialized without timeout_receiver_stream"); let mut request_receiver = streams.request_receiver; loop { - futures::select! { + tokio::select! { // Requests sent from the handle - request = request_receiver.select_next_some() => { + Some(request) = request_receiver.next() => { let (request, reply) = request.split(); let _ = reply.send(self.handle_request(request).await); }, // Outbound request messages from the OutboundMempoolServiceInterface - outbound_request_context = outbound_request_stream.select_next_some() => { + Some(outbound_request_context) = outbound_request_stream.next() => { self.spawn_handle_outbound_request(outbound_request_context); }, // Outbound tx messages from the OutboundMempoolServiceInterface - (txn, excluded_peers) = outbound_tx_stream.select_next_some() => { + Some((txn, excluded_peers)) = outbound_tx_stream.recv() => { self.spawn_handle_outbound_tx(txn, excluded_peers); }, // Incoming request messages from the Comms layer - domain_msg = inbound_request_stream.select_next_some() => { + Some(domain_msg) = inbound_request_stream.next() => { self.spawn_handle_incoming_request(domain_msg); }, // Incoming response messages from the Comms layer - domain_msg = inbound_response_stream.select_next_some() => { + Some(domain_msg) = inbound_response_stream.next() => { self.spawn_handle_incoming_response(domain_msg); }, // Incoming transaction messages from the Comms layer - transaction_msg = inbound_transaction_stream.select_next_some() => { + Some(transaction_msg) = inbound_transaction_stream.next() => { self.spawn_handle_incoming_tx(transaction_msg).await; } // Incoming local request messages from the LocalMempoolServiceInterface and other local services - local_request_context = local_request_stream.select_next_some() => { + Some(local_request_context) = local_request_stream.next() => { self.spawn_handle_local_request(local_request_context); }, // Block events from local Base Node. - block_event = block_event_stream.select_next_some() => { + block_event = block_event_stream.recv() => { if let Ok(block_event) = block_event { self.spawn_handle_block_event(block_event); } }, // Timeout events for waiting requests - timeout_request_key = timeout_receiver_stream.select_next_some() => { + Some(timeout_request_key) = timeout_receiver_stream.recv() => { self.spawn_handle_request_timeout(timeout_request_key); }, - complete => { + else => { info!(target: LOG_TARGET, "Mempool service shutting down"); break; } } } + Ok(()) } @@ -490,7 +487,7 @@ async fn handle_outbound_tx( exclude_peers: Vec, ) -> Result<(), MempoolServiceError> { let result = outbound_message_service - .propagate( + .flood( NodeDestination::Unknown, OutboundEncryption::ClearText, exclude_peers, @@ -506,9 +503,9 @@ async fn handle_outbound_tx( Ok(()) } -fn spawn_request_timeout(mut timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { +fn spawn_request_timeout(timeout_sender: mpsc::Sender, request_key: RequestKey, timeout: Duration) { task::spawn(async move { - tokio::time::delay_for(timeout).await; + tokio::time::sleep(timeout).await; let _ = timeout_sender.send(request_key).await; }); } diff --git a/base_layer/core/src/mempool/sync_protocol/initializer.rs b/base_layer/core/src/mempool/sync_protocol/initializer.rs index df40de648a..9af871121d 100644 --- a/base_layer/core/src/mempool/sync_protocol/initializer.rs +++ b/base_layer/core/src/mempool/sync_protocol/initializer.rs @@ -28,13 +28,17 @@ use crate::{ MempoolServiceConfig, }, }; -use futures::channel::mpsc; +use log::*; +use std::time::Duration; use tari_comms::{ connectivity::ConnectivityRequester, protocol::{ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolNotification}, Substream, }; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::{sync::mpsc, time::sleep}; + +const LOG_TARGET: &str = "c::mempool::sync_protocol"; pub struct MempoolSyncInitializer { config: MempoolServiceConfig, @@ -70,17 +74,31 @@ impl ServiceInitializer for MempoolSyncInitializer { let mempool = self.mempool.clone(); let notif_rx = self.notif_rx.take().unwrap(); - context.spawn_when_ready(move |handles| { + context.spawn_until_shutdown(move |handles| async move { let state_machine = handles.expect_handle::(); let connectivity = handles.expect_handle::(); - MempoolSyncProtocol::new( - config, - notif_rx, - connectivity.get_event_subscription(), - mempool, - Some(state_machine), - ) - .run() + // Ensure that we get an subscription ASAP so that we don't miss any connectivity events + let connectivity_event_subscription = connectivity.get_event_subscription(); + + let mut status_watch = state_machine.get_status_info_watch(); + if !status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Waiting for node to bootstrap..."); + while status_watch.changed().await.is_ok() { + if status_watch.borrow().bootstrapped { + debug!(target: LOG_TARGET, "Node bootstrapped. Starting mempool sync protocol"); + break; + } + trace!( + target: LOG_TARGET, + "Mempool sync still on hold, waiting for bootstrap to finish", + ); + sleep(Duration::from_secs(1)).await; + } + } + + MempoolSyncProtocol::new(config, notif_rx, connectivity_event_subscription, mempool) + .run() + .await; }); Ok(()) diff --git a/base_layer/core/src/mempool/sync_protocol/mod.rs b/base_layer/core/src/mempool/sync_protocol/mod.rs index 08219f2a8d..dc133f744d 100644 --- a/base_layer/core/src/mempool/sync_protocol/mod.rs +++ b/base_layer/core/src/mempool/sync_protocol/mod.rs @@ -73,12 +73,11 @@ mod initializer; pub use initializer::MempoolSyncInitializer; use crate::{ - base_node::StateMachineHandle, mempool::{async_mempool, proto, Mempool, MempoolServiceConfig}, proto as shared_proto, transactions::transaction::Transaction, }; -use futures::{stream, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{stream, SinkExt, Stream, StreamExt}; use log::*; use prost::Message; use std::{ @@ -88,7 +87,6 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx}, @@ -101,7 +99,11 @@ use tari_comms::{ PeerConnection, }; use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tokio::{sync::Semaphore, task}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Semaphore, + task, +}; const MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB const LOG_TARGET: &str = "c::mempool::sync_protocol"; @@ -111,11 +113,10 @@ pub static MEMPOOL_SYNC_PROTOCOL: Bytes = Bytes::from_static(b"t/mempool-sync/1" pub struct MempoolSyncProtocol { config: MempoolServiceConfig, protocol_notifier: ProtocolNotificationRx, - connectivity_events: Fuse, + connectivity_events: ConnectivityEventRx, mempool: Mempool, num_synched: Arc, permits: Arc, - state_machine: Option, } impl MempoolSyncProtocol @@ -126,54 +127,34 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static protocol_notifier: ProtocolNotificationRx, connectivity_events: ConnectivityEventRx, mempool: Mempool, - state_machine: Option, ) -> Self { Self { config, protocol_notifier, - connectivity_events: connectivity_events.fuse(), + connectivity_events, mempool, num_synched: Arc::new(AtomicUsize::new(0)), permits: Arc::new(Semaphore::new(1)), - state_machine, } } pub async fn run(mut self) { info!(target: LOG_TARGET, "Mempool protocol handler has started"); - while let Some(ref v) = self.state_machine { - let status_watch = v.get_status_info_watch(); - if (*status_watch.borrow()).bootstrapped { - break; - } - trace!( - target: LOG_TARGET, - "Mempool sync still on hold, waiting for bootstrap to finish", - ); - tokio::time::delay_for(Duration::from_secs(1)).await; - } + loop { - futures::select! { - event = self.connectivity_events.select_next_some() => { - if let Ok(event) = event { - self.handle_connectivity_event(&*event).await; - } + tokio::select! { + Ok(event) = self.connectivity_events.recv() => { + self.handle_connectivity_event(event).await; }, - notif = self.protocol_notifier.select_next_some() => { + Some(notif) = self.protocol_notifier.recv() => { self.handle_protocol_notification(notif); } - - // protocol_notifier and connectivity_events are closed - complete => { - info!(target: LOG_TARGET, "Mempool protocol handler is shutting down"); - break; - } } } } - async fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { + async fn handle_connectivity_event(&mut self, event: ConnectivityEvent) { match event { // If this node is connecting to a peer ConnectivityEvent::PeerConnected(conn) if conn.direction().is_outbound() => { diff --git a/base_layer/core/src/mempool/sync_protocol/test.rs b/base_layer/core/src/mempool/sync_protocol/test.rs index dd77fe3c70..30f4b096d2 100644 --- a/base_layer/core/src/mempool/sync_protocol/test.rs +++ b/base_layer/core/src/mempool/sync_protocol/test.rs @@ -30,7 +30,7 @@ use crate::{ transactions::{helpers::create_tx, tari_amount::uT, transaction::Transaction}, validation::mocks::MockValidator, }; -use futures::{channel::mpsc, Sink, SinkExt, Stream, StreamExt}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use std::{fmt, io, iter::repeat_with, sync::Arc}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventTx}, @@ -44,7 +44,10 @@ use tari_comms::{ BytesMut, }; use tari_crypto::tari_utilities::ByteArray; -use tokio::{sync::broadcast, task}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; pub fn create_transactions(n: usize) -> Vec { repeat_with(|| { @@ -82,7 +85,6 @@ fn setup( protocol_notif_rx, connectivity_events_rx, mempool.clone(), - None, ); task::spawn(protocol.run()); @@ -90,7 +92,7 @@ fn setup( (protocol_notif_tx, connectivity_events_tx, mempool, transactions) } -#[tokio_macros::test_basic] +#[tokio::test] async fn empty_set() { let (_, connectivity_events_tx, mempool1, _) = setup(0); @@ -101,7 +103,7 @@ async fn empty_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -120,7 +122,7 @@ async fn empty_set() { assert_eq!(transactions.len(), 0); } -#[tokio_macros::test_basic] +#[tokio::test] async fn synchronise() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(5); @@ -131,7 +133,7 @@ async fn synchronise() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -154,7 +156,7 @@ async fn synchronise() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn duplicate_set() { let (_, connectivity_events_tx, mempool1, transactions1) = setup(2); @@ -165,7 +167,7 @@ async fn duplicate_set() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); @@ -189,9 +191,9 @@ async fn duplicate_set() { assert!(transactions2.iter().all(|txn| transactions.contains(&txn))); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -225,9 +227,9 @@ async fn responder() { // this. } -#[tokio_macros::test_basic] +#[tokio::test] async fn initiator_messages() { - let (mut protocol_notif, _, _, transactions1) = setup(2); + let (protocol_notif, _, _, transactions1) = setup(2); let node1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -260,7 +262,7 @@ async fn initiator_messages() { assert_eq!(indexes.indexes, [0, 1]); } -#[tokio_macros::test_basic] +#[tokio::test] async fn responder_messages() { let (_, connectivity_events_tx, _, transactions1) = setup(1); @@ -271,7 +273,7 @@ async fn responder_messages() { // This node connected to a peer, so it should open the substream connectivity_events_tx - .send(Arc::new(ConnectivityEvent::PeerConnected(node2_conn))) + .send(ConnectivityEvent::PeerConnected(node2_conn)) .unwrap(); let substream = node1_mock.next_incoming_substream().await.unwrap(); diff --git a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs index db4c1e95ec..0a9d7a4e64 100644 --- a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs +++ b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs @@ -20,6 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::tari_utilities::{hex::Hex, Hashable}; + use crate::{ blocks::Block, mempool::{ @@ -27,18 +36,9 @@ use crate::{ priority::{FeePriority, PrioritizedTransaction}, unconfirmed_pool::UnconfirmedPoolError, }, - transactions::{ - transaction::Transaction, - types::{HashOutput, Signature}, - }, + transactions::transaction::Transaction, }; -use log::*; -use serde::{Deserialize, Serialize}; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; -use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tari_common_types::types::{HashOutput, Signature}; pub const LOG_TARGET: &str = "c::mp::unconfirmed_pool::unconfirmed_pool_storage"; @@ -474,7 +474,8 @@ impl UnconfirmedPool { #[cfg(test)] mod test { - use super::*; + use tari_common::configuration::Network; + use crate::{ consensus::ConsensusManagerBuilder, test_helpers::create_orphan_block, @@ -483,12 +484,14 @@ mod test { helpers::{TestParams, UtxoTestParams}, tari_amount::MicroTari, transaction::KernelFeatures, - types::{CryptoFactories, HashDigest}, + CryptoFactories, SenderTransactionProtocol, }, tx, }; - use tari_common::configuration::Network; + use tari_common_types::types::HashDigest; + + use super::*; #[test] fn test_find_duplicate_input() { diff --git a/base_layer/core/src/proto/block.rs b/base_layer/core/src/proto/block.rs index 94a2f7fd20..50778f5a92 100644 --- a/base_layer/core/src/proto/block.rs +++ b/base_layer/core/src/proto/block.rs @@ -25,10 +25,9 @@ use crate::{ blocks::{Block, NewBlock, NewBlockHeaderTemplate, NewBlockTemplate}, chain_storage::{BlockHeaderAccumulatedData, HistoricalBlock}, proof_of_work::ProofOfWork, - transactions::types::BlindingFactor, }; use std::convert::{TryFrom, TryInto}; -use tari_common_types::types::BLOCK_HASH_LENGTH; +use tari_common_types::types::{BlindingFactor, BLOCK_HASH_LENGTH}; use tari_crypto::tari_utilities::ByteArray; //---------------------------------- Block --------------------------------------------// diff --git a/base_layer/core/src/proto/block_header.rs b/base_layer/core/src/proto/block_header.rs index 4258836106..a2ac77689e 100644 --- a/base_layer/core/src/proto/block_header.rs +++ b/base_layer/core/src/proto/block_header.rs @@ -25,9 +25,9 @@ use crate::{ blocks::BlockHeader, proof_of_work::{PowAlgorithm, ProofOfWork}, proto::utils::{datetime_to_timestamp, timestamp_to_datetime}, - transactions::types::BlindingFactor, }; use std::convert::TryFrom; +use tari_common_types::types::BlindingFactor; use tari_crypto::tari_utilities::ByteArray; //---------------------------------- BlockHeader --------------------------------------------// diff --git a/base_layer/core/src/proto/transaction.rs b/base_layer/core/src/proto/transaction.rs index cb96c11c66..d500157360 100644 --- a/base_layer/core/src/proto/transaction.rs +++ b/base_layer/core/src/proto/transaction.rs @@ -27,7 +27,6 @@ use crate::{ tari_utilities::convert::try_convert_all, transactions::{ aggregated_body::AggregateBody, - bullet_rangeproofs::BulletRangeProof, tari_amount::MicroTari, transaction::{ KernelFeatures, @@ -38,10 +37,10 @@ use crate::{ TransactionKernel, TransactionOutput, }, - types::{BlindingFactor, Commitment, PublicKey}, }, }; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::{BlindingFactor, BulletRangeProof, Commitment, PublicKey}; use tari_crypto::{ script::{ExecutionStack, TariScript}, tari_utilities::{ByteArray, ByteArrayError}, diff --git a/base_layer/core/src/proto/types_impls.rs b/base_layer/core/src/proto/types_impls.rs index e978d86724..8e865345c1 100644 --- a/base_layer/core/src/proto/types_impls.rs +++ b/base_layer/core/src/proto/types_impls.rs @@ -21,7 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::types as proto; -use crate::transactions::types::{ +use std::convert::TryFrom; +use tari_common_types::types::{ BlindingFactor, ComSignature, Commitment, @@ -30,7 +31,6 @@ use crate::transactions::types::{ PublicKey, Signature, }; -use std::convert::TryFrom; use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; //---------------------------------- Commitment --------------------------------------------// diff --git a/base_layer/core/src/test_helpers/blockchain.rs b/base_layer/core/src/test_helpers/blockchain.rs index e871ed10c9..d52df6deb7 100644 --- a/base_layer/core/src/test_helpers/blockchain.rs +++ b/base_layer/core/src/test_helpers/blockchain.rs @@ -20,6 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + fs, + ops::Deref, + path::{Path, PathBuf}, +}; + +use croaring::Bitmap; + +use tari_common::configuration::Network; +use tari_common_types::chain_metadata::ChainMetadata; +use tari_storage::lmdb_store::LMDBConfig; +use tari_test_utils::paths::create_temporary_data_path; + use crate::{ blocks::{genesis_block::get_weatherwax_genesis_block, Block, BlockHeader}, chain_storage::{ @@ -45,7 +58,7 @@ use crate::{ consensus::{chain_strength_comparer::ChainStrengthComparerBuilder, ConsensusConstantsBuilder, ConsensusManager}, transactions::{ transaction::{TransactionInput, TransactionKernel}, - types::{Commitment, CryptoFactories, HashOutput, Signature}, + CryptoFactories, }, validation::{ block_validators::{BodyOnlyValidator, OrphanBlockValidator}, @@ -53,16 +66,7 @@ use crate::{ DifficultyCalculator, }, }; -use croaring::Bitmap; -use std::{ - fs, - ops::Deref, - path::{Path, PathBuf}, -}; -use tari_common::configuration::Network; -use tari_common_types::chain_metadata::ChainMetadata; -use tari_storage::lmdb_store::LMDBConfig; -use tari_test_utils::paths::create_temporary_data_path; +use tari_common_types::types::{Commitment, HashOutput, Signature}; /// Create a new blockchain database containing no blocks. pub fn create_new_blockchain() -> BlockchainDatabase { @@ -111,7 +115,7 @@ pub fn create_store_with_consensus(rules: ConsensusManager) -> BlockchainDatabas let validators = Validators::new( BodyOnlyValidator::default(), MockValidator::new(true), - OrphanBlockValidator::new(rules.clone(), factories), + OrphanBlockValidator::new(rules.clone(), false, factories), ); create_store_with_consensus_and_validators(rules, validators) } diff --git a/base_layer/core/src/test_helpers/mod.rs b/base_layer/core/src/test_helpers/mod.rs index a1055b75da..bbcb23da67 100644 --- a/base_layer/core/src/test_helpers/mod.rs +++ b/base_layer/core/src/test_helpers/mod.rs @@ -23,7 +23,13 @@ //! Common test helper functions that are small and useful enough to be included in the main crate, rather than the //! integration test folder. -pub mod blockchain; +use std::{iter, path::Path, sync::Arc}; + +use rand::{distributions::Alphanumeric, Rng}; + +use tari_common::configuration::Network; +use tari_comms::PeerManager; +use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use crate::{ blocks::{Block, BlockHeader}, @@ -34,15 +40,12 @@ use crate::{ transactions::{ tari_amount::T, transaction::{Transaction, UnblindedOutput}, - types::CryptoFactories, CoinbaseBuilder, + CryptoFactories, }, }; -use rand::{distributions::Alphanumeric, Rng}; -use std::{iter, path::Path, sync::Arc}; -use tari_common::configuration::Network; -use tari_comms::PeerManager; -use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; + +pub mod blockchain; /// Create a partially constructed block using the provided set of transactions /// is chain_block, or rename it to `create_orphan_block` and drop the prev_block argument diff --git a/base_layer/core/src/transactions/aggregated_body.rs b/base_layer/core/src/transactions/aggregated_body.rs index d1dfb4f198..ac44b04b4d 100644 --- a/base_layer/core/src/transactions/aggregated_body.rs +++ b/base_layer/core/src/transactions/aggregated_body.rs @@ -1,3 +1,14 @@ +use std::fmt::{Display, Error, Formatter}; + +use log::*; +use serde::{Deserialize, Serialize}; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::PublicKey as PublicKeyTrait, + ristretto::pedersen::PedersenCommitment, + tari_utilities::hex::Hex, +}; + // Copyright 2019, The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -19,20 +30,14 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::{ - fee::Fee, - tari_amount::*, - transaction::*, - types::{BlindingFactor, Commitment, CommitmentFactory, CryptoFactories, PrivateKey, PublicKey, RangeProofService}, -}; -use log::*; -use serde::{Deserialize, Serialize}; -use std::fmt::{Display, Error, Formatter}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::PublicKey as PublicKeyTrait, - ristretto::pedersen::PedersenCommitment, - tari_utilities::hex::Hex, +use crate::transactions::{crypto_factories::CryptoFactories, fee::Fee, tari_amount::*, transaction::*}; +use tari_common_types::types::{ + BlindingFactor, + Commitment, + CommitmentFactory, + PrivateKey, + PublicKey, + RangeProofService, }; pub const LOG_TARGET: &str = "c::tx::aggregated_body"; @@ -307,6 +312,7 @@ impl AggregateBody { &self, tx_offset: &BlindingFactor, script_offset: &BlindingFactor, + bypass_range_proof_verification: bool, total_reward: MicroTari, factories: &CryptoFactories, ) -> Result<(), TransactionError> { @@ -316,7 +322,9 @@ impl AggregateBody { self.verify_kernel_signatures()?; self.validate_kernel_sum(total_offset, &factories.commitment)?; - self.validate_range_proofs(&factories.range_proof)?; + if !bypass_range_proof_verification { + self.validate_range_proofs(&factories.range_proof)?; + } self.verify_metadata_signatures()?; self.validate_script_offset(script_offset_g, &factories.commitment) } diff --git a/base_layer/core/src/transactions/bullet_rangeproofs.rs b/base_layer/core/src/transactions/bullet_rangeproofs.rs index 9d96e2bb03..5ba0a05923 100644 --- a/base_layer/core/src/transactions/bullet_rangeproofs.rs +++ b/base_layer/core/src/transactions/bullet_rangeproofs.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::transactions::types::HashDigest; use digest::Digest; use serde::{ de::{self, Visitor}, @@ -30,6 +29,7 @@ use serde::{ Serializer, }; use std::fmt; +use tari_common_types::types::HashDigest; use tari_crypto::tari_utilities::{byte_array::*, hash::*, hex::*}; #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] diff --git a/base_layer/core/src/transactions/coinbase_builder.rs b/base_layer/core/src/transactions/coinbase_builder.rs index 52cb4ac0ba..5091844e6a 100644 --- a/base_layer/core/src/transactions/coinbase_builder.rs +++ b/base_layer/core/src/transactions/coinbase_builder.rs @@ -21,12 +21,23 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + inputs, + keys::{PublicKey as PK, SecretKey}, + script, + script::TariScript, +}; +use thiserror::Error; + use crate::{ consensus::{ emission::{Emission, EmissionSchedule}, ConsensusConstants, }, transactions::{ + crypto_factories::CryptoFactories, tari_amount::{uT, MicroTari}, transaction::{ KernelBuilder, @@ -38,18 +49,9 @@ use crate::{ UnblindedOutput, }, transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, - types::{BlindingFactor, CryptoFactories, PrivateKey, PublicKey, Signature}, }, }; -use rand::rngs::OsRng; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - inputs, - keys::{PublicKey as PK, SecretKey}, - script, - script::TariScript, -}; -use thiserror::Error; +use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey, Signature}; #[derive(Debug, Clone, Error, PartialEq)] pub enum CoinbaseBuildError { @@ -241,21 +243,24 @@ impl CoinbaseBuilder { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey as SecretKeyTrait}; + + use tari_common::configuration::Network; + use crate::{ consensus::{emission::Emission, ConsensusManager, ConsensusManagerBuilder}, transactions::{ coinbase_builder::CoinbaseBuildError, + crypto_factories::CryptoFactories, helpers::TestParams, tari_amount::uT, transaction::{KernelFeatures, OutputFeatures, OutputFlags, TransactionError}, transaction_protocol::RewindData, - types::{BlindingFactor, CryptoFactories, PrivateKey}, CoinbaseBuilder, }, }; - use rand::rngs::OsRng; - use tari_common::configuration::Network; - use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey as SecretKeyTrait}; + use tari_common_types::types::{BlindingFactor, PrivateKey}; fn get_builder() -> (CoinbaseBuilder, ConsensusManager, CryptoFactories) { let network = Network::LocalNet; @@ -520,6 +525,7 @@ mod test { tx.body.validate_internal_consistency( &BlindingFactor::default(), &PrivateKey::default(), + false, block_reward, &factories ), diff --git a/base_layer/core/src/transactions/crypto_factories.rs b/base_layer/core/src/transactions/crypto_factories.rs new file mode 100644 index 0000000000..86270dc42d --- /dev/null +++ b/base_layer/core/src/transactions/crypto_factories.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use tari_common_types::types::{CommitmentFactory, RangeProofService, MAX_RANGE_PROOF_RANGE}; + +/// A convenience struct wrapping cryptographic factories that are used through-out the rest of the code base +/// Uses Arcs internally so calling clone on this is cheap, no need to wrap this in an Arc +pub struct CryptoFactories { + pub commitment: Arc, + pub range_proof: Arc, +} + +impl Default for CryptoFactories { + /// Return a default set of crypto factories based on Pedersen commitments with G and H defined in + /// [pedersen.rs](/infrastructure/crypto/src/ristretto/pedersen.rs), and an associated range proof factory with a + /// range of `[0; 2^64)`. + fn default() -> Self { + CryptoFactories::new(MAX_RANGE_PROOF_RANGE) + } +} + +impl CryptoFactories { + /// Create a new set of crypto factories. + /// + /// ## Parameters + /// + /// * `max_proof_range`: Sets the the maximum value in range proofs, where `max = 2^max_proof_range` + pub fn new(max_proof_range: usize) -> Self { + let commitment = Arc::new(CommitmentFactory::default()); + let range_proof = Arc::new(RangeProofService::new(max_proof_range, &commitment).unwrap()); + Self { + commitment, + range_proof, + } + } +} + +/// Uses Arc's internally so calling clone on this is cheap, no need to wrap this in an Arc +impl Clone for CryptoFactories { + fn clone(&self) -> Self { + Self { + commitment: self.commitment.clone(), + range_proof: self.range_proof.clone(), + } + } +} diff --git a/base_layer/core/src/transactions/helpers.rs b/base_layer/core/src/transactions/helpers.rs index 8e6c4d4c7b..54f90cebb8 100644 --- a/base_layer/core/src/transactions/helpers.rs +++ b/base_layer/core/src/transactions/helpers.rs @@ -20,7 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use num::pow; +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + inputs, + keys::{PublicKey as PK, SecretKey}, + range_proof::RangeProofService, + script, + script::{ExecutionStack, TariScript}, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, tari_amount::MicroTari, transaction::{ @@ -34,21 +49,9 @@ use crate::transactions::{ UnblindedOutput, }, transaction_protocol::{build_challenge, TransactionMetadata}, - types::{Commitment, CommitmentFactory, CryptoFactories, PrivateKey, PublicKey, Signature}, SenderTransactionProtocol, }; -use num::pow; -use rand::rngs::OsRng; -use std::sync::Arc; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - inputs, - keys::{PublicKey as PK, SecretKey}, - range_proof::RangeProofService, - script, - script::{ExecutionStack, TariScript}, -}; +use tari_common_types::types::{Commitment, CommitmentFactory, PrivateKey, PublicKey, Signature}; pub fn create_test_input( amount: MicroTari, diff --git a/base_layer/core/src/transactions/mod.rs b/base_layer/core/src/transactions/mod.rs index 9d98bd394b..5653939bd4 100644 --- a/base_layer/core/src/transactions/mod.rs +++ b/base_layer/core/src/transactions/mod.rs @@ -1,20 +1,19 @@ pub mod aggregated_body; -pub mod bullet_rangeproofs; +mod crypto_factories; pub mod fee; pub mod tari_amount; pub mod transaction; #[allow(clippy::op_ref)] pub mod transaction_protocol; + +pub use crypto_factories::*; + pub mod types; // Re-export commonly used structs pub use transaction_protocol::{recipient::ReceiverTransactionProtocol, sender::SenderTransactionProtocol}; #[macro_use] pub mod helpers; -#[cfg(any(feature = "base_node", feature = "transactions"))] -mod coinbase_builder; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuildError; -#[cfg(any(feature = "base_node", feature = "transactions"))] -pub use crate::transactions::coinbase_builder::CoinbaseBuilder; +mod coinbase_builder; +pub use crate::transactions::coinbase_builder::{CoinbaseBuildError, CoinbaseBuilder}; diff --git a/base_layer/core/src/transactions/transaction.rs b/base_layer/core/src/transactions/transaction.rs index 81d972d90b..ce3e2c8aec 100644 --- a/base_layer/core/src/transactions/transaction.rs +++ b/base_layer/core/src/transactions/transaction.rs @@ -23,29 +23,6 @@ // Portions of this file were originally copyrighted (c) 2018 The Grin Developers, issued under the Apache License, // Version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0. -use crate::transactions::{ - aggregated_body::AggregateBody, - tari_amount::{uT, MicroTari}, - transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, - types::{ - BlindingFactor, - Challenge, - ComSignature, - Commitment, - CommitmentFactory, - CryptoFactories, - HashDigest, - MessageHash, - PrivateKey, - PublicKey, - RangeProof, - RangeProofService, - Signature, - }, -}; -use blake2::Digest; -use rand::rngs::OsRng; -use serde::{Deserialize, Serialize}; use std::{ cmp::{max, min, Ordering}, fmt, @@ -53,6 +30,10 @@ use std::{ hash::{Hash, Hasher}, ops::Add, }; + +use blake2::Digest; +use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::{PublicKey as PublicKeyTrait, SecretKey}, @@ -70,6 +51,27 @@ use tari_crypto::{ }; use thiserror::Error; +use crate::transactions::{ + aggregated_body::AggregateBody, + crypto_factories::CryptoFactories, + tari_amount::{uT, MicroTari}, + transaction_protocol::{build_challenge, RewindData, TransactionMetadata}, +}; +use tari_common_types::types::{ + BlindingFactor, + Challenge, + ComSignature, + Commitment, + CommitmentFactory, + HashDigest, + MessageHash, + PrivateKey, + PublicKey, + RangeProof, + RangeProofService, + Signature, +}; + // Tx_weight(inputs(12,500), outputs(500), kernels(1)) = 19,003, still well enough below block weight of 19,500 pub const MAX_TRANSACTION_INPUTS: usize = 12_500; pub const MAX_TRANSACTION_OUTPUTS: usize = 500; @@ -1109,12 +1111,18 @@ impl Transaction { #[allow(clippy::erasing_op)] // This is for 0 * uT pub fn validate_internal_consistency( &self, + bypass_range_proof_verification: bool, factories: &CryptoFactories, reward: Option, ) -> Result<(), TransactionError> { let reward = reward.unwrap_or_else(|| 0 * uT); - self.body - .validate_internal_consistency(&self.offset, &self.script_offset, reward, factories) + self.body.validate_internal_consistency( + &self.offset, + &self.script_offset, + bypass_range_proof_verification, + reward, + factories, + ) } pub fn get_body(&self) -> &AggregateBody { @@ -1264,7 +1272,7 @@ impl TransactionBuilder { if let (Some(script_offset), Some(offset)) = (self.script_offset, self.offset) { let (i, o, k) = self.body.dissolve(); let tx = Transaction::new(i, o, k, offset, script_offset); - tx.validate_internal_consistency(factories, self.reward)?; + tx.validate_internal_consistency(true, factories, self.reward)?; Ok(tx) } else { Err(TransactionError::ValidationError( @@ -1289,24 +1297,26 @@ impl Default for TransactionBuilder { #[cfg(test)] mod test { - use super::*; + use rand::{self, rngs::OsRng}; + use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + ristretto::pedersen::PedersenCommitmentFactory, + script, + script::ExecutionStack, + }; + use crate::{ transactions::{ helpers, helpers::{TestParams, UtxoTestParams}, tari_amount::T, transaction::OutputFeatures, - types::{BlindingFactor, PrivateKey, PublicKey, RangeProof}, }, txn_schema, }; - use rand::{self, rngs::OsRng}; - use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - ristretto::pedersen::PedersenCommitmentFactory, - script, - script::ExecutionStack, - }; + use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; + + use super::*; #[test] fn input_and_output_hash_match() { @@ -1514,7 +1524,7 @@ mod test { let (tx, _, _) = helpers::create_tx(5000.into(), 15.into(), 1, 2, 1, 4); let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(false, &factories, None).is_ok()); } #[test] @@ -1527,7 +1537,7 @@ mod test { assert_eq!(tx.body.kernels().len(), 1); let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(false, &factories, None).is_ok()); let schema = txn_schema!(from: vec![outputs[1].clone()], to: vec![1 * T, 2 * T]); let (tx2, _outputs, _) = helpers::spend_utxos(schema); @@ -1558,10 +1568,12 @@ mod test { } // Validate basis transaction where cut-through has not been applied. - assert!(tx3.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx3.validate_internal_consistency(false, &factories, None).is_ok()); // tx3_cut_through has manual cut-through, it should not be possible so this should fail - assert!(tx3_cut_through.validate_internal_consistency(&factories, None).is_err()); + assert!(tx3_cut_through + .validate_internal_consistency(false, &factories, None) + .is_err()); } #[test] @@ -1598,7 +1610,7 @@ mod test { tx.body.inputs_mut()[0].input_data = stack; let factories = CryptoFactories::default(); - let err = tx.validate_internal_consistency(&factories, None).unwrap_err(); + let err = tx.validate_internal_consistency(false, &factories, None).unwrap_err(); assert!(matches!(err, TransactionError::InvalidSignatureError(_))); } diff --git a/base_layer/core/src/transactions/transaction_protocol/mod.rs b/base_layer/core/src/transactions/transaction_protocol/mod.rs index 2f66ef3643..5e140c4388 100644 --- a/base_layer/core/src/transactions/transaction_protocol/mod.rs +++ b/base_layer/core/src/transactions/transaction_protocol/mod.rs @@ -86,13 +86,11 @@ pub mod sender; pub mod single_receiver; pub mod transaction_initializer; -use crate::transactions::{ - tari_amount::*, - transaction::TransactionError, - types::{Challenge, MessageHash, PrivateKey, PublicKey}, -}; +use crate::transactions::{tari_amount::*, transaction::TransactionError}; use digest::Digest; use serde::{Deserialize, Serialize}; +use tari_common_types::types::{MessageHash, PrivateKey, PublicKey}; +use tari_comms::types::Challenge; use tari_crypto::{ range_proof::{RangeProofError, REWIND_USER_MESSAGE_LENGTH}, signatures::SchnorrSignatureError, diff --git a/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs b/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs index 699cf3145a..c149874ef3 100644 --- a/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs +++ b/base_layer/core/src/transactions/transaction_protocol/proto/recipient_signed_message.rs @@ -22,8 +22,9 @@ use super::protocol as proto; -use crate::transactions::{transaction_protocol::recipient::RecipientSignedMessage, types::PublicKey}; +use crate::transactions::transaction_protocol::recipient::RecipientSignedMessage; use std::convert::{TryFrom, TryInto}; +use tari_common_types::types::PublicKey; use tari_crypto::tari_utilities::ByteArray; impl TryFrom for RecipientSignedMessage { diff --git a/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs b/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs index 22f0e59306..14c0f7ee2c 100644 --- a/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/proto/transaction_sender.rs @@ -28,8 +28,8 @@ use std::convert::{TryFrom, TryInto}; use tari_crypto::tari_utilities::ByteArray; // The generated _oneof_ enum -use crate::transactions::types::PublicKey; use proto::transaction_sender_message::Message as ProtoTxnSenderMessage; +use tari_common_types::types::PublicKey; use tari_crypto::script::TariScript; impl proto::TransactionSenderMessage { diff --git a/base_layer/core/src/transactions/transaction_protocol/recipient.rs b/base_layer/core/src/transactions/transaction_protocol/recipient.rs index e9f21c3778..2518f8f1de 100644 --- a/base_layer/core/src/transactions/transaction_protocol/recipient.rs +++ b/base_layer/core/src/transactions/transaction_protocol/recipient.rs @@ -20,7 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{collections::HashMap, fmt}; + +use serde::{Deserialize, Serialize}; + use crate::transactions::{ + crypto_factories::CryptoFactories, transaction::{OutputFeatures, TransactionOutput}, transaction_protocol::{ sender::{SingleRoundSenderData as SD, TransactionSenderMessage}, @@ -28,10 +33,8 @@ use crate::transactions::{ RewindData, TransactionProtocolError, }, - types::{CryptoFactories, MessageHash, PrivateKey, PublicKey, Signature}, }; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt}; +use tari_common_types::types::{MessageHash, PrivateKey, PublicKey, Signature}; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[allow(clippy::large_enum_variant)] @@ -202,9 +205,16 @@ impl ReceiverTransactionProtocol { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PK, SecretKey as SecretKeyTrait}, + }; + use crate::{ crypto::script::TariScript, transactions::{ + crypto_factories::CryptoFactories, helpers::TestParams, tari_amount::*, transaction::OutputFeatures, @@ -214,15 +224,10 @@ mod test { RewindData, TransactionMetadata, }, - types::{CryptoFactories, PrivateKey, PublicKey, Signature}, ReceiverTransactionProtocol, }, }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PK, SecretKey as SecretKeyTrait}, - }; + use tari_common_types::types::{PrivateKey, PublicKey, Signature}; #[test] fn single_round_recipient() { diff --git a/base_layer/core/src/transactions/transaction_protocol/sender.rs b/base_layer/core/src/transactions/transaction_protocol/sender.rs index 0341dcbda1..c91097d32a 100644 --- a/base_layer/core/src/transactions/transaction_protocol/sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/sender.rs @@ -20,7 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::fmt; + +use digest::Digest; +use serde::{Deserialize, Serialize}; +use tari_crypto::{ + keys::PublicKey as PublicKeyTrait, + ristretto::pedersen::{PedersenCommitment, PedersenCommitmentFactory}, + script::TariScript, + tari_utilities::ByteArray, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, tari_amount::*, transaction::{ KernelBuilder, @@ -42,17 +54,8 @@ use crate::transactions::{ TransactionMetadata, TransactionProtocolError as TPE, }, - types::{BlindingFactor, ComSignature, CryptoFactories, PrivateKey, PublicKey, RangeProofService, Signature}, -}; -use digest::Digest; -use serde::{Deserialize, Serialize}; -use std::fmt; -use tari_crypto::{ - keys::PublicKey as PublicKeyTrait, - ristretto::pedersen::{PedersenCommitment, PedersenCommitmentFactory}, - script::TariScript, - tari_utilities::ByteArray, }; +use tari_common_types::types::{BlindingFactor, ComSignature, PrivateKey, PublicKey, RangeProofService, Signature}; //---------------------------------------- Local Data types ----------------------------------------------------// @@ -562,7 +565,7 @@ impl SenderTransactionProtocol { } let transaction = result.unwrap(); let result = transaction - .validate_internal_consistency(factories, None) + .validate_internal_consistency(true, factories, None) .map_err(TPE::TransactionBuildError); if let Err(e) = result { self.state = SenderState::Failed(e.clone()); @@ -705,7 +708,20 @@ impl fmt::Display for SenderState { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + range_proof::RangeProofService, + ristretto::pedersen::PedersenCommitmentFactory, + script, + script::{ExecutionStack, TariScript}, + tari_utilities::{hex::Hex, ByteArray}, + }; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, helpers::{create_test_input, create_unblinded_output, TestParams}, tari_amount::*, @@ -716,19 +732,8 @@ mod test { RewindData, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey, PublicKey, RangeProof}, - }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - range_proof::RangeProofService, - ristretto::pedersen::PedersenCommitmentFactory, - script, - script::{ExecutionStack, TariScript}, - tari_utilities::{hex::Hex, ByteArray}, }; + use tari_common_types::types::{PrivateKey, PublicKey, RangeProof}; #[test] fn test_metadata_signature_finalize() { @@ -965,7 +970,10 @@ mod test { assert_eq!(tx.body.inputs().len(), 1); assert_eq!(tx.body.inputs()[0], utxo); assert_eq!(tx.body.outputs().len(), 2); - assert!(tx.clone().validate_internal_consistency(&factories, None).is_ok()); + assert!(tx + .clone() + .validate_internal_consistency(false, &factories, None) + .is_ok()); } #[test] diff --git a/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs b/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs index 5d60e64acc..7f6060fed6 100644 --- a/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs +++ b/base_layer/core/src/transactions/transaction_protocol/single_receiver.rs @@ -20,7 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::PublicKey as PK, + range_proof::{RangeProofError, RangeProofService as RPS}, + tari_utilities::byte_array::ByteArray, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, transaction::{OutputFeatures, TransactionOutput}, transaction_protocol::{ build_challenge, @@ -29,14 +37,8 @@ use crate::transactions::{ RewindData, TransactionProtocolError as TPE, }, - types::{CryptoFactories, PrivateKey as SK, PublicKey, RangeProof, Signature}, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::PublicKey as PK, - range_proof::{RangeProofError, RangeProofService as RPS}, - tari_utilities::byte_array::ByteArray, }; +use tari_common_types::types::{PrivateKey as SK, PublicKey, RangeProof, Signature}; /// SingleReceiverTransactionProtocol represents the actions taken by the single receiver in the one-round Tari /// transaction protocol. The procedure is straightforward. Upon receiving the sender's information, the receiver: @@ -133,7 +135,15 @@ impl SingleReceiverTransactionProtocol { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PK, SecretKey as SK}, + script::TariScript, + }; + use crate::transactions::{ + crypto_factories::CryptoFactories, tari_amount::*, transaction::OutputFeatures, transaction_protocol::{ @@ -143,14 +153,8 @@ mod test { TransactionMetadata, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey, PublicKey}, - }; - use rand::rngs::OsRng; - use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PK, SecretKey as SK}, - script::TariScript, }; + use tari_common_types::types::{PrivateKey, PublicKey}; fn generate_output_parms() -> (PrivateKey, PrivateKey, OutputFeatures) { let r = PrivateKey::random(&mut OsRng); diff --git a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs index 0d5beb738d..b27a1e528f 100644 --- a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs +++ b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs @@ -20,7 +20,24 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + fmt::{Debug, Error, Formatter}, +}; + +use digest::Digest; +use log::*; +use rand::rngs::OsRng; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + ristretto::pedersen::PedersenCommitmentFactory, + script::{ExecutionStack, TariScript}, + tari_utilities::fixed_set::FixedSet, +}; + use crate::transactions::{ + crypto_factories::CryptoFactories, fee::Fee, tari_amount::*, transaction::{ @@ -38,22 +55,8 @@ use crate::transactions::{ RewindData, TransactionMetadata, }, - types::{BlindingFactor, CryptoFactories, PrivateKey, PublicKey}, -}; -use digest::Digest; -use log::*; -use rand::rngs::OsRng; -use std::{ - collections::HashMap, - fmt::{Debug, Error, Formatter}, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - ristretto::pedersen::PedersenCommitmentFactory, - script::{ExecutionStack, TariScript}, - tari_utilities::fixed_set::FixedSet, }; +use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; pub const LOG_TARGET: &str = "c::tx::tx_protocol::tx_initializer"; @@ -571,9 +574,18 @@ impl SenderTransactionInitializer { #[cfg(test)] mod test { + use rand::rngs::OsRng; + use tari_crypto::{ + common::Blake256, + keys::SecretKey, + script, + script::{ExecutionStack, TariScript}, + }; + use crate::{ consensus::{KERNEL_WEIGHT, WEIGHT_PER_INPUT, WEIGHT_PER_OUTPUT}, transactions::{ + crypto_factories::CryptoFactories, fee::Fee, helpers::{create_test_input, create_unblinded_output, TestParams, UtxoTestParams}, tari_amount::*, @@ -583,16 +595,9 @@ mod test { transaction_initializer::SenderTransactionInitializer, TransactionProtocolError, }, - types::{CryptoFactories, PrivateKey}, }, }; - use rand::rngs::OsRng; - use tari_crypto::{ - common::Blake256, - keys::SecretKey, - script, - script::{ExecutionStack, TariScript}, - }; + use tari_common_types::types::PrivateKey; /// One input, 2 outputs #[test] @@ -763,6 +768,7 @@ mod test { .with_output(output, p.sender_offset_private_key) .unwrap() .with_fee_per_gram(MicroTari(2)); + for _ in 0..MAX_TRANSACTION_INPUTS + 1 { let (utxo, input) = create_test_input(MicroTari(50), 0, &factories.commitment); builder.with_input(utxo, input); diff --git a/base_layer/core/src/transactions/types.rs b/base_layer/core/src/transactions/types.rs index 40b0fa625d..d051788bca 100644 --- a/base_layer/core/src/transactions/types.rs +++ b/base_layer/core/src/transactions/types.rs @@ -19,99 +19,3 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::transactions::bullet_rangeproofs::BulletRangeProof; -use std::sync::Arc; -use tari_crypto::{ - common::Blake256, - ristretto::{ - dalek_range_proof::DalekRangeProofService, - pedersen::{PedersenCommitment, PedersenCommitmentFactory}, - RistrettoComSig, - RistrettoPublicKey, - RistrettoSchnorr, - RistrettoSecretKey, - }, -}; - -/// Define the explicit Signature implementation for the Tari base layer. A different signature scheme can be -/// employed by redefining this type. -pub type Signature = RistrettoSchnorr; -/// Define the explicit Commitment Signature implementation for the Tari base layer. -pub type ComSignature = RistrettoComSig; - -/// Define the explicit Commitment implementation for the Tari base layer. -pub type Commitment = PedersenCommitment; -pub type CommitmentFactory = PedersenCommitmentFactory; - -/// Define the explicit Secret key implementation for the Tari base layer. -pub type PrivateKey = RistrettoSecretKey; -pub type BlindingFactor = RistrettoSecretKey; - -/// Define the hash function that will be used to produce a signature challenge -pub type SignatureHasher = Blake256; - -/// Define the explicit Public key implementation for the Tari base layer -pub type PublicKey = RistrettoPublicKey; - -/// Specify the Hash function for general hashing -pub type HashDigest = Blake256; - -/// Specify the digest type for signature challenges -pub type Challenge = Blake256; - -/// The type of output that `Challenge` produces -pub type MessageHash = Vec; - -/// Specify the range proof type -pub type RangeProofService = DalekRangeProofService; - -/// Specify the range proof -pub type RangeProof = BulletRangeProof; - -/// Define the data type that is used to store results of `HashDigest` -pub type HashOutput = Vec; - -pub const MAX_RANGE_PROOF_RANGE: usize = 64; // 2^64 - -/// A convenience struct wrapping cryptographic factories that are used through-out the rest of the code base -/// Uses Arcs internally so calling clone on this is cheap, no need to wrap this in an Arc -pub struct CryptoFactories { - pub commitment: Arc, - pub range_proof: Arc, -} - -impl Default for CryptoFactories { - /// Return a default set of crypto factories based on Pedersen commitments with G and H defined in - /// [pedersen.rs](/infrastructure/crypto/src/ristretto/pedersen.rs), and an associated range proof factory with a - /// range of `[0; 2^64)`. - fn default() -> Self { - CryptoFactories::new(MAX_RANGE_PROOF_RANGE) - } -} - -impl CryptoFactories { - /// Create a new set of crypto factories. - /// - /// ## Parameters - /// - /// * `max_proof_range`: Sets the the maximum value in range proofs, where `max = 2^max_proof_range` - pub fn new(max_proof_range: usize) -> Self { - let commitment = Arc::new(CommitmentFactory::default()); - let range_proof = Arc::new(RangeProofService::new(max_proof_range, &commitment).unwrap()); - Self { - commitment, - range_proof, - } - } -} - -/// Uses Arc's internally so calling clone on this is cheap, no need to wrap this in an Arc -impl Clone for CryptoFactories { - fn clone(&self) -> Self { - Self { - commitment: self.commitment.clone(), - range_proof: self.range_proof.clone(), - } - } -} diff --git a/base_layer/core/src/validation/block_validators.rs b/base_layer/core/src/validation/block_validators.rs index 0c4ee76bfd..3908c28a5f 100644 --- a/base_layer/core/src/validation/block_validators.rs +++ b/base_layer/core/src/validation/block_validators.rs @@ -1,3 +1,13 @@ +use std::marker::PhantomData; + +use log::*; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + tari_utilities::{hash::Hashable, hex::Hex}, +}; + +use tari_common_types::chain_metadata::ChainMetadata; + // Copyright 2019. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the @@ -27,7 +37,7 @@ use crate::{ transactions::{ aggregated_body::AggregateBody, transaction::{KernelFeatures, OutputFlags, TransactionError}, - types::CryptoFactories, + CryptoFactories, }, validation::{ helpers::{check_accounting_balance, check_block_weight, check_coinbase_output, is_all_unique_and_sorted}, @@ -37,13 +47,6 @@ use crate::{ ValidationError, }, }; -use log::*; -use std::marker::PhantomData; -use tari_common_types::chain_metadata::ChainMetadata; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - tari_utilities::{hash::Hashable, hex::Hex}, -}; pub const LOG_TARGET: &str = "c::val::block_validators"; @@ -51,12 +54,17 @@ pub const LOG_TARGET: &str = "c::val::block_validators"; #[derive(Clone)] pub struct OrphanBlockValidator { rules: ConsensusManager, + bypass_range_proof_verification: bool, factories: CryptoFactories, } impl OrphanBlockValidator { - pub fn new(rules: ConsensusManager, factories: CryptoFactories) -> Self { - Self { rules, factories } + pub fn new(rules: ConsensusManager, bypass_range_proof_verification: bool, factories: CryptoFactories) -> Self { + Self { + rules, + bypass_range_proof_verification, + factories, + } } } @@ -101,7 +109,12 @@ impl OrphanValidation for OrphanBlockValidator { trace!(target: LOG_TARGET, "SV - Output constraints are ok for {} ", &block_id); check_coinbase_output(block, &self.rules, &self.factories)?; trace!(target: LOG_TARGET, "SV - Coinbase output is ok for {} ", &block_id); - check_accounting_balance(block, &self.rules, &self.factories)?; + check_accounting_balance( + block, + &self.rules, + self.bypass_range_proof_verification, + &self.factories, + )?; trace!(target: LOG_TARGET, "SV - accounting balance correct for {}", &block_id); debug!( target: LOG_TARGET, @@ -311,15 +324,17 @@ fn check_mmr_roots(block: &Block, db: &B) -> Result<(), Va /// the block body using the header. It is assumed that the `BlockHeader` has already been validated. pub struct BlockValidator { rules: ConsensusManager, + bypass_range_proof_verification: bool, factories: CryptoFactories, phantom_data: PhantomData, } impl BlockValidator { - pub fn new(rules: ConsensusManager, factories: CryptoFactories) -> Self { + pub fn new(rules: ConsensusManager, bypass_range_proof_verification: bool, factories: CryptoFactories) -> Self { Self { rules, factories, + bypass_range_proof_verification, phantom_data: Default::default(), } } @@ -428,7 +443,12 @@ impl CandidateBlockBodyValidation for BlockValidator self.check_inputs(block)?; self.check_outputs(block)?; - check_accounting_balance(block, &self.rules, &self.factories)?; + check_accounting_balance( + block, + &self.rules, + self.bypass_range_proof_verification, + &self.factories, + )?; trace!(target: LOG_TARGET, "SV - accounting balance correct for {}", &block_id); debug!( target: LOG_TARGET, diff --git a/base_layer/core/src/validation/chain_balance.rs b/base_layer/core/src/validation/chain_balance.rs index f620bcfdca..6efcc3cb6e 100644 --- a/base_layer/core/src/validation/chain_balance.rs +++ b/base_layer/core/src/validation/chain_balance.rs @@ -20,18 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::marker::PhantomData; + +use log::*; +use tari_crypto::commitment::HomomorphicCommitmentFactory; + use crate::{ chain_storage::BlockchainBackend, consensus::ConsensusManager, - transactions::{ - tari_amount::MicroTari, - types::{Commitment, CryptoFactories, PrivateKey}, - }, + transactions::{tari_amount::MicroTari, CryptoFactories}, validation::{FinalHorizonStateValidation, ValidationError}, }; -use log::*; -use std::marker::PhantomData; -use tari_crypto::commitment::HomomorphicCommitmentFactory; +use tari_common_types::types::{Commitment, PrivateKey}; const LOG_TARGET: &str = "c::bn::state_machine_service::states::horizon_state_sync::chain_balance"; diff --git a/base_layer/core/src/validation/error.rs b/base_layer/core/src/validation/error.rs index 1a9ee2fab6..e651079fb8 100644 --- a/base_layer/core/src/validation/error.rs +++ b/base_layer/core/src/validation/error.rs @@ -24,8 +24,9 @@ use crate::{ blocks::{block_header::BlockHeaderValidationError, BlockValidationError}, chain_storage::ChainStorageError, proof_of_work::{monero_rx::MergeMineError, PowError}, - transactions::{transaction::TransactionError, types::HashOutput}, + transactions::transaction::TransactionError, }; +use tari_common_types::types::HashOutput; use thiserror::Error; #[derive(Debug, Error)] diff --git a/base_layer/core/src/validation/helpers.rs b/base_layer/core/src/validation/helpers.rs index e6405a04d2..f0d1947b1e 100644 --- a/base_layer/core/src/validation/helpers.rs +++ b/base_layer/core/src/validation/helpers.rs @@ -20,6 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use log::*; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}; + use crate::{ blocks::{ block_header::{BlockHeader, BlockHeaderValidationError}, @@ -38,11 +41,9 @@ use crate::{ PowAlgorithm, PowError, }, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::ValidationError, }; -use log::*; -use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable, hex::Hex}; pub const LOG_TARGET: &str = "c::val::helpers"; @@ -199,6 +200,7 @@ pub fn check_block_weight(block: &Block, consensus_constants: &ConsensusConstant pub fn check_accounting_balance( block: &Block, rules: &ConsensusManager, + bypass_range_proof_verification: bool, factories: &CryptoFactories, ) -> Result<(), ValidationError> { if block.header.height == 0 { @@ -210,7 +212,13 @@ pub fn check_accounting_balance( let total_coinbase = rules.calculate_coinbase_and_fees(block); block .body - .validate_internal_consistency(&offset, &script_offset, total_coinbase, factories) + .validate_internal_consistency( + &offset, + &script_offset, + bypass_range_proof_verification, + total_coinbase, + factories, + ) .map_err(|err| { warn!( target: LOG_TARGET, diff --git a/base_layer/core/src/validation/mocks.rs b/base_layer/core/src/validation/mocks.rs index 03c8951d3f..c2b3ffe5cf 100644 --- a/base_layer/core/src/validation/mocks.rs +++ b/base_layer/core/src/validation/mocks.rs @@ -24,7 +24,7 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockchainBackend, ChainBlock}, proof_of_work::{sha3_difficulty, AchievedTargetDifficulty, Difficulty, PowAlgorithm}, - transactions::{transaction::Transaction, types::Commitment}, + transactions::transaction::Transaction, validation::{ error::ValidationError, CandidateBlockBodyValidation, @@ -40,7 +40,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::Commitment}; #[derive(Clone)] pub struct MockValidator { diff --git a/base_layer/core/src/validation/test.rs b/base_layer/core/src/validation/test.rs index a5998fa3e2..e3a50914b2 100644 --- a/base_layer/core/src/validation/test.rs +++ b/base_layer/core/src/validation/test.rs @@ -20,6 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use tari_crypto::{commitment::HomomorphicCommitment, script}; + +use tari_common::configuration::Network; + use crate::{ blocks::BlockHeader, chain_storage::{BlockHeaderAccumulatedData, ChainBlock, ChainHeader, DbTransaction}, @@ -31,13 +37,11 @@ use crate::{ helpers::{create_random_signature_from_s_key, create_utxo}, tari_amount::{uT, MicroTari}, transaction::{KernelBuilder, KernelFeatures, OutputFeatures, TransactionKernel}, - types::{Commitment, CryptoFactories}, + CryptoFactories, }, validation::{header_iter::HeaderIter, ChainBalanceValidator, FinalHorizonStateValidation}, }; -use std::sync::Arc; -use tari_common::configuration::Network; -use tari_crypto::{commitment::HomomorphicCommitment, script}; +use tari_common_types::types::Commitment; #[test] fn header_iter_empty_and_invalid_height() { diff --git a/base_layer/core/src/validation/traits.rs b/base_layer/core/src/validation/traits.rs index e7fabb449b..cc8c287b0a 100644 --- a/base_layer/core/src/validation/traits.rs +++ b/base_layer/core/src/validation/traits.rs @@ -24,10 +24,10 @@ use crate::{ blocks::{Block, BlockHeader}, chain_storage::{BlockchainBackend, ChainBlock}, proof_of_work::AchievedTargetDifficulty, - transactions::{transaction::Transaction, types::Commitment}, + transactions::transaction::Transaction, validation::{error::ValidationError, DifficultyCalculator}, }; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::Commitment}; /// A validator that determines if a block body is valid, assuming that the header has already been /// validated diff --git a/base_layer/core/src/validation/transaction_validators.rs b/base_layer/core/src/validation/transaction_validators.rs index 59ff3bfc41..4f136aeea5 100644 --- a/base_layer/core/src/validation/transaction_validators.rs +++ b/base_layer/core/src/validation/transaction_validators.rs @@ -20,14 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use log::*; + use crate::{ blocks::BlockValidationError, chain_storage::{BlockchainBackend, BlockchainDatabase, MmrTree}, crypto::tari_utilities::Hashable, - transactions::{transaction::Transaction, types::CryptoFactories}, + transactions::{transaction::Transaction, CryptoFactories}, validation::{MempoolTransactionValidation, ValidationError}, }; -use log::*; pub const LOG_TARGET: &str = "c::val::transaction_validators"; @@ -40,17 +41,21 @@ pub const LOG_TARGET: &str = "c::val::transaction_validators"; /// This function does NOT check that inputs come from the UTXO set pub struct TxInternalConsistencyValidator { factories: CryptoFactories, + bypass_range_proof_verification: bool, } impl TxInternalConsistencyValidator { - pub fn new(factories: CryptoFactories) -> Self { - Self { factories } + pub fn new(factories: CryptoFactories, bypass_range_proof_verification: bool) -> Self { + Self { + factories, + bypass_range_proof_verification, + } } } impl MempoolTransactionValidation for TxInternalConsistencyValidator { fn validate(&self, tx: &Transaction) -> Result<(), ValidationError> { - tx.validate_internal_consistency(&self.factories, None) + tx.validate_internal_consistency(self.bypass_range_proof_verification, &self.factories, None) .map_err(ValidationError::TransactionError)?; Ok(()) } diff --git a/base_layer/core/tests/async_db.rs b/base_layer/core/tests/async_db.rs index afedd9c7d9..a8e7902ed5 100644 --- a/base_layer/core/tests/async_db.rs +++ b/base_layer/core/tests/async_db.rs @@ -21,16 +21,17 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -#[allow(dead_code)] -mod helpers; +use std::ops::Deref; + +use tari_crypto::{commitment::HomomorphicCommitmentFactory, tari_utilities::Hashable}; use helpers::{ block_builders::chain_block_with_new_coinbase, database::create_orphan_block, sample_blockchains::{create_blockchain_db_no_cut_through, create_new_blockchain}, }; -use std::ops::Deref; use tari_common::configuration::Network; +use tari_common_types::types::CommitmentFactory; use tari_core::{ blocks::Block, chain_storage::{async_db::AsyncBlockchainDb, BlockAddResult, PrunedOutput}, @@ -38,13 +39,15 @@ use tari_core::{ helpers::schema_to_transaction, tari_amount::T, transaction::{TransactionOutput, UnblindedOutput}, - types::{CommitmentFactory, CryptoFactories}, + CryptoFactories, }, txn_schema, }; -use tari_crypto::{commitment::HomomorphicCommitmentFactory, tari_utilities::Hashable}; use tari_test_utils::runtime::test_async; +#[allow(dead_code)] +mod helpers; + /// Finds the UTXO in a block corresponding to the unblinded output. We have to search for outputs because UTXOs get /// sorted in blocks, and so the order they were inserted in can change. fn find_utxo(output: &UnblindedOutput, block: &Block, factory: &CommitmentFactory) -> Option { diff --git a/base_layer/core/tests/base_node_rpc.rs b/base_layer/core/tests/base_node_rpc.rs index 9b96512d47..e8b627dd57 100644 --- a/base_layer/core/tests/base_node_rpc.rs +++ b/base_layer/core/tests/base_node_rpc.rs @@ -42,13 +42,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod helpers; - -use crate::helpers::{ - block_builders::{chain_block, create_genesis_block_with_coinbase_value}, - nodes::{BaseNodeBuilder, NodeInterfaces}, -}; use std::convert::TryFrom; + +use tempfile::{tempdir, TempDir}; + use tari_common::configuration::Network; use tari_comms::protocol::rpc::mock::RpcRequestMock; use tari_core::{ @@ -76,27 +73,30 @@ use tari_core::{ helpers::schema_to_transaction, tari_amount::{uT, T}, transaction::{TransactionOutput, UnblindedOutput}, - types::CryptoFactories, + CryptoFactories, }, txn_schema, }; -use tempfile::{tempdir, TempDir}; -use tokio::runtime::Runtime; -fn setup() -> ( +use crate::helpers::{ + block_builders::{chain_block, create_genesis_block_with_coinbase_value}, + nodes::{BaseNodeBuilder, NodeInterfaces}, +}; + +mod helpers; + +async fn setup() -> ( BaseNodeWalletRpcService, NodeInterfaces, RpcRequestMock, ConsensusManager, ChainBlock, UnblindedOutput, - Runtime, TempDir, ) { let network = NetworkConsensus::from(Network::LocalNet); let consensus_constants = network.create_consensus_constants(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxo0) = @@ -107,13 +107,14 @@ fn setup() -> ( let (mut base_node, _consensus_manager) = BaseNodeBuilder::new(network) .with_consensus_manager(consensus_manager.clone()) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; base_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), }); - let request_mock = runtime.enter(|| RpcRequestMock::new(base_node.comms.peer_manager())); + let request_mock = RpcRequestMock::new(base_node.comms.peer_manager()); let service = BaseNodeWalletRpcService::new( base_node.blockchain_db.clone().into(), base_node.mempool_handle.clone(), @@ -126,16 +127,15 @@ fn setup() -> ( consensus_manager, block0, utxo0, - runtime, temp_dir, ) } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_base_node_wallet_rpc() { +async fn test_base_node_wallet_rpc() { // Testing the submit_transaction() and transaction_query() rpc calls - let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, mut runtime, _temp_dir) = setup(); + let (service, mut base_node, request_mock, consensus_manager, block0, utxo0, _temp_dir) = setup().await; let (txs1, utxos1) = schema_to_transaction(&[txn_schema!(from: vec![utxo0.clone()], to: vec![1 * T, 1 * T])]); let tx1 = (*txs1[0]).clone(); @@ -151,8 +151,8 @@ fn test_base_node_wallet_rpc() { // Query Tx1 let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = service.transaction_query(req).await.unwrap().into_message(); + let resp = TxQueryResponse::try_from(resp).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -162,13 +162,7 @@ fn test_base_node_wallet_rpc() { let msg = TransactionProto::from(tx2.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::Orphan); @@ -176,8 +170,7 @@ fn test_base_node_wallet_rpc() { // Query Tx2 to confirm it wasn't accepted let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -189,24 +182,22 @@ fn test_base_node_wallet_rpc() { .prepare_block_merkle_roots(chain_block(&block0.block(), vec![tx1.clone()], &consensus_manager)) .unwrap(); - assert!(runtime - .block_on(base_node.local_nci.submit_block(block1.clone(), Broadcast::from(true))) - .is_ok()); + base_node + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); // Check that subitting Tx2 will now be accepted let msg = TransactionProto::from(tx2); let req = request_mock.request_with_context(Default::default(), msg); - let resp = runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(); + let resp = service.submit_transaction(req).await.unwrap().into_message(); assert!(resp.accepted); // Query Tx2 which should now be in the mempool let msg = SignatureProto::from(tx2_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 0); assert_eq!(resp.block_hash, None); @@ -215,13 +206,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined as Tx1's kernel is present let msg = TransactionProto::from(tx1); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::AlreadyMined); @@ -233,13 +218,7 @@ fn test_base_node_wallet_rpc() { // Now if we submit Tx1 is should return as rejected as AlreadyMined let msg = TransactionProto::from(tx1b); let req = request_mock.request_with_context(Default::default(), msg); - let resp = TxSubmissionResponse::try_from( - runtime - .block_on(service.submit_transaction(req)) - .unwrap() - .into_message(), - ) - .unwrap(); + let resp = TxSubmissionResponse::try_from(service.submit_transaction(req).await.unwrap().into_message()).unwrap(); assert!(!resp.accepted); assert_eq!(resp.rejection_reason, TxSubmissionRejectionReason::DoubleSpend); @@ -253,15 +232,16 @@ fn test_base_node_wallet_rpc() { block2.header.output_mmr_size += 1; block2.header.kernel_mmr_size += 1; - runtime - .block_on(base_node.local_nci.submit_block(block2, Broadcast::from(true))) + base_node + .local_nci + .submit_block(block2, Broadcast::from(true)) + .await .unwrap(); // Query Tx1 which should be in block 1 with 1 confirmation let msg = SignatureProto::from(tx1_sig.clone()); let req = request_mock.request_with_context(Default::default(), msg); - let resp = - TxQueryResponse::try_from(runtime.block_on(service.transaction_query(req)).unwrap().into_message()).unwrap(); + let resp = TxQueryResponse::try_from(service.transaction_query(req).await.unwrap().into_message()).unwrap(); assert_eq!(resp.confirmations, 1); assert_eq!(resp.block_hash, Some(block1.hash())); @@ -271,10 +251,7 @@ fn test_base_node_wallet_rpc() { sigs: vec![SignatureProto::from(tx1_sig.clone()), SignatureProto::from(tx2_sig)], }; let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.transaction_batch_query(req)) - .unwrap() - .into_message(); + let response = service.transaction_batch_query(req).await.unwrap().into_message(); for r in response.responses { let response = TxQueryBatchResponse::try_from(r).unwrap(); @@ -299,10 +276,7 @@ fn test_base_node_wallet_rpc() { let req = request_mock.request_with_context(Default::default(), msg); - let response = runtime - .block_on(service.fetch_matching_utxos(req)) - .unwrap() - .into_message(); + let response = service.fetch_matching_utxos(req).await.unwrap().into_message(); assert_eq!(response.outputs.len(), utxos1.len()); for output_proto in response.outputs.iter() { diff --git a/base_layer/core/tests/block_validation.rs b/base_layer/core/tests/block_validation.rs index 6961e48dfd..1a91df6012 100644 --- a/base_layer/core/tests/block_validation.rs +++ b/base_layer/core/tests/block_validation.rs @@ -20,9 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::helpers::{block_builders::chain_block_with_new_coinbase, test_blockchain::TestBlockchain}; -use monero::blockdata::block::Block as MoneroBlock; use std::sync::Arc; + +use monero::blockdata::block::Block as MoneroBlock; +use tari_crypto::inputs; + use tari_common::configuration::Network; use tari_core::{ blocks::{Block, BlockHeaderValidationError, BlockValidationError}, @@ -38,7 +40,7 @@ use tari_core::{ transactions::{ helpers::{schema_to_transaction, TestParams, UtxoTestParams}, tari_amount::T, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::{ @@ -50,7 +52,8 @@ use tari_core::{ ValidationError, }, }; -use tari_crypto::inputs; + +use crate::helpers::{block_builders::chain_block_with_new_coinbase, test_blockchain::TestBlockchain}; mod helpers; @@ -63,7 +66,7 @@ fn test_genesis_block() { let validators = Validators::new( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules.clone(), factories), + OrphanBlockValidator::new(rules.clone(), false, factories), ); let db = BlockchainDatabase::new( backend, @@ -216,7 +219,7 @@ fn inputs_are_not_malleable() { input_mut.input_data = malicious_input.input_data; input_mut.script_signature = malicious_input.script_signature; - let validator = BlockValidator::new(blockchain.consensus_manager().clone(), CryptoFactories::default()); + let validator = BlockValidator::new(blockchain.consensus_manager().clone(), true, CryptoFactories::default()); let err = validator .validate_body(&block, &*blockchain.store().db_read_access().unwrap()) .unwrap_err(); diff --git a/base_layer/core/tests/chain_storage_tests/chain_storage.rs b/base_layer/core/tests/chain_storage_tests/chain_storage.rs index b716858d4c..98a3aac848 100644 --- a/base_layer/core/tests/chain_storage_tests/chain_storage.rs +++ b/base_layer/core/tests/chain_storage_tests/chain_storage.rs @@ -20,24 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// use crate::helpers::database::create_test_db; -// use crate::helpers::database::create_store; -use crate::helpers::{ - block_builders::{ - append_block, - chain_block, - create_chain_header, - create_genesis_block, - find_header_with_achieved_difficulty, - generate_new_block, - generate_new_block_with_achieved_difficulty, - generate_new_block_with_coinbase, - }, - database::create_orphan_block, - sample_blockchains::{create_new_blockchain, create_new_blockchain_lmdb}, - test_blockchain::TestBlockchain, -}; use rand::{rngs::OsRng, RngCore}; +use tari_crypto::{script::StackItem, tari_utilities::Hashable}; + use tari_common::configuration::Network; use tari_common_types::types::BlockHash; use tari_core::{ @@ -63,16 +48,33 @@ use tari_core::{ transactions::{ helpers::{schema_to_transaction, spend_utxos}, tari_amount::{uT, MicroTari, T}, - types::CryptoFactories, + CryptoFactories, }, tx, txn_schema, validation::{mocks::MockValidator, DifficultyCalculator, ValidationError}, }; -use tari_crypto::{script::StackItem, tari_utilities::Hashable}; use tari_storage::lmdb_store::LMDBConfig; use tari_test_utils::{paths::create_temporary_data_path, unpack_enum}; +// use crate::helpers::database::create_test_db; +// use crate::helpers::database::create_store; +use crate::helpers::{ + block_builders::{ + append_block, + chain_block, + create_chain_header, + create_genesis_block, + find_header_with_achieved_difficulty, + generate_new_block, + generate_new_block_with_achieved_difficulty, + generate_new_block_with_coinbase, + }, + database::create_orphan_block, + sample_blockchains::{create_new_blockchain, create_new_blockchain_lmdb}, + test_blockchain::TestBlockchain, +}; + #[test] fn fetch_nonexistent_header() { let network = Network::LocalNet; diff --git a/base_layer/core/tests/helpers/block_builders.rs b/base_layer/core/tests/helpers/block_builders.rs index 6ff5c2102a..12d4659d62 100644 --- a/base_layer/core/tests/helpers/block_builders.rs +++ b/base_layer/core/tests/helpers/block_builders.rs @@ -20,10 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{iter::repeat_with, sync::Arc}; + use croaring::Bitmap; use rand::{rngs::OsRng, RngCore}; -use std::{iter::repeat_with, sync::Arc}; +use tari_crypto::{ + keys::PublicKey as PublicKeyTrait, + script, + tari_utilities::{hash::Hashable, hex::Hex}, +}; + use tari_common::configuration::Network; +use tari_common_types::types::{Commitment, HashDigest, HashOutput, PublicKey}; use tari_core::{ blocks::{Block, BlockHeader, NewBlockTemplate}, chain_storage::{ @@ -57,14 +65,9 @@ use tari_core::{ TransactionOutput, UnblindedOutput, }, - types::{Commitment, CryptoFactories, HashDigest, HashOutput, PublicKey}, + CryptoFactories, }, }; -use tari_crypto::{ - keys::PublicKey as PublicKeyTrait, - script, - tari_utilities::{hash::Hashable, hex::Hex}, -}; use tari_mmr::MutableMmr; pub fn create_coinbase( diff --git a/base_layer/core/tests/helpers/database.rs b/base_layer/core/tests/helpers/database.rs index 0bafeed84b..e1132445a3 100644 --- a/base_layer/core/tests/helpers/database.rs +++ b/base_layer/core/tests/helpers/database.rs @@ -20,13 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::helpers::block_builders::create_coinbase; use tari_core::{ blocks::{Block, BlockHeader, NewBlockTemplate}, consensus::{emission::Emission, ConsensusManager}, - transactions::{tari_amount::MicroTari, transaction::Transaction, types::CryptoFactories}, + transactions::{tari_amount::MicroTari, transaction::Transaction, CryptoFactories}, }; +use crate::helpers::block_builders::create_coinbase; + // use tari_test_utils::paths::create_temporary_data_path; /// Create a partially constructed block using the provided set of transactions diff --git a/base_layer/core/tests/helpers/event_stream.rs b/base_layer/core/tests/helpers/event_stream.rs index b79b494900..5485467f4c 100644 --- a/base_layer/core/tests/helpers/event_stream.rs +++ b/base_layer/core/tests/helpers/event_stream.rs @@ -20,16 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{future, future::Either, FutureExt, Stream, StreamExt}; use std::time::Duration; +use tokio::{sync::broadcast, time}; #[allow(dead_code)] -pub async fn event_stream_next(stream: &mut TStream, timeout: Duration) -> Option -where TStream: Stream + Unpin { - let either = future::select(stream.next(), tokio::time::delay_for(timeout).fuse()).await; - - match either { - Either::Left((v, _)) => v, - Either::Right(_) => None, +pub async fn event_stream_next(stream: &mut broadcast::Receiver, timeout: Duration) -> Option { + tokio::select! { + item = stream.recv() => match item { + Ok(item) => Some(item), + Err(broadcast::error::RecvError::Closed) => None, + Err(broadcast::error::RecvError::Lagged(n)) => panic!("Lagged events channel {}", n), + }, + _ = time::sleep(timeout) => None } } diff --git a/base_layer/core/tests/helpers/mock_state_machine.rs b/base_layer/core/tests/helpers/mock_state_machine.rs index 7d49f93e85..0d4b6ce512 100644 --- a/base_layer/core/tests/helpers/mock_state_machine.rs +++ b/base_layer/core/tests/helpers/mock_state_machine.rs @@ -40,7 +40,7 @@ impl MockBaseNodeStateMachine { } pub fn publish_status(&mut self, status: StatusInfo) { - let _ = self.status_sender.broadcast(status); + let _ = self.status_sender.send(status); } pub fn get_initializer(&self) -> MockBaseNodeStateMachineInitializer { diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index ffe69c8034..06f2d5f8e0 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -21,9 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::helpers::mock_state_machine::MockBaseNodeStateMachine; -use futures::Sink; use rand::rngs::OsRng; -use std::{error::Error, path::Path, sync::Arc, time::Duration}; +use std::{path::Path, sync::Arc, time::Duration}; use tari_common::configuration::Network; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, @@ -60,13 +59,12 @@ use tari_core::{ }, }; use tari_p2p::{ - comms_connector::{pubsub_connector, InboundDomainConnector, PeerMessage}, + comms_connector::{pubsub_connector, InboundDomainConnector}, initialization::initialize_local_test_comms, services::liveness::{LivenessConfig, LivenessHandle, LivenessInitializer}, }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tokio::runtime::Runtime; /// The NodeInterfaces is used as a container for providing access to all the services and interfaces of a single node. pub struct NodeInterfaces { @@ -91,7 +89,7 @@ pub struct NodeInterfaces { #[allow(dead_code)] impl NodeInterfaces { pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -182,7 +180,7 @@ impl BaseNodeBuilder { /// Build the test base node and start its services. #[allow(clippy::redundant_closure)] - pub fn start(self, runtime: &mut Runtime, data_path: &str) -> (NodeInterfaces, ConsensusManager) { + pub async fn start(self, data_path: &str) -> (NodeInterfaces, ConsensusManager) { let validators = self.validators.unwrap_or_else(|| { Validators::new( MockValidator::new(true), @@ -199,7 +197,6 @@ impl BaseNodeBuilder { let mempool = Mempool::new(self.mempool_config.unwrap_or_default(), Arc::new(mempool_validator)); let node_identity = self.node_identity.unwrap_or_else(|| random_node_identity()); let node_interfaces = setup_base_node_services( - runtime, node_identity, self.peers.unwrap_or_default(), blockchain_db, @@ -209,17 +206,19 @@ impl BaseNodeBuilder { self.mempool_service_config.unwrap_or_default(), self.liveness_service_config.unwrap_or_default(), data_path, - ); + ) + .await; (node_interfaces, consensus_manager) } } -#[allow(dead_code)] -pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { +pub async fn wait_until_online(nodes: &[&NodeInterfaces]) { for node in nodes { - runtime - .block_on(node.comms.connectivity().wait_for_connectivity(Duration::from_secs(10))) + node.comms + .connectivity() + .wait_for_connectivity(Duration::from_secs(10)) + .await .map_err(|err| format!("Node '{}' failed to go online {:?}", node.node_identity.node_id(), err)) .unwrap(); } @@ -227,10 +226,7 @@ pub fn wait_until_online(runtime: &mut Runtime, nodes: &[&NodeInterfaces]) { // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes( - runtime: &mut Runtime, - data_path: &str, -) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { +pub async fn create_network_with_2_base_nodes(data_path: &str) -> (NodeInterfaces, NodeInterfaces, ConsensusManager) { let alice_node_identity = random_node_identity(); let bob_node_identity = random_node_identity(); @@ -238,22 +234,23 @@ pub fn create_network_with_2_base_nodes( let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) - .start(runtime, data_path); + .start(data_path) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) .with_consensus_manager(consensus_manager) - .start(runtime, data_path); + .start(data_path) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with two Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_2_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_2_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -269,7 +266,8 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity) .with_peers(vec![alice_node_identity]) @@ -277,35 +275,34 @@ pub fn create_network_with_2_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; (alice_node, bob_node, consensus_manager) } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes( data_path: &str, ) -> (NodeInterfaces, NodeInterfaces, NodeInterfaces, ConsensusManager) { let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); create_network_with_3_base_nodes_with_config( - runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, data_path, ) + .await } // Creates a network with three Base Nodes where each node in the network knows the other nodes in the network. #[allow(dead_code)] -pub fn create_network_with_3_base_nodes_with_config>( - runtime: &mut Runtime, +pub async fn create_network_with_3_base_nodes_with_config>( base_node_service_config: BaseNodeServiceConfig, mempool_service_config: MempoolServiceConfig, liveness_service_config: LivenessConfig, @@ -329,7 +326,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("carol").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("carol").as_os_str().to_str().unwrap()) + .await; let (bob_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![carol_node_identity.clone()]) @@ -337,7 +335,8 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config.clone()) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("bob").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("bob").as_os_str().to_str().unwrap()) + .await; let (alice_node, consensus_manager) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) @@ -345,9 +344,10 @@ pub fn create_network_with_3_base_nodes_with_config>( .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .start(runtime, data_path.as_ref().join("alice").as_os_str().to_str().unwrap()); + .start(data_path.as_ref().join("alice").as_os_str().to_str().unwrap()) + .await; - wait_until_online(runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; (alice_node, bob_node, carol_node, consensus_manager) } @@ -365,16 +365,12 @@ pub fn random_node_identity() -> Arc { // Helper function for starting the comms stack. #[allow(dead_code)] -async fn setup_comms_services( +async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, -) -> (CommsNode, Dht, MessagingEventSender, Shutdown) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender, Shutdown) { let peers = peers.into_iter().map(|p| p.to_peer()).collect(); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = initialize_local_test_comms( @@ -393,8 +389,7 @@ where // Helper function for starting the services of the Base node. #[allow(clippy::too_many_arguments)] -fn setup_base_node_services( - runtime: &mut Runtime, +async fn setup_base_node_services( node_identity: Arc, peers: Vec>, blockchain_db: BlockchainDatabase, @@ -405,14 +400,14 @@ fn setup_base_node_services( liveness_service_config: LivenessConfig, data_path: &str, ) -> NodeInterfaces { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht, messaging_events, shutdown) = - runtime.block_on(setup_comms_services(node_identity.clone(), peers, publisher, data_path)); + setup_comms_services(node_identity.clone(), peers, publisher, data_path).await; let mock_state_machine = MockBaseNodeStateMachine::new(); - let fut = StackBuilder::new(shutdown.to_signal()) + let handles = StackBuilder::new(shutdown.to_signal()) .add_initializer(RegisterHandle::new(dht)) .add_initializer(RegisterHandle::new(comms.connectivity())) .add_initializer(LivenessInitializer::new( @@ -433,9 +428,9 @@ fn setup_base_node_services( )) .add_initializer(mock_state_machine.get_initializer()) .add_initializer(ChainMetadataServiceInitializer) - .build(); - - let handles = runtime.block_on(fut).expect("Service initialization failed"); + .build() + .await + .unwrap(); let outbound_nci = handles.expect_handle::(); let local_nci = handles.expect_handle::(); diff --git a/base_layer/core/tests/helpers/sample_blockchains.rs b/base_layer/core/tests/helpers/sample_blockchains.rs index ddc6398fed..108d77c248 100644 --- a/base_layer/core/tests/helpers/sample_blockchains.rs +++ b/base_layer/core/tests/helpers/sample_blockchains.rs @@ -21,8 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use crate::helpers::block_builders::{create_genesis_block, generate_new_block}; - use tari_common::configuration::Network; use tari_core::{ chain_storage::{ @@ -38,12 +36,15 @@ use tari_core::{ transactions::{ tari_amount::{uT, T}, transaction::UnblindedOutput, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::DifficultyCalculator, }; use tari_storage::lmdb_store::LMDBConfig; + +use crate::helpers::block_builders::{create_genesis_block, generate_new_block}; + // use crate::helpers::database::{TempDatabase, create_store_with_consensus}; static EMISSION: [u64; 2] = [10, 10]; diff --git a/base_layer/core/tests/helpers/test_blockchain.rs b/base_layer/core/tests/helpers/test_blockchain.rs index e961cb14de..43291d5f4e 100644 --- a/base_layer/core/tests/helpers/test_blockchain.rs +++ b/base_layer/core/tests/helpers/test_blockchain.rs @@ -21,24 +21,27 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use crate::helpers::{ - block_builders::{chain_block_with_new_coinbase, find_header_with_achieved_difficulty}, - block_proxy::BlockProxy, - sample_blockchains::create_new_blockchain, - test_block_builder::{TestBlockBuilder, TestBlockBuilderInner}, -}; +use std::{collections::HashMap, sync::Arc}; + use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{collections::HashMap, sync::Arc}; +use tari_crypto::tari_utilities::Hashable; + use tari_common::configuration::Network; use tari_core::{ blocks::Block, chain_storage::{BlockAddResult, BlockchainDatabase, ChainStorageError}, consensus::ConsensusManager, test_helpers::blockchain::TempDatabase, - transactions::{transaction::UnblindedOutput, types::CryptoFactories}, + transactions::{transaction::UnblindedOutput, CryptoFactories}, +}; + +use crate::helpers::{ + block_builders::{chain_block_with_new_coinbase, find_header_with_achieved_difficulty}, + block_proxy::BlockProxy, + sample_blockchains::create_new_blockchain, + test_block_builder::{TestBlockBuilder, TestBlockBuilderInner}, }; -use tari_crypto::tari_utilities::Hashable; const LOG_TARGET: &str = "tari_core::tests::helpers::test_blockchain"; diff --git a/base_layer/core/tests/mempool.rs b/base_layer/core/tests/mempool.rs index 80e187a99a..ae79003806 100644 --- a/base_layer/core/tests/mempool.rs +++ b/base_layer/core/tests/mempool.rs @@ -20,8 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; +// use crate::helpers::database::create_store; +use std::{ops::Deref, sync::Arc, time::Duration}; + +use tari_crypto::{keys::PublicKey as PublicKeyTrait, script}; +use tempfile::tempdir; use helpers::{ block_builders::{ @@ -35,10 +38,8 @@ use helpers::{ nodes::{create_network_with_2_base_nodes_with_config, create_network_with_3_base_nodes_with_config}, sample_blockchains::{create_new_blockchain, create_new_blockchain_with_constants}, }; -use tari_crypto::keys::PublicKey as PublicKeyTrait; -// use crate::helpers::database::create_store; -use std::{ops::Deref, sync::Arc, time::Duration}; use tari_common::configuration::Network; +use tari_common_types::types::{Commitment, PrivateKey, PublicKey, Signature}; use tari_comms_dht::domain_message::OutboundDomainMessage; use tari_core::{ base_node::{ @@ -56,21 +57,20 @@ use tari_core::{ tari_amount::{uT, MicroTari, T}, transaction::{KernelBuilder, OutputFeatures, Transaction, TransactionOutput}, transaction_protocol::{build_challenge, TransactionMetadata}, - types::{Commitment, CryptoFactories, PrivateKey, PublicKey, Signature}, + CryptoFactories, }, tx, txn_schema, validation::transaction_validators::{TxConsensusValidator, TxInputAndMaturityValidator}, }; -use tari_crypto::script; use tari_p2p::{services::liveness::LivenessConfig, tari_message::TariMessageType}; use tari_test_utils::async_assert_eventually; -use tempfile::tempdir; -use tokio::runtime::Runtime; +#[allow(dead_code)] +mod helpers; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_insert_and_process_published_block() { +async fn test_insert_and_process_published_block() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -201,9 +201,9 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.total_weight, 30); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_time_locked() { +async fn test_time_locked() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -245,9 +245,9 @@ fn test_time_locked() { assert_eq!(mempool.insert(tx2).unwrap(), TxStorageResponse::UnconfirmedPool); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_retrieve() { +async fn test_retrieve() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -331,9 +331,9 @@ fn test_retrieve() { assert!(retrieved_txs.contains(&tx2[1])); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_zero_conf() { +async fn test_zero_conf() { let network = Network::LocalNet; let (mut store, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(store.clone()); @@ -631,9 +631,9 @@ fn test_zero_conf() { assert!(retrieved_txs.contains(&Arc::new(tx34))); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn test_reorg() { +async fn test_reorg() { let network = Network::LocalNet; let (mut db, mut blocks, mut outputs, consensus_manager) = create_new_blockchain(network); let mempool_validator = TxInputAndMaturityValidator::new(db.clone()); @@ -712,13 +712,12 @@ fn test_reorg() { mempool.process_reorg(vec![], vec![reorg_block4.into()]).unwrap(); } -#[test] // TODO: This test returns 0 in the unconfirmed pool, so might not catch errors. It should be updated to return better // data #[allow(clippy::identity_op)] -fn request_response_get_stats() { +#[tokio::test] +async fn request_response_get_stats() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -731,13 +730,13 @@ fn request_response_get_stats() { .with_block(block0) .build(); let (mut alice, bob, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path(), - ); + ) + .await; // Create a tx spending the genesis output. Then create 2 orphan txs let (tx1, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![2 * T, 2 * T, 2 * T])); @@ -759,21 +758,18 @@ fn request_response_get_stats() { assert_eq!(stats.reorg_txs, 0); assert_eq!(stats.total_weight, 0); - runtime.block_on(async { - // Alice will request mempool stats from Bob, and thus should be identical - let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); - assert_eq!(received_stats.total_txs, 0); - assert_eq!(received_stats.unconfirmed_txs, 0); - assert_eq!(received_stats.reorg_txs, 0); - assert_eq!(received_stats.total_weight, 0); - }); + // Alice will request mempool stats from Bob, and thus should be identical + let received_stats = alice.outbound_mp_interface.get_stats().await.unwrap(); + assert_eq!(received_stats.total_txs, 0); + assert_eq!(received_stats.unconfirmed_txs, 0); + assert_eq!(received_stats.reorg_txs, 0); + assert_eq!(received_stats.total_weight, 0); } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn request_response_get_tx_state_by_excess_sig() { +async fn request_response_get_tx_state_by_excess_sig() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -786,13 +782,13 @@ fn request_response_get_tx_state_by_excess_sig() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let (tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo.clone()], to: vec![2 * T, 2 * T, 2 * T])); let (unpublished_tx, _, _) = spend_utxos(txn_schema!(from: vec![utxo], to: vec![3 * T])); @@ -807,43 +803,40 @@ fn request_response_get_tx_state_by_excess_sig() { // Check that the transactions are in the expected pools. // Spending the coinbase utxo will be in the pending pool, because cb utxos have a maturity. // The orphan tx will be in the orphan pool, while the unadded tx won't be found - runtime.block_on(async { - let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); - let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); - let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - assert_eq!( - alice_node - .outbound_mp_interface - .get_tx_state_by_excess_sig(orphan_tx_excess_sig) - .await - .unwrap(), - TxStorageResponse::NotStored - ); - }); + let tx_excess_sig = tx.body.kernels()[0].excess_sig.clone(); + let unpublished_tx_excess_sig = unpublished_tx.body.kernels()[0].excess_sig.clone(); + let orphan_tx_excess_sig = orphan_tx.body.kernels()[0].excess_sig.clone(); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(unpublished_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); + assert_eq!( + alice_node + .outbound_mp_interface + .get_tx_state_by_excess_sig(orphan_tx_excess_sig) + .await + .unwrap(), + TxStorageResponse::NotStored + ); } static EMISSION: [u64; 2] = [10, 10]; -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn receive_and_propagate_transaction() { +async fn receive_and_propagate_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -857,13 +850,13 @@ fn receive_and_propagate_transaction() { .build(); let (mut alice_node, mut bob_node, mut carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -884,63 +877,61 @@ fn receive_and_propagate_transaction() { assert!(alice_node.mempool.insert(Arc::new(tx.clone())).is_ok()); assert!(alice_node.mempool.insert(Arc::new(orphan.clone())).is_ok()); - runtime.block_on(async { - alice_node - .outbound_message_service - .send_direct( - bob_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), - ) - .await - .unwrap(); - alice_node - .outbound_message_service - .send_direct( - carol_node.node_identity.public_key().clone(), - OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), - ) - .await - .unwrap(); + alice_node + .outbound_message_service + .send_direct( + bob_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(tx)), + ) + .await + .unwrap(); + alice_node + .outbound_message_service + .send_direct( + carol_node.node_identity.public_key().clone(), + OutboundDomainMessage::new(TariMessageType::NewTransaction, proto::types::Transaction::from(orphan)), + ) + .await + .unwrap(); - async_assert_eventually!( - bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(tx_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // Carol got sent the orphan tx directly, so it will be in her mempool - async_assert_eventually!( - carol_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - max_attempts = 10, - interval = Duration::from_millis(1000) - ); - // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated - // by the time we check it - async_assert_eventually!( - bob_node - .mempool - .has_tx_with_excess_sig(orphan_excess_sig.clone()) - .unwrap(), - expect = TxStorageResponse::NotStored, - ); - }); + async_assert_eventually!( + bob_node.mempool.has_tx_with_excess_sig(tx_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(tx_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // Carol got sent the orphan tx directly, so it will be in her mempool + async_assert_eventually!( + carol_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + // It's difficult to test a negative here, but let's at least make sure that the orphan TX was not propagated + // by the time we check it + async_assert_eventually!( + bob_node + .mempool + .has_tx_with_excess_sig(orphan_excess_sig.clone()) + .unwrap(), + expect = TxStorageResponse::NotStored, + ); } -#[test] -fn consensus_validation_large_tx() { +#[tokio::test] +async fn consensus_validation_large_tx() { let network = Network::LocalNet; // We dont want to compute the 19500 limit of local net, so we create smaller blocks let consensus_constants = ConsensusConstantsBuilder::new(network) @@ -1029,7 +1020,7 @@ fn consensus_validation_large_tx() { // make sure the tx was correctly made and is valid let factories = CryptoFactories::default(); - assert!(tx.validate_internal_consistency(&factories, None).is_ok()); + assert!(tx.validate_internal_consistency(true, &factories, None).is_ok()); let weight = tx.calculate_weight(); let height = blocks.len() as u64; @@ -1042,9 +1033,8 @@ fn consensus_validation_large_tx() { assert!(matches!(response, TxStorageResponse::NotStored)); } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let mempool_service_config = MempoolServiceConfig { @@ -1053,27 +1043,25 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), mempool_service_config, LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - bob_node.shutdown().await; + bob_node.shutdown().await; - match alice_node.outbound_mp_interface.get_stats().await { - Err(MempoolServiceError::RequestTimedOut) => {}, - _ => panic!(), - } - }); + match alice_node.outbound_mp_interface.get_stats().await { + Err(MempoolServiceError::RequestTimedOut) => {}, + _ => panic!(), + } } -#[test] +#[tokio::test] #[allow(clippy::identity_op)] -fn block_event_and_reorg_event_handling() { +async fn block_event_and_reorg_event_handling() { // This test creates 2 nodes Alice and Bob // Then creates 2 chains B1 -> B2A (diff 1) and B1 -> B2B (diff 10) // There are 5 transactions created @@ -1086,7 +1074,6 @@ fn block_event_and_reorg_event_handling() { let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let (block0, utxos0) = create_genesis_block_with_coinbase_value(&factories, 100_000_000.into(), &consensus_constants[0]); @@ -1095,13 +1082,13 @@ fn block_event_and_reorg_event_handling() { .with_block(block0.clone()) .build(); let (mut alice, mut bob, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; alice.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -1135,88 +1122,86 @@ fn block_event_and_reorg_event_handling() { .prepare_block_merkle_roots(chain_block(block0.block(), vec![], &consensus_manager)) .unwrap(); - runtime.block_on(async { - // Add one empty block, so the coinbase UTXO is no longer time-locked. - assert!(bob - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - assert!(alice - .local_nci - .submit_block(empty_block.clone(), Broadcast::from(true)) - .await - .is_ok()); - alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); - let mut block1 = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); - // Add Block1 - tx1 will be moved to the ReorgPool. - assert!(bob - .local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .is_ok()); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); - bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); - - let mut block2a = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); - // Block2b also builds on Block1 but has a stronger PoW - let mut block2b = bob - .blockchain_db - .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) - .unwrap(); - find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); - - // Add Block2a - tx2b and tx3b will be discarded as double spends. - assert!(bob - .local_nci - .submit_block(block2a.clone(), Broadcast::from(true)) - .await - .is_ok()); - - async_assert_eventually!( - bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - async_assert_eventually!( - alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), - expect = TxStorageResponse::ReorgPool, - max_attempts = 20, - interval = Duration::from_millis(1000) - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - assert_eq!( - alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), - TxStorageResponse::ReorgPool - ); - }); + // Add one empty block, so the coinbase UTXO is no longer time-locked. + assert!(bob + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + assert!(alice + .local_nci + .submit_block(empty_block.clone(), Broadcast::from(true)) + .await + .is_ok()); + alice.mempool.insert(Arc::new(tx1.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx1.clone())).unwrap(); + let mut block1 = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&empty_block, vec![tx1], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block1.header, Difficulty::from(1)); + // Add Block1 - tx1 will be moved to the ReorgPool. + assert!(bob + .local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .is_ok()); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx1_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + alice.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + alice.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3a.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx2b.clone())).unwrap(); + bob.mempool.insert(Arc::new(tx3b.clone())).unwrap(); + + let mut block2a = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2a, tx3a], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2a.header, Difficulty::from(1)); + // Block2b also builds on Block1 but has a stronger PoW + let mut block2b = bob + .blockchain_db + .prepare_block_merkle_roots(chain_block(&block1, vec![tx2b, tx3b], &consensus_manager)) + .unwrap(); + find_header_with_achieved_difficulty(&mut block2b.header, Difficulty::from(10)); + + // Add Block2a - tx2b and tx3b will be discarded as double spends. + assert!(bob + .local_nci + .submit_block(block2a.clone(), Broadcast::from(true)) + .await + .is_ok()); + + async_assert_eventually!( + bob.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + async_assert_eventually!( + alice.mempool.has_tx_with_excess_sig(tx2a_excess_sig.clone()).unwrap(), + expect = TxStorageResponse::ReorgPool, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3a_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx2b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); + assert_eq!( + alice.mempool.has_tx_with_excess_sig(tx3b_excess_sig.clone()).unwrap(), + TxStorageResponse::ReorgPool + ); } diff --git a/base_layer/core/tests/node_comms_interface.rs b/base_layer/core/tests/node_comms_interface.rs index 532d102bbe..a6096b8dc9 100644 --- a/base_layer/core/tests/node_comms_interface.rs +++ b/base_layer/core/tests/node_comms_interface.rs @@ -20,13 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; -use futures::{channel::mpsc, StreamExt}; -use helpers::block_builders::append_block; use std::sync::Arc; + +use futures::StreamExt; +use helpers::block_builders::append_block; use tari_common::configuration::Network; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_common_types::{chain_metadata::ChainMetadata, types::PublicKey}; use tari_comms::peer_manager::NodeId; use tari_core::{ base_node::{ @@ -42,7 +41,7 @@ use tari_core::{ helpers::{create_utxo, spend_utxos}, tari_amount::MicroTari, transaction::{OutputFeatures, TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, PublicKey}, + CryptoFactories, }, txn_schema, validation::{mocks::MockValidator, transaction_validators::TxInputAndMaturityValidator}, @@ -56,6 +55,10 @@ use tari_crypto::{ }; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tokio::sync::broadcast; + +use tokio::sync::mpsc; +#[allow(dead_code)] +mod helpers; // use crate::helpers::database::create_test_db; async fn test_request_responder( @@ -71,10 +74,10 @@ fn new_mempool() -> Mempool { Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)) } -#[tokio_macros::test] +#[tokio::test] async fn outbound_get_metadata() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let metadata = ChainMetadata::new(5, vec![0u8], 3, 0, 5); @@ -86,7 +89,7 @@ async fn outbound_get_metadata() { assert_eq!(received_metadata.unwrap(), metadata); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_get_metadata() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -95,7 +98,7 @@ async fn inbound_get_metadata() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -117,7 +120,7 @@ async fn inbound_get_metadata() { } } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_kernel_by_excess_sig() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -126,7 +129,7 @@ async fn inbound_fetch_kernel_by_excess_sig() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -149,10 +152,10 @@ async fn inbound_fetch_kernel_by_excess_sig() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_headers() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let mut header = BlockHeader::new(0); @@ -167,7 +170,7 @@ async fn outbound_fetch_headers() { assert_eq!(received_headers[0], header); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_headers() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -175,7 +178,7 @@ async fn inbound_fetch_headers() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -197,11 +200,11 @@ async fn inbound_fetch_headers() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_utxos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (utxo, _, _) = create_utxo( @@ -221,7 +224,7 @@ async fn outbound_fetch_utxos() { assert_eq!(received_utxos[0], utxo); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_utxos() { let factories = CryptoFactories::default(); @@ -231,7 +234,7 @@ async fn inbound_fetch_utxos() { let consensus_manager = ConsensusManager::builder(network).build(); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -264,11 +267,11 @@ async fn inbound_fetch_utxos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_txos() { let factories = CryptoFactories::default(); let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let (txo1, _, _) = create_utxo( @@ -296,7 +299,7 @@ async fn outbound_fetch_txos() { assert_eq!(received_txos[1], txo2); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_txos() { let factories = CryptoFactories::default(); let store = create_test_blockchain_db(); @@ -305,7 +308,7 @@ async fn inbound_fetch_txos() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -366,10 +369,10 @@ async fn inbound_fetch_txos() { } } -#[tokio_macros::test] +#[tokio::test] async fn outbound_fetch_blocks() { let (request_sender, mut request_receiver) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let mut outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -385,7 +388,7 @@ async fn outbound_fetch_blocks() { assert_eq!(received_blocks[0], block); } -#[tokio_macros::test] +#[tokio::test] async fn inbound_fetch_blocks() { let store = create_test_blockchain_db(); let mempool = new_mempool(); @@ -393,7 +396,7 @@ async fn inbound_fetch_blocks() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, @@ -415,7 +418,7 @@ async fn inbound_fetch_blocks() { } } -#[tokio_macros::test] +#[tokio::test] // Test needs to be updated to new pruned structure. async fn inbound_fetch_blocks_before_horizon_height() { let factories = CryptoFactories::default(); @@ -437,7 +440,7 @@ async fn inbound_fetch_blocks_before_horizon_height() { let mempool = Mempool::new(MempoolConfig::default(), Arc::new(mempool_validator)); let (block_event_sender, _) = broadcast::channel(50); let (request_sender, _) = reply_channel::unbounded(); - let (block_sender, _) = mpsc::unbounded(); + let (block_sender, _) = mpsc::unbounded_channel(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); let inbound_nch = InboundNodeCommsHandlers::new( block_event_sender, diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index 14f277f19b..af128e5966 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -23,7 +23,6 @@ #[allow(dead_code)] mod helpers; use crate::helpers::block_builders::{construct_chained_blocks, create_coinbase}; -use futures::join; use helpers::{ block_builders::{ append_block, @@ -59,7 +58,7 @@ use tari_core::{ helpers::{schema_to_transaction, spend_utxos}, tari_amount::{uT, T}, transaction::OutputFeatures, - types::CryptoFactories, + CryptoFactories, }, txn_schema, validation::{ @@ -68,15 +67,13 @@ use tari_core::{ mocks::MockValidator, }, }; -use tari_crypto::tari_utilities::hash::Hashable; +use tari_crypto::tari_utilities::Hashable; use tari_p2p::services::liveness::LivenessConfig; use tari_test_utils::unpack_enum; use tempfile::tempdir; -use tokio::runtime::Runtime; -#[test] -fn request_response_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_response_get_metadata() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -89,27 +86,24 @@ fn request_response_get_metadata() { .with_block(block0) .build(); let (mut alice_node, bob_node, carol_node, _consensus_manager) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); - assert_eq!(received_metadata.height_of_longest_chain(), 0); + let received_metadata = alice_node.outbound_nci.get_metadata().await.unwrap(); + assert_eq!(received_metadata.height_of_longest_chain(), 0); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -122,13 +116,13 @@ fn request_and_response_fetch_blocks() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -147,26 +141,23 @@ fn request_and_response_fetch_blocks() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0]).await.unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + let received_blocks = alice_node.outbound_nci.fetch_blocks(vec![0, 1]).await.unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(*received_blocks[0].block(), *received_blocks[1].block()); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn request_and_response_fetch_blocks_with_hashes() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn request_and_response_fetch_blocks_with_hashes() { let factories = CryptoFactories::default(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; @@ -179,13 +170,13 @@ fn request_and_response_fetch_blocks_with_hashes() { .with_block(block0.clone()) .build(); let (mut alice_node, mut bob_node, carol_node, _) = create_network_with_3_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager.clone(), temp_dir.path().to_str().unwrap(), - ); + ) + .await; let mut blocks = vec![block0]; let db = &mut bob_node.blockchain_db; @@ -206,34 +197,31 @@ fn request_and_response_fetch_blocks_with_hashes() { .unwrap() .assert_added(); - runtime.block_on(async { - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 1); - assert_eq!(received_blocks[0].block(), blocks[0].block()); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 1); + assert_eq!(received_blocks[0].block(), blocks[0].block()); - let received_blocks = alice_node - .outbound_nci - .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) - .await - .unwrap(); - assert_eq!(received_blocks.len(), 2); - assert_ne!(received_blocks[0], received_blocks[1]); - assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); - assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + let received_blocks = alice_node + .outbound_nci + .fetch_blocks_with_hashes(vec![block0_hash.clone(), block1_hash.clone()]) + .await + .unwrap(); + assert_eq!(received_blocks.len(), 2); + assert_ne!(received_blocks[0], received_blocks[1]); + assert!(received_blocks[0].block() == blocks[0].block() || received_blocks[1].block() == blocks[0].block()); + assert!(received_blocks[0].block() == blocks[1].block() || received_blocks[1].block() == blocks[1].block()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_many_valid_blocks() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_many_valid_blocks() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate a number of block hashes to bob, bob will receive it, request the full block, verify and @@ -261,24 +249,28 @@ fn propagate_and_forward_many_valid_blocks() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![bob_node_identity.clone()]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity) .with_peers(vec![carol_node_identity, bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -302,56 +294,49 @@ fn propagate_and_forward_many_valid_blocks() { let blocks = construct_chained_blocks(&alice_node.blockchain_db, block0, &rules, 5); - runtime.block_on(async { - for block in &blocks { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block.block()), vec![]) - .await - .unwrap(); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - let block_hash = block.hash(); - - if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Bob's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*carol_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Carol's node did not receive and validate the expected block"); - } - if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = - &*dan_block_event.unwrap().unwrap() - { - assert_eq!(&received_block.hash(), block_hash); - } else { - panic!("Dan's node did not receive and validate the expected block"); - } + for block in &blocks { + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block.block()), vec![]) + .await + .unwrap(); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(20000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + let block_hash = block.hash(); + + if let BlockEvent::ValidBlockAdded(received_block, _, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Bob's node did not receive and validate the expected block"); } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Carol's node did not receive and validate the expected block"); + } + if let BlockEvent::ValidBlockAdded(received_block, _block_add_result, _) = &*dan_block_event.unwrap() { + assert_eq!(&received_block.hash(), block_hash); + } else { + panic!("Dan's node did not receive and validate the expected block"); + } + } - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn propagate_and_forward_invalid_block_hash() { +#[tokio::test] +async fn propagate_and_forward_invalid_block_hash() { // Alice will propagate a "made up" block hash to Bob, Bob will request the block from Alice. Alice will not be able // to provide the block and so Bob will not propagate the hash further to Carol. // alice -> bob -> carol - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); @@ -370,19 +355,22 @@ fn propagate_and_forward_invalid_block_hash() { let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity) .with_peers(vec![bob_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, state_info: StateInfo::Listening(ListeningInfo::new(true)), @@ -409,42 +397,37 @@ fn propagate_and_forward_invalid_block_hash() { let mut bob_message_events = bob_node.messaging_events.subscribe(); let mut carol_message_events = carol_node.messaging_events.subscribe(); - runtime.block_on(async { - alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .unwrap(); + alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .unwrap(); - // Alice propagated to Bob - // Bob received the invalid hash - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); - // Sent the request for the block to Alice - // Bob received a response from Alice - let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) - .await - .unwrap() - .unwrap(); - unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); - assert_eq!(&*node_id, alice_node.node_identity.node_id()); - // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be - // flaky. - let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; - assert!(msg_event.is_none()); - - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - }); + // Alice propagated to Bob + // Bob received the invalid hash + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(_a, _b) = &*msg_event); + // Sent the request for the block to Alice + // Bob received a response from Alice + let msg_event = event_stream_next(&mut bob_message_events, Duration::from_secs(10)) + .await + .unwrap(); + unpack_enum!(MessagingEvent::MessageReceived(node_id, _a) = &*msg_event); + assert_eq!(&*node_id, alice_node.node_identity.node_id()); + // Checking a negative: Bob should not have propagated this hash to Carol. If Bob does, this assertion will be + // flaky. + let msg_event = event_stream_next(&mut carol_message_events, Duration::from_millis(500)).await; + assert!(msg_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; } -#[test] -fn propagate_and_forward_invalid_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn propagate_and_forward_invalid_block() { let temp_dir = tempdir().unwrap(); let factories = CryptoFactories::default(); // Alice will propagate an invalid block to Carol and Bob, they will check the received block and not propagate the @@ -467,13 +450,14 @@ fn propagate_and_forward_invalid_block() { .with_consensus_constants(consensus_constants) .with_block(block0.clone()) .build(); - let stateless_block_validator = OrphanBlockValidator::new(rules.clone(), factories); + let stateless_block_validator = OrphanBlockValidator::new(rules.clone(), true, factories); let mock_validator = MockValidator::new(false); let (mut dan_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(dan_node_identity.clone()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("dan").to_str().unwrap()); + .start(temp_dir.path().join("dan").to_str().unwrap()) + .await; let (mut carol_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![dan_node_identity.clone()]) @@ -483,20 +467,23 @@ fn propagate_and_forward_invalid_block() { mock_validator.clone(), stateless_block_validator.clone(), ) - .start(&mut runtime, temp_dir.path().join("carol").to_str().unwrap()); + .start(temp_dir.path().join("carol").to_str().unwrap()) + .await; let (mut bob_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(bob_node_identity.clone()) .with_peers(vec![dan_node_identity]) .with_consensus_manager(rules) .with_validators(mock_validator.clone(), mock_validator, stateless_block_validator) - .start(&mut runtime, temp_dir.path().join("bob").to_str().unwrap()); + .start(temp_dir.path().join("bob").to_str().unwrap()) + .await; let (mut alice_node, rules) = BaseNodeBuilder::new(network.into()) .with_node_identity(alice_node_identity) .with_peers(vec![bob_node_identity, carol_node_identity]) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().join("alice").to_str().unwrap()); + .start(temp_dir.path().join("alice").to_str().unwrap()) + .await; - wait_until_online(&mut runtime, &[&alice_node, &bob_node, &carol_node, &dan_node]); + wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; alice_node.mock_base_node_state_machine.publish_status(StatusInfo { bootstrapped: true, @@ -520,45 +507,42 @@ fn propagate_and_forward_invalid_block() { let block1 = append_block(&alice_node.blockchain_db, &block0, vec![], &rules, 1.into()).unwrap(); let block1_hash = block1.hash(); - runtime.block_on(async { - let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); - let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); - let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - - assert!(alice_node - .outbound_nci - .propagate_block(NewBlock::from(block1.block()), vec![]) - .await - .is_ok()); - - let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); - let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); - let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); - let (bob_block_event, carol_block_event, dan_block_event) = - join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); - - if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Bob's node should have detected an invalid block"); - } - if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap().unwrap() { - assert_eq!(&received_block.hash(), block1_hash); - } else { - panic!("Carol's node should have detected an invalid block"); - } - assert!(dan_block_event.is_none()); + let mut bob_block_event_stream = bob_node.local_nci.get_block_event_stream(); + let mut carol_block_event_stream = carol_node.local_nci.get_block_event_stream(); + let mut dan_block_event_stream = dan_node.local_nci.get_block_event_stream(); - alice_node.shutdown().await; - bob_node.shutdown().await; - carol_node.shutdown().await; - dan_node.shutdown().await; - }); + assert!(alice_node + .outbound_nci + .propagate_block(NewBlock::from(block1.block()), vec![]) + .await + .is_ok()); + + let bob_block_event_fut = event_stream_next(&mut bob_block_event_stream, Duration::from_millis(20000)); + let carol_block_event_fut = event_stream_next(&mut carol_block_event_stream, Duration::from_millis(20000)); + let dan_block_event_fut = event_stream_next(&mut dan_block_event_stream, Duration::from_millis(5000)); + let (bob_block_event, carol_block_event, dan_block_event) = + tokio::join!(bob_block_event_fut, carol_block_event_fut, dan_block_event_fut); + + if let BlockEvent::AddBlockFailed(received_block, _) = &*bob_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Bob's node should have detected an invalid block"); + } + if let BlockEvent::AddBlockFailed(received_block, _) = &*carol_block_event.unwrap() { + assert_eq!(&received_block.hash(), block1_hash); + } else { + panic!("Carol's node should have detected an invalid block"); + } + assert!(dan_block_event.is_none()); + + alice_node.shutdown().await; + bob_node.shutdown().await; + carol_node.shutdown().await; + dan_node.shutdown().await; } -#[test] -fn service_request_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn service_request_timeout() { let network = Network::LocalNet; let consensus_manager = ConsensusManager::builder(network).build(); let base_node_service_config = BaseNodeServiceConfig { @@ -569,47 +553,42 @@ fn service_request_timeout() { }; let temp_dir = tempdir().unwrap(); let (mut alice_node, bob_node, _consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, base_node_service_config, MempoolServiceConfig::default(), LivenessConfig::default(), consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; - runtime.block_on(async { - // Bob should not be reachable - bob_node.shutdown().await; - unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); - alice_node.shutdown().await; - }); + // Bob should not be reachable + bob_node.shutdown().await; + unpack_enum!(CommsInterfaceError::RequestTimedOut = alice_node.outbound_nci.get_metadata().await.unwrap_err()); + alice_node.shutdown().await; } -#[test] -fn local_get_metadata() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_get_metadata() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let block0 = db.fetch_block(0).unwrap().try_into_chain_block().unwrap(); let block1 = append_block(db, &block0, vec![], &consensus_manager, 1.into()).unwrap(); let block2 = append_block(db, &block1, vec![], &consensus_manager, 1.into()).unwrap(); - runtime.block_on(async { - let metadata = node.local_nci.get_metadata().await.unwrap(); - assert_eq!(metadata.height_of_longest_chain(), 2); - assert_eq!(metadata.best_block(), block2.hash()); + let metadata = node.local_nci.get_metadata().await.unwrap(); + assert_eq!(metadata.height_of_longest_chain(), 2); + assert_eq!(metadata.best_block(), block2.hash()); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_template_and_get_new_block() { +#[tokio::test] +async fn local_get_new_block_template_and_get_new_block() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -620,7 +599,8 @@ fn local_get_new_block_template_and_get_new_block() { .build(); let (mut node, _rules) = BaseNodeBuilder::new(network.into()) .with_consensus_manager(rules) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let schema = [ txn_schema!(from: vec![outputs[1].clone()], to: vec![10_000 * uT, 20_000 * uT]), @@ -630,29 +610,26 @@ fn local_get_new_block_template_and_get_new_block() { assert!(node.mempool.insert(txs[0].clone()).is_ok()); assert!(node.mempool.insert(txs[1].clone()).is_ok()); - runtime.block_on(async { - let block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 2); + let block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 2); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); - node.blockchain_db.add_block(block.clone().into()).unwrap(); + node.blockchain_db.add_block(block.clone().into()).unwrap(); - node.shutdown().await; - }); + node.shutdown().await; } -#[test] -fn local_get_new_block_with_zero_conf() { +#[tokio::test] +async fn local_get_new_block_with_zero_conf() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -666,9 +643,10 @@ fn local_get_new_block_with_zero_conf() { .with_validators( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules, factories.clone()), + OrphanBlockValidator::new(rules, true, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -700,38 +678,35 @@ fn local_get_new_block_with_zero_conf() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_get_new_block_with_combined_transaction() { +#[tokio::test] +async fn local_get_new_block_with_combined_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; let consensus_constants = NetworkConsensus::from(network).create_consensus_constants(); @@ -745,9 +720,10 @@ fn local_get_new_block_with_combined_transaction() { .with_validators( BodyOnlyValidator::default(), HeaderValidator::new(rules.clone()), - OrphanBlockValidator::new(rules, factories.clone()), + OrphanBlockValidator::new(rules, true, factories.clone()), ) - .start(&mut runtime, temp_dir.path().to_str().unwrap()); + .start(temp_dir.path().to_str().unwrap()) + .await; let (tx01, tx01_out, _) = spend_utxos( txn_schema!(from: vec![outputs[1].clone()], to: vec![20_000 * uT], fee: 10*uT, lock: 0, features: OutputFeatures::default()), @@ -774,41 +750,39 @@ fn local_get_new_block_with_combined_transaction() { TxStorageResponse::UnconfirmedPool ); - runtime.block_on(async { - let mut block_template = node - .local_nci - .get_new_block_template(PowAlgorithm::Sha3, 0) - .await - .unwrap(); - assert_eq!(block_template.header.height, 1); - assert_eq!(block_template.body.kernels().len(), 4); - let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); - let (output, kernel, _) = create_coinbase( - &factories, - coinbase_value, - rules.consensus_constants(1).coinbase_lock_height() + 1, - ); - block_template.body.add_kernel(kernel); - block_template.body.add_output(output); - block_template.body.sort(); - let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); - assert_eq!(block.header.height, 1); - assert_eq!(block.body, block_template.body); - assert_eq!(block_template.body.kernels().len(), 5); - - node.blockchain_db.add_block(block.clone().into()).unwrap(); - - node.shutdown().await; - }); + let mut block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Sha3, 0) + .await + .unwrap(); + assert_eq!(block_template.header.height, 1); + assert_eq!(block_template.body.kernels().len(), 4); + let coinbase_value = rules.get_block_reward_at(1) + block_template.body.get_total_fee(); + let (output, kernel, _) = create_coinbase( + &factories, + coinbase_value, + rules.consensus_constants(1).coinbase_lock_height() + 1, + ); + block_template.body.add_kernel(kernel); + block_template.body.add_output(output); + block_template.body.sort(); + let block = node.local_nci.get_new_block(block_template.clone()).await.unwrap(); + assert_eq!(block.header.height, 1); + assert_eq!(block.body, block_template.body); + assert_eq!(block_template.body.kernels().len(), 5); + + node.blockchain_db.add_block(block.clone().into()).unwrap(); + + node.shutdown().await; } -#[test] -fn local_submit_block() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn local_submit_block() { let temp_dir = tempdir().unwrap(); let network = Network::LocalNet; - let (mut node, consensus_manager) = - BaseNodeBuilder::new(network.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (mut node, consensus_manager) = BaseNodeBuilder::new(network.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; let db = &node.blockchain_db; let mut event_stream = node.local_nci.get_block_event_stream(); @@ -818,20 +792,18 @@ fn local_submit_block() { .unwrap(); block1.header.kernel_mmr_size += 1; block1.header.output_mmr_size += 1; - runtime.block_on(async { - node.local_nci - .submit_block(block1.clone(), Broadcast::from(true)) - .await - .unwrap(); + node.local_nci + .submit_block(block1.clone(), Broadcast::from(true)) + .await + .unwrap(); - let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; - if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap().unwrap() { - assert_eq!(received_block.hash(), block1.hash()); - result.assert_added(); - } else { - panic!("Block validation failed"); - } + let event = event_stream_next(&mut event_stream, Duration::from_millis(20000)).await; + if let BlockEvent::ValidBlockAdded(received_block, result, _) = &*event.unwrap() { + assert_eq!(received_block.hash(), block1.hash()); + result.assert_added(); + } else { + panic!("Block validation failed"); + } - node.shutdown().await; - }); + node.shutdown().await; } diff --git a/base_layer/core/tests/node_state_machine.rs b/base_layer/core/tests/node_state_machine.rs index bcbeeea436..da9e5e3d95 100644 --- a/base_layer/core/tests/node_state_machine.rs +++ b/base_layer/core/tests/node_state_machine.rs @@ -20,16 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#[allow(dead_code)] -mod helpers; - -use futures::StreamExt; use helpers::{ block_builders::{append_block, chain_block, create_genesis_block}, chain_metadata::{random_peer_metadata, MockChainMetadata}, nodes::{create_network_with_2_base_nodes_with_config, wait_until_online, BaseNodeBuilder}, }; -use std::{thread, time::Duration}; +use std::time::Duration; use tari_common::configuration::Network; use tari_core::{ base_node::{ @@ -47,22 +43,24 @@ use tari_core::{ mempool::MempoolServiceConfig, proof_of_work::randomx_factory::RandomXFactory, test_helpers::blockchain::create_test_blockchain_db, - transactions::types::CryptoFactories, + transactions::CryptoFactories, validation::mocks::MockValidator, }; use tari_p2p::services::liveness::LivenessConfig; use tari_shutdown::Shutdown; use tempfile::tempdir; use tokio::{ - runtime::Runtime, sync::{broadcast, watch}, + task, time, }; +#[allow(dead_code)] +mod helpers; + static EMISSION: [u64; 2] = [10, 10]; -#[test] -fn test_listening_lagging() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_listening_lagging() { let factories = CryptoFactories::default(); let network = Network::LocalNet; let temp_dir = tempdir().unwrap(); @@ -75,7 +73,6 @@ fn test_listening_lagging() { .with_block(prev_block.clone()) .build(); let (alice_node, bob_node, consensus_manager) = create_network_with_2_base_nodes_with_config( - &mut runtime, BaseNodeServiceConfig::default(), MempoolServiceConfig::default(), LivenessConfig { @@ -84,7 +81,8 @@ fn test_listening_lagging() { }, consensus_manager, temp_dir.path().to_str().unwrap(), - ); + ) + .await; let shutdown = Shutdown::new(); let (state_change_event_publisher, _) = broadcast::channel(10); let (status_event_sender, _status_event_receiver) = watch::channel(StatusInfo::new()); @@ -103,46 +101,44 @@ fn test_listening_lagging() { consensus_manager.clone(), shutdown.to_signal(), ); - wait_until_online(&mut runtime, &[&alice_node, &bob_node]); + wait_until_online(&[&alice_node, &bob_node]).await; - let await_event_task = runtime.spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); + let await_event_task = task::spawn(async move { Listening::new().next_event(&mut alice_state_machine).await }); - runtime.block_on(async move { - let bob_db = bob_node.blockchain_db; - let mut bob_local_nci = bob_node.local_nci; + let bob_db = bob_node.blockchain_db; + let mut bob_local_nci = bob_node.local_nci; - // Bob Block 1 - no block event - let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); - // Bob Block 2 - with block event and liveness service metadata update - let mut prev_block = bob_db - .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) - .unwrap(); - prev_block.header.output_mmr_size += 1; - prev_block.header.kernel_mmr_size += 1; - bob_local_nci - .submit_block(prev_block, Broadcast::from(true)) - .await - .unwrap(); - assert_eq!(bob_db.get_height().unwrap(), 2); + // Bob Block 1 - no block event + let prev_block = append_block(&bob_db, &prev_block, vec![], &consensus_manager, 3.into()).unwrap(); + // Bob Block 2 - with block event and liveness service metadata update + let mut prev_block = bob_db + .prepare_block_merkle_roots(chain_block(&prev_block.block(), vec![], &consensus_manager)) + .unwrap(); + prev_block.header.output_mmr_size += 1; + prev_block.header.kernel_mmr_size += 1; + bob_local_nci + .submit_block(prev_block, Broadcast::from(true)) + .await + .unwrap(); + assert_eq!(bob_db.get_height().unwrap(), 2); - let next_event = time::timeout(Duration::from_secs(10), await_event_task) - .await - .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") - .unwrap(); + let next_event = time::timeout(Duration::from_secs(10), await_event_task) + .await + .expect("Alice did not emit `StateEvent::FallenBehind` within 10 seconds") + .unwrap(); - match next_event { - StateEvent::InitialSync => {}, - _ => panic!(), - } - }); + match next_event { + StateEvent::InitialSync => {}, + _ => panic!(), + } } -#[test] -fn test_event_channel() { +#[tokio::test] +async fn test_event_channel() { let temp_dir = tempdir().unwrap(); - let mut runtime = Runtime::new().unwrap(); - let (node, consensus_manager) = - BaseNodeBuilder::new(Network::Weatherwax.into()).start(&mut runtime, temp_dir.path().to_str().unwrap()); + let (node, consensus_manager) = BaseNodeBuilder::new(Network::Weatherwax.into()) + .start(temp_dir.path().to_str().unwrap()) + .await; // let shutdown = Shutdown::new(); let db = create_test_blockchain_db(); let shutdown = Shutdown::new(); @@ -165,24 +161,21 @@ fn test_event_channel() { shutdown.to_signal(), ); - runtime.spawn(state_machine.run()); + task::spawn(state_machine.run()); let PeerChainMetadata { node_id, chain_metadata, } = random_peer_metadata(10, 5_000); - runtime - .block_on(mock.publish_chain_metadata(&node_id, &chain_metadata)) + mock.publish_chain_metadata(&node_id, &chain_metadata) + .await .expect("Could not publish metadata"); - thread::sleep(Duration::from_millis(50)); - runtime.block_on(async { - let event = state_change_event_subscriber.next().await; - assert_eq!(*event.unwrap().unwrap(), StateEvent::Initialized); - let event = state_change_event_subscriber.next().await; - let event = event.unwrap().unwrap(); - match event.as_ref() { - StateEvent::InitialSync => (), - _ => panic!("Unexpected state was found:{:?}", event), - } - }); + let event = state_change_event_subscriber.recv().await; + assert_eq!(*event.unwrap(), StateEvent::Initialized); + let event = state_change_event_subscriber.recv().await; + let event = event.unwrap(); + match event.as_ref() { + StateEvent::InitialSync => (), + _ => panic!("Unexpected state was found:{:?}", event), + } } diff --git a/base_layer/key_manager/Cargo.toml b/base_layer/key_manager/Cargo.toml index e57d0bf8b4..a40415ce17 100644 --- a/base_layer/key_manager/Cargo.toml +++ b/base_layer/key_manager/Cargo.toml @@ -15,7 +15,7 @@ sha2 = "0.9.5" serde = "1.0.89" serde_derive = "1.0.89" serde_json = "1.0.39" -thiserror = "1.0.20" +thiserror = "1.0.26" [features] avx2 = ["tari_crypto/avx2"] diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index 38b723a5f2..0d725874fb 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -14,7 +14,7 @@ benches = ["criterion"] [dependencies] tari_utilities = "^0.3" -thiserror = "1.0.20" +thiserror = "1.0.26" digest = "0.9.0" log = "0.4" serde = { version = "1.0.97", features = ["derive"] } diff --git a/base_layer/p2p/Cargo.toml b/base_layer/p2p/Cargo.toml index e9d3a0291c..0d7aae7390 100644 --- a/base_layer/p2p/Cargo.toml +++ b/base_layer/p2p/Cargo.toml @@ -10,37 +10,38 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../../comms"} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht"} -tari_common = { version= "^0.9", path = "../../common" } +tari_comms = { version = "^0.9", path = "../../comms" } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } +tari_common = { version = "^0.9", path = "../../common" } tari_crypto = "0.11.1" -tari_service_framework = { version = "^0.9", path = "../service_framework"} -tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } tari_utilities = "^0.3" anyhow = "1.0.32" bytes = "0.5" -chrono = {version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } fs2 = "0.3.0" -futures = {version = "^0.3.1"} +futures = { version = "^0.3.1" } lmdb-zero = "0.4.4" log = "0.4.6" -pgp = {version = "0.7.1", optional = true} -prost = "0.6.1" +pgp = { version = "0.7.1", optional = true } +prost = "=0.8.0" rand = "0.8" -reqwest = {version = "0.10", optional = true, default-features = false} +reqwest = { version = "0.10", optional = true, default-features = false } semver = "1.0.1" serde = "1.0.90" serde_derive = "1.0.90" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["blocking"]} +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["macros"] } +tokio-stream = { version = "0.1.7", default-features = false, features = ["time"] } tower = "0.3.0-alpha.2" -tower-service = { version="0.3.0-alpha.2" } -trust-dns-client = {version="0.19.5", features=["dns-over-rustls"]} +tower-service = { version = "0.3.0-alpha.2" } +trust-dns-client = { version = "0.21.0-alpha.1", features = ["dns-over-rustls"] } [dev-dependencies] -tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } clap = "2.33.0" env_logger = "0.6.2" @@ -48,7 +49,6 @@ futures-timer = "0.3.0" lazy_static = "1.3.0" stream-cancel = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.4" [dev-dependencies.log4rs] version = "^0.8" @@ -56,7 +56,7 @@ features = ["console_appender", "file_appender", "file", "yaml_format"] default-features = false [build-dependencies] -tari_common = { version = "^0.9", path="../../common", features = ["build"] } +tari_common = { version = "^0.9", path = "../../common", features = ["build"] } [features] test-mocks = [] diff --git a/base_layer/p2p/examples/gen_tor_identity.rs b/base_layer/p2p/examples/gen_tor_identity.rs index 52a2e3c785..c1a7693de3 100644 --- a/base_layer/p2p/examples/gen_tor_identity.rs +++ b/base_layer/p2p/examples/gen_tor_identity.rs @@ -39,7 +39,7 @@ fn to_abs_path(path: &str) -> String { } } -#[tokio_macros::main] +#[tokio::main] async fn main() { let matches = App::new("Tor identity file generator") .version("1.0") diff --git a/base_layer/p2p/src/auto_update/dns.rs b/base_layer/p2p/src/auto_update/dns.rs index 7042b19919..63fb4ed05e 100644 --- a/base_layer/p2p/src/auto_update/dns.rs +++ b/base_layer/p2p/src/auto_update/dns.rs @@ -32,7 +32,7 @@ use std::{ use tari_common::configuration::bootstrap::ApplicationType; use tari_utilities::hex::{from_hex, Hex}; -const LOG_TARGET: &str = "p2p::auto-update:dns"; +const LOG_TARGET: &str = "p2p::auto_update::dns"; pub struct DnsSoftwareUpdate { client: DnsClient, @@ -189,19 +189,26 @@ impl Display for UpdateSpec { #[cfg(test)] mod test { use super::*; + use crate::dns::mock; use trust_dns_client::{ - proto::rr::{rdata, RData, RecordType}, + op::Query, + proto::{ + rr::{rdata, Name, RData, RecordType}, + xfer::DnsResponse, + }, rr::Record, }; - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let resp_query = Query::query(Name::from_str("test.local.").unwrap(), RecordType::A); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } mod update_spec { @@ -220,7 +227,6 @@ mod test { mod dns_software_update { use super::*; use crate::DEFAULT_DNS_NAME_SERVER; - use std::{collections::HashMap, iter::FromIterator}; impl AutoUpdateConfig { fn get_test_defaults() -> Self { @@ -238,15 +244,15 @@ mod test { } } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_ignores_non_conforming_txt_entries() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec![":::"]), - create_txt_record(vec!["base-node:::"]), - create_txt_record(vec!["base-node::1.0:"]), - create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec![":::"])), + Ok(create_txt_record(vec!["base-node:::"])), + Ok(create_txt_record(vec!["base-node::1.0:"])), + Ok(create_txt_record(vec!["base-node:android-armv7:0.1.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:bada55"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), @@ -258,12 +264,12 @@ mod test { assert!(spec.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn it_returns_best_update() { - let records = HashMap::from_iter([("test.local.", vec![ - create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"]), - create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"]), - ])]); + let records = vec![ + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.0:abcdef"])), + Ok(create_txt_record(vec!["base-node:linux-x86_64:1.0.1:abcdef01"])), + ]; let updater = DnsSoftwareUpdate { client: DnsClient::connect_mock(records).await.unwrap(), config: AutoUpdateConfig::get_test_defaults(), diff --git a/base_layer/p2p/src/auto_update/mod.rs b/base_layer/p2p/src/auto_update/mod.rs index 21de03b693..8ce17b5f2b 100644 --- a/base_layer/p2p/src/auto_update/mod.rs +++ b/base_layer/p2p/src/auto_update/mod.rs @@ -46,7 +46,7 @@ use std::{ use tari_common::configuration::bootstrap::ApplicationType; use tari_utilities::hex::Hex; -const LOG_TARGET: &str = "p2p::auto-update"; +const LOG_TARGET: &str = "p2p::auto_update"; #[derive(Debug, Clone)] pub struct AutoUpdateConfig { @@ -58,6 +58,12 @@ pub struct AutoUpdateConfig { pub hashes_sig_url: String, } +impl AutoUpdateConfig { + pub fn is_update_enabled(&self) -> bool { + !self.update_uris.is_empty() + } +} + pub async fn check_for_updates( app: ApplicationType, arch: &str, diff --git a/base_layer/p2p/src/auto_update/service.rs b/base_layer/p2p/src/auto_update/service.rs index a786d84ad3..a235ec6fe0 100644 --- a/base_layer/p2p/src/auto_update/service.rs +++ b/base_layer/p2p/src/auto_update/service.rs @@ -24,19 +24,19 @@ use crate::{ auto_update, auto_update::{AutoUpdateConfig, SoftwareUpdate, Version}, }; -use futures::{ - channel::{mpsc, oneshot}, - future::Either, - stream, - SinkExt, - StreamExt, -}; +use futures::{future::Either, stream, StreamExt}; +use log::*; use std::{env::consts, time::Duration}; use tari_common::configuration::bootstrap::ApplicationType; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; -use tokio::{sync::watch, time}; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; +use tokio_stream::wrappers; -const LOG_TARGET: &str = "app:auto-update"; +const LOG_TARGET: &str = "p2p::auto_update"; /// A watch notifier that contains the latest software update, if any pub type SoftwareUpdateNotifier = watch::Receiver>; @@ -94,20 +94,25 @@ impl SoftwareUpdaterService { new_update_notification: watch::Receiver>, ) { let mut interval_or_never = match self.check_interval { - Some(interval) => Either::Left(time::interval(interval)).fuse(), - None => Either::Right(stream::empty()).fuse(), + Some(interval) => { + let mut interval = time::interval(interval); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + Either::Left(wrappers::IntervalStream::new(interval)) + }, + None => Either::Right(stream::empty()), }; loop { let last_version = new_update_notification.borrow().clone(); - let maybe_update = futures::select! { - reply = request_rx.select_next_some() => { + let maybe_update = tokio::select! { + Some(reply) = request_rx.recv() => { let maybe_update = self.check_for_updates().await; let _ = reply.send(maybe_update.clone()); maybe_update }, - _ = interval_or_never.next() => { + + Some(_) = interval_or_never.next() => { // Periodically, check for updates if configured to do so. // If an update is found the new update notifier will be triggered and any listeners notified self.check_for_updates().await @@ -121,7 +126,7 @@ impl SoftwareUpdaterService { .map(|up| up.version() < update.version()) .unwrap_or(true) { - let _ = notifier.broadcast(Some(update.clone())); + let _ = notifier.send(Some(update.clone())); } } } @@ -133,6 +138,13 @@ impl SoftwareUpdaterService { "Checking for updates ({})...", self.config.update_uris.join(", ") ); + if !self.config.is_update_enabled() { + warn!( + target: LOG_TARGET, + "Check for updates has been called but auto update has been disabled in the config" + ); + return None; + } let arch = format!("{}-{}", consts::OS, consts::ARCH); diff --git a/base_layer/p2p/src/comms_connector/inbound_connector.rs b/base_layer/p2p/src/comms_connector/inbound_connector.rs index ed16cd578d..6feed82be7 100644 --- a/base_layer/p2p/src/comms_connector/inbound_connector.rs +++ b/base_layer/p2p/src/comms_connector/inbound_connector.rs @@ -22,45 +22,42 @@ use super::peer_message::PeerMessage; use anyhow::anyhow; -use futures::{task::Context, Future, Sink, SinkExt}; +use futures::{task::Context, Future}; use log::*; use std::{pin::Pin, sync::Arc, task::Poll}; use tari_comms::pipeline::PipelineError; use tari_comms_dht::{domain_message::MessageHeader, inbound::DecryptedDhtMessage}; +use tokio::sync::mpsc; use tower::Service; const LOG_TARGET: &str = "comms::middleware::inbound_connector"; /// This service receives DecryptedDhtMessage, deserializes the MessageHeader and /// sends a `PeerMessage` on the given sink. #[derive(Clone)] -pub struct InboundDomainConnector { - sink: TSink, +pub struct InboundDomainConnector { + sink: mpsc::Sender>, } -impl InboundDomainConnector { - pub fn new(sink: TSink) -> Self { +impl InboundDomainConnector { + pub fn new(sink: mpsc::Sender>) -> Self { Self { sink } } } -impl Service for InboundDomainConnector -where - TSink: Sink> + Unpin + Clone + 'static, - TSink::Error: std::error::Error + Send + Sync + 'static, -{ +impl Service for InboundDomainConnector { type Error = PipelineError; - type Future = Pin>>>; + type Future = Pin> + Send>>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - let mut sink = self.sink.clone(); + let sink = self.sink.clone(); let future = async move { let peer_message = Self::construct_peer_message(msg)?; - // If this fails there is something wrong with the sink and the pubsub middleware should not + // If this fails the channel has closed and the pubsub middleware should not // continue sink.send(Arc::new(peer_message)).await?; @@ -70,7 +67,7 @@ where } } -impl InboundDomainConnector { +impl InboundDomainConnector { fn construct_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result { let envelope_body = inbound_message .success_mut() @@ -107,41 +104,17 @@ impl InboundDomainConnector { } } -impl Sink for InboundDomainConnector -where - TSink: Sink> + Unpin, - TSink::Error: Into + Send + Sync + 'static, -{ - type Error = PipelineError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_ready(cx).map_err(Into::into) - } - - fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { - let item = Self::construct_peer_message(item)?; - Pin::new(&mut self.sink).start_send(Arc::new(item)).map_err(Into::into) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_flush(cx).map_err(Into::into) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.sink).poll_close(cx).map_err(Into::into) - } -} - #[cfg(test)] mod test { use super::*; use crate::test_utils::{make_dht_inbound_message, make_node_identity}; - use futures::{channel::mpsc, executor::block_on, StreamExt}; + use futures::executor::block_on; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; use tari_comms_dht::domain_message::MessageHeader; + use tokio::sync::mpsc; use tower::ServiceExt; - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -151,12 +124,12 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); @@ -165,14 +138,14 @@ mod test { let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); - InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); + InboundDomainConnector::new(tx).call(decrypted).await.unwrap(); - let peer_message = block_on(rx.next()).unwrap(); + let peer_message = block_on(rx.recv()).unwrap(); assert_eq!(peer_message.message_header.message_type, 123); assert_eq!(peer_message.decode_message::().unwrap(), "my message"); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); @@ -182,10 +155,11 @@ mod test { let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); - assert!(rx.try_next().unwrap().is_none()); + rx.close(); + assert!(rx.recv().await.is_none()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_fail_send() { // Drop the receiver of the channel, this is the only reason this middleware should return an error // from it's call function diff --git a/base_layer/p2p/src/comms_connector/pubsub.rs b/base_layer/p2p/src/comms_connector/pubsub.rs index ae01e8ced8..198eee63f2 100644 --- a/base_layer/p2p/src/comms_connector/pubsub.rs +++ b/base_layer/p2p/src/comms_connector/pubsub.rs @@ -22,11 +22,15 @@ use super::peer_message::PeerMessage; use crate::{comms_connector::InboundDomainConnector, tari_message::TariMessageType}; -use futures::{channel::mpsc, future, stream::Fuse, Stream, StreamExt}; +use futures::{future, Stream, StreamExt}; use log::*; use std::{cmp, fmt::Debug, sync::Arc, time::Duration}; use tari_comms::rate_limit::RateLimit; -use tokio::{runtime::Handle, sync::broadcast}; +use tokio::{ + sync::{broadcast, mpsc}, + task, +}; +use tokio_stream::wrappers; const LOG_TARGET: &str = "comms::middleware::pubsub"; @@ -35,16 +39,11 @@ const RATE_LIMIT_MIN_CAPACITY: usize = 5; const RATE_LIMIT_RESTOCK_INTERVAL: Duration = Duration::from_millis(1000); /// Alias for a pubsub-type domain connector -pub type PubsubDomainConnector = InboundDomainConnector>>; +pub type PubsubDomainConnector = InboundDomainConnector; pub type SubscriptionFactory = TopicSubscriptionFactory>; /// Connects `InboundDomainConnector` to a `tari_pubsub::TopicPublisher` through a buffered broadcast channel -pub fn pubsub_connector( - // TODO: Remove this arg in favor of task::spawn - executor: Handle, - buf_size: usize, - rate_limit: usize, -) -> (PubsubDomainConnector, SubscriptionFactory) { +pub fn pubsub_connector(buf_size: usize, rate_limit: usize) -> (PubsubDomainConnector, SubscriptionFactory) { let (publisher, subscription_factory) = pubsub_channel(buf_size); let (sender, receiver) = mpsc::channel(buf_size); trace!( @@ -55,8 +54,8 @@ pub fn pubsub_connector( ); // Spawn a task which forwards messages from the pubsub service to the TopicPublisher - executor.spawn(async move { - let forwarder = receiver + task::spawn(async move { + wrappers::ReceiverStream::new(receiver) // Rate limit the receiver; the sender will adhere to the limit .rate_limit(cmp::max(rate_limit, RATE_LIMIT_MIN_CAPACITY), RATE_LIMIT_RESTOCK_INTERVAL) // Map DomainMessage into a TopicPayload @@ -89,8 +88,7 @@ pub fn pubsub_connector( ); } future::ready(()) - }); - forwarder.await; + }).await; }); (InboundDomainConnector::new(sender), subscription_factory) } @@ -98,8 +96,8 @@ pub fn pubsub_connector( /// Create a topic-based pub-sub channel fn pubsub_channel(size: usize) -> (TopicPublisher, TopicSubscriptionFactory) where - T: Clone + Debug + Send + Eq, - M: Send + Clone, + T: Clone + Debug + Send + Eq + 'static, + M: Send + Clone + 'static, { let (publisher, _) = broadcast::channel(size); (publisher.clone(), TopicSubscriptionFactory::new(publisher)) @@ -138,8 +136,8 @@ pub struct TopicSubscriptionFactory { impl TopicSubscriptionFactory where - T: Clone + Eq + Debug + Send, - M: Clone + Send, + T: Clone + Eq + Debug + Send + 'static, + M: Clone + Send + 'static, { pub fn new(sender: broadcast::Sender>) -> Self { TopicSubscriptionFactory { sender } @@ -148,38 +146,22 @@ where /// Create a subscription stream to a particular topic. The provided label is used to identify which consumer is /// lagging. pub fn get_subscription(&self, topic: T, label: &'static str) -> impl Stream { - self.sender - .subscribe() - .filter_map({ - let topic = topic.clone(); - move |result| { - let opt = match result { - Ok(payload) => Some(payload), - Err(broadcast::RecvError::Closed) => None, - Err(broadcast::RecvError::Lagged(n)) => { - warn!( - target: LOG_TARGET, - "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n - ); - None - }, - }; - future::ready(opt) - } - }) - .filter_map(move |item| { - let opt = if item.topic() == &topic { - Some(item.message) - } else { - None + wrappers::BroadcastStream::new(self.sender.subscribe()).filter_map({ + move |result| { + let opt = match result { + Ok(payload) if *payload.topic() == topic => Some(payload.message), + Ok(_) => None, + Err(wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => { + warn!( + target: LOG_TARGET, + "Subscription '{}' for topic '{:?}' lagged. {} message(s) dropped.", label, topic, n + ); + None + }, }; future::ready(opt) - }) - } - - /// Convenience function that returns a fused (`stream::Fuse`) version of the subscription stream. - pub fn get_subscription_fused(&self, topic: T, label: &'static str) -> Fuse> { - self.get_subscription(topic, label).fuse() + } + }) } } @@ -190,7 +172,7 @@ mod test { use std::time::Duration; use tari_test_utils::collect_stream; - #[tokio_macros::test_basic] + #[tokio::test] async fn topic_pub_sub() { let (publisher, subscriber_factory) = pubsub_channel(10); diff --git a/base_layer/p2p/src/dns/client.rs b/base_layer/p2p/src/dns/client.rs index 85e03f71e4..78093186ef 100644 --- a/base_layer/p2p/src/dns/client.rs +++ b/base_layer/p2p/src/dns/client.rs @@ -20,35 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#[cfg(test)] +use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use super::DnsClientError; -use futures::future; +use futures::{future, FutureExt}; use std::{net::SocketAddr, sync::Arc}; use tari_shutdown::Shutdown; use tokio::{net::UdpSocket, task}; use trust_dns_client::{ - client::{AsyncClient, AsyncDnssecClient}, - op::{DnsResponse, Query}, - proto::{ - rr::dnssec::TrustAnchor, - udp::{UdpClientStream, UdpResponse}, - xfer::DnsRequestOptions, - DnsHandle, - }, + client::{AsyncClient, AsyncDnssecClient, ClientHandle}, + op::Query, + proto::{error::ProtoError, rr::dnssec::TrustAnchor, udp::UdpClientStream, xfer::DnsResponse, DnsHandle}, rr::{DNSClass, IntoName, RecordType}, serialize::binary::BinEncoder, }; -#[cfg(test)] -use std::collections::HashMap; -#[cfg(test)] -use trust_dns_client::{proto::xfer::DnsMultiplexerSerialResponse, rr::Record}; - #[derive(Clone)] pub enum DnsClient { - Secure(Client>), - Normal(Client>), + Secure(Client), + Normal(Client), #[cfg(test)] - Mock(Client>), + Mock(Client>), } impl DnsClient { @@ -63,18 +56,18 @@ impl DnsClient { } #[cfg(test)] - pub async fn connect_mock(records: HashMap<&'static str, Vec>) -> Result { - let client = Client::connect_mock(records).await?; + pub async fn connect_mock(messages: Vec>) -> Result { + let client = Client::connect_mock(messages).await?; Ok(DnsClient::Mock(client)) } - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { + pub async fn lookup(&mut self, query: Query) -> Result { use DnsClient::*; match self { - Secure(ref mut client) => client.lookup(query, options).await, - Normal(ref mut client) => client.lookup(query, options).await, + Secure(ref mut client) => client.lookup(query).await, + Normal(ref mut client) => client.lookup(query).await, #[cfg(test)] - Mock(ref mut client) => client.lookup(query, options).await, + Mock(ref mut client) => client.lookup(query).await, } } @@ -85,11 +78,11 @@ impl DnsClient { .set_query_class(DNSClass::IN) .set_query_type(RecordType::TXT); - let response = self.lookup(query, Default::default()).await?; + let responses = self.lookup(query).await?; - let records = response - .messages() - .flat_map(|msg| msg.answers()) + let records = responses + .answers() + .iter() .map(|answer| { let data = answer.rdata(); let mut buf = Vec::new(); @@ -116,7 +109,7 @@ pub struct Client { shutdown: Arc, } -impl Client> { +impl Client { pub async fn connect_secure(name_server: SocketAddr, trust_anchor: TrustAnchor) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); @@ -124,7 +117,7 @@ impl Client> { .trust_anchor(trust_anchor) .build() .await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -133,12 +126,12 @@ impl Client> { } } -impl Client> { +impl Client { pub async fn connect(name_server: SocketAddr) -> Result { let shutdown = Shutdown::new(); let stream = UdpClientStream::::new(name_server); let (client, background) = AsyncClient::connect(stream).await?; - task::spawn(future::select(shutdown.to_signal(), background)); + task::spawn(future::select(shutdown.to_signal(), background.fuse())); Ok(Self { inner: client, @@ -148,87 +141,31 @@ impl Client> { } impl Client -where C: DnsHandle +where C: DnsHandle { - pub async fn lookup(&mut self, query: Query, options: DnsRequestOptions) -> Result { - let resp = self.inner.lookup(query, options).await?; - Ok(resp) + pub async fn lookup(&mut self, query: Query) -> Result { + let client_resp = self + .inner + .query(query.name().clone(), query.query_class(), query.query_type()) + .await?; + Ok(client_resp) } } #[cfg(test)] mod mock { use super::*; - use futures::{channel::mpsc, future, Stream, StreamExt}; - use std::{ - fmt, - fmt::Display, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; + use crate::dns::mock::{DefaultOnSend, MockClientHandle}; + use std::sync::Arc; use tari_shutdown::Shutdown; - use tokio::task; - use trust_dns_client::{ - client::AsyncClient, - op::Message, - proto::{ - error::ProtoError, - xfer::{DnsClientStream, DnsMultiplexerSerialResponse, SerialMessage}, - StreamHandle, - }, - rr::Record, - }; - - pub struct MockStream { - receiver: mpsc::UnboundedReceiver>, - answers: HashMap<&'static str, Vec>, - } - - impl DnsClientStream for MockStream { - fn name_server_addr(&self) -> SocketAddr { - ([0u8, 0, 0, 0], 53).into() - } - } - - impl Display for MockStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MockStream") - } - } - - impl Stream for MockStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let req = match futures::ready!(self.receiver.poll_next_unpin(cx)) { - Some(r) => r, - None => return Poll::Ready(None), - }; - let req = Message::from_vec(&req).unwrap(); - let name = req.queries()[0].name().to_string(); - let mut msg = Message::new(); - let answers = self.answers.get(name.as_str()).into_iter().flatten().cloned(); - msg.set_id(req.id()).add_answers(answers); - Poll::Ready(Some(Ok(SerialMessage::new( - msg.to_vec().unwrap(), - self.name_server_addr(), - )))) - } - } - - impl Client> { - pub async fn connect_mock(answers: HashMap<&'static str, Vec>) -> Result { - let (tx, rx) = mpsc::unbounded(); - let stream = future::ready(Ok(MockStream { receiver: rx, answers })); - let (client, background) = AsyncClient::new(stream, Box::new(StreamHandle::new(tx)), None).await?; + use trust_dns_client::proto::error::ProtoError; - let shutdown = Shutdown::new(); - task::spawn(future::select(shutdown.to_signal(), background)); + impl Client> { + pub async fn connect_mock(messages: Vec>) -> Result { + let client = MockClientHandle::mock(messages); Ok(Self { inner: client, - shutdown: Arc::new(shutdown), + shutdown: Arc::new(Shutdown::new()), }) } } diff --git a/base_layer/p2p/src/dns/mock.rs b/base_layer/p2p/src/dns/mock.rs new file mode 100644 index 0000000000..9dec5155cb --- /dev/null +++ b/base_layer/p2p/src/dns/mock.rs @@ -0,0 +1,105 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{future, stream, Future}; +use std::{error::Error, pin::Pin, sync::Arc}; +use trust_dns_client::{ + op::{Message, Query}, + proto::{ + error::ProtoError, + xfer::{DnsHandle, DnsRequest, DnsResponse}, + }, + rr::Record, +}; + +#[derive(Clone)] +pub struct MockClientHandle { + messages: Arc>>, + on_send: O, +} + +impl MockClientHandle { + /// constructs a new MockClient which returns each Message one after the other + pub fn mock(messages: Vec>) -> Self { + println!("MockClientHandle::mock message count: {}", messages.len()); + + MockClientHandle { + messages: Arc::new(messages), + on_send: DefaultOnSend, + } + } +} + +impl DnsHandle for MockClientHandle +where E: From + Error + Clone + Send + Sync + Unpin + 'static +{ + type Error = E; + type Response = stream::Once>>; + + fn send>(&mut self, _: R) -> Self::Response { + let responses = (*self.messages) + .clone() + .into_iter() + .fold(Result::<_, E>::Ok(Message::new()), |msg, resp| { + msg.and_then(|mut msg| { + resp.map(move |resp| { + msg.add_answers(resp.answers().iter().cloned()); + msg + }) + }) + }) + .map(DnsResponse::from); + + // let stream = stream::unfold(messages, |mut msgs| async move { + // let msg = msgs.pop()?; + // Some((msg, msgs)) + // }); + + stream::once(future::ready(responses)) + } +} + +pub fn message(query: Query, answers: Vec, name_servers: Vec, additionals: Vec) -> Message { + let mut message = Message::new(); + message.add_query(query); + message.insert_answers(answers); + message.insert_name_servers(name_servers); + message.insert_additionals(additionals); + message +} + +pub trait OnSend: Clone + Send + Sync + 'static { + fn on_send( + &mut self, + response: Result, + ) -> Pin> + Send>> + where + E: From + Send + 'static, + { + Box::pin(future::ready(response)) + } +} + +#[derive(Clone)] +pub struct DefaultOnSend; + +impl OnSend for DefaultOnSend {} diff --git a/base_layer/p2p/src/dns/mod.rs b/base_layer/p2p/src/dns/mod.rs index 197b788236..8c49de1993 100644 --- a/base_layer/p2p/src/dns/mod.rs +++ b/base_layer/p2p/src/dns/mod.rs @@ -4,6 +4,9 @@ pub use client::DnsClient; mod error; pub use error::DnsClientError; +#[cfg(test)] +pub(crate) mod mock; + use trust_dns_client::proto::rr::dnssec::{public_key::Rsa, TrustAnchor}; #[inline] diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9264a5033c..0915479612 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -19,20 +19,20 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] use crate::{ - comms_connector::{InboundDomainConnector, PeerMessage, PubsubDomainConnector}, + comms_connector::{InboundDomainConnector, PubsubDomainConnector}, peer_seeds::{DnsSeedResolver, SeedPeer}, transport::{TorConfig, TransportType}, MAJOR_NETWORK_VERSION, MINOR_NETWORK_VERSION, }; use fs2::FileExt; -use futures::{channel::mpsc, future, Sink}; +use futures::future; use log::*; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::{ - error::Error, fs::File, iter, net::SocketAddr, @@ -47,7 +47,6 @@ use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerManagerError}, pipeline, - pipeline::SinkService, protocol::{ messaging::{MessagingEventSender, MessagingProtocolExtension}, rpc::RpcServer, @@ -71,7 +70,7 @@ use tari_storage::{ LMDBWrapper, }; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::ServiceBuilder; const LOG_TARGET: &str = "p2p::initialization"; @@ -158,18 +157,14 @@ pub struct CommsConfig { } /// Initialize Tari Comms configured for tests -pub async fn initialize_local_test_comms( +pub async fn initialize_local_test_comms( node_identity: Arc, - connector: InboundDomainConnector, + connector: InboundDomainConnector, data_path: &str, discovery_request_timeout: Duration, seed_peers: Vec, shutdown_signal: ShutdownSignal, -) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> Result<(CommsNode, Dht, MessagingEventSender), CommsInitializationError> { let peer_database_name = { let mut rng = thread_rng(); iter::repeat(()) @@ -230,7 +225,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); @@ -319,15 +314,11 @@ async fn initialize_hidden_service( builder.build().await } -async fn configure_comms_and_dht( +async fn configure_comms_and_dht( builder: CommsBuilder, config: &CommsConfig, - connector: InboundDomainConnector, -) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> -where - TSink: Sink> + Unpin + Clone + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ + connector: InboundDomainConnector, +) -> Result<(UnspawnedCommsNode, Dht), CommsInitializationError> { let file_lock = acquire_exclusive_file_lock(&config.datastore_path)?; let datastore = LMDBBuilder::new() @@ -391,7 +382,7 @@ where .with_inbound_pipeline( ServiceBuilder::new() .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(connector)), + .service(connector), ) .build(); diff --git a/base_layer/p2p/src/lib.rs b/base_layer/p2p/src/lib.rs index c21dace083..93cb45d1db 100644 --- a/base_layer/p2p/src/lib.rs +++ b/base_layer/p2p/src/lib.rs @@ -20,8 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// Needed to make futures::select! work -#![recursion_limit = "256"] #![cfg_attr(not(debug_assertions), deny(unused_variables))] #![cfg_attr(not(debug_assertions), deny(unused_imports))] #![cfg_attr(not(debug_assertions), deny(dead_code))] diff --git a/base_layer/p2p/src/peer_seeds.rs b/base_layer/p2p/src/peer_seeds.rs index 24ce3c4dda..4698770fa1 100644 --- a/base_layer/p2p/src/peer_seeds.rs +++ b/base_layer/p2p/src/peer_seeds.rs @@ -120,6 +120,8 @@ mod test { use super::*; use tari_utilities::hex::Hex; + const TEST_NAME: &str = "test.local."; + mod peer_seed { use super::*; @@ -182,76 +184,88 @@ mod test { mod peer_seed_resolver { use super::*; - use std::{collections::HashMap, iter::FromIterator}; - use trust_dns_client::rr::{rdata, RData, Record, RecordType}; + use crate::dns::mock; + use trust_dns_client::{ + proto::{ + op::Query, + rr::{DNSClass, Name}, + xfer::DnsResponse, + }, + rr::{rdata, RData, Record, RecordType}, + }; #[ignore = "This test requires network IO and is mostly useful during development"] - #[tokio_macros::test] - async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { + #[tokio::test] + async fn it_returns_seeds_from_real_address() { let mut resolver = DnsSeedResolver { client: DnsClient::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(), }; - let seeds = resolver.resolve("tari.com").await.unwrap(); - assert!(seeds.is_empty()); + let seeds = resolver.resolve("seeds.weatherwax.tari.com").await.unwrap(); + assert!(!seeds.is_empty()); } - fn create_txt_record(contents: Vec<&str>) -> Record { + fn create_txt_record(contents: Vec<&str>) -> DnsResponse { + let mut resp_query = Query::query(Name::from_str(TEST_NAME).unwrap(), RecordType::TXT); + resp_query.set_query_class(DNSClass::IN); let mut record = Record::new(); record .set_record_type(RecordType::TXT) .set_rdata(RData::TXT(rdata::TXT::new( contents.into_iter().map(ToString::to_string).collect(), ))); - record + + mock::message(resp_query, vec![record], vec![], vec![]).into() } - #[tokio_macros::test] + #[tokio::test] async fn it_returns_peer_seeds() { - let records = HashMap::from_iter([("test.local.", vec![ + let records = vec![ // Multiple addresses(works) - create_txt_record(vec![ - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::/\ + Ok(create_txt_record(vec![ + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c::/ip4/127.0.0.1/tcp/8000::/\ onion3/bsmuof2cn4y2ysz253gzsvg3s72fcgh4f3qcm3hdlxdtcwe6al2dicyd:1234", - ]), + ])), // Misc - create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"]), + Ok(create_txt_record(vec!["v=spf1 include:_spf.spf.com ~all"])), // Single address (works) - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // Single address trailing delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000::", - ]), + ])), // Invalid public key - create_txt_record(vec![ + Ok(create_txt_record(vec![ "07e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/ip4/127.0.0.1/tcp/8000", - ]), + ])), // No Address with delim - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::", - ]), + ])), // No Address no delim - create_txt_record(vec!["06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a"]), + Ok(create_txt_record(vec![ + "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a", + ])), // Invalid address - create_txt_record(vec![ + Ok(create_txt_record(vec![ "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/onion3/invalid:1234", - ]), - ])]); + ])), + ]; let mut resolver = DnsSeedResolver { client: DnsClient::connect_mock(records).await.unwrap(), }; - let seeds = resolver.resolve("test.local.").await.unwrap(); + let seeds = resolver.resolve(TEST_NAME).await.unwrap(); assert_eq!(seeds.len(), 2); assert_eq!( seeds[0].public_key.to_hex(), - "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" + "fab24c542183073996ddf3a6c73ff8b8562fed351d252ec5cb8f269d1ad92f0c" ); + assert_eq!(seeds[0].addresses.len(), 2); assert_eq!( seeds[1].public_key.to_hex(), "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a" ); - assert_eq!(seeds[0].addresses.len(), 2); assert_eq!(seeds[1].addresses.len(), 1); } } diff --git a/base_layer/p2p/src/services/liveness/mock.rs b/base_layer/p2p/src/services/liveness/mock.rs index fa5f2f72c1..470cca84c2 100644 --- a/base_layer/p2p/src/services/liveness/mock.rs +++ b/base_layer/p2p/src/services/liveness/mock.rs @@ -36,9 +36,9 @@ use std::sync::{ RwLock, }; -use tari_crypto::tari_utilities::acquire_write_lock; +use tari_crypto::tari_utilities::{acquire_read_lock, acquire_write_lock}; use tari_service_framework::{reply_channel, reply_channel::RequestContext}; -use tokio::sync::{broadcast, broadcast::SendError}; +use tokio::sync::{broadcast, broadcast::error::SendError}; const LOG_TARGET: &str = "p2p::liveness_mock"; @@ -69,7 +69,8 @@ impl LivenessMockState { } pub async fn publish_event(&self, event: LivenessEvent) -> Result<(), SendError>> { - acquire_write_lock!(self.event_publisher).send(Arc::new(event))?; + let lock = acquire_read_lock!(self.event_publisher); + lock.send(Arc::new(event))?; Ok(()) } diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 65e85e0ea8..0f35122ea9 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -49,6 +49,7 @@ use tari_comms_dht::{ use tari_service_framework::reply_channel::RequestContext; use tari_shutdown::ShutdownSignal; use tokio::time; +use tokio_stream::wrappers; /// Service responsible for testing Liveness of Peers. pub struct LivenessService { @@ -59,7 +60,7 @@ pub struct LivenessService { connectivity: ConnectivityRequester, outbound_messaging: OutboundMessageRequester, event_publisher: LivenessEventSender, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl LivenessService @@ -85,7 +86,7 @@ where connectivity, outbound_messaging, event_publisher, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, config, } } @@ -100,39 +101,36 @@ where pin_mut!(request_stream); let mut ping_tick = match self.config.auto_ping_interval { - Some(interval) => Either::Left(time::interval_at((Instant::now() + interval).into(), interval)), + Some(interval) => Either::Left(wrappers::IntervalStream::new(time::interval_at( + (Instant::now() + interval).into(), + interval, + ))), None => Either::Right(futures::stream::iter(iter::empty())), - } - .fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("Liveness service initialized without shutdown signal"); + }; loop { - futures::select! { + tokio::select! { // Requests from the handle - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let _ = reply_tx.send(self.handle_request(request).await); }, // Tick events - _ = ping_tick.select_next_some() => { + Some(_) = ping_tick.next() => { if let Err(err) = self.start_ping_round().await { warn!(target: LOG_TARGET, "Error when pinging peers: {}", err); } }, // Incoming messages from the Comms layer - msg = ping_stream.select_next_some() => { + Some(msg) = ping_stream.next() => { if let Err(err) = self.handle_incoming_message(msg).await { warn!(target: LOG_TARGET, "Failed to handle incoming PingPong message: {}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Liveness service shutting down because the shutdown signal was received"); break; } @@ -143,11 +141,13 @@ where async fn handle_incoming_message(&mut self, msg: DomainMessage) -> Result<(), LivenessError> { let DomainMessage::<_> { source_peer, + dht_header, inner: ping_pong_msg, .. } = msg; let node_id = source_peer.node_id; let public_key = source_peer.public_key; + let message_tag = dht_header.message_tag; match ping_pong_msg.kind().ok_or(LivenessError::InvalidPingPongType)? { PingPong::Ping => { @@ -157,9 +157,10 @@ where debug!( target: LOG_TARGET, - "Received ping from peer '{}' with useragent '{}'", + "Received ping from peer '{}' with useragent '{}' (Trace: {})", node_id.short_str(), source_peer.user_agent, + message_tag, ); let ping_event = PingPongEvent::new(node_id, None, ping_pong_msg.metadata.into()); @@ -169,9 +170,10 @@ where if !self.state.is_inflight(ping_pong_msg.nonce) { debug!( target: LOG_TARGET, - "Received Pong that was not requested from '{}' with useragent {}. Ignoring it.", + "Received Pong that was not requested from '{}' with useragent {}. Ignoring it. (Trace: {})", node_id.short_str(), source_peer.user_agent, + message_tag, ); return Ok(()); } @@ -179,10 +181,11 @@ where let maybe_latency = self.state.record_pong(ping_pong_msg.nonce); debug!( target: LOG_TARGET, - "Received pong from peer '{}' with useragent '{}'. {}", + "Received pong from peer '{}' with useragent '{}'. {} (Trace: {})", node_id.short_str(), source_peer.user_agent, maybe_latency.map(|ms| format!("Latency: {}ms", ms)).unwrap_or_default(), + message_tag, ); let pong_event = PingPongEvent::new(node_id, maybe_latency, ping_pong_msg.metadata.into()); @@ -306,10 +309,7 @@ mod test { proto::liveness::MetadataKey, services::liveness::{handle::LivenessHandle, state::Metadata}, }; - use futures::{ - channel::{mpsc, oneshot}, - stream, - }; + use futures::stream; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ @@ -325,9 +325,12 @@ mod test { use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; use tari_shutdown::Shutdown; - use tokio::{sync::broadcast, task}; + use tokio::{ + sync::{broadcast, mpsc, oneshot}, + task, + }; - #[tokio_macros::test_basic] + #[tokio::test] async fn get_ping_pong_count() { let mut state = LivenessState::new(); state.inc_pings_received(); @@ -369,7 +372,7 @@ mod test { assert_eq!(res, 2); } - #[tokio_macros::test] + #[tokio::test] async fn send_ping() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); @@ -401,8 +404,9 @@ mod test { let node_id = NodeId::from_key(&pk); // Receive outbound request task::spawn(async move { - match outbound_rx.select_next_some().await { - DhtOutboundRequest::SendMessage(_, _, reply_tx) => { + #[allow(clippy::single_match)] + match outbound_rx.recv().await { + Some(DhtOutboundRequest::SendMessage(_, _, reply_tx)) => { let (_, rx) = oneshot::channel(); reply_tx .send(SendMessageResponse::Queued( @@ -410,6 +414,7 @@ mod test { )) .unwrap(); }, + None => {}, } }); @@ -445,7 +450,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn handle_message_ping() { let state = LivenessState::new(); @@ -478,10 +483,10 @@ mod test { task::spawn(service.run()); // Test oms got request to send message - unwrap_oms_send_msg!(outbound_rx.select_next_some().await); + unwrap_oms_send_msg!(outbound_rx.recv().await.unwrap()); } - #[tokio_macros::test_basic] + #[tokio::test] async fn handle_message_pong() { let mut state = LivenessState::new(); @@ -516,9 +521,9 @@ mod test { task::spawn(service.run()); // Listen for the pong event - let subscriber = publisher.subscribe(); + let mut subscriber = publisher.subscribe(); - let event = time::timeout(Duration::from_secs(10), subscriber.fuse().select_next_some()) + let event = time::timeout(Duration::from_secs(10), subscriber.recv()) .await .unwrap() .unwrap(); @@ -530,12 +535,12 @@ mod test { _ => panic!("Unexpected event"), } - shutdown.trigger().unwrap(); + shutdown.trigger(); // No further events (malicious_msg was ignored) - let mut subscriber = publisher.subscribe().fuse(); + let mut subscriber = publisher.subscribe(); drop(publisher); - let msg = subscriber.next().await; - assert!(msg.is_none()); + let msg = subscriber.recv().await; + assert!(msg.is_err()); } } diff --git a/base_layer/p2p/tests/services/liveness.rs b/base_layer/p2p/tests/services/liveness.rs index ab9a66cf59..2505c15543 100644 --- a/base_layer/p2p/tests/services/liveness.rs +++ b/base_layer/p2p/tests/services/liveness.rs @@ -35,16 +35,15 @@ use tari_p2p::{ }; use tari_service_framework::{RegisterHandle, StackBuilder}; use tari_shutdown::Shutdown; -use tari_test_utils::collect_stream; +use tari_test_utils::collect_try_recv; use tempfile::tempdir; -use tokio::runtime; pub async fn setup_liveness_service( node_identity: Arc, peers: Vec>, data_path: &str, ) -> (LivenessHandle, CommsNode, Dht, Shutdown) { - let (publisher, subscription_factory) = pubsub_connector(runtime::Handle::current(), 100, 20); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let shutdown = Shutdown::new(); let (comms, dht, _) = @@ -75,7 +74,7 @@ fn make_node_identity() -> Arc { )) } -#[tokio_macros::test_basic] +#[tokio::test] async fn end_to_end() { let node_1_identity = make_node_identity(); let node_2_identity = make_node_identity(); @@ -114,34 +113,34 @@ async fn end_to_end() { liveness1.send_ping(node_2_identity.node_id().clone()).await.unwrap(); } - let events = collect_stream!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20),); + let events = collect_try_recv!(liveness1_event_stream, take = 18, timeout = Duration::from_secs(20)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 10); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 8); - let events = collect_stream!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10),); + let events = collect_try_recv!(liveness2_event_stream, take = 18, timeout = Duration::from_secs(10)); let ping_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPing(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPing(_))) .count(); assert_eq!(ping_count, 8); let pong_count = events .iter() - .filter(|event| matches!(**(**event).as_ref().unwrap(), LivenessEvent::ReceivedPong(_))) + .filter(|event| matches!(&***event, LivenessEvent::ReceivedPong(_))) .count(); assert_eq!(pong_count, 10); diff --git a/base_layer/p2p/tests/support/comms_and_services.rs b/base_layer/p2p/tests/support/comms_and_services.rs index 40cc85a710..33bc8fdef7 100644 --- a/base_layer/p2p/tests/support/comms_and_services.rs +++ b/base_layer/p2p/tests/support/comms_and_services.rs @@ -20,27 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeIdentity, protocol::messaging::MessagingEventSender, CommsNode}; use tari_comms_dht::Dht; -use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, - initialization::initialize_local_test_comms, -}; +use tari_p2p::{comms_connector::InboundDomainConnector, initialization::initialize_local_test_comms}; use tari_shutdown::ShutdownSignal; -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, data_path: &str, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht, MessagingEventSender) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht, MessagingEventSender) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, messaging_events) = initialize_local_test_comms( node_identity, diff --git a/base_layer/service_framework/Cargo.toml b/base_layer/service_framework/Cargo.toml index 49e829c949..ead5e311f2 100644 --- a/base_layer/service_framework/Cargo.toml +++ b/base_layer/service_framework/Cargo.toml @@ -14,15 +14,15 @@ tari_shutdown = { version = "^0.9", path="../../infrastructure/shutdown" } anyhow = "1.0.32" async-trait = "0.1.50" -futures = { version = "^0.3.1", features=["async-await"]} +futures = { version = "^0.3.16", features=["async-await"]} log = "0.4.8" -thiserror = "1.0.20" -tokio = { version = "0.2.10" } +thiserror = "1.0.26" +tokio = {version="1.10", features=["rt"]} tower-service = { version="0.3.0" } [dev-dependencies] tari_test_utils = { version = "^0.9", path="../../infrastructure/test_utils" } +tokio = {version="1.10", features=["rt-multi-thread", "macros", "time"]} futures-test = { version = "0.3.3" } -tokio-macros = "0.2.5" tower = "0.3.1" diff --git a/base_layer/service_framework/examples/services/service_a.rs b/base_layer/service_framework/examples/services/service_a.rs index c898696415..dfbd9ace93 100644 --- a/base_layer/service_framework/examples/services/service_a.rs +++ b/base_layer/service_framework/examples/services/service_a.rs @@ -69,7 +69,7 @@ impl ServiceA { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service A API Request"); @@ -82,7 +82,7 @@ impl ServiceA { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service A shutting down because the shutdown signal was received"); break; } diff --git a/base_layer/service_framework/examples/services/service_b.rs b/base_layer/service_framework/examples/services/service_b.rs index decf53ab14..8e74408077 100644 --- a/base_layer/service_framework/examples/services/service_b.rs +++ b/base_layer/service_framework/examples/services/service_b.rs @@ -31,7 +31,7 @@ use tari_service_framework::{ ServiceInitializerContext, }; use tari_shutdown::ShutdownSignal; -use tokio::time::delay_for; +use tokio::time::sleep; use tower::Service; pub struct ServiceB { @@ -67,7 +67,7 @@ impl ServiceB { pin_mut!(request_stream); loop { - futures::select! { + tokio::select! { //Incoming request request_context = request_stream.select_next_some() => { println!("Handling Service B API Request"); @@ -76,7 +76,7 @@ impl ServiceB { response.push_str(request.clone().as_str()); let _ = reply_tx.send(response); }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { println!("Service B shutting down because the shutdown signal was received"); break; } @@ -134,7 +134,7 @@ impl ServiceInitializer for ServiceBInitializer { println!("Service B has shutdown and initializer spawned task is now ending"); }); - delay_for(Duration::from_secs(10)).await; + sleep(Duration::from_secs(10)).await; Ok(()) } } diff --git a/base_layer/service_framework/examples/stack_builder_example.rs b/base_layer/service_framework/examples/stack_builder_example.rs index 35fd785ed7..6f150796e7 100644 --- a/base_layer/service_framework/examples/stack_builder_example.rs +++ b/base_layer/service_framework/examples/stack_builder_example.rs @@ -25,9 +25,9 @@ use crate::services::{ServiceAHandle, ServiceAInitializer, ServiceBHandle, Servi use std::time::Duration; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; -use tokio::time::delay_for; +use tokio::time::sleep; -#[tokio_macros::main] +#[tokio::main] async fn main() { let mut shutdown = Shutdown::new(); let fut = StackBuilder::new(shutdown.to_signal()) @@ -40,7 +40,7 @@ async fn main() { let mut service_a_handle = handles.expect_handle::(); let mut service_b_handle = handles.expect_handle::(); - delay_for(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; println!("----------------------------------------------------"); let response_b = service_b_handle.send_msg("Hello B".to_string()).await; println!("Response from Service B: {}", response_b); @@ -51,5 +51,5 @@ async fn main() { let _ = shutdown.trigger(); - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } diff --git a/base_layer/service_framework/src/reply_channel.rs b/base_layer/service_framework/src/reply_channel.rs index 54cef8d90f..4921ed9601 100644 --- a/base_layer/service_framework/src/reply_channel.rs +++ b/base_layer/service_framework/src/reply_channel.rs @@ -20,26 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{ - mpsc::{self, SendError}, - oneshot, - }, - ready, - stream::FusedStream, - task::Context, - Future, - FutureExt, - Stream, - StreamExt, -}; +use futures::{ready, stream::FusedStream, task::Context, Future, FutureExt, Stream}; use std::{pin::Pin, task::Poll}; use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; use tower_service::Service; /// Create a new Requester/Responder pair which wraps and calls the given service pub fn unbounded() -> (SenderService, Receiver) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (SenderService::new(tx), Receiver::new(rx)) } @@ -81,20 +70,15 @@ impl Service for SenderService { type Future = TransportResponseFuture; type Response = TRes; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready(cx).map_err(|err| { - if err.is_disconnected() { - return TransportChannelError::ChannelClosed; - } - - unreachable!("unbounded channels can never be full"); - }) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + // An unbounded sender is always ready (i.e. will never wait to send) + Poll::Ready(Ok(())) } fn call(&mut self, request: TReq) -> Self::Future { let (tx, rx) = oneshot::channel(); - if self.tx.unbounded_send((request, tx)).is_ok() { + if self.tx.send((request, tx)).is_ok() { TransportResponseFuture::new(rx) } else { // We're not able to send (rx closed) so return a future which resolves to @@ -106,8 +90,6 @@ impl Service for SenderService { #[derive(Debug, Error, Eq, PartialEq, Clone)] pub enum TransportChannelError { - #[error("Error occurred when sending: `{0}`")] - SendError(#[from] SendError), #[error("Request was canceled")] Canceled, #[error("The response channel has closed")] @@ -188,23 +170,21 @@ impl RequestContext { } /// Receiver side of the reply channel. -/// This is functionally equivalent to `rx.map(|(req, reply_tx)| RequestContext::new(req, reply_tx))` -/// but is ergonomically better to use with the `futures::select` macro (implements FusedStream) -/// and has a short type signature. pub struct Receiver { rx: Rx, + is_closed: bool, } impl FusedStream for Receiver { fn is_terminated(&self) -> bool { - self.rx.is_terminated() + self.is_closed } } impl Receiver { // Create a new Responder pub fn new(rx: Rx) -> Self { - Self { rx } + Self { rx, is_closed: false } } pub fn close(&mut self) { @@ -216,10 +196,17 @@ impl Stream for Receiver { type Item = RequestContext; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.rx.poll_next_unpin(cx)) { + if self.is_terminated() { + return Poll::Ready(None); + } + + match ready!(self.rx.poll_recv(cx)) { Some((req, tx)) => Poll::Ready(Some(RequestContext::new(req, tx))), // Stream has closed, so we're done - None => Poll::Ready(None), + None => { + self.is_closed = true; + Poll::Ready(None) + }, } } } @@ -227,7 +214,7 @@ impl Stream for Receiver { #[cfg(test)] mod test { use super::*; - use futures::{executor::block_on, future}; + use futures::{executor::block_on, future, StreamExt}; use std::fmt::Debug; use tari_test_utils::unpack_enum; use tower::ServiceExt; @@ -247,7 +234,7 @@ mod test { async fn reply(mut rx: Rx, msg: TResp) where TResp: Debug { - match rx.next().await { + match rx.recv().await { Some((_, tx)) => { tx.send(msg).unwrap(); }, @@ -257,7 +244,7 @@ mod test { #[test] fn requestor_call() { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); let requestor = SenderService::<_, _>::new(tx); let fut = future::join(requestor.oneshot("PING"), reply(rx, "PONG")); diff --git a/base_layer/service_framework/src/stack.rs b/base_layer/service_framework/src/stack.rs index 2489006d56..ad38a94be3 100644 --- a/base_layer/service_framework/src/stack.rs +++ b/base_layer/service_framework/src/stack.rs @@ -103,7 +103,7 @@ mod test { use tari_shutdown::Shutdown; use tower::service_fn; - #[tokio_macros::test] + #[tokio::test] async fn service_defn_simple() { // This is less of a test and more of a demo of using the short-hand implementation of ServiceInitializer let simple_initializer = |_: ServiceInitializerContext| Ok(()); @@ -155,7 +155,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn service_stack_new() { let shared_state = Arc::new(AtomicUsize::new(0)); diff --git a/base_layer/tari_stratum_ffi/Cargo.toml b/base_layer/tari_stratum_ffi/Cargo.toml index 6598df9f06..9ed773d89f 100644 --- a/base_layer/tari_stratum_ffi/Cargo.toml +++ b/base_layer/tari_stratum_ffi/Cargo.toml @@ -14,7 +14,7 @@ tari_app_grpc = { path = "../../applications/tari_app_grpc" } tari_core = { path = "../../base_layer/core", default-features = false, features = ["transactions"]} tari_utilities = "^0.3" libc = "0.2.65" -thiserror = "1.0.20" +thiserror = "1.0.26" hex = "0.4.2" serde = { version="1.0.106", features = ["derive"] } serde_json = "1.0.57" diff --git a/base_layer/wallet/Cargo.toml b/base_layer/wallet/Cargo.toml index c2cf12ca22..639bdbd9fd 100644 --- a/base_layer/wallet/Cargo.toml +++ b/base_layer/wallet/Cargo.toml @@ -7,55 +7,53 @@ version = "0.9.5" edition = "2018" [dependencies] -tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} -tari_comms = { version = "^0.9", path = "../../comms"} +tari_common_types = { version = "^0.9", path = "../../base_layer/common_types" } +tari_comms = { version = "^0.9", path = "../../comms" } tari_comms_dht = { version = "^0.9", path = "../../comms/dht" } tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } -tari_service_framework = { version = "^0.9", path = "../service_framework"} +tari_service_framework = { version = "^0.9", path = "../service_framework" } tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } aes-gcm = "^0.8" blake2 = "0.9.0" -chrono = { version = "0.4.6", features = ["serde"]} +chrono = { version = "0.4.6", features = ["serde"] } crossbeam-channel = "0.3.8" digest = "0.9.0" -diesel = { version="1.4.7", features = ["sqlite", "serde_json", "chrono"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } fs2 = "0.3.0" -futures = { version = "^0.3.1", features =["compat", "std"]} -lazy_static = "1.4.0" +futures = { version = "^0.3.1", features = ["compat", "std"] } log = "0.4.6" -log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} +log4rs = { version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"] } lmdb-zero = "0.4.4" rand = "0.8" -serde = {version = "1.0.89", features = ["derive"] } +serde = { version = "1.0.89", features = ["derive"] } serde_json = "1.0.39" -tokio = { version = "0.2.10", features = ["blocking", "sync"]} +tokio = { version = "1.10", features = ["sync", "macros"] } tower = "0.3.0-alpha.2" tempfile = "3.1.0" -time = {version = "0.1.39"} -thiserror = "1.0.20" +time = { version = "0.1.39" } +thiserror = "1.0.26" bincode = "1.3.1" [dependencies.tari_core] path = "../../base_layer/core" version = "^0.9" default-features = false -features = ["transactions", "mempool_proto", "base_node_proto",] +features = ["transactions", "mempool_proto", "base_node_proto", ] [dev-dependencies] -tari_p2p = { version = "^0.9", path = "../p2p", features=["test-mocks"]} -tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features=["test-mocks"]} +tari_p2p = { version = "^0.9", path = "../p2p", features = ["test-mocks"] } +tari_comms_dht = { version = "^0.9", path = "../../comms/dht", features = ["test-mocks"] } tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } -lazy_static = "1.3.0" env_logger = "0.7.1" -prost = "0.6.1" -tokio-macros = "0.2.4" +prost = "0.8.0" [features] c_integration = [] avx2 = ["tari_crypto/avx2", "tari_core/avx2"] +bundled_sqlite = ["libsqlite3-sys"] diff --git a/base_layer/wallet/src/base_node_service/handle.rs b/base_layer/wallet/src/base_node_service/handle.rs index 4957823c72..f495479778 100644 --- a/base_layer/wallet/src/base_node_service/handle.rs +++ b/base_layer/wallet/src/base_node_service/handle.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::BaseNodeServiceError, service::BaseNodeState}; -use futures::{stream::Fuse, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::peer_manager::Peer; @@ -72,8 +71,8 @@ impl BaseNodeServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> BaseNodeEventReceiver { + self.event_stream_sender.subscribe() } pub async fn get_chain_metadata(&mut self) -> Result, BaseNodeServiceError> { diff --git a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs index 1bc57ed9d2..9aa981150d 100644 --- a/base_layer/wallet/src/base_node_service/mock_base_node_service.rs +++ b/base_layer/wallet/src/base_node_service/mock_base_node_service.rs @@ -20,13 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::{ - error::BaseNodeServiceError, - handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, - service::BaseNodeState, - }, - connectivity_service::OnlineStatus, +use crate::base_node_service::{ + error::BaseNodeServiceError, + handle::{BaseNodeServiceRequest, BaseNodeServiceResponse}, + service::BaseNodeState, }; use futures::StreamExt; use tari_common_types::chain_metadata::ChainMetadata; @@ -81,30 +78,28 @@ impl MockBaseNodeService { /// Set the mock server state, either online and synced to a specific height, or offline with None pub fn set_base_node_state(&mut self, height: Option) { - let (chain_metadata, is_synced, online) = match height { + let (chain_metadata, is_synced) = match height { Some(height) => { let metadata = ChainMetadata::new(height, Vec::new(), 0, 0, 0); - (Some(metadata), Some(true), OnlineStatus::Online) + (Some(metadata), Some(true)) }, - None => (None, None, OnlineStatus::Offline), + None => (None, None), }; self.state = BaseNodeState { chain_metadata, is_synced, updated: None, latency: None, - online, } } pub fn set_default_base_node_state(&mut self) { - let metadata = ChainMetadata::new(std::u64::MAX, Vec::new(), 0, 0, 0); + let metadata = ChainMetadata::new(u64::MAX, Vec::new(), 0, 0, 0); self.state = BaseNodeState { chain_metadata: Some(metadata), is_synced: Some(true), updated: None, latency: None, - online: OnlineStatus::Online, } } diff --git a/base_layer/wallet/src/base_node_service/monitor.rs b/base_layer/wallet/src/base_node_service/monitor.rs index 5a2c3a7e76..8e0298ca27 100644 --- a/base_layer/wallet/src/base_node_service/monitor.rs +++ b/base_layer/wallet/src/base_node_service/monitor.rs @@ -25,7 +25,7 @@ use crate::{ handle::{BaseNodeEvent, BaseNodeEventSender}, service::BaseNodeState, }, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, error::WalletStorageError, storage::database::{WalletBackend, WalletDatabase}, }; @@ -33,7 +33,7 @@ use chrono::Utc; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_common_types::chain_metadata::ChainMetadata; -use tari_comms::{peer_manager::NodeId, protocol::rpc::RpcError}; +use tari_comms::protocol::rpc::RpcError; use tokio::{sync::RwLock, time}; const LOG_TARGET: &str = "wallet::base_node_service::chain_metadata_monitor"; @@ -78,9 +78,6 @@ impl BaseNodeMonitor { }, Err(e @ BaseNodeMonitorError::RpcFailed(_)) => { warn!(target: LOG_TARGET, "Connectivity failure to base node: {}", e); - debug!(target: LOG_TARGET, "Setting as OFFLINE and retrying...",); - - self.set_offline().await; continue; }, Err(e @ BaseNodeMonitorError::InvalidBaseNodeResponse(_)) | @@ -96,34 +93,19 @@ impl BaseNodeMonitor { ); } - async fn update_connectivity_status(&self) -> NodeId { - let mut watcher = self.wallet_connectivity.get_connectivity_status_watch(); - loop { - use OnlineStatus::*; - match watcher.recv().await.unwrap_or(Offline) { - Online => match self.wallet_connectivity.get_current_base_node_id() { - Some(node_id) => return node_id, - _ => continue, - }, - Connecting => { - self.set_connecting().await; - }, - Offline => { - self.set_offline().await; - }, - } - } - } - async fn monitor_node(&mut self) -> Result<(), BaseNodeMonitorError> { loop { - let peer_node_id = self.update_connectivity_status().await; let mut client = self .wallet_connectivity .obtain_base_node_wallet_rpc_client() .await .ok_or(BaseNodeMonitorError::NodeShuttingDown)?; + let base_node_id = match self.wallet_connectivity.get_current_base_node_id() { + Some(n) => n, + None => continue, + }; + let tip_info = client.get_tip_info().await?; let chain_metadata = tip_info @@ -138,7 +120,7 @@ impl BaseNodeMonitor { debug!( target: LOG_TARGET, "Base node {} Tip: {} ({}) Latency: {} ms", - peer_node_id, + base_node_id, chain_metadata.height_of_longest_chain(), if is_synced { "Synced" } else { "Syncing..." }, latency.as_millis() @@ -151,11 +133,10 @@ impl BaseNodeMonitor { is_synced: Some(is_synced), updated: Some(Utc::now().naive_utc()), latency: Some(latency), - online: OnlineStatus::Online, }) .await; - time::delay_for(self.interval).await + time::sleep(self.interval).await } // loop only exits on shutdown/error @@ -163,28 +144,6 @@ impl BaseNodeMonitor { Ok(()) } - async fn set_connecting(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Connecting, - }) - .await; - } - - async fn set_offline(&self) { - self.map_state(|_| BaseNodeState { - chain_metadata: None, - is_synced: None, - updated: Some(Utc::now().naive_utc()), - latency: None, - online: OnlineStatus::Offline, - }) - .await; - } - async fn map_state(&self, transform: F) where F: FnOnce(&BaseNodeState) -> BaseNodeState { let new_state = { diff --git a/base_layer/wallet/src/base_node_service/service.rs b/base_layer/wallet/src/base_node_service/service.rs index 3da987c8b1..eb2b91ebda 100644 --- a/base_layer/wallet/src/base_node_service/service.rs +++ b/base_layer/wallet/src/base_node_service/service.rs @@ -27,7 +27,7 @@ use super::{ }; use crate::{ base_node_service::monitor::BaseNodeMonitor, - connectivity_service::{OnlineStatus, WalletConnectivityHandle}, + connectivity_service::WalletConnectivityHandle, storage::database::{WalletBackend, WalletDatabase}, }; use chrono::NaiveDateTime; @@ -49,8 +49,6 @@ pub struct BaseNodeState { pub is_synced: Option, pub updated: Option, pub latency: Option, - pub online: OnlineStatus, - // pub base_node_peer: Option, } impl Default for BaseNodeState { @@ -60,7 +58,6 @@ impl Default for BaseNodeState { is_synced: None, updated: None, latency: None, - online: OnlineStatus::Connecting, } } } diff --git a/base_layer/wallet/src/config.rs b/base_layer/wallet/src/config.rs index cd17024068..3844fc13e7 100644 --- a/base_layer/wallet/src/config.rs +++ b/base_layer/wallet/src/config.rs @@ -20,14 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::time::Duration; + +use tari_core::{consensus::NetworkConsensus, transactions::CryptoFactories}; +use tari_p2p::initialization::CommsConfig; + use crate::{ base_node_service::config::BaseNodeServiceConfig, output_manager_service::config::OutputManagerServiceConfig, transaction_service::config::TransactionServiceConfig, }; -use std::time::Duration; -use tari_core::{consensus::NetworkConsensus, transactions::types::CryptoFactories}; -use tari_p2p::initialization::CommsConfig; pub const KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY: &str = "comms"; diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index ac218edc5e..5a35696e14 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -22,16 +22,12 @@ use super::service::OnlineStatus; use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use tari_comms::{ peer_manager::{NodeId, Peer}, protocol::rpc::RpcClientLease, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::sync::watch; +use tokio::sync::{mpsc, oneshot, watch}; pub enum WalletConnectivityRequest { ObtainBaseNodeWalletRpcClient(oneshot::Sender>), @@ -102,8 +98,8 @@ impl WalletConnectivityHandle { reply_rx.await.ok() } - pub async fn get_connectivity_status(&mut self) -> OnlineStatus { - self.online_status_rx.recv().await.unwrap_or(OnlineStatus::Offline) + pub fn get_connectivity_status(&mut self) -> OnlineStatus { + *self.online_status_rx.borrow() } pub fn get_connectivity_status_watch(&self) -> watch::Receiver { diff --git a/base_layer/wallet/src/connectivity_service/initializer.rs b/base_layer/wallet/src/connectivity_service/initializer.rs index d0c2b94126..1610a834e3 100644 --- a/base_layer/wallet/src/connectivity_service/initializer.rs +++ b/base_layer/wallet/src/connectivity_service/initializer.rs @@ -30,8 +30,8 @@ use super::{handle::WalletConnectivityHandle, service::WalletConnectivityService, watch::Watch}; use crate::{base_node_service::config::BaseNodeServiceConfig, connectivity_service::service::OnlineStatus}; -use futures::channel::mpsc; use tari_service_framework::{async_trait, ServiceInitializationError, ServiceInitializer, ServiceInitializerContext}; +use tokio::sync::mpsc; pub struct WalletConnectivityInitializer { config: BaseNodeServiceConfig, @@ -59,8 +59,13 @@ impl ServiceInitializer for WalletConnectivityInitializer { context.spawn_until_shutdown(move |handles| { let connectivity = handles.expect_handle(); - let service = - WalletConnectivityService::new(config, receiver, base_node_watch, online_status_watch, connectivity); + let service = WalletConnectivityService::new( + config, + receiver, + base_node_watch.get_receiver(), + online_status_watch, + connectivity, + ); service.start() }); diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index c0cf474b96..950b9a9a72 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -24,15 +24,8 @@ use crate::{ base_node_service::config::BaseNodeServiceConfig, connectivity_service::{error::WalletConnectivityError, handle::WalletConnectivityRequest, watch::Watch}, }; -use core::mem; -use futures::{ - channel::{mpsc, oneshot}, - future, - future::Either, - stream::Fuse, - StreamExt, -}; use log::*; +use std::{mem, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeId, Peer}, @@ -40,7 +33,11 @@ use tari_comms::{ PeerConnection, }; use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot, watch}, + time, + time::MissedTickBehavior, +}; const LOG_TARGET: &str = "wallet::connectivity"; @@ -54,9 +51,9 @@ pub enum OnlineStatus { pub struct WalletConnectivityService { config: BaseNodeServiceConfig, - request_stream: Fuse>, + request_stream: mpsc::Receiver, connectivity: ConnectivityRequester, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, pools: Option, online_status_watch: Watch, pending_requests: Vec, @@ -71,13 +68,13 @@ impl WalletConnectivityService { pub(super) fn new( config: BaseNodeServiceConfig, request_stream: mpsc::Receiver, - base_node_watch: Watch>, + base_node_watch: watch::Receiver>, online_status_watch: Watch, connectivity: ConnectivityRequester, ) -> Self { Self { config, - request_stream: request_stream.fuse(), + request_stream, connectivity, base_node_watch, pools: None, @@ -88,22 +85,41 @@ impl WalletConnectivityService { pub async fn start(mut self) { debug!(target: LOG_TARGET, "Wallet connectivity service has started."); - let mut base_node_watch_rx = self.base_node_watch.get_receiver().fuse(); + let mut check_connection = + time::interval_at(time::Instant::now() + Duration::from_secs(5), Duration::from_secs(5)); + check_connection.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - req = self.request_stream.select_next_some() => { - self.handle_request(req).await; - }, - maybe_peer = base_node_watch_rx.select_next_some() => { - if maybe_peer.is_some() { + tokio::select! { + // BIASED: select branches are in order of priority + biased; + + Ok(_) = self.base_node_watch.changed() => { + if self.base_node_watch.borrow().is_some() { // This will block the rest until the connection is established. This is what we want. self.setup_base_node_connection().await; } + }, + + Some(req) = self.request_stream.recv() => { + self.handle_request(req).await; + }, + + _ = check_connection.tick() => { + self.check_connection().await; } } } } + async fn check_connection(&mut self) { + if let Some(pool) = self.pools.as_ref() { + if !pool.base_node_wallet_rpc_client.is_connected().await { + debug!(target: LOG_TARGET, "Peer connection lost. Attempting to reconnect..."); + self.setup_base_node_connection().await; + } + } + } + async fn handle_request(&mut self, request: WalletConnectivityRequest) { use WalletConnectivityRequest::*; match request { @@ -138,7 +154,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -169,7 +184,6 @@ impl WalletConnectivityService { target: LOG_TARGET, "Base node connection failed: {}. Reconnecting...", e ); - self.trigger_reconnect(); self.pending_requests.push(reply.into()); }, }, @@ -186,21 +200,6 @@ impl WalletConnectivityService { } } - fn trigger_reconnect(&mut self) { - let peer = self - .base_node_watch - .borrow() - .clone() - .expect("trigger_reconnect called before base node is set"); - // Trigger the watch so that a peer connection is reinitiated - self.set_base_node_peer(peer); - } - - fn set_base_node_peer(&mut self, peer: Peer) { - self.pools = None; - self.base_node_watch.broadcast(Some(peer)); - } - fn current_base_node(&self) -> Option { self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone()) } @@ -236,8 +235,8 @@ impl WalletConnectivityService { } else { self.set_online_status(OnlineStatus::Offline); } - error!(target: LOG_TARGET, "{}", e); - time::delay_for(self.config.base_node_monitor_refresh_interval).await; + warn!(target: LOG_TARGET, "{}", e); + time::sleep(self.config.base_node_monitor_refresh_interval).await; continue; }, } @@ -275,13 +274,15 @@ impl WalletConnectivityService { } async fn try_dial_peer(&mut self, peer: NodeId) -> Result, WalletConnectivityError> { - let recv_fut = self.base_node_watch.recv(); - futures::pin_mut!(recv_fut); - let dial_fut = self.connectivity.dial_peer(peer); - futures::pin_mut!(dial_fut); - match future::select(recv_fut, dial_fut).await { - Either::Left(_) => Ok(None), - Either::Right((conn, _)) => Ok(Some(conn?)), + tokio::select! { + biased; + + _ = self.base_node_watch.changed() => { + Ok(None) + } + result = self.connectivity.dial_peer(peer) => { + Ok(Some(result?)) + } } } @@ -307,8 +308,8 @@ impl ReplyOneshot { pub fn is_canceled(&self) -> bool { use ReplyOneshot::*; match self { - WalletRpc(tx) => tx.is_canceled(), - SyncRpc(tx) => tx.is_canceled(), + WalletRpc(tx) => tx.is_closed(), + SyncRpc(tx) => tx.is_closed(), } } } diff --git a/base_layer/wallet/src/connectivity_service/test.rs b/base_layer/wallet/src/connectivity_service/test.rs index 7c24ef5b46..9a8f5a2da9 100644 --- a/base_layer/wallet/src/connectivity_service/test.rs +++ b/base_layer/wallet/src/connectivity_service/test.rs @@ -23,7 +23,7 @@ use super::service::WalletConnectivityService; use crate::connectivity_service::{watch::Watch, OnlineStatus, WalletConnectivityHandle}; use core::convert; -use futures::{channel::mpsc, future}; +use futures::future; use std::{iter, sync::Arc}; use tari_comms::{ peer_manager::PeerFeatures, @@ -39,7 +39,10 @@ use tari_comms::{ }; use tari_shutdown::Shutdown; use tari_test_utils::runtime::spawn_until_shutdown; -use tokio::{sync::Barrier, task}; +use tokio::{ + sync::{mpsc, Barrier}, + task, +}; async fn setup() -> ( WalletConnectivityHandle, @@ -57,7 +60,7 @@ async fn setup() -> ( let service = WalletConnectivityService::new( Default::default(), rx, - base_node_watch, + base_node_watch.get_receiver(), online_status_watch, connectivity, ); @@ -70,7 +73,7 @@ async fn setup() -> ( (handle, mock_server, mock_state, shutdown) } -#[tokio_macros::test] +#[tokio::test] async fn it_dials_peer_when_base_node_is_set() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -92,7 +95,7 @@ async fn it_dials_peer_when_base_node_is_set() { assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_resolves_many_pending_rpc_session_requests() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -122,7 +125,7 @@ async fn it_resolves_many_pending_rpc_session_requests() { assert!(results.into_iter().map(Result::unwrap).all(convert::identity)); } -#[tokio_macros::test] +#[tokio::test] async fn it_changes_to_a_new_base_node() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -138,7 +141,7 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer1.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); + assert!(mock_state.count_calls_containing("AddManagedPeer").await >= 1); let _ = mock_state.take_calls().await; let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); @@ -149,13 +152,12 @@ async fn it_changes_to_a_new_base_node() { mock_state.await_call_count(2).await; mock_state.expect_dial_peer(base_node_peer2.node_id()).await; - assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2); let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap(); assert!(rpc_client.is_connected()); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_connect_fail_reconnect() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -198,7 +200,7 @@ async fn it_gracefully_handles_connect_fail_reconnect() { pending_request.await.unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn it_gracefully_handles_multiple_connection_failures() { let (mut handle, mock_server, mock_state, _shutdown) = setup().await; let base_node_peer = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/src/connectivity_service/watch.rs b/base_layer/wallet/src/connectivity_service/watch.rs index 1f1e868d47..4669b355f6 100644 --- a/base_layer/wallet/src/connectivity_service/watch.rs +++ b/base_layer/wallet/src/connectivity_service/watch.rs @@ -26,24 +26,19 @@ use tokio::sync::watch; #[derive(Clone)] pub struct Watch(Arc>, watch::Receiver); -impl Watch { +impl Watch { pub fn new(initial: T) -> Self { let (tx, rx) = watch::channel(initial); Self(Arc::new(tx), rx) } - #[allow(dead_code)] - pub async fn recv(&mut self) -> Option { - self.receiver_mut().recv().await - } - pub fn borrow(&self) -> watch::Ref<'_, T> { self.receiver().borrow() } pub fn broadcast(&self, item: T) { - // SAFETY: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime - if self.sender().broadcast(item).is_err() { + // PANIC: broadcast becomes infallible because the receiver is owned in Watch and so has the same lifetime + if self.sender().send(item).is_err() { // Result::expect requires E: fmt::Debug and `watch::SendError` is not, this is equivalent panic!("watch internal receiver is dropped"); } @@ -53,10 +48,6 @@ impl Watch { &self.0 } - fn receiver_mut(&mut self) -> &mut watch::Receiver { - &mut self.1 - } - pub fn receiver(&self) -> &watch::Receiver { &self.1 } diff --git a/base_layer/wallet/src/contacts_service/service.rs b/base_layer/wallet/src/contacts_service/service.rs index f86f6b49cc..cfc8473202 100644 --- a/base_layer/wallet/src/contacts_service/service.rs +++ b/base_layer/wallet/src/contacts_service/service.rs @@ -76,8 +76,8 @@ where T: ContactsBackend + 'static info!(target: LOG_TARGET, "Contacts Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { + tokio::select! { + Some(request_context) = request_stream.next() => { let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { error!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -88,14 +88,10 @@ where T: ContactsBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Contacts service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Contacts service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Contacts Service ended"); diff --git a/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs b/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs index 8fc8234f56..b32f798777 100644 --- a/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/contacts_service/storage/sqlite_db.rs @@ -30,7 +30,7 @@ use crate::{ }; use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; use std::convert::TryFrom; -use tari_core::transactions::types::PublicKey; +use tari_common_types::types::PublicKey; use tari_crypto::tari_utilities::ByteArray; /// A Sqlite backend for the Output Manager Service. The Backend is accessed via a connection pool to the Sqlite file. @@ -192,7 +192,7 @@ mod test { use diesel::{Connection, SqliteConnection}; use rand::rngs::OsRng; use std::convert::TryFrom; - use tari_core::transactions::types::{PrivateKey, PublicKey}; + use tari_common_types::types::{PrivateKey, PublicKey}; use tari_crypto::{ keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, tari_utilities::ByteArray, diff --git a/base_layer/wallet/src/lib.rs b/base_layer/wallet/src/lib.rs index bc0b4c1a04..22bce8bfdb 100644 --- a/base_layer/wallet/src/lib.rs +++ b/base_layer/wallet/src/lib.rs @@ -25,8 +25,6 @@ pub mod wallet; extern crate diesel; #[macro_use] extern crate diesel_migrations; -#[macro_use] -extern crate lazy_static; mod config; pub mod schema; diff --git a/base_layer/wallet/src/output_manager_service/handle.rs b/base_layer/wallet/src/output_manager_service/handle.rs index 659fab4a42..54b082c900 100644 --- a/base_layer/wallet/src/output_manager_service/handle.rs +++ b/base_layer/wallet/src/output_manager_service/handle.rs @@ -31,14 +31,13 @@ use crate::{ types::ValidationRetryStrategy, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc, time::Duration}; +use tari_common_types::types::PublicKey; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{Transaction, TransactionInput, TransactionOutput, UnblindedOutput}, transaction_protocol::sender::TransactionSenderMessage, - types::PublicKey, ReceiverTransactionProtocol, SenderTransactionProtocol, }; @@ -191,8 +190,8 @@ impl OutputManagerHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> OutputManagerEventReceiver { + self.event_stream_sender.subscribe() } pub async fn add_output(&mut self, output: UnblindedOutput) -> Result<(), OutputManagerError> { diff --git a/base_layer/wallet/src/output_manager_service/master_key_manager.rs b/base_layer/wallet/src/output_manager_service/master_key_manager.rs index 7b315569dc..4f33a909cf 100644 --- a/base_layer/wallet/src/output_manager_service/master_key_manager.rs +++ b/base_layer/wallet/src/output_manager_service/master_key_manager.rs @@ -30,10 +30,8 @@ use crate::{ }; use futures::lock::Mutex; use log::*; -use tari_core::transactions::{ - transaction_protocol::RewindData, - types::{PrivateKey, PublicKey}, -}; +use tari_common_types::types::{PrivateKey, PublicKey}; +use tari_core::transactions::transaction_protocol::RewindData; use tari_crypto::{keys::PublicKey as PublicKeyTrait, range_proof::REWIND_USER_MESSAGE_LENGTH}; use tari_key_manager::{ key_manager::KeyManager, diff --git a/base_layer/wallet/src/output_manager_service/mod.rs b/base_layer/wallet/src/output_manager_service/mod.rs index ce6fd70699..80f02f2445 100644 --- a/base_layer/wallet/src/output_manager_service/mod.rs +++ b/base_layer/wallet/src/output_manager_service/mod.rs @@ -20,22 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::handle::BaseNodeServiceHandle, - output_manager_service::{ - config::OutputManagerServiceConfig, - handle::OutputManagerHandle, - service::OutputManagerService, - storage::database::{OutputManagerBackend, OutputManagerDatabase}, - }, - transaction_service::handle::TransactionServiceHandle, -}; use futures::future; use log::*; +use tokio::sync::broadcast; + +pub(crate) use master_key_manager::MasterKeyManager; use tari_comms::{connectivity::ConnectivityRequester, types::CommsSecretKey}; use tari_core::{ consensus::{ConsensusConstantsBuilder, NetworkConsensus}, - transactions::types::CryptoFactories, + transactions::CryptoFactories, }; use tari_service_framework::{ async_trait, @@ -44,7 +37,18 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; +pub use tasks::TxoValidationType; + +use crate::{ + base_node_service::handle::BaseNodeServiceHandle, + output_manager_service::{ + config::OutputManagerServiceConfig, + handle::OutputManagerHandle, + service::OutputManagerService, + storage::database::{OutputManagerBackend, OutputManagerDatabase}, + }, + transaction_service::handle::TransactionServiceHandle, +}; pub mod config; pub mod error; @@ -57,9 +61,6 @@ pub mod service; pub mod storage; mod tasks; -pub(crate) use master_key_manager::MasterKeyManager; -pub use tasks::TxoValidationType; - const LOG_TARGET: &str = "wallet::output_manager_service::initializer"; pub type TxId = u64; diff --git a/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs b/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs index 64e4d510d2..1da885d7f1 100644 --- a/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs +++ b/base_layer/wallet/src/output_manager_service/recovery/standard_outputs_recoverer.rs @@ -20,6 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use log::*; +use tari_crypto::{inputs, keys::PublicKey as PublicKeyTrait, tari_utilities::hex::Hex}; + +use tari_common_types::types::PublicKey; +use tari_core::transactions::{ + transaction::{TransactionOutput, UnblindedOutput}, + CryptoFactories, +}; + use crate::output_manager_service::{ error::OutputManagerError, storage::{ @@ -28,13 +39,6 @@ use crate::output_manager_service::{ }, MasterKeyManager, }; -use log::*; -use std::sync::Arc; -use tari_core::transactions::{ - transaction::{TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, PublicKey}, -}; -use tari_crypto::{inputs, keys::PublicKey as PublicKeyTrait, tari_utilities::hex::Hex}; const LOG_TARGET: &str = "wallet::output_manager_service::recovery"; diff --git a/base_layer/wallet/src/output_manager_service/resources.rs b/base_layer/wallet/src/output_manager_service/resources.rs index d6e17b570b..f094b0b79c 100644 --- a/base_layer/wallet/src/output_manager_service/resources.rs +++ b/base_layer/wallet/src/output_manager_service/resources.rs @@ -20,6 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::sync::Arc; + +use tari_comms::{connectivity::ConnectivityRequester, types::CommsPublicKey}; +use tari_core::{consensus::ConsensusConstants, transactions::CryptoFactories}; +use tari_shutdown::ShutdownSignal; + use crate::{ output_manager_service::{ config::OutputManagerServiceConfig, @@ -29,10 +35,6 @@ use crate::{ }, transaction_service::handle::TransactionServiceHandle, }; -use std::sync::Arc; -use tari_comms::{connectivity::ConnectivityRequester, types::CommsPublicKey}; -use tari_core::{consensus::ConsensusConstants, transactions::types::CryptoFactories}; -use tari_shutdown::ShutdownSignal; /// This struct is a collection of the common resources that a async task in the service requires. #[derive(Clone)] diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index c8946b9ff9..bb1a98f505 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -20,38 +20,30 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::handle::BaseNodeServiceHandle, - output_manager_service::{ - config::OutputManagerServiceConfig, - error::{OutputManagerError, OutputManagerProtocolError, OutputManagerStorageError}, - handle::{OutputManagerEventSender, OutputManagerRequest, OutputManagerResponse}, - recovery::StandardUtxoRecoverer, - resources::OutputManagerResources, - storage::{ - database::{OutputManagerBackend, OutputManagerDatabase, PendingTransactionOutputs}, - models::{DbUnblindedOutput, KnownOneSidedPaymentScript}, - }, - tasks::{TxoValidationTask, TxoValidationType}, - MasterKeyManager, - TxId, - }, - transaction_service::handle::TransactionServiceHandle, - types::{HashDigest, ValidationRetryStrategy}, +use std::{ + cmp::Ordering, + collections::HashMap, + fmt::{self, Display}, + sync::Arc, + time::Duration, }; + use blake2::Digest; use chrono::Utc; use diesel::result::{DatabaseErrorKind, Error as DieselError}; use futures::{pin_mut, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{ - cmp::Ordering, - collections::HashMap, - fmt::{self, Display}, - sync::Arc, - time::Duration, +use tari_crypto::{ + inputs, + keys::{DiffieHellmanSharedSecret, PublicKey as PublicKeyTrait, SecretKey}, + script, + script::TariScript, + tari_utilities::{hex::Hex, ByteArray}, }; +use tokio::sync::broadcast; + +use tari_common_types::types::{PrivateKey, PublicKey}; use tari_comms::{ connectivity::ConnectivityRequester, types::{CommsPublicKey, CommsSecretKey}, @@ -70,22 +62,34 @@ use tari_core::{ UnblindedOutput, }, transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, PrivateKey, PublicKey}, CoinbaseBuilder, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }, }; -use tari_crypto::{ - inputs, - keys::{DiffieHellmanSharedSecret, PublicKey as PublicKeyTrait, SecretKey}, - script, - script::TariScript, - tari_utilities::{hex::Hex, ByteArray}, -}; use tari_service_framework::reply_channel; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; + +use crate::{ + base_node_service::handle::BaseNodeServiceHandle, + output_manager_service::{ + config::OutputManagerServiceConfig, + error::{OutputManagerError, OutputManagerProtocolError, OutputManagerStorageError}, + handle::{OutputManagerEventSender, OutputManagerRequest, OutputManagerResponse}, + recovery::StandardUtxoRecoverer, + resources::OutputManagerResources, + storage::{ + database::{OutputManagerBackend, OutputManagerDatabase, PendingTransactionOutputs}, + models::{DbUnblindedOutput, KnownOneSidedPaymentScript}, + }, + tasks::{TxoValidationTask, TxoValidationType}, + MasterKeyManager, + TxId, + }, + transaction_service::handle::TransactionServiceHandle, + types::{HashDigest, ValidationRetryStrategy}, +}; const LOG_TARGET: &str = "wallet::output_manager_service"; const LOG_TARGET_STRESS: &str = "stress_test::output_manager_service"; @@ -166,9 +170,9 @@ where TBackend: OutputManagerBackend + 'static info!(target: LOG_TARGET, "Output Manager Service started"); loop { - futures::select! { - request_context = request_stream.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Service API Request"); + tokio::select! { + Some(request_context) = request_stream.next() => { + trace!(target: LOG_TARGET, "Handling Service API Request"); let (request, reply_tx) = request_context.split(); let response = self.handle_request(request).await.map_err(|e| { warn!(target: LOG_TARGET, "Error handling request: {:?}", e); @@ -179,14 +183,10 @@ where TBackend: OutputManagerBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Output manager service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Output manager service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Output Manager Service ended"); diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 52d552e016..7344550c63 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -35,11 +35,8 @@ use std::{ sync::Arc, time::Duration, }; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::TransactionOutput, - types::{BlindingFactor, Commitment, PrivateKey}, -}; +use tari_common_types::types::{BlindingFactor, Commitment, PrivateKey}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::TransactionOutput}; const LOG_TARGET: &str = "wallet::output_manager_service::database"; diff --git a/base_layer/wallet/src/output_manager_service/storage/models.rs b/base_layer/wallet/src/output_manager_service/storage/models.rs index dd36eb6934..e0f00a0569 100644 --- a/base_layer/wallet/src/output_manager_service/storage/models.rs +++ b/base_layer/wallet/src/output_manager_service/storage/models.rs @@ -20,17 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::output_manager_service::error::OutputManagerStorageError; use std::cmp::Ordering; + +use tari_crypto::script::{ExecutionStack, TariScript}; + +use tari_common_types::types::{Commitment, HashOutput, PrivateKey}; use tari_core::{ tari_utilities::hash::Hashable, - transactions::{ - transaction::UnblindedOutput, - transaction_protocol::RewindData, - types::{Commitment, CryptoFactories, HashOutput, PrivateKey}, - }, + transactions::{transaction::UnblindedOutput, transaction_protocol::RewindData, CryptoFactories}, }; -use tari_crypto::script::{ExecutionStack, TariScript}; + +use crate::output_manager_service::error::OutputManagerStorageError; #[derive(Debug, Clone)] pub struct DbUnblindedOutput { diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index 665c408bdf..052bad580b 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -20,6 +20,37 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + convert::TryFrom, + str::from_utf8, + sync::{Arc, RwLock}, + time::Duration, +}; + +use aes_gcm::{aead::Error as AeadError, Aes256Gcm, Error}; +use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; +use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; +use log::*; +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + script::{ExecutionStack, TariScript}, + tari_utilities::{ + hex::{from_hex, Hex}, + ByteArray, + }, +}; + +use tari_common_types::types::{ComSignature, Commitment, PrivateKey, PublicKey}; +use tari_core::{ + tari_utilities::hash::Hashable, + transactions::{ + tari_amount::MicroTari, + transaction::{OutputFeatures, OutputFlags, TransactionOutput, UnblindedOutput}, + CryptoFactories, + }, +}; + use crate::{ output_manager_service::{ error::OutputManagerStorageError, @@ -41,33 +72,6 @@ use crate::{ storage::sqlite_utilities::WalletDbConnection, util::encryption::{decrypt_bytes_integral_nonce, encrypt_bytes_integral_nonce, Encryptable}, }; -use aes_gcm::{aead::Error as AeadError, Aes256Gcm, Error}; -use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; -use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; -use log::*; -use std::{ - collections::HashMap, - convert::TryFrom, - str::from_utf8, - sync::{Arc, RwLock}, - time::Duration, -}; -use tari_core::{ - tari_utilities::hash::Hashable, - transactions::{ - tari_amount::MicroTari, - transaction::{OutputFeatures, OutputFlags, TransactionOutput, UnblindedOutput}, - types::{ComSignature, Commitment, CryptoFactories, PrivateKey, PublicKey}, - }, -}; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - script::{ExecutionStack, TariScript}, - tari_utilities::{ - hex::{from_hex, Hex}, - ByteArray, - }, -}; const LOG_TARGET: &str = "wallet::output_manager_service::database::sqlite_db"; @@ -1714,6 +1718,27 @@ impl Encryptable for KnownOneSidedPaymentScriptSql { #[cfg(test)] mod test { + use std::{convert::TryFrom, time::Duration}; + + use aes_gcm::{ + aead::{generic_array::GenericArray, NewAead}, + Aes256Gcm, + }; + use chrono::{Duration as ChronoDuration, Utc}; + use diesel::{Connection, SqliteConnection}; + use rand::{rngs::OsRng, RngCore}; + use tari_crypto::{keys::SecretKey, script}; + use tempfile::tempdir; + + use tari_common_types::types::{CommitmentFactory, PrivateKey}; + use tari_core::transactions::{ + helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, + tari_amount::MicroTari, + transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, + CryptoFactories, + }; + use tari_test_utils::random; + use crate::{ output_manager_service::storage::{ database::{DbKey, KeyManagerState, OutputManagerBackend}, @@ -1731,23 +1756,6 @@ mod test { storage::sqlite_utilities::WalletDbConnection, util::encryption::Encryptable, }; - use aes_gcm::{ - aead::{generic_array::GenericArray, NewAead}, - Aes256Gcm, - }; - use chrono::{Duration as ChronoDuration, Utc}; - use diesel::{Connection, SqliteConnection}; - use rand::{rngs::OsRng, RngCore}; - use std::{convert::TryFrom, time::Duration}; - use tari_core::transactions::{ - helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, - tari_amount::MicroTari, - transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, - types::{CommitmentFactory, CryptoFactories, PrivateKey}, - }; - use tari_crypto::{keys::SecretKey, script}; - use tari_test_utils::random; - use tempfile::tempdir; pub fn make_input(val: MicroTari) -> (TransactionInput, UnblindedOutput) { let test_params = TestParamsHelpers::new(); diff --git a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs index e3e022e4fb..e08059e16b 100644 --- a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs +++ b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs @@ -30,17 +30,18 @@ use crate::{ transaction_service::storage::models::TransactionStatus, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, collections::HashMap, convert::TryFrom, fmt, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; use tari_core::{ base_node::rpc::BaseNodeWalletRpcClient, proto::base_node::FetchMatchingUtxos, - transactions::{transaction::TransactionOutput, types::Signature}, + transactions::transaction::TransactionOutput, }; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::output_manager_service::utxo_validation_task"; @@ -87,21 +88,17 @@ where TBackend: OutputManagerBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - OutputManagerProtocolError::new( - self.id, - OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), - ) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + OutputManagerProtocolError::new( + self.id, + OutputManagerError::ServiceError("A Base Node Update receiver was not provided".to_string()), + ) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); let total_retries_str = match self.retry_strategy { - ValidationRetryStrategy::Limited(n) => format!("{}", n), + ValidationRetryStrategy::Limited(n) => n.to_string(), ValidationRetryStrategy::UntilSuccess => "∞".to_string(), }; @@ -180,14 +177,14 @@ where TBackend: OutputManagerBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key.clone()); let mut connection: Option = None; - let delay = delay_for(self.resources.config.peer_dial_retry_timeout); + let delay = sleep(self.resources.config.peer_dial_retry_timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -197,7 +194,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!( @@ -228,7 +225,7 @@ where TBackend: OutputManagerBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -236,7 +233,7 @@ where TBackend: OutputManagerBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -253,7 +250,7 @@ where TBackend: OutputManagerBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, @@ -294,9 +291,9 @@ where TBackend: OutputManagerBackend + 'static batch_num, batch_total ); - let delay = delay_for(self.retry_delay); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.retry_delay); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_bn) => { info!(target: LOG_TARGET, "TXO Validation protocol aborted due to Base Node Public key change" ); @@ -323,7 +320,7 @@ where TBackend: OutputManagerBackend + 'static } } }, - result = self.send_query_batch(batch.clone(), &mut client).fuse() => { + result = self.send_query_batch(batch.clone(), &mut client) => { match result { Ok(synced) => { self.base_node_synced = synced; @@ -374,7 +371,7 @@ where TBackend: OutputManagerBackend + 'static }, } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because it received the shutdown signal", self.id); return Err(OutputManagerProtocolError::new(self.id, OutputManagerError::Shutdown)); }, diff --git a/base_layer/wallet/src/storage/database.rs b/base_layer/wallet/src/storage/database.rs index 3d5cbf8838..d23d806298 100644 --- a/base_layer/wallet/src/storage/database.rs +++ b/base_layer/wallet/src/storage/database.rs @@ -373,7 +373,7 @@ mod test { #[test] fn test_database_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db_name = format!("{}.sqlite3", string(8).as_str()); let db_folder = tempdir().unwrap().path().to_str().unwrap().to_string(); diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index c197dd2024..05e9e4af2b 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -34,7 +34,7 @@ use tari_p2p::services::liveness::error::LivenessError; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; use time::OutOfRangeError; -use tokio::sync::broadcast::RecvError; +use tokio::sync::broadcast::error::RecvError; #[derive(Debug, Error)] pub enum TransactionServiceError { diff --git a/base_layer/wallet/src/transaction_service/handle.rs b/base_layer/wallet/src/transaction_service/handle.rs index f34a5f667f..7a6ab48649 100644 --- a/base_layer/wallet/src/transaction_service/handle.rs +++ b/base_layer/wallet/src/transaction_service/handle.rs @@ -28,7 +28,6 @@ use crate::{ }, }; use aes_gcm::Aes256Gcm; -use futures::{stream::Fuse, StreamExt}; use std::{collections::HashMap, fmt, sync::Arc}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; @@ -187,8 +186,8 @@ impl TransactionServiceHandle { } } - pub fn get_event_stream_fused(&self) -> Fuse { - self.event_stream_sender.subscribe().fuse() + pub fn get_event_stream(&self) -> TransactionEventReceiver { + self.event_stream_sender.subscribe() } pub async fn send_transaction( diff --git a/base_layer/wallet/src/transaction_service/mod.rs b/base_layer/wallet/src/transaction_service/mod.rs index 3efbade3c3..541d898770 100644 --- a/base_layer/wallet/src/transaction_service/mod.rs +++ b/base_layer/wallet/src/transaction_service/mod.rs @@ -20,31 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod config; -pub mod error; -pub mod handle; -pub mod protocols; -pub mod service; -pub mod storage; -pub mod tasks; +use std::sync::Arc; -use crate::{ - output_manager_service::handle::OutputManagerHandle, - transaction_service::{ - config::TransactionServiceConfig, - handle::TransactionServiceHandle, - service::TransactionService, - storage::database::{TransactionBackend, TransactionDatabase}, - }, -}; use futures::{Stream, StreamExt}; use log::*; -use std::sync::Arc; +use tokio::sync::broadcast; + use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeIdentity}; use tari_comms_dht::Dht; use tari_core::{ proto::base_node as base_node_proto, - transactions::{transaction_protocol::proto, types::CryptoFactories}, + transactions::{transaction_protocol::proto, CryptoFactories}, }; use tari_p2p::{ comms_connector::SubscriptionFactory, @@ -59,7 +45,24 @@ use tari_service_framework::{ ServiceInitializer, ServiceInitializerContext, }; -use tokio::sync::broadcast; + +use crate::{ + output_manager_service::handle::OutputManagerHandle, + transaction_service::{ + config::TransactionServiceConfig, + handle::TransactionServiceHandle, + service::TransactionService, + storage::database::{TransactionBackend, TransactionDatabase}, + }, +}; + +pub mod config; +pub mod error; +pub mod handle; +pub mod protocols; +pub mod service; +pub mod storage; +pub mod tasks; const LOG_TARGET: &str = "wallet::transaction_service"; const SUBSCRIPTION_LABEL: &str = "Transaction Service"; diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs index 05191c8f8d..4a28383226 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs @@ -32,19 +32,20 @@ use crate::{ }, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; use tari_core::{ base_node::{ proto::wallet_rpc::{TxLocation, TxQueryResponse, TxSubmissionRejectionReason, TxSubmissionResponse}, rpc::BaseNodeWalletRpcClient, }, - transactions::{transaction::Transaction, types::Signature}, + transactions::transaction::Transaction, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::broadcast_protocol"; @@ -86,21 +87,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; let mut shutdown = self.resources.shutdown_signal.clone(); // Main protocol loop @@ -108,14 +101,14 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { - dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { + tokio::select! { + dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()) => { match dial_result { Ok(base_node_connection) => { connection = Some(base_node_connection); @@ -139,7 +132,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -158,7 +151,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -179,7 +172,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -187,11 +180,11 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, @@ -243,10 +236,10 @@ where TBackend: TransactionBackend + 'static }, }; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -315,7 +308,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -332,7 +325,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Broadcast Protocol (TxId: {}) shutting down because it received the shutdown signal", self.tx_id); return Err(TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs index 65fb58f601..f368d33d89 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_coinbase_monitoring_protocol.rs @@ -29,19 +29,17 @@ use crate::{ storage::{database::TransactionBackend, models::CompletedTransaction}, }, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_common_types::types::Signature; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; -use tari_core::{ - base_node::{ - proto::wallet_rpc::{TxLocation, TxQueryResponse}, - rpc::BaseNodeWalletRpcClient, - }, - transactions::types::Signature, +use tari_core::base_node::{ + proto::wallet_rpc::{TxLocation, TxQueryResponse}, + rpc::BaseNodeWalletRpcClient, }; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::coinbase_monitoring"; @@ -86,21 +84,13 @@ where TBackend: TransactionBackend + 'static /// The task that defines the execution of the protocol. pub async fn execute(mut self) -> Result { - let mut base_node_update_receiver = self - .base_node_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut base_node_update_receiver = self.base_node_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; - let mut timeout_update_receiver = self - .timeout_update_receiver - .take() - .ok_or_else(|| { - TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) - })? - .fuse(); + let mut timeout_update_receiver = self.timeout_update_receiver.take().ok_or_else(|| { + TransactionServiceProtocolError::new(self.tx_id, TransactionServiceError::InvalidStateError) + })?; trace!( target: LOG_TARGET, @@ -173,7 +163,7 @@ where TBackend: TransactionBackend + 'static self.base_node_public_key, self.tx_id, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -203,7 +193,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -225,7 +215,7 @@ where TBackend: TransactionBackend + 'static } } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -248,7 +238,7 @@ where TBackend: TransactionBackend + 'static } } } - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring protocol (TxId: {}) shutting down because it received the shutdown \ @@ -259,14 +249,14 @@ where TBackend: TransactionBackend + 'static }, } - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the \ @@ -314,10 +304,10 @@ where TBackend: TransactionBackend + 'static TransactionServiceError::InvalidCompletedTransaction, )); } - let delay = delay_for(self.timeout).fuse(); + let delay = sleep(self.timeout).fuse(); loop { - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(bn) => { self.base_node_public_key = bn; @@ -392,7 +382,7 @@ where TBackend: TransactionBackend + 'static delay.await; break; }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { if let Ok(to) = updated_timeout { self.timeout = to; info!( @@ -411,7 +401,7 @@ where TBackend: TransactionBackend + 'static ); } }, - _ = shutdown => { + _ = shutdown.wait() => { info!( target: LOG_TARGET, "Coinbase Monitoring Protocol (TxId: {}) shutting down because it received the shutdown \ diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs index 744bb6f1fc..0a6bd89fb2 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs @@ -34,21 +34,18 @@ use crate::{ }, }; use chrono::Utc; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; +use futures::future::FutureExt; use log::*; use std::sync::Arc; use tari_comms::types::CommsPublicKey; +use tokio::sync::{mpsc, oneshot}; use tari_core::transactions::{ transaction::Transaction, transaction_protocol::{recipient::RecipientState, sender::TransactionSenderMessage}, }; use tari_crypto::tari_utilities::Hashable; -use tokio::time::delay_for; +use tokio::time::sleep; const LOG_TARGET: &str = "wallet::transaction_service::protocols::receive_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::receive_protocol"; @@ -263,7 +260,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match inbound_tx.last_send_timestamp { @@ -310,9 +308,9 @@ where TBackend: TransactionBackend + 'static let mut incoming_finalized_transaction = None; loop { loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, tx_id, tx) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, tx_id, tx)) = receiver.recv() => { incoming_finalized_transaction = Some(tx); if inbound_tx.source_public_key != spk { warn!( @@ -325,16 +323,14 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { - if result.is_ok() { - info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); - return Err(TransactionServiceProtocolError::new( - self.id, - TransactionServiceError::TransactionCancelled, - )); - } + Ok(_) = &mut cancellation_receiver => { + info!(target: LOG_TARGET, "Cancelling Transaction Receive Protocol for TxId: {}", self.id); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + )); }, - () = resend_timeout => { + _ = resend_timeout => { match send_transaction_reply( inbound_tx.clone(), self.resources.outbound_message_service.clone(), @@ -353,10 +349,10 @@ where TBackend: TransactionBackend + 'static ), } }, - () = timeout_delay => { + _ = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Receive Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } @@ -381,7 +377,7 @@ where TBackend: TransactionBackend + 'static ); finalized_transaction - .validate_internal_consistency(&self.resources.factories, None) + .validate_internal_consistency(true, &self.resources.factories, None) .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; // Find your own output in the transaction diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs index cf30a7f928..e88c9013e6 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -20,12 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::sync::Arc; - -use chrono::Utc; -use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; -use log::*; - use crate::transaction_service::{ config::TransactionRoutingMechanism, error::{TransactionServiceError, TransactionServiceProtocolError}, @@ -41,7 +35,10 @@ use crate::transaction_service::{ wait_on_dial::wait_on_dial, }, }; -use futures::channel::oneshot; +use chrono::Utc; +use futures::FutureExt; +use log::*; +use std::sync::Arc; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; use tari_comms_dht::{ domain_message::OutboundDomainMessage, @@ -55,7 +52,10 @@ use tari_core::transactions::{ }; use tari_crypto::script; use tari_p2p::tari_message::TariMessageType; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc::Receiver, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::send_protocol"; const LOG_TARGET_STRESS: &str = "stress_test::send_protocol"; @@ -344,7 +344,8 @@ where TBackend: TransactionBackend + 'static }, Some(t) => t, }; - let mut timeout_delay = delay_for(timeout_duration).fuse(); + let timeout_delay = sleep(timeout_duration).fuse(); + tokio::pin!(timeout_delay); // check to see if a resend is due let resend = match outbound_tx.last_send_timestamp { @@ -390,9 +391,9 @@ where TBackend: TransactionBackend + 'static #[allow(unused_assignments)] let mut reply = None; loop { - let mut resend_timeout = delay_for(self.resources.config.transaction_resend_period).fuse(); - futures::select! { - (spk, rr) = receiver.select_next_some() => { + let resend_timeout = sleep(self.resources.config.transaction_resend_period).fuse(); + tokio::select! { + Some((spk, rr)) = receiver.recv() => { let rr_tx_id = rr.tx_id; reply = Some(rr); @@ -407,7 +408,7 @@ where TBackend: TransactionBackend + 'static break; } }, - result = cancellation_receiver => { + result = &mut cancellation_receiver => { if result.is_ok() { info!(target: LOG_TARGET, "Cancelling Transaction Send Protocol (TxId: {})", self.id); let _ = send_transaction_cancelled_message(self.id,self.dest_pubkey.clone(), self.resources.outbound_message_service.clone(), ).await.map_err(|e| { @@ -441,10 +442,10 @@ where TBackend: TransactionBackend + 'static .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; } }, - () = timeout_delay => { + () = &mut timeout_delay => { return self.timeout_transaction().await; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Send Protocol (id: {}) shutting down because it received the shutdown signal", self.id); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs index d0b2f7f6ac..dcf072c272 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs @@ -32,7 +32,7 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use log::*; use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, PeerConnection}; @@ -43,7 +43,7 @@ use tari_core::{ }, proto::{base_node::Signatures as SignaturesProto, types::Signature as SignatureProto}, }; -use tokio::{sync::broadcast, time::delay_for}; +use tokio::{sync::broadcast, time::sleep}; const LOG_TARGET: &str = "wallet::transaction_service::protocols::validation_protocol"; @@ -94,14 +94,12 @@ where TBackend: TransactionBackend + 'static let mut timeout_update_receiver = self .timeout_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut base_node_update_receiver = self .base_node_update_receiver .take() - .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? - .fuse(); + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; let mut shutdown = self.resources.shutdown_signal.clone(); @@ -158,13 +156,13 @@ where TBackend: TransactionBackend + 'static let base_node_node_id = NodeId::from_key(&self.base_node_public_key); let mut connection: Option = None; - let delay = delay_for(self.timeout); + let delay = sleep(self.timeout); debug!( target: LOG_TARGET, "Connecting to Base Node (Public Key: {})", self.base_node_public_key, ); - futures::select! { + tokio::select! { dial_result = self.resources.connectivity_manager.dial_peer(base_node_node_id.clone()).fuse() => { match dial_result { Ok(base_node_connection) => { @@ -175,7 +173,7 @@ where TBackend: TransactionBackend + 'static }, } }, - new_base_node = base_node_update_receiver.select_next_some() => { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { @@ -204,7 +202,7 @@ where TBackend: TransactionBackend + 'static } } } - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -223,7 +221,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -231,7 +229,7 @@ where TBackend: TransactionBackend + 'static let mut base_node_connection = match connection { None => { - futures::select! { + tokio::select! { _ = delay.fuse() => { let _ = self .resources @@ -248,7 +246,7 @@ where TBackend: TransactionBackend + 'static retries += 1; continue; }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, @@ -282,9 +280,9 @@ where TBackend: TransactionBackend + 'static } else { break 'main; }; - let delay = delay_for(self.timeout); - futures::select! { - new_base_node = base_node_update_receiver.select_next_some() => { + let delay = sleep(self.timeout); + tokio::select! { + new_base_node = base_node_update_receiver.recv() => { match new_base_node { Ok(_) => { info!(target: LOG_TARGET, "Aborting Transaction Validation Protocol as new Base node is set"); @@ -372,7 +370,7 @@ where TBackend: TransactionBackend + 'static }, } }, - updated_timeout = timeout_update_receiver.select_next_some() => { + updated_timeout = timeout_update_receiver.recv() => { match updated_timeout { Ok(to) => { self.timeout = to; @@ -391,7 +389,7 @@ where TBackend: TransactionBackend + 'static } } }, - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction Validation Protocol shutting down because it received the shutdown signal"); return Err(TransactionServiceProtocolError::new(self.id, TransactionServiceError::Shutdown)) }, diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index c30fb7f412..12aa052382 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -20,49 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - output_manager_service::{handle::OutputManagerHandle, TxId}, - transaction_service::{ - config::TransactionServiceConfig, - error::{TransactionServiceError, TransactionServiceProtocolError}, - handle::{TransactionEvent, TransactionEventSender, TransactionServiceRequest, TransactionServiceResponse}, - protocols::{ - transaction_broadcast_protocol::TransactionBroadcastProtocol, - transaction_coinbase_monitoring_protocol::TransactionCoinbaseMonitoringProtocol, - transaction_receive_protocol::{TransactionReceiveProtocol, TransactionReceiveProtocolStage}, - transaction_send_protocol::{TransactionSendProtocol, TransactionSendProtocolStage}, - transaction_validation_protocol::TransactionValidationProtocol, - }, - storage::{ - database::{TransactionBackend, TransactionDatabase}, - models::{CompletedTransaction, TransactionDirection, TransactionStatus}, - }, - tasks::{ - send_finalized_transaction::send_finalized_transaction_message, - send_transaction_cancelled::send_transaction_cancelled_message, - send_transaction_reply::send_transaction_reply, - }, - }, - types::{HashDigest, ValidationRetryStrategy}, -}; -use chrono::{NaiveDateTime, Utc}; -use digest::Digest; -use futures::{ - channel::{mpsc, mpsc::Sender, oneshot}, - pin_mut, - stream::FuturesUnordered, - SinkExt, - Stream, - StreamExt, -}; -use log::*; -use rand::{rngs::OsRng, RngCore}; use std::{ collections::{HashMap, HashSet}, convert::TryInto, sync::Arc, time::{Duration, Instant}, }; + +use chrono::{NaiveDateTime, Utc}; +use digest::Digest; +use futures::{pin_mut, stream::FuturesUnordered, Stream, StreamExt}; +use log::*; +use rand::{rngs::OsRng, RngCore}; +use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteArray}; +use tokio::{sync::broadcast, task::JoinHandle}; + +use tari_common_types::types::PrivateKey; use tari_comms::{connectivity::ConnectivityRequester, peer_manager::NodeIdentity, types::CommsPublicKey}; use tari_comms_dht::outbound::OutboundMessageRequester; use tari_core::{ @@ -77,15 +50,40 @@ use tari_core::{ sender::TransactionSenderMessage, RewindData, }, - types::{CryptoFactories, PrivateKey}, + CryptoFactories, ReceiverTransactionProtocol, }, }; -use tari_crypto::{keys::DiffieHellmanSharedSecret, script, tari_utilities::ByteArray}; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::{reply_channel, reply_channel::Receiver}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle}; +use tokio::sync::{mpsc, mpsc::Sender, oneshot}; + +use crate::{ + output_manager_service::{handle::OutputManagerHandle, TxId}, + transaction_service::{ + config::TransactionServiceConfig, + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::{TransactionEvent, TransactionEventSender, TransactionServiceRequest, TransactionServiceResponse}, + protocols::{ + transaction_broadcast_protocol::TransactionBroadcastProtocol, + transaction_coinbase_monitoring_protocol::TransactionCoinbaseMonitoringProtocol, + transaction_receive_protocol::{TransactionReceiveProtocol, TransactionReceiveProtocolStage}, + transaction_send_protocol::{TransactionSendProtocol, TransactionSendProtocolStage}, + transaction_validation_protocol::TransactionValidationProtocol, + }, + storage::{ + database::{TransactionBackend, TransactionDatabase}, + models::{CompletedTransaction, TransactionDirection, TransactionStatus}, + }, + tasks::{ + send_finalized_transaction::send_finalized_transaction_message, + send_transaction_cancelled::send_transaction_cancelled_message, + send_transaction_reply::send_transaction_reply, + }, + }, + types::{HashDigest, ValidationRetryStrategy}, +}; const LOG_TARGET: &str = "wallet::transaction_service::service"; @@ -276,9 +274,9 @@ where info!(target: LOG_TARGET, "Transaction Service started"); loop { - futures::select! { + tokio::select! { //Incoming request - request_context = request_stream.select_next_some() => { + Some(request_context) = request_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (request, reply_tx) = request_context.split(); @@ -303,7 +301,7 @@ where ); }, // Incoming Transaction messages from the Comms layer - msg = transaction_stream.select_next_some() => { + Some(msg) = transaction_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -333,7 +331,7 @@ where ); }, // Incoming Transaction Reply messages from the Comms layer - msg = transaction_reply_stream.select_next_some() => { + Some(msg) = transaction_reply_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -364,7 +362,7 @@ where ); }, // Incoming Finalized Transaction messages from the Comms layer - msg = transaction_finalized_stream.select_next_some() => { + Some(msg) = transaction_finalized_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -402,7 +400,7 @@ where ); }, // Incoming messages from the Comms layer - msg = base_node_response_stream.select_next_some() => { + Some(msg) = base_node_response_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -421,7 +419,7 @@ where ); } // Incoming messages from the Comms layer - msg = transaction_cancelled_stream.select_next_some() => { + Some(msg) = transaction_cancelled_stream.next() => { // TODO: Remove time measurements; this is to aid in system testing only let start = Instant::now(); let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); @@ -436,7 +434,7 @@ where finish.duration_since(start).as_millis(), ); } - join_result = send_transaction_protocol_handles.select_next_some() => { + Some(join_result) = send_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Send Protocol for Transaction has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_send_transaction_protocol( @@ -446,7 +444,7 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = receive_transaction_protocol_handles.select_next_some() => { + Some(join_result) = receive_transaction_protocol_handles.next() => { trace!(target: LOG_TARGET, "Receive Transaction Protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_receive_transaction_protocol( @@ -456,14 +454,14 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Send Transaction Protocol: {:?}", e), }; } - join_result = transaction_broadcast_protocol_handles.select_next_some() => { + Some(join_result) = transaction_broadcast_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Broadcast protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_broadcast_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Broadcast Protocol: {:?}", e), }; } - join_result = coinbase_transaction_monitoring_protocol_handles.select_next_some() => { + Some(join_result) = coinbase_transaction_monitoring_protocol_handles.next() => { trace!(target: LOG_TARGET, "Coinbase transaction monitoring protocol has ended with result {:?}", join_result); match join_result { @@ -471,21 +469,17 @@ where Err(e) => error!(target: LOG_TARGET, "Error resolving Coinbase Monitoring protocol: {:?}", e), }; } - join_result = transaction_validation_protocol_handles.select_next_some() => { + Some(join_result) = transaction_validation_protocol_handles.next() => { trace!(target: LOG_TARGET, "Transaction Validation protocol has ended with result {:?}", join_result); match join_result { Ok(join_result_inner) => self.complete_transaction_validation_protocol(join_result_inner).await, Err(e) => error!(target: LOG_TARGET, "Error resolving Transaction Validation protocol: {:?}", e), }; } - _ = shutdown => { + _ = shutdown.wait() => { info!(target: LOG_TARGET, "Transaction service shutting down because it received the shutdown signal"); break; } - complete => { - info!(target: LOG_TARGET, "Transaction service shutting down"); - break; - } } } info!(target: LOG_TARGET, "Transaction service shut down"); diff --git a/base_layer/wallet/src/transaction_service/storage/database.rs b/base_layer/wallet/src/transaction_service/storage/database.rs index aaad0b0618..7cbaa52c85 100644 --- a/base_layer/wallet/src/transaction_service/storage/database.rs +++ b/base_layer/wallet/src/transaction_service/storage/database.rs @@ -43,8 +43,9 @@ use std::{ fmt::{Display, Error, Formatter}, sync::Arc, }; +use tari_common_types::types::BlindingFactor; use tari_comms::types::CommsPublicKey; -use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction, types::BlindingFactor}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; const LOG_TARGET: &str = "wallet::transaction_service::database"; diff --git a/base_layer/wallet/src/transaction_service/storage/models.rs b/base_layer/wallet/src/transaction_service/storage/models.rs index 37f84cc3fb..4d1f57f238 100644 --- a/base_layer/wallet/src/transaction_service/storage/models.rs +++ b/base_layer/wallet/src/transaction_service/storage/models.rs @@ -27,11 +27,11 @@ use std::{ convert::TryFrom, fmt::{Display, Error, Formatter}, }; +use tari_common_types::types::PrivateKey; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ tari_amount::MicroTari, transaction::Transaction, - types::PrivateKey, ReceiverTransactionProtocol, SenderTransactionProtocol, }; diff --git a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs index 0cd55fc7f2..700590562e 100644 --- a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs @@ -20,6 +20,26 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::{ + collections::HashMap, + convert::TryFrom, + str::from_utf8, + sync::{Arc, MutexGuard, RwLock}, +}; + +use aes_gcm::{self, aead::Error as AeadError, Aes256Gcm}; +use chrono::{NaiveDateTime, Utc}; +use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; +use log::*; +use tari_crypto::tari_utilities::{ + hex::{from_hex, Hex}, + ByteArray, +}; + +use tari_common_types::types::PublicKey; +use tari_comms::types::CommsPublicKey; +use tari_core::transactions::tari_amount::MicroTari; + use crate::{ output_manager_service::TxId, schema::{completed_transactions, inbound_transactions, outbound_transactions}, @@ -40,22 +60,6 @@ use crate::{ }, util::encryption::{decrypt_bytes_integral_nonce, encrypt_bytes_integral_nonce, Encryptable}, }; -use aes_gcm::{self, aead::Error as AeadError, Aes256Gcm}; -use chrono::{NaiveDateTime, Utc}; -use diesel::{prelude::*, result::Error as DieselError, SqliteConnection}; -use log::*; -use std::{ - collections::HashMap, - convert::TryFrom, - str::from_utf8, - sync::{Arc, MutexGuard, RwLock}, -}; -use tari_comms::types::CommsPublicKey; -use tari_core::transactions::{tari_amount::MicroTari, types::PublicKey}; -use tari_crypto::tari_utilities::{ - hex::{from_hex, Hex}, - ByteArray, -}; const LOG_TARGET: &str = "wallet::transaction_service::database::sqlite_db"; @@ -1650,6 +1654,34 @@ impl From for UpdateCompletedTransactionSql { #[cfg(test)] mod test { + use std::convert::TryFrom; + + use aes_gcm::{ + aead::{generic_array::GenericArray, NewAead}, + Aes256Gcm, + }; + use chrono::Utc; + use diesel::{Connection, SqliteConnection}; + use rand::rngs::OsRng; + use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + script, + script::{ExecutionStack, TariScript}, + }; + use tempfile::tempdir; + + use tari_common_types::types::{HashDigest, PrivateKey, PublicKey}; + use tari_core::transactions::{ + helpers::{create_unblinded_output, TestParams}, + tari_amount::MicroTari, + transaction::{OutputFeatures, Transaction}, + transaction_protocol::sender::TransactionSenderMessage, + CryptoFactories, + ReceiverTransactionProtocol, + SenderTransactionProtocol, + }; + use tari_test_utils::random::string; + use crate::{ storage::sqlite_utilities::WalletDbConnection, transaction_service::storage::{ @@ -1670,30 +1702,6 @@ mod test { }, util::encryption::Encryptable, }; - use aes_gcm::{ - aead::{generic_array::GenericArray, NewAead}, - Aes256Gcm, - }; - use chrono::Utc; - use diesel::{Connection, SqliteConnection}; - use rand::rngs::OsRng; - use std::convert::TryFrom; - use tari_core::transactions::{ - helpers::{create_unblinded_output, TestParams}, - tari_amount::MicroTari, - transaction::{OutputFeatures, Transaction}, - transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, HashDigest, PrivateKey, PublicKey}, - ReceiverTransactionProtocol, - SenderTransactionProtocol, - }; - use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - script, - script::{ExecutionStack, TariScript}, - }; - use tari_test_utils::random::string; - use tempfile::tempdir; #[test] fn test_crud() { diff --git a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs index 61b3ee7d75..522bacdbb9 100644 --- a/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs +++ b/base_layer/wallet/src/transaction_service/tasks/start_transaction_validation_and_broadcast_protocols.rs @@ -27,8 +27,8 @@ use crate::{ }, types::ValidationRetryStrategy, }; -use futures::StreamExt; use log::*; +use tokio::sync::broadcast; const LOG_TARGET: &str = "wallet::transaction_service::tasks::start_tx_validation_and_broadcast"; @@ -36,16 +36,16 @@ pub async fn start_transaction_validation_and_broadcast_protocols( mut handle: TransactionServiceHandle, retry_strategy: ValidationRetryStrategy, ) -> Result<(), TransactionServiceError> { - let mut event_stream = handle.get_event_stream_fused(); + let mut event_stream = handle.get_event_stream(); let our_id = handle.validate_transactions(retry_strategy).await?; // Now that its started we will spawn an task to monitor the event bus and when its successful we will start the // Broadcast protocols tokio::spawn(async move { - while let Some(event_item) = event_stream.next().await { - if let Ok(event) = event_item { - match (*event).clone() { + loop { + match event_stream.recv().await { + Ok(event) => match &*event { TransactionEvent::TransactionValidationSuccess(_id) => { info!( target: LOG_TARGET, @@ -59,19 +59,28 @@ pub async fn start_transaction_validation_and_broadcast_protocols( } }, TransactionEvent::TransactionValidationFailure(id) => { - if our_id == id { + if our_id == *id { error!(target: LOG_TARGET, "Transaction Validation failed!"); break; } }, _ => (), - } - } else { - warn!( - target: LOG_TARGET, - "Error reading from Transaction Service Event Stream" - ); - break; + }, + Err(e @ broadcast::error::RecvError::Lagged(_)) => { + warn!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols: {}", e + ); + continue; + }, + Err(broadcast::error::RecvError::Closed) => { + debug!( + target: LOG_TARGET, + "start_transaction_validation_and_broadcast_protocols is exiting because the event stream \ + closed", + ); + break; + }, } } }); diff --git a/base_layer/wallet/src/util/mod.rs b/base_layer/wallet/src/util/mod.rs index 9664a0e376..7217ac5056 100644 --- a/base_layer/wallet/src/util/mod.rs +++ b/base_layer/wallet/src/util/mod.rs @@ -20,6 +20,4 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod emoji; pub mod encryption; -pub mod luhn; diff --git a/base_layer/wallet/src/utxo_scanner_service/mod.rs b/base_layer/wallet/src/utxo_scanner_service/mod.rs index b0b475b96b..956a32848b 100644 --- a/base_layer/wallet/src/utxo_scanner_service/mod.rs +++ b/base_layer/wallet/src/utxo_scanner_service/mod.rs @@ -33,7 +33,7 @@ use futures::future; use log::*; use std::{sync::Arc, time::Duration}; use tari_comms::{connectivity::ConnectivityRequester, NodeIdentity}; -use tari_core::transactions::types::CryptoFactories; +use tari_core::transactions::CryptoFactories; use tari_service_framework::{ async_trait, reply_channel, diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs index 6aaa38363f..31955c981c 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanning.rs @@ -20,24 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - error::WalletError, - output_manager_service::{handle::OutputManagerHandle, TxId}, - storage::{ - database::{WalletBackend, WalletDatabase}, - sqlite_db::WalletSqliteDatabase, - }, - transaction_service::handle::TransactionServiceHandle, - utxo_scanner_service::{ - error::UtxoScannerError, - handle::{UtxoScannerEvent, UtxoScannerRequest, UtxoScannerResponse}, - }, - WalletSqlite, -}; -use chrono::Utc; -use futures::{pin_mut, StreamExt}; -use log::*; -use serde::{Deserialize, Serialize}; use std::{ convert::TryFrom, sync::{ @@ -46,6 +28,14 @@ use std::{ }, time::{Duration, Instant}, }; + +use chrono::Utc; +use futures::{pin_mut, StreamExt}; +use log::*; +use serde::{Deserialize, Serialize}; +use tokio::{sync::broadcast, task, time}; + +use tari_common_types::types::HashOutput; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::NodeId, @@ -64,12 +54,27 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{TransactionOutput, UnblindedOutput}, - types::{CryptoFactories, HashOutput}, + CryptoFactories, }, }; use tari_service_framework::{reply_channel, reply_channel::SenderService}; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task, time}; + +use crate::{ + error::WalletError, + output_manager_service::{handle::OutputManagerHandle, TxId}, + storage::{ + database::{WalletBackend, WalletDatabase}, + sqlite_db::WalletSqliteDatabase, + }, + transaction_service::handle::TransactionServiceHandle, + utxo_scanner_service::{ + error::UtxoScannerError, + handle::{UtxoScannerEvent, UtxoScannerRequest, UtxoScannerResponse}, + }, + WalletSqlite, +}; +use tokio::time::MissedTickBehavior; pub const LOG_TARGET: &str = "wallet::utxo_scanning"; @@ -715,35 +720,23 @@ where TBackend: WalletBackend + 'static let mut shutdown = self.shutdown_signal.clone(); let start_at = Instant::now() + Duration::from_secs(1); - let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval).fuse(); - let mut previous = Instant::now(); + let mut work_interval = time::interval_at(start_at.into(), self.scan_for_utxo_interval); + work_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { - futures::select! { - _ = work_interval.select_next_some() => { - // This bit of code prevents bottled up tokio interval events to be fired successively for the edge - // case where a computer wakes up from sleep. - if start_at.elapsed() > self.scan_for_utxo_interval && - previous.elapsed() < self.scan_for_utxo_interval.mul_f32(0.9) - { - debug!( - target: LOG_TARGET, - "UTXO scanning work interval event fired too quickly, not running the task" - ); - } else { - let running_flag = self.is_running.clone(); - if !running_flag.load(Ordering::SeqCst) { - let task = self.create_task(); - debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); - task::spawn(async move { - if let Err(err) = task.run().await { - error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); - } - //we make sure the flag is set to false here - running_flag.store(false, Ordering::Relaxed); - }); - } + tokio::select! { + _ = work_interval.tick() => { + let running_flag = self.is_running.clone(); + if !running_flag.load(Ordering::SeqCst) { + let task = self.create_task(); + debug!(target: LOG_TARGET, "UTXO scanning service starting scan for utxos"); + task::spawn(async move { + if let Err(err) = task.run().await { + error!(target: LOG_TARGET, "Error scanning UTXOs: {}", err); + } + //we make sure the flag is set to false here + running_flag.store(false, Ordering::Relaxed); + }); } - previous = Instant::now(); }, request_context = request_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Service API Request"); @@ -757,7 +750,7 @@ where TBackend: WalletBackend + 'static e }); }, - _ = shutdown => { + _ = shutdown.wait() => { // this will stop the task if its running, and let that thread exit gracefully self.is_running.store(false, Ordering::Relaxed); info!(target: LOG_TARGET, "UTXO scanning service shutting down because it received the shutdown signal"); diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index 1f91f3d625..24e4573181 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -20,28 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - base_node_service::{handle::BaseNodeServiceHandle, BaseNodeServiceInitializer}, - config::{WalletConfig, KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY}, - connectivity_service::{WalletConnectivityHandle, WalletConnectivityInitializer}, - contacts_service::{handle::ContactsServiceHandle, storage::database::ContactsBackend, ContactsServiceInitializer}, - error::WalletError, - output_manager_service::{ - error::OutputManagerError, - handle::OutputManagerHandle, - storage::{database::OutputManagerBackend, models::KnownOneSidedPaymentScript}, - OutputManagerServiceInitializer, - TxId, - }, - storage::database::{WalletBackend, WalletDatabase}, - transaction_service::{ - handle::TransactionServiceHandle, - storage::database::TransactionBackend, - TransactionServiceInitializer, - }, - types::KeyDigest, - utxo_scanner_service::{handle::UtxoScannerHandle, UtxoScannerServiceInitializer}, -}; +use std::{marker::PhantomData, sync::Arc}; + use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, @@ -49,7 +29,17 @@ use aes_gcm::{ use digest::Digest; use log::*; use rand::rngs::OsRng; -use std::{marker::PhantomData, sync::Arc}; +use tari_crypto::{ + common::Blake256, + keys::SecretKey, + ristretto::{RistrettoPublicKey, RistrettoSchnorr, RistrettoSecretKey}, + script, + script::{ExecutionStack, TariScript}, + signatures::{SchnorrSignature, SchnorrSignatureError}, + tari_utilities::hex::Hex, +}; + +use tari_common_types::types::{ComSignature, PrivateKey, PublicKey}; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, @@ -62,22 +52,35 @@ use tari_comms_dht::{store_forward::StoreAndForwardRequester, Dht}; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, UnblindedOutput}, - types::{ComSignature, CryptoFactories, PrivateKey, PublicKey}, -}; -use tari_crypto::{ - common::Blake256, - keys::SecretKey, - ristretto::{RistrettoPublicKey, RistrettoSchnorr, RistrettoSecretKey}, - script, - script::{ExecutionStack, TariScript}, - signatures::{SchnorrSignature, SchnorrSignatureError}, - tari_utilities::hex::Hex, + CryptoFactories, }; use tari_key_manager::key_manager::KeyManager; use tari_p2p::{comms_connector::pubsub_connector, initialization, initialization::P2pInitializer}; use tari_service_framework::StackBuilder; use tari_shutdown::ShutdownSignal; -use tokio::runtime; + +use crate::{ + base_node_service::{handle::BaseNodeServiceHandle, BaseNodeServiceInitializer}, + config::{WalletConfig, KEY_MANAGER_COMMS_SECRET_KEY_BRANCH_KEY}, + connectivity_service::{WalletConnectivityHandle, WalletConnectivityInitializer}, + contacts_service::{handle::ContactsServiceHandle, storage::database::ContactsBackend, ContactsServiceInitializer}, + error::WalletError, + output_manager_service::{ + error::OutputManagerError, + handle::OutputManagerHandle, + storage::{database::OutputManagerBackend, models::KnownOneSidedPaymentScript}, + OutputManagerServiceInitializer, + TxId, + }, + storage::database::{WalletBackend, WalletDatabase}, + transaction_service::{ + handle::TransactionServiceHandle, + storage::database::TransactionBackend, + TransactionServiceInitializer, + }, + types::KeyDigest, + utxo_scanner_service::{handle::UtxoScannerHandle, UtxoScannerServiceInitializer}, +}; const LOG_TARGET: &str = "wallet"; @@ -139,8 +142,7 @@ where let bn_service_db = wallet_database.clone(); let factories = config.clone().factories; - let (publisher, subscription_factory) = - pubsub_connector(runtime::Handle::current(), config.buffer_size, config.rate_limit); + let (publisher, subscription_factory) = pubsub_connector(config.buffer_size, config.rate_limit); let peer_message_subscription_factory = Arc::new(subscription_factory); let transport_type = config.comms_config.transport_type.clone(); diff --git a/base_layer/wallet/tests/contacts_service/mod.rs b/base_layer/wallet/tests/contacts_service/mod.rs index 80970c6a17..ed5ad5033c 100644 --- a/base_layer/wallet/tests/contacts_service/mod.rs +++ b/base_layer/wallet/tests/contacts_service/mod.rs @@ -22,7 +22,7 @@ use crate::support::data::get_temp_sqlite_database_connection; use rand::rngs::OsRng; -use tari_core::transactions::types::PublicKey; +use tari_common_types::types::PublicKey; use tari_crypto::keys::PublicKey as PublicKeyTrait; use tari_service_framework::StackBuilder; use tari_shutdown::Shutdown; diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index c6c23da53e..ceb4ed4c8b 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -19,18 +19,18 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - use crate::support::{ data::get_temp_sqlite_database_connection, rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, utils::{make_input, make_input_with_features, TestParams}, }; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use rand::{rngs::OsRng, RngCore}; -use std::{sync::Arc, thread, time::Duration}; +use std::{sync::Arc, time::Duration}; +use tari_common_types::types::{PrivateKey, PublicKey}; use tari_comms::{ peer_manager::{NodeIdentity, PeerFeatures}, - protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, + protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcClientConfig, RpcStatus}, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -51,12 +51,12 @@ use tari_core::{ sender::TransactionSenderMessage, single_receiver::SingleReceiverTransactionProtocol, }, - types::{CryptoFactories, PrivateKey, PublicKey}, + CryptoFactories, SenderTransactionProtocol, }, }; use tari_crypto::{ - hash::blake2::Blake256, + common::Blake256, inputs, keys::{PublicKey as PublicKeyTrait, SecretKey}, script, @@ -74,7 +74,7 @@ use tari_wallet::{ service::OutputManagerService, storage::{ database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, - models::DbUnblindedOutput, + models::{DbUnblindedOutput, OutputStatus}, sqlite_db::OutputManagerSqliteDatabase, }, TxId, @@ -83,18 +83,14 @@ use tari_wallet::{ transaction_service::handle::TransactionServiceHandle, types::ValidationRetryStrategy, }; - -use tari_comms::protocol::rpc::RpcClientConfig; -use tari_wallet::output_manager_service::storage::models::OutputStatus; use tokio::{ - runtime::Runtime, sync::{broadcast, broadcast::channel}, - time::delay_for, + task, + time, }; #[allow(clippy::type_complexity)] -pub fn setup_output_manager_service( - runtime: &mut Runtime, +async fn setup_output_manager_service( backend: T, with_connection: bool, ) -> ( @@ -124,11 +120,11 @@ pub fn setup_output_manager_service( let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); + task::spawn(connectivity_mock.run()); let service = BaseNodeWalletRpcMockService::new(); let rpc_service_state = service.get_state(); @@ -137,43 +133,39 @@ pub fn setup_output_manager_service( let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = MockRpcServer::new(server, server_node_identity.clone()); - runtime.handle().enter(|| mock_server.serve()); + mock_server.serve(); if with_connection { - let connection = runtime.block_on(async { - mock_server - .create_connection(server_node_identity.to_peer(), protocol_name.into()) - .await - }); - runtime.block_on(connectivity_mock_state.add_active_connection(connection)); + let connection = mock_server + .create_connection(server_node_identity.to_peer(), protocol_name.into()) + .await; + connectivity_mock_state.add_active_connection(connection).await; } - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - basenode_service_handle, - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + basenode_service_handle, + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -218,8 +210,7 @@ async fn complete_transaction(mut stp: SenderTransactionProtocol, mut oms: Outpu stp.get_transaction().unwrap().clone() } -pub fn setup_oms_with_bn_state( - runtime: &mut Runtime, +pub async fn setup_oms_with_bn_state( backend: T, height: Option, ) -> ( @@ -246,35 +237,35 @@ pub fn setup_oms_with_bn_state( let base_node_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_base_node_state(height); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, connectivity_mock) = create_connectivity_mock(); let _connectivity_mock_state = connectivity_mock.get_shared_state(); - runtime.spawn(connectivity_mock.run()); - - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig { - base_node_query_timeout: Duration::from_secs(10), - max_utxo_query_size: 2, - peer_dial_retry_timeout: Duration::from_secs(5), - ..Default::default() - }, - ts_handle.clone(), - oms_request_receiver, - OutputManagerDatabase::new(backend), - oms_event_publisher.clone(), - factories, - constants, - shutdown.to_signal(), - base_node_service_handle.clone(), - connectivity_manager, - CommsSecretKey::default(), - )) - .unwrap(); + task::spawn(connectivity_mock.run()); + + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig { + base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, + peer_dial_retry_timeout: Duration::from_secs(5), + ..Default::default() + }, + ts_handle.clone(), + oms_request_receiver, + OutputManagerDatabase::new(backend), + oms_event_publisher.clone(), + factories, + constants, + shutdown.to_signal(), + base_node_service_handle.clone(), + connectivity_manager, + CommsSecretKey::default(), + ) + .await + .unwrap(); let output_manager_service_handle = OutputManagerHandle::new(oms_request_sender, oms_event_publisher); - runtime.spawn(async move { output_manager_service.start().await.unwrap() }); + task::spawn(async move { output_manager_service.start().await.unwrap() }); ( output_manager_service_handle, @@ -321,63 +312,65 @@ fn generate_sender_transaction_message(amount: MicroTari) -> (TxId, TransactionS ) } -#[test] -fn fee_estimate() { +#[tokio::test] +async fn fee_estimate() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let (_, uo) = make_input(&mut OsRng.clone(), MicroTari::from(3000), &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // minimum fee let fee_per_gram = MicroTari::from(1); - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, 1) + .await .unwrap(); assert_eq!(fee, MicroTari::from(100)); let fee_per_gram = MicroTari::from(25); for outputs in 1..5 { - let fee = runtime - .block_on(oms.fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs)) + let fee = oms + .fee_estimate(MicroTari::from(100), fee_per_gram, 1, outputs) + .await .unwrap(); assert_eq!(fee, Fee::calculate(fee_per_gram, 1, 1, outputs as usize)); } // not enough funds - let err = runtime - .block_on(oms.fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1)) + let err = oms + .fee_estimate(MicroTari::from(2750), fee_per_gram, 1, 1) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_no_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_no_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // no chain metadata let (mut oms, _shutdown, _, _) = - setup_oms_with_bn_state(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), None); + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), None).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -389,24 +382,25 @@ fn test_utxo_selection_no_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } // but we have no chain state so the lowest maturity should be used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 8); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -415,34 +409,31 @@ fn test_utxo_selection_no_chain_metadata() { } // test that we can get a fee estimate with no chain metadata - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(300)); // test if a fee estimate would be possible with pending funds included // at this point 52000 uT is still spendable, with pending change incoming of 1690 uT // so instead of returning "not enough funds", return "funds pending" let spendable_amount = (3..=10).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::FundsPending)); // test not enough funds let broke_amount = spendable_amount + MicroTari::from(2000); - let err = runtime - .block_on(oms.fee_estimate(broke_amount, fee_per_gram, 1, 2)) - .unwrap_err(); + let err = oms.fee_estimate(broke_amount, fee_per_gram, 1, 2).await.unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // coin split uses the "Largest" selection strategy - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(fee, MicroTari::from(820)); assert_eq!(utxos_total_value, MicroTari::from(10_000)); // test that largest utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for (index, utxo) in utxos.iter().enumerate() { let i = index as u64 + 3; @@ -452,31 +443,28 @@ fn test_utxo_selection_no_chain_metadata() { } #[allow(clippy::identity_op)] -#[test] -fn test_utxo_selection_with_chain_metadata() { +#[tokio::test] +async fn test_utxo_selection_with_chain_metadata() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); // setup with chain metadata at a height of 6 - let (mut oms, _shutdown, _, _) = setup_oms_with_bn_state( - &mut runtime, - OutputManagerSqliteDatabase::new(connection, None), - Some(6), - ); + let (mut oms, _shutdown, _, _) = + setup_oms_with_bn_state(OutputManagerSqliteDatabase::new(connection, None), Some(6)).await; // no utxos - not enough funds let amount = MicroTari::from(1000); let fee_per_gram = MicroTari::from(10); - let err = runtime - .block_on(oms.prepare_transaction_to_send( + let err = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); @@ -488,52 +476,52 @@ fn test_utxo_selection_with_chain_metadata() { &factories.commitment, Some(OutputFeatures::with_maturity(i)), ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); + oms.add_output(uo.clone()).await.unwrap(); } - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 10); // test fee estimates - let fee = runtime.block_on(oms.fee_estimate(amount, fee_per_gram, 1, 2)).unwrap(); + let fee = oms.fee_estimate(amount, fee_per_gram, 1, 2).await.unwrap(); assert_eq!(fee, MicroTari::from(310)); // test fee estimates are maturity aware // even though we have utxos for the fee, they can't be spent because they are not mature yet let spendable_amount = (1..=6).sum::() * amount; - let err = runtime - .block_on(oms.fee_estimate(spendable_amount, fee_per_gram, 1, 2)) + let err = oms + .fee_estimate(spendable_amount, fee_per_gram, 1, 2) + .await .unwrap_err(); assert!(matches!(err, OutputManagerError::NotEnoughFunds)); // test coin split is maturity aware - let (_, _, fee, utxos_total_value) = runtime - .block_on(oms.create_coin_split(amount, 5, fee_per_gram, None)) - .unwrap(); + let (_, _, fee, utxos_total_value) = oms.create_coin_split(amount, 5, fee_per_gram, None).await.unwrap(); assert_eq!(utxos_total_value, MicroTari::from(6_000)); assert_eq!(fee, MicroTari::from(820)); // test that largest spendable utxo was encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 9); let found = utxos.iter().any(|u| u.value == 6 * amount); assert!(!found, "An unspendable utxo was selected"); // test transactions - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the lowest 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 7); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 1); @@ -543,20 +531,21 @@ fn test_utxo_selection_with_chain_metadata() { } // when the amount is greater than the largest utxo, then "Largest" selection strategy is used - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), 6 * amount, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); assert!(stp.get_tx_id().is_ok()); // test that utxos with the highest spendable 2 maturities were encumbered - let utxos = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let utxos = oms.get_unspent_outputs().await.unwrap(); assert_eq!(utxos.len(), 5); for utxo in utxos.iter() { assert_ne!(utxo.features.maturity, 4); @@ -566,22 +555,21 @@ fn test_utxo_selection_with_chain_metadata() { } } -#[test] -fn sending_transaction_and_confirmation() { +#[tokio::test] +async fn sending_transaction_and_confirmation() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let mut runtime = Runtime::new().unwrap(); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let (_ti, uo) = make_input( &mut OsRng.clone(), MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo.clone())).unwrap(); - match runtime.block_on(oms.add_output(uo)) { + oms.add_output(uo.clone()).await.unwrap(); + match oms.add_output(uo).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::DuplicateOutput)) => {}, _ => panic!("Incorrect error message"), }; @@ -592,25 +580,26 @@ fn sending_transaction_and_confirmation() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); // 1 of the 2 outputs should be rewindable, there should be 2 outputs due to change but if we get unlucky enough // that there is no change we will skip this aspect of the test @@ -643,23 +632,23 @@ fn sending_transaction_and_confirmation() { assert_eq!(num_rewound, 1, "Should only be 1 rewindable output"); } - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); assert_eq!( - runtime.block_on(oms.get_pending_transactions()).unwrap().len(), + oms.get_pending_transactions().await.unwrap().len(), 0, "Should have no pending tx" ); assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), + oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len(), "# Outputs should equal number of sent inputs" ); assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - num_outputs + 1 - runtime.block_on(oms.get_spent_outputs()).unwrap().len() + tx.body.outputs().len() - 1, + oms.get_unspent_outputs().await.unwrap().len(), + num_outputs + 1 - oms.get_spent_outputs().await.unwrap().len() + tx.body.outputs().len() - 1, "Unspent outputs" ); @@ -675,16 +664,14 @@ fn sending_transaction_and_confirmation() { } } -#[test] -fn send_not_enough_funds() { +#[tokio::test] +async fn send_not_enough_funds() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { let (_ti, uo) = make_input( @@ -692,68 +679,70 @@ fn send_not_enough_funds() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(num_outputs * 2000), - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(num_outputs * 2000), + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn send_no_change() { +#[tokio::test] +async fn send_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - script!(Nop), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + script!(Nop), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - let mut stp = runtime - .block_on(oms.prepare_transaction_to_send( + let mut stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(value1 + value2) - fee_without_change, fee_per_gram, None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); assert_eq!(stp.get_amount_to_self().unwrap(), MicroTari::from(0)); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let msg = stp.build_single_round_message().unwrap(); @@ -776,99 +765,91 @@ fn send_no_change() { let tx = stp.get_transaction().unwrap(); - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!( - runtime.block_on(oms.get_spent_outputs()).unwrap().len(), - tx.body.inputs().len() - ); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_spent_outputs().await.unwrap().len(), tx.body.inputs().len()); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); } -#[test] -fn send_not_enough_for_change() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn send_not_enough_for_change() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(20); let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let value1 = 500; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value1), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value1), + )) + .await + .unwrap(); let value2 = 800; - runtime - .block_on(oms.add_output(create_unblinded_output( - TariScript::default(), - OutputFeatures::default(), - TestParamsHelpers::new(), - MicroTari::from(value2), - ))) - .unwrap(); + oms.add_output(create_unblinded_output( + TariScript::default(), + OutputFeatures::default(), + TestParamsHelpers::new(), + MicroTari::from(value2), + )) + .await + .unwrap(); - match runtime.block_on(oms.prepare_transaction_to_send( - OsRng.next_u64(), - MicroTari::from(value1 + value2 + 1) - fee_without_change, - MicroTari::from(20), - None, - "".to_string(), - script!(Nop), - )) { + match oms + .prepare_transaction_to_send( + OsRng.next_u64(), + MicroTari::from(value1 + value2 + 1) - fee_without_change, + MicroTari::from(20), + None, + "".to_string(), + script!(Nop), + ) + .await + { Err(OutputManagerError::NotEnoughFunds) => {}, _ => panic!(), } } -#[test] -fn receiving_and_confirmation() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn receiving_and_confirmation() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(tx_id, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); } -#[test] -fn cancel_transaction() { +#[tokio::test] +async fn cancel_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); - let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -877,46 +858,43 @@ fn cancel_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - match runtime.block_on(oms.cancel_transaction(1)) { + match oms.cancel_transaction(1).await { Err(OutputManagerError::OutputManagerStorageError(OutputManagerStorageError::ValueNotFound)) => {}, _ => panic!("Value should not exist"), } - runtime - .block_on(oms.cancel_transaction(stp.get_tx_id().unwrap())) - .unwrap(); + oms.cancel_transaction(stp.get_tx_id().unwrap()).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn cancel_transaction_and_reinstate_inbound_tx() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn cancel_transaction_and_reinstate_inbound_tx() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); - let pending_txs = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_txs = oms.get_pending_transactions().await.unwrap(); assert_eq!(pending_txs.len(), 1); @@ -928,7 +906,7 @@ fn cancel_transaction_and_reinstate_inbound_tx() { .unwrap() .clone(); - runtime.block_on(oms.cancel_transaction(tx_id)).unwrap(); + oms.cancel_transaction(tx_id).await.unwrap(); let cancelled_output = backend .fetch(&DbKey::OutputsByTxIdAndStatus(tx_id, OutputStatus::CancelledInbound)) @@ -942,28 +920,25 @@ fn cancel_transaction_and_reinstate_inbound_tx() { panic!("Should have found cancelled output"); } - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 0); - runtime - .block_on(oms.reinstate_cancelled_inbound_transaction(tx_id)) - .unwrap(); + oms.reinstate_cancelled_inbound_transaction(tx_id).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_incoming_balance, value); } -#[test] -fn timeout_transaction() { +#[tokio::test] +async fn timeout_transaction() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let num_outputs = 20; for _i in 0..num_outputs { @@ -972,50 +947,43 @@ fn timeout_transaction() { MicroTari::from(100 + OsRng.next_u64() % 1000), &factories.commitment, ); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); } - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let remaining_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap().len(); + let remaining_outputs = oms.get_unspent_outputs().await.unwrap().len(); - thread::sleep(Duration::from_millis(2)); + time::sleep(Duration::from_millis(2)).await; - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1000))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1000)).await.unwrap(); - assert_eq!( - runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), - remaining_outputs - ); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), remaining_outputs); - runtime - .block_on(oms.timeout_transactions(Duration::from_millis(1))) - .unwrap(); + oms.timeout_transactions(Duration::from_millis(1)).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), num_outputs); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), num_outputs); } -#[test] -fn test_get_balance() { +#[tokio::test] +async fn test_get_balance() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(MicroTari::from(0), balance.available_balance); @@ -1023,63 +991,62 @@ fn test_get_balance() { let output_val = MicroTari::from(2000); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let (_ti, uo) = make_input(&mut OsRng.clone(), output_val, &factories.commitment); total += uo.value; - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); let send_value = MicroTari::from(1000); - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), send_value, MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let change_val = stp.get_change_amount().unwrap(); let recv_value = MicroTari::from(1500); let (_tx_id, sender_message) = generate_sender_transaction_message(recv_value); - let _rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); + let _rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(output_val, balance.available_balance); assert_eq!(recv_value + change_val, balance.pending_incoming_balance); assert_eq!(output_val, balance.pending_outgoing_balance); } -#[test] -fn test_confirming_received_output() { - let mut runtime = Runtime::new().unwrap(); - +#[tokio::test] +async fn test_confirming_received_output() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let value = MicroTari::from(5000); let (tx_id, sender_message) = generate_sender_transaction_message(value); - let rtp = runtime.block_on(oms.get_recipient_transaction(sender_message)).unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); + let rtp = oms.get_recipient_transaction(sender_message).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); let output = match rtp.state { RecipientState::Finalized(s) => s.output, RecipientState::Failed(_) => panic!("Should not be in Failed state"), }; - runtime - .block_on(oms.confirm_transaction(tx_id, vec![], vec![output.clone()])) + oms.confirm_transaction(tx_id, vec![], vec![output.clone()]) + .await .unwrap(); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value); let factories = CryptoFactories::default(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1090,99 +1057,100 @@ fn test_confirming_received_output() { assert_eq!(rewind_result.committed_value, value); } -#[test] -fn sending_transaction_with_short_term_clear() { +#[tokio::test] +async fn sending_transaction_with_short_term_clear() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; let available_balance = 10_000 * uT; let (_ti, uo) = make_input(&mut OsRng.clone(), available_balance, &factories.commitment); - runtime.block_on(oms.add_output(uo)).unwrap(); + oms.add_output(uo).await.unwrap(); // Check that funds are encumbered and then unencumbered if the pending tx is not confirmed before restart - let _stp = runtime - .block_on(oms.prepare_transaction_to_send( + let _stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); let expected_change = balance.pending_incoming_balance; assert_eq!(balance.pending_outgoing_balance, available_balance); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone(), true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend.clone(), true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that a unconfirm Pending Transaction can be cancelled - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - runtime.block_on(oms.cancel_transaction(sender_tx_id)).unwrap(); + oms.cancel_transaction(sender_tx_id).await.unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, available_balance); // Check that is the pending tx is confirmed that the encumberance persists after restart - let stp = runtime - .block_on(oms.prepare_transaction_to_send( + let stp = oms + .prepare_transaction_to_send( OsRng.next_u64(), MicroTari::from(1000), MicroTari::from(20), None, "".to_string(), script!(Nop), - )) + ) + .await .unwrap(); let sender_tx_id = stp.get_tx_id().unwrap(); - runtime.block_on(oms.confirm_pending_transaction(sender_tx_id)).unwrap(); + oms.confirm_pending_transaction(sender_tx_id).await.unwrap(); drop(oms); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.pending_outgoing_balance, available_balance); - let tx = runtime.block_on(complete_transaction(stp, oms.clone())); + let tx = complete_transaction(stp, oms.clone()).await; - runtime - .block_on(oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone())) + oms.confirm_transaction(sender_tx_id, tx.body.inputs().clone(), tx.body.outputs().clone()) + .await .unwrap(); - let balance = runtime.block_on(oms.get_balance()).unwrap(); + let balance = oms.get_balance().await.unwrap(); assert_eq!(balance.available_balance, expected_change); } -#[test] -fn coin_split_with_change() { +#[tokio::test] +async fn coin_split_with_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let val1 = 6_000 * uT; let val2 = 7_000 * uT; @@ -1190,14 +1158,15 @@ fn coin_split_with_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); let fee_per_gram = MicroTari::from(25); let split_count = 8; - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 2); assert_eq!(coin_split_tx.body.outputs().len(), split_count + 1); @@ -1205,13 +1174,12 @@ fn coin_split_with_change() { assert_eq!(amount, val2 + val3); } -#[test] -fn coin_split_no_change() { +#[tokio::test] +async fn coin_split_no_change() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let fee_per_gram = MicroTari::from(25); let split_count = 15; @@ -1222,12 +1190,13 @@ fn coin_split_no_change() { let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); - assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); - assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + assert!(oms.add_output(uo1).await.is_ok()); + assert!(oms.add_output(uo2).await.is_ok()); + assert!(oms.add_output(uo3).await.is_ok()); - let (_tx_id, coin_split_tx, fee, amount) = runtime - .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + let (_tx_id, coin_split_tx, fee, amount) = oms + .create_coin_split(1000.into(), split_count, fee_per_gram, None) + .await .unwrap(); assert_eq!(coin_split_tx.body.inputs().len(), 3); assert_eq!(coin_split_tx.body.outputs().len(), split_count); @@ -1235,13 +1204,12 @@ fn coin_split_no_change() { assert_eq!(amount, val1 + val2 + val3); } -#[test] -fn handle_coinbase() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn handle_coinbase() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); - let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(&mut runtime, backend, true); + let (mut oms, _shutdown, _, _, _, _, _) = setup_output_manager_service(backend, true).await; let reward1 = MicroTari::from(1000); let fees1 = MicroTari::from(500); @@ -1253,37 +1221,25 @@ fn handle_coinbase() { let fees3 = MicroTari::from(500); let value3 = reward3 + fees3; - let _ = runtime - .block_on(oms.get_coinbase_transaction(1, reward1, fees1, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value1 - ); - let _tx2 = runtime - .block_on(oms.get_coinbase_transaction(2, reward2, fees2, 1)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - let tx3 = runtime - .block_on(oms.get_coinbase_transaction(3, reward3, fees3, 2)) - .unwrap(); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 0); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 2); + let _ = oms.get_coinbase_transaction(1, reward1, fees1, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value1); + let _tx2 = oms.get_coinbase_transaction(2, reward2, fees2, 1).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); + let tx3 = oms.get_coinbase_transaction(3, reward3, fees3, 2).await.unwrap(); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 0); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, + oms.get_balance().await.unwrap().pending_incoming_balance, value2 + value3 ); let output = tx3.body.outputs()[0].clone(); - let rewind_public_keys = runtime.block_on(oms.get_rewind_public_keys()).unwrap(); + let rewind_public_keys = oms.get_rewind_public_keys().await.unwrap(); let rewind_result = output .rewind_range_proof_value_only( &factories.range_proof, @@ -1293,28 +1249,22 @@ fn handle_coinbase() { .unwrap(); assert_eq!(rewind_result.committed_value, value3); - runtime - .block_on(oms.confirm_transaction(3, vec![], vec![output])) - .unwrap(); + oms.confirm_transaction(3, vec![], vec![output]).await.unwrap(); - assert_eq!(runtime.block_on(oms.get_pending_transactions()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_unspent_outputs()).unwrap().len(), 1); - assert_eq!(runtime.block_on(oms.get_balance()).unwrap().available_balance, value3); + assert_eq!(oms.get_pending_transactions().await.unwrap().len(), 1); + assert_eq!(oms.get_unspent_outputs().await.unwrap().len(), 1); + assert_eq!(oms.get_balance().await.unwrap().available_balance, value3); + assert_eq!(oms.get_balance().await.unwrap().pending_incoming_balance, value2); assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_incoming_balance, - value2 - ); - assert_eq!( - runtime.block_on(oms.get_balance()).unwrap().pending_outgoing_balance, + oms.get_balance().await.unwrap().pending_outgoing_balance, MicroTari::from(0) ); } -#[test] -fn test_utxo_stxo_invalid_txo_validation() { +#[tokio::test] +async fn test_utxo_stxo_invalid_txo_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1374,8 +1324,8 @@ fn test_utxo_stxo_invalid_txo_validation() { .unwrap(); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1386,7 +1336,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1396,7 +1346,7 @@ fn test_utxo_stxo_invalid_txo_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1407,7 +1357,7 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3.clone())).unwrap(); + oms.add_output(unspent_output3.clone()).await.unwrap(); let unspent_value4 = 901; let unspent_output4 = create_unblinded_output( @@ -1418,44 +1368,42 @@ fn test_utxo_stxo_invalid_txo_validation() { ); let unspent_tx_output4 = unspent_output4.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output4.clone())).unwrap(); + oms.add_output(unspent_output4.clone()).await.unwrap(); rpc_service_state.set_utxos(vec![invalid_output.as_transaction_output(&factories).unwrap()]); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Invalid, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Invalid) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); @@ -1466,36 +1414,34 @@ fn test_utxo_stxo_invalid_txo_validation() { unspent_tx_output3, ]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(3, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(3, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = (*msg).clone() { - success = true; - break; - }; - }; - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,TxoValidationType::Unspent) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 4); assert!(outputs.iter().any(|o| o == &unspent_output1)); @@ -1505,46 +1451,45 @@ fn test_utxo_stxo_invalid_txo_validation() { rpc_service_state.set_utxos(vec![spent_tx_output1]); - runtime - .block_on(oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Spent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { - success = true; - break; - }; - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationSuccess(_, TxoValidationType::Spent) = (*msg).clone() { + success = true; + break; + }; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 5); assert!(outputs.iter().any(|o| o == &spent_output1)); } -#[test] -fn test_base_node_switch_during_validation() { +#[tokio::test] +async fn test_base_node_switch_during_validation() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1556,8 +1501,8 @@ fn test_base_node_switch_during_validation() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1568,7 +1513,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1578,7 +1523,7 @@ fn test_base_node_switch_during_validation() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); let unspent_value3 = 900; let unspent_output3 = create_unblinded_output( @@ -1589,7 +1534,7 @@ fn test_base_node_switch_during_validation() { ); let unspent_tx_output3 = unspent_output3.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output3)).unwrap(); + oms.add_output(unspent_output3).await.unwrap(); // First RPC server state rpc_service_state.set_utxos(vec![unspent_tx_output1, unspent_tx_output3]); @@ -1598,53 +1543,52 @@ fn test_base_node_switch_during_validation() { // New base node we will switch to let new_server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess)) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::UntilSuccess) + .await .unwrap(); - let _fetch_utxo_calls = runtime - .block_on(rpc_service_state.wait_pop_fetch_utxos_calls(1, Duration::from_secs(60))) + let _fetch_utxo_calls = rpc_service_state + .wait_pop_fetch_utxos_calls(1, Duration::from_secs(60)) + .await .unwrap(); - runtime - .block_on(oms.set_base_node_public_key(new_server_node_identity.public_key().clone())) + oms.set_base_node_public_key(new_server_node_identity.public_key().clone()) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut abort = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { - abort = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut abort = false; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationAborted(_,_) = (*msg).clone() { + abort = true; + break; + } + } + }, + () = &mut delay => { + break; + }, } - assert!(abort, "Did not receive validation abort"); - }); + } + assert!(abort, "Did not receive validation abort"); } -#[test] -fn test_txo_validation_connection_timeout_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_connection_timeout_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, _rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, false); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, false).await; + let mut event_stream = oms.get_event_stream(); let unspent_value1 = 500; let unspent_output1 = create_unblinded_output( @@ -1654,7 +1598,7 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1664,57 +1608,54 @@ fn test_txo_validation_connection_timeout_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut timeout = 0; - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - match (*msg).clone() { - OutputManagerEvent::TxoValidationTimedOut(_,_) => { - timeout+=1; - }, - OutputManagerEvent::TxoValidationFailure(_,_) => { - failed+=1; - }, - _ => (), - } - }; - if timeout+failed >= 3 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)); + tokio::pin!(delay); + let mut timeout = 0; + let mut failed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + match &*event { + OutputManagerEvent::TxoValidationTimedOut(_,_) => { + timeout+=1; + }, + OutputManagerEvent::TxoValidationFailure(_,_) => { + failed+=1; + }, + _ => (), + } + + if timeout+failed >= 3 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - assert_eq!(timeout, 2); - }); + } + assert_eq!(failed, 1); + assert_eq!(timeout, 2); } -#[test] -fn test_txo_validation_rpc_error_retries() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_error_retries() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_rpc_status_error(Some(RpcStatus::bad_request("blah".to_string()))); let unspent_value1 = 500; @@ -1725,7 +1666,7 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1735,44 +1676,42 @@ fn test_txo_validation_rpc_error_retries() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } - } - - if failed >= 1 { - break; + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { + failed+=1; } - }, - () = delay => { + } + + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_rpc_timeout() { - let mut runtime = Runtime::new().unwrap(); +#[tokio::test] +async fn test_txo_validation_rpc_timeout() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); @@ -1784,8 +1723,8 @@ fn test_txo_validation_rpc_timeout() { server_node_identity, mut rpc_service_state, _connectivity_mock_state, - ) = setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + ) = setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(120))); let unspent_value1 = 500; @@ -1796,7 +1735,7 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value1), ); - runtime.block_on(oms.add_output(unspent_output1)).unwrap(); + oms.add_output(unspent_output1).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1806,57 +1745,51 @@ fn test_txo_validation_rpc_timeout() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(1)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for( - RpcClientConfig::default().deadline.unwrap() + - RpcClientConfig::default().deadline_grace_period + - Duration::from_secs(30), - ) - .fuse(); - let mut failed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationFailure(_,_) = (*msg).clone() { - failed+=1; - } + let delay = + time::sleep(RpcClientConfig::default().timeout_with_grace_period().unwrap() + Duration::from_secs(30)).fuse(); + tokio::pin!(delay); + let mut failed = 0; + loop { + tokio::select! { + event = event_stream.recv() => { + if let Ok(msg) = event { + if let OutputManagerEvent::TxoValidationFailure(_,_) = &*msg { + failed+=1; } + } - if failed >= 1 { - break; - } - }, - () = delay => { + if failed >= 1 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(failed, 1); - }); + } + assert_eq!(failed, 1); } -#[test] -fn test_txo_validation_base_node_not_synced() { +#[tokio::test] +async fn test_txo_validation_base_node_not_synced() { let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let (mut oms, _shutdown, _ts, _mock_rpc_server, server_node_identity, rpc_service_state, _connectivity_mock_state) = - setup_output_manager_service(&mut runtime, backend, true); - let mut event_stream = oms.get_event_stream_fused(); + setup_output_manager_service(backend, true).await; + let mut event_stream = oms.get_event_stream(); rpc_service_state.set_is_synced(false); let unspent_value1 = 500; @@ -1868,7 +1801,7 @@ fn test_txo_validation_base_node_not_synced() { ); let unspent_tx_output1 = unspent_output1.as_transaction_output(&factories).unwrap(); - runtime.block_on(oms.add_output(unspent_output1.clone())).unwrap(); + oms.add_output(unspent_output1.clone()).await.unwrap(); let unspent_value2 = 800; let unspent_output2 = create_unblinded_output( @@ -1878,74 +1811,67 @@ fn test_txo_validation_base_node_not_synced() { MicroTari::from(unspent_value2), ); - runtime.block_on(oms.add_output(unspent_output2)).unwrap(); + oms.add_output(unspent_output2).await.unwrap(); - runtime - .block_on(oms.set_base_node_public_key(server_node_identity.public_key().clone())) + oms.set_base_node_public_key(server_node_identity.public_key().clone()) + .await .unwrap(); - runtime - .block_on(oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5))) + oms.validate_txos(TxoValidationType::Unspent, ValidationRetryStrategy::Limited(5)) + .await .unwrap(); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut delayed = 0; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationDelayed(_,_) = (*msg).clone() { - delayed+=1; - } - } - if delayed >= 2 { - break; - } - }, - () = delay => { + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut delayed = 0; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationDelayed(_,_) = &*event { + delayed += 1; + } + if delayed >= 2 { break; - }, - } + } + }, + () = &mut delay => { + break; + }, } - assert_eq!(delayed, 2); - }); + } + assert_eq!(delayed, 2); rpc_service_state.set_is_synced(true); rpc_service_state.set_utxos(vec![unspent_tx_output1]); - runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut success = false; - loop { - futures::select! { - event = event_stream.select_next_some() => { - if let Ok(msg) = event { - if let OutputManagerEvent::TxoValidationSuccess(_,_) = (*msg).clone() { - success = true; - break; - } - } - }, - () = delay => { - break; - }, - } + let delay = time::sleep(Duration::from_secs(60)).fuse(); + tokio::pin!(delay); + let mut success = false; + loop { + tokio::select! { + Ok(event) = event_stream.recv() => { + if let OutputManagerEvent::TxoValidationSuccess(_,_) = &*event { + success = true; + break; + } + }, + () = &mut delay => { + break; + }, } - assert!(success, "Did not receive validation success event"); - }); + } + assert!(success, "Did not receive validation success event"); - let outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); + let outputs = oms.get_unspent_outputs().await.unwrap(); assert_eq!(outputs.len(), 1); assert!(outputs.iter().any(|o| o == &unspent_output1)); } -#[test] -fn test_oms_key_manager_discrepancy() { +#[tokio::test] +async fn test_oms_key_manager_discrepancy() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); - let mut runtime = Runtime::new().unwrap(); let (_oms_request_sender, oms_request_receiver) = reply_channel::unbounded(); let (oms_event_publisher, _) = broadcast::channel(200); @@ -1959,7 +1885,7 @@ fn test_oms_key_manager_discrepancy() { let basenode_service_handle = BaseNodeServiceHandle::new(sender, event_publisher_bns); let mut mock_base_node_service = MockBaseNodeService::new(receiver_bns, shutdown.to_signal()); mock_base_node_service.set_default_base_node_state(); - runtime.spawn(mock_base_node_service.run()); + task::spawn(mock_base_node_service.run()); let (connectivity_manager, _connectivity_mock) = create_connectivity_mock(); @@ -1968,45 +1894,45 @@ fn test_oms_key_manager_discrepancy() { let master_key1 = CommsSecretKey::random(&mut OsRng); - let output_manager_service = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1.clone(), - )) - .unwrap(); + let output_manager_service = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1.clone(), + ) + .await + .unwrap(); drop(output_manager_service); let (_oms_request_sender2, oms_request_receiver2) = reply_channel::unbounded(); - let output_manager_service2 = runtime - .block_on(OutputManagerService::new( - OutputManagerServiceConfig::default(), - ts_handle.clone(), - oms_request_receiver2, - db.clone(), - oms_event_publisher.clone(), - factories.clone(), - constants.clone(), - shutdown.to_signal(), - basenode_service_handle.clone(), - connectivity_manager.clone(), - master_key1, - )) - .expect("Should be able to make a new OMS with same master key"); + let output_manager_service2 = OutputManagerService::new( + OutputManagerServiceConfig::default(), + ts_handle.clone(), + oms_request_receiver2, + db.clone(), + oms_event_publisher.clone(), + factories.clone(), + constants.clone(), + shutdown.to_signal(), + basenode_service_handle.clone(), + connectivity_manager.clone(), + master_key1, + ) + .await + .expect("Should be able to make a new OMS with same master key"); drop(output_manager_service2); let (_oms_request_sender3, oms_request_receiver3) = reply_channel::unbounded(); let master_key2 = CommsSecretKey::random(&mut OsRng); - let output_manager_service3 = runtime.block_on(OutputManagerService::new( + let output_manager_service3 = OutputManagerService::new( OutputManagerServiceConfig::default(), ts_handle, oms_request_receiver3, @@ -2018,7 +1944,8 @@ fn test_oms_key_manager_discrepancy() { basenode_service_handle, connectivity_manager, master_key2, - )); + ) + .await; assert!(matches!( output_manager_service3, @@ -2026,26 +1953,25 @@ fn test_oms_key_manager_discrepancy() { )); } -#[test] -fn get_coinbase_tx_for_same_height() { +#[tokio::test] +async fn get_coinbase_tx_for_same_height() { let (connection, _tempdir) = get_temp_sqlite_database_connection(); - let mut runtime = Runtime::new().unwrap(); let (mut oms, _shutdown, _, _, _, _, _) = - setup_output_manager_service(&mut runtime, OutputManagerSqliteDatabase::new(connection, None), true); + setup_output_manager_service(OutputManagerSqliteDatabase::new(connection, None), true).await; - runtime - .block_on(oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(1, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(pending_transactions.values().any(|p| p.tx_id == 1)); - runtime - .block_on(oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1)) + oms.get_coinbase_transaction(2, 100_000.into(), 100.into(), 1) + .await .unwrap(); - let pending_transactions = runtime.block_on(oms.get_pending_transactions()).unwrap(); + let pending_transactions = oms.get_pending_transactions().await.unwrap(); assert!(!pending_transactions.values().any(|p| p.tx_id == 1)); assert!(pending_transactions.values().any(|p| p.tx_id == 2)); } diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index c0609da64c..746d3cd9c5 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -20,7 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{data::get_temp_sqlite_database_connection, utils::make_input}; +use std::time::Duration; + use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, @@ -28,14 +29,16 @@ use aes_gcm::{ use chrono::{Duration as ChronoDuration, Utc}; use diesel::result::{DatabaseErrorKind, Error::DatabaseError}; use rand::{rngs::OsRng, RngCore}; -use std::time::Duration; +use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey, script::TariScript}; +use tokio::runtime::Runtime; + +use tari_common_types::types::PrivateKey; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::MicroTari, transaction::OutputFeatures, - types::{CryptoFactories, PrivateKey}, + CryptoFactories, }; -use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey, script::TariScript}; use tari_wallet::output_manager_service::{ error::OutputManagerStorageError, service::Balance, @@ -46,11 +49,11 @@ use tari_wallet::output_manager_service::{ }, }; -use tokio::runtime::Runtime; +use crate::support::{data::get_temp_sqlite_database_connection, utils::make_input}; #[allow(clippy::same_item_push)] pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let db = OutputManagerDatabase::new(backend); let factories = CryptoFactories::default(); @@ -392,7 +395,7 @@ pub fn test_output_manager_sqlite_db_encrypted() { #[test] pub fn test_key_manager_crud() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection, None); let db = OutputManagerDatabase::new(backend); @@ -429,7 +432,7 @@ pub fn test_key_manager_crud() { assert_eq!(read_state3.primary_key_index, 2); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_short_term_encumberance() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); @@ -510,7 +513,7 @@ pub async fn test_short_term_encumberance() { ); } -#[tokio_macros::test] +#[tokio::test] pub async fn test_no_duplicate_outputs() { let factories = CryptoFactories::default(); let (connection, _tempdir) = get_temp_sqlite_database_connection(); diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index f9d8010ac7..1b1243d72b 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -20,8 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::Sink; -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tari_comms::{ message::MessageTag, multiaddr::Multiaddr, @@ -32,7 +31,7 @@ use tari_comms::{ }; use tari_comms_dht::{envelope::DhtMessageHeader, Dht}; use tari_p2p::{ - comms_connector::{InboundDomainConnector, PeerMessage}, + comms_connector::InboundDomainConnector, domain_message::DomainMessage, initialization::initialize_local_test_comms, }; @@ -43,18 +42,14 @@ pub fn get_next_memory_address() -> Multiaddr { format!("/memory/{}", port).parse().unwrap() } -pub async fn setup_comms_services( +pub async fn setup_comms_services( node_identity: Arc, peers: Vec>, - publisher: InboundDomainConnector, + publisher: InboundDomainConnector, database_path: String, discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, -) -> (CommsNode, Dht) -where - TSink: Sink> + Clone + Unpin + Send + Sync + 'static, - TSink::Error: Error + Send + Sync, -{ +) -> (CommsNode, Dht) { let peers = peers.into_iter().map(|ni| ni.to_peer()).collect(); let (comms, dht, _) = initialize_local_test_comms( node_identity, diff --git a/base_layer/wallet/tests/support/rpc.rs b/base_layer/wallet/tests/support/rpc.rs index 0f7009bdd0..29e99e3372 100644 --- a/base_layer/wallet/tests/support/rpc.rs +++ b/base_layer/wallet/tests/support/rpc.rs @@ -25,6 +25,7 @@ use std::{ sync::{Arc, Mutex}, time::{Duration, Instant}, }; +use tari_common_types::types::Signature; use tari_comms::protocol::rpc::{Request, Response, RpcStatus}; use tari_core::{ base_node::{ @@ -52,12 +53,9 @@ use tari_core::{ }, }, tari_utilities::Hashable, - transactions::{ - transaction::{Transaction, TransactionOutput}, - types::Signature, - }, + transactions::transaction::{Transaction, TransactionOutput}, }; -use tokio::time::delay_for; +use tokio::time::sleep; /// This macro unlocks a Mutex or RwLock. If the lock is /// poisoned (i.e. panic while unlocked) the last value @@ -212,7 +210,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -234,7 +232,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -256,7 +254,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err(format!( "Did not receive enough calls within the timeout period, received {}, expected {}.", @@ -276,7 +274,7 @@ impl BaseNodeWalletRpcMockState { return Ok((*lock).drain(..num_calls).collect()); } drop(lock); - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; } Err("Did not receive enough calls within the timeout period".to_string()) } @@ -318,7 +316,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -345,7 +343,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -371,7 +369,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -415,7 +413,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { ) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } let message = request.into_message(); @@ -448,7 +446,7 @@ impl BaseNodeWalletService for BaseNodeWalletRpcMockService { async fn get_tip_info(&self, _request: Request<()>) -> Result, RpcStatus> { let delay_lock = *acquire_lock!(self.state.response_delay); if let Some(delay) = delay_lock { - delay_for(delay).await; + sleep(delay).await; } log::info!("Get tip info call received"); @@ -483,17 +481,18 @@ mod test { }; use std::convert::TryFrom; + use tari_common_types::types::BlindingFactor; use tari_core::{ base_node::{ proto::wallet_rpc::{TxSubmissionRejectionReason, TxSubmissionResponse}, rpc::{BaseNodeWalletRpcClient, BaseNodeWalletRpcServer}, }, proto::base_node::{ChainMetadata, TipInfoResponse}, - transactions::{transaction::Transaction, types::BlindingFactor}, + transactions::transaction::Transaction, }; use tokio::time::Duration; - #[tokio_macros::test] + #[tokio::test] async fn test_wallet_rpc_mock() { let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let client_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); diff --git a/base_layer/wallet/tests/support/utils.rs b/base_layer/wallet/tests/support/utils.rs index 630b8cd6fe..034c116c08 100644 --- a/base_layer/wallet/tests/support/utils.rs +++ b/base_layer/wallet/tests/support/utils.rs @@ -22,11 +22,11 @@ use rand::{CryptoRng, Rng}; use std::{fmt::Debug, thread, time::Duration}; +use tari_common_types::types::{CommitmentFactory, PrivateKey, PublicKey}; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams as TestParamsHelpers}, tari_amount::MicroTari, transaction::{OutputFeatures, TransactionInput, UnblindedOutput}, - types::{CommitmentFactory, PrivateKey, PublicKey}, }; use tari_crypto::{ keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index dc508f7321..6e1265460d 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -20,30 +20,48 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - support::{ - comms_and_services::{create_dummy_message, get_next_memory_address, setup_comms_services}, - rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, - utils::{make_input, TestParams}, - }, - transaction_service::transaction_protocols::add_transaction_to_database, +use std::{ + convert::{TryFrom, TryInto}, + path::Path, + sync::Arc, + time::Duration, }; + use chrono::{Duration as ChronoDuration, Utc}; use futures::{ channel::{mpsc, mpsc::Sender}, FutureExt, SinkExt, - StreamExt, }; use prost::Message; use rand::{rngs::OsRng, RngCore}; -use std::{ - convert::{TryFrom, TryInto}, - path::Path, - sync::Arc, - time::Duration, +use tari_crypto::{ + commitment::HomomorphicCommitmentFactory, + common::Blake256, + inputs, + keys::{PublicKey as PK, SecretKey as SK}, + script, + script::{ExecutionStack, TariScript}, +}; +use tempfile::tempdir; +use tokio::{ + runtime, + runtime::{Builder, Runtime}, + sync::{broadcast, broadcast::channel}, +}; + +use crate::{ + support::{ + comms_and_services::{create_dummy_message, get_next_memory_address, setup_comms_services}, + rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, + utils::{make_input, TestParams}, + }, + transaction_service::transaction_protocols::add_transaction_to_database, +}; +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{PrivateKey, PublicKey, Signature}, }; -use tari_common_types::chain_metadata::ChainMetadata; use tari_comms::{ message::EnvelopeBody, peer_manager::{NodeIdentity, PeerFeatures}, @@ -75,19 +93,11 @@ use tari_core::{ tari_amount::*, transaction::{KernelBuilder, KernelFeatures, OutputFeatures, Transaction}, transaction_protocol::{proto, recipient::RecipientSignedMessage, sender::TransactionSenderMessage}, - types::{CryptoFactories, PrivateKey, PublicKey, Signature}, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }, }; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - common::Blake256, - inputs, - keys::{PublicKey as PK, SecretKey as SK}, - script, - script::{ExecutionStack, TariScript}, -}; use tari_p2p::{comms_connector::pubsub_connector, domain_message::DomainMessage, Network}; use tari_service_framework::{reply_channel, RegisterHandle, StackBuilder}; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -137,19 +147,12 @@ use tari_wallet::{ }, types::{HashDigest, ValidationRetryStrategy}, }; -use tempfile::tempdir; -use tokio::{ - runtime, - runtime::{Builder, Runtime}, - sync::{broadcast, broadcast::channel}, - time::delay_for, -}; +use tokio::time::sleep; fn create_runtime() -> Runtime { - Builder::new() - .threaded_scheduler() + Builder::new_multi_thread() .enable_all() - .core_threads(8) + .worker_threads(8) .build() .unwrap() } @@ -172,7 +175,8 @@ pub fn setup_transaction_service< discovery_request_timeout: Duration, shutdown_signal: ShutdownSignal, ) -> (TransactionServiceHandle, OutputManagerHandle, CommsNode) { - let (publisher, subscription_factory) = pubsub_connector(runtime.handle().clone(), 100, 20); + let _enter = runtime.enter(); + let (publisher, subscription_factory) = pubsub_connector(100, 20); let subscription_factory = Arc::new(subscription_factory); let (comms, dht) = runtime.block_on(setup_comms_services( node_identity, @@ -303,11 +307,15 @@ pub fn setup_transaction_service_no_comms_and_oms_backend< let protocol_name = server.as_protocol_name(); let server_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); - let mut mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(server, server_node_identity.clone())); + let mut mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(server, server_node_identity.clone()) + }; - runtime.handle().enter(|| mock_server.serve()); + { + let _enter = runtime.handle().enter(); + mock_server.serve(); + } let connection = runtime.block_on(async { mock_server @@ -504,9 +512,9 @@ fn manage_single_transaction() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( &mut runtime, @@ -524,7 +532,7 @@ fn manage_single_transaction() { .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); let _ = runtime.block_on( bob_comms @@ -556,18 +564,19 @@ fn manage_single_transaction() { .expect("Alice sending tx"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut count = 0; loop { - futures::select! { - _event = alice_event_stream.select_next_some() => { + tokio::select! { + _event = alice_event_stream.recv() => { println!("alice: {:?}", &*_event.as_ref().unwrap()); count+=1; if count>=2 { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -576,18 +585,19 @@ fn manage_single_transaction() { let mut tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { println!("bob: {:?}", &*event.as_ref().unwrap()); if let TransactionEvent::ReceivedFinalizedTransaction(id) = &*event.unwrap() { tx_id = *id; finalized+=1; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -747,7 +757,7 @@ fn send_one_sided_transaction_to_other() { shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -792,11 +802,12 @@ fn send_one_sided_transaction_to_other() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut found = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCompletedImmediately(id) = &*event.unwrap() { if id == &tx_id { found = true; @@ -804,7 +815,7 @@ fn send_one_sided_transaction_to_other() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1071,9 +1082,9 @@ fn manage_multiple_transactions() { Duration::from_secs(60), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); // Spin up Bob and Carol let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( @@ -1088,8 +1099,8 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + let mut bob_event_stream = bob_ts.get_event_stream(); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( &mut runtime, @@ -1103,18 +1114,18 @@ fn manage_multiple_transactions() { Duration::from_secs(1), shutdown.to_signal(), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); // Establish some connections beforehand, to reduce the amount of work done concurrently in tests // Connect Bob and Alice - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); let _ = runtime.block_on( bob_comms .connectivity() .dial_peer(alice_node_identity.node_id().clone()), ); - runtime.block_on(async { delay_for(Duration::from_secs(3)).await }); + runtime.block_on(async { sleep(Duration::from_secs(3)).await }); // Connect alice to carol let _ = runtime.block_on( @@ -1182,12 +1193,13 @@ fn manage_multiple_transactions() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1198,7 +1210,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1210,12 +1222,14 @@ fn manage_multiple_transactions() { log::trace!("Alice received all Tx messages"); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + + tokio::pin!(delay); let mut tx_reply = 0; let mut finalized = 0; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, @@ -1225,7 +1239,7 @@ fn manage_multiple_transactions() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1235,14 +1249,17 @@ fn manage_multiple_transactions() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let delay = sleep(Duration::from_secs(90)); + tokio::pin!(delay); + + tokio::pin!(delay); let mut finalized = 0; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event.unwrap() { finalized+=1 } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1264,7 +1281,7 @@ fn manage_multiple_transactions() { assert_eq!(carol_pending_inbound.len(), 0); assert_eq!(carol_completed_tx.len(), 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; bob_comms.wait_until_shutdown().await; @@ -1303,7 +1320,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); @@ -1355,11 +1372,14 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); +tokio::pin!(delay); + + tokio::pin!(delay); let mut errors = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { log::error!("ERROR: {:?}", event); if let TransactionEvent::Error(s) = &*event.unwrap() { if s == &"TransactionProtocolError(TransactionBuildError(InvalidSignatureError(\"Verifying kernel signature\")))".to_string() @@ -1371,7 +1391,7 @@ fn test_accepting_unknown_tx_id_and_malformed_reply() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1415,7 +1435,7 @@ fn finalize_tx_with_incorrect_pubkey() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1488,15 +1508,18 @@ fn finalize_tx_with_incorrect_pubkey() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event!"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1542,7 +1565,7 @@ fn finalize_tx_with_missing_output() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE); @@ -1623,15 +1646,18 @@ fn finalize_tx_with_missing_output() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(_) = (*event.unwrap()).clone() { panic!("Should not have received finalized event"); } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1648,8 +1674,7 @@ fn discovery_async_return_test() { let db_tempdir = tempdir().unwrap(); let db_folder = db_tempdir.path(); - let mut runtime = runtime::Builder::new() - .basic_scheduler() + let mut runtime = runtime::Builder::new_current_thread() .enable_time() .thread_name("discovery_async_return_test") .build() @@ -1714,7 +1739,7 @@ fn discovery_async_return_test() { Duration::from_secs(20), shutdown.to_signal(), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo1a) = make_input(&mut OsRng, MicroTari(5500), &factories.commitment); runtime.block_on(alice_oms.add_output(uo1a)).unwrap(); @@ -1741,17 +1766,20 @@ fn discovery_async_return_test() { let mut txid = 0; let mut is_success = true; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, result) = (*event.unwrap()).clone() { txid = tx_id; is_success = result; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to fail sending"); }, } @@ -1772,18 +1800,21 @@ fn discovery_async_return_test() { let mut success_result = false; let mut success_tx_id = 0u64; runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionDirectSendResult(tx_id, success) = &*event.unwrap() { success_result = *success; success_tx_id = *tx_id; break; } }, - () = delay => { + () = &mut delay => { panic!("Timeout while waiting for transaction to successfully be sent"); }, } @@ -1794,24 +1825,26 @@ fn discovery_async_return_test() { assert!(success_result); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap() { if tx_id == &tx_id2 { break; } } }, - () = delay => { + () = &mut delay => { panic!("Timeout while Alice was waiting for a transaction reply"); }, } } }); - shutdown.trigger().unwrap(); + shutdown.trigger(); runtime.block_on(async move { alice_comms.wait_until_shutdown().await; carol_comms.wait_until_shutdown().await; @@ -2012,7 +2045,7 @@ fn test_transaction_cancellation() { ..Default::default() }), ); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let alice_total_available = 250000 * uT; let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); @@ -2030,15 +2063,17 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionStoreForwardSendResult(_,_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2054,7 +2089,7 @@ fn test_transaction_cancellation() { None => (), Some(_) => break, } - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); if i >= 12 { panic!("Pending outbound transaction should have been added by now"); } @@ -2066,17 +2101,18 @@ fn test_transaction_cancellation() { // Wait for cancellation event, in an effort to nail down where the issue is for the flakey CI test runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2143,15 +2179,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2213,15 +2250,16 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2243,7 +2281,7 @@ fn test_transaction_cancellation() { ))) .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime .block_on(alice_ts.get_pending_inbound_transactions()) @@ -2257,17 +2295,18 @@ fn test_transaction_cancellation() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)).fuse(); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2386,7 +2425,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob_outbound_service.call_count(), 0, "Should be no more calls"); let (_wallet_backend, backend, oms_backend, _, _temp_dir) = make_wallet_databases(None); @@ -2428,7 +2467,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { .try_into() .unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(bob2_outbound_service.call_count(), 0, "Should be no more calls"); // Test finalize is sent Direct Only. @@ -2449,7 +2488,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { let _ = alice_outbound_service.pop_call().unwrap(); let _ = alice_outbound_service.pop_call().unwrap(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls"); // Now to repeat sending so we can test the SAF send of the finalize message @@ -2520,7 +2559,7 @@ fn test_direct_vs_saf_send_of_tx_reply_and_finalize() { assert_eq!(alice_outbound_service.call_count(), 1); let _ = alice_outbound_service.pop_call(); - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); assert_eq!(alice_outbound_service.call_count(), 0, "Should be no more calls2"); } @@ -2548,7 +2587,7 @@ fn test_tx_direct_send_behaviour() { _, _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); let (_utxo, uo) = make_input(&mut OsRng, 1000000 * uT, &factories.commitment); runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); @@ -2576,12 +2615,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if !result { saf_count+=1}, _ => (), @@ -2591,7 +2631,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2619,12 +2659,13 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if !result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 @@ -2635,7 +2676,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2663,11 +2704,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut direct_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionDirectSendResult(_, result) => if *result { direct_count+=1 }, TransactionEvent::TransactionStoreForwardSendResult(_, _) => panic!("Should be no SAF messages"), @@ -2678,7 +2720,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2705,11 +2747,12 @@ fn test_tx_direct_send_behaviour() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); +tokio::pin!(delay); let mut saf_count = 0; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionStoreForwardSendResult(_, result) => if *result { saf_count+=1 }, TransactionEvent::TransactionDirectSendResult(_, result) => if *result { panic!( @@ -2720,7 +2763,7 @@ fn test_tx_direct_send_behaviour() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2852,7 +2895,7 @@ fn test_restarting_transaction_protocols() { // Test that Bob's node restarts the send protocol let (mut bob_ts, _bob_oms, _bob_outbound_service, _, _, mut bob_tx_reply, _, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, bob_oms_backend, None); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream(); runtime .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2864,18 +2907,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_reply = false; loop { - futures::select! { - event = bob_event_stream.select_next_some() => { + tokio::select! { + event = bob_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_reply = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -2886,7 +2930,7 @@ fn test_restarting_transaction_protocols() { // Test Alice's node restarts the receive protocol let (mut alice_ts, _alice_oms, _alice_outbound_service, _, _, _, mut alice_tx_finalized, _, _, _shutdown, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories, alice_backend, alice_oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) @@ -2906,18 +2950,19 @@ fn test_restarting_transaction_protocols() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(15)).fuse(); + let delay = sleep(Duration::from_secs(15)); + tokio::pin!(delay); let mut received_finalized = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(id) = (*event.unwrap()).clone() { assert_eq!(id, tx_id); received_finalized = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3046,7 +3091,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3131,11 +3176,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3145,7 +3191,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3170,11 +3216,12 @@ fn test_coinbase_monitoring_stuck_in_mempool() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3184,7 +3231,7 @@ fn test_coinbase_monitoring_stuck_in_mempool() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3215,7 +3262,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3301,11 +3348,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMinedUnconfirmed(tx_id, _) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3316,7 +3364,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3331,10 +3379,14 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); - runtime.handle().enter(|| new_mock_server.serve()); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); runtime.block_on(connectivity_mock_state.add_active_connection(connection)); @@ -3368,11 +3420,12 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3382,7 +3435,7 @@ fn test_coinbase_monitoring_with_base_node_change_and_mined() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3413,7 +3466,7 @@ fn test_coinbase_monitoring_mined_not_synced() { server_node_identity, mut rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories, backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); rpc_service_state.set_response_delay(Some(Duration::from_secs(1))); let block_height_a = 10; @@ -3499,11 +3552,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3513,7 +3567,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3538,11 +3592,12 @@ fn test_coinbase_monitoring_mined_not_synced() { println!(" {}", e) } runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let delay = sleep(Duration::from_secs(30)); + tokio::pin!(delay); let mut count = 0usize; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap() { if tx_id == &tx_id1 || tx_id == &tx_id2 { count += 1; @@ -3552,7 +3607,7 @@ fn test_coinbase_monitoring_mined_not_synced() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -3760,7 +3815,7 @@ fn test_transaction_resending() { assert_eq!(bob_reply_message.tx_id, tx_id); } - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); // See if sending a second message too soon is ignored runtime .block_on(bob_tx_sender.send(create_dummy_message( @@ -3772,7 +3827,7 @@ fn test_transaction_resending() { assert!(bob_outbound_service.wait_call_count(1, Duration::from_secs(2)).is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(bob_tx_sender.send(create_dummy_message( alice_sender_message.into(), @@ -3819,7 +3874,7 @@ fn test_transaction_resending() { .is_err()); // Wait for the cooldown to expire but before the resend period has elapsed see if a repeat illicts a reponse. - runtime.block_on(async { delay_for(Duration::from_secs(2)).await }); + runtime.block_on(async { sleep(Duration::from_secs(2)).await }); runtime .block_on(alice_tx_reply_sender.send(create_dummy_message( @@ -4143,7 +4198,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(data.tx_id, tx_id); } // Need a moment for Alice's wallet to finish writing to its database before cancelling - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); runtime.block_on(alice_ts.cancel_transaction(tx_id)).unwrap(); @@ -4193,7 +4248,7 @@ fn test_replying_to_cancelled_tx() { assert_eq!(bob_reply_message.tx_id, tx_id); // Wait for cooldown to expire - runtime.block_on(async { delay_for(Duration::from_secs(5)).await }); + runtime.block_on(async { sleep(Duration::from_secs(5)).await }); let _ = alice_outbound_service.take_calls(); @@ -4406,7 +4461,7 @@ fn test_transaction_timeout_cancellation() { ..Default::default() }), ); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let mut carol_event_stream = carol_ts.get_event_stream(); runtime .block_on(carol_tx_sender.send(create_dummy_message( @@ -4431,11 +4486,12 @@ fn test_transaction_timeout_cancellation() { assert_eq!(carol_reply_message.tx_id, tx_id); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut transaction_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(t) = &*event.unwrap() { if t == &tx_id { transaction_cancelled = true; @@ -4443,7 +4499,7 @@ fn test_transaction_timeout_cancellation() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4481,7 +4537,7 @@ fn transaction_service_tx_broadcast() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4609,11 +4665,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -4621,7 +4678,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4655,11 +4712,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_mined = true; @@ -4667,7 +4725,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4683,11 +4741,12 @@ fn transaction_service_tx_broadcast() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_received = true; @@ -4695,7 +4754,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4733,11 +4792,12 @@ fn transaction_service_tx_broadcast() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx2_cancelled = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionCancelled(tx_id) = &*event.unwrap(){ if tx_id == &tx_id2 { tx2_cancelled = true; @@ -4745,7 +4805,7 @@ fn transaction_service_tx_broadcast() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4851,15 +4911,16 @@ fn broadcast_all_completed_transactions_on_startup() { assert!(runtime.block_on(alice_ts.restart_broadcast_protocols()).is_ok()); - let mut event_stream = alice_ts.get_event_stream_fused(); + let mut event_stream = alice_ts.get_event_stream(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut found1 = false; let mut found2 = false; let mut found3 = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBroadcast(tx_id) = (*event.unwrap()).clone() { if tx_id == 1u64 { found1 = true @@ -4876,7 +4937,7 @@ fn broadcast_all_completed_transactions_on_startup() { } }, - () = delay => { + () = &mut delay => { break; }, } @@ -4916,7 +4977,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { server_node_identity, rpc_service_state, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, oms_backend, None); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream(); runtime .block_on(alice_ts.set_base_node_public_key(server_node_identity.public_key().clone())) @@ -4995,11 +5056,12 @@ fn transaction_service_tx_broadcast_with_base_node_change() { .unwrap(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx1_received = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap(){ if tx_id == &tx_id1 { tx1_received = true; @@ -5007,7 +5069,7 @@ fn transaction_service_tx_broadcast_with_base_node_change() { } } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5038,11 +5100,15 @@ fn transaction_service_tx_broadcast_with_base_node_change() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); @@ -5075,17 +5141,18 @@ fn transaction_service_tx_broadcast_with_base_node_change() { }); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_mined = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => { + tokio::select! { + event = alice_event_stream.recv() => { if let TransactionEvent::TransactionMined(_) = &*event.unwrap(){ tx_mined = true; break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -5350,11 +5417,15 @@ fn start_validation_protocol_then_broadcast_protocol_change_base_node() { let new_server = BaseNodeWalletRpcServer::new(service); let protocol_name = new_server.as_protocol_name(); - let mut new_mock_server = runtime - .handle() - .enter(|| MockRpcServer::new(new_server, new_server_node_identity.clone())); + let mut new_mock_server = { + let _enter = runtime.handle().enter(); + MockRpcServer::new(new_server, new_server_node_identity.clone()) + }; - runtime.handle().enter(|| new_mock_server.serve()); + { + let _enter = runtime.handle().enter(); + new_mock_server.serve(); + } let connection = runtime.block_on(new_mock_server.create_connection(new_server_node_identity.to_peer(), protocol_name.into())); diff --git a/base_layer/wallet/tests/transaction_service/storage.rs b/base_layer/wallet/tests/transaction_service/storage.rs index 6ea420e6a4..5573bd63e5 100644 --- a/base_layer/wallet/tests/transaction_service/storage.rs +++ b/base_layer/wallet/tests/transaction_service/storage.rs @@ -26,20 +26,24 @@ use aes_gcm::{ }; use chrono::Utc; use rand::rngs::OsRng; +use tari_crypto::{ + keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, + script, + script::{ExecutionStack, TariScript}, +}; +use tempfile::tempdir; +use tokio::runtime::Runtime; + +use tari_common_types::types::{HashDigest, PrivateKey, PublicKey}; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::{uT, MicroTari}, transaction::{OutputFeatures, Transaction}, transaction_protocol::sender::TransactionSenderMessage, - types::{CryptoFactories, HashDigest, PrivateKey, PublicKey}, + CryptoFactories, ReceiverTransactionProtocol, SenderTransactionProtocol, }; -use tari_crypto::{ - keys::{PublicKey as PublicKeyTrait, SecretKey as SecretKeyTrait}, - script, - script::{ExecutionStack, TariScript}, -}; use tari_test_utils::random; use tari_wallet::{ storage::sqlite_utilities::run_migration_and_create_sqlite_connection, @@ -56,11 +60,8 @@ use tari_wallet::{ sqlite_db::TransactionServiceSqliteDatabase, }, }; -use tempfile::tempdir; -use tokio::runtime::Runtime; - pub fn test_db_backend(backend: T) { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let mut db = TransactionDatabase::new(backend); let factories = CryptoFactories::default(); let input = create_unblinded_output( diff --git a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs index 66079613ca..d3d8c16ca2 100644 --- a/base_layer/wallet/tests/transaction_service/transaction_protocols.rs +++ b/base_layer/wallet/tests/transaction_service/transaction_protocols.rs @@ -20,14 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{ - rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, - utils::make_input, -}; use chrono::Utc; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use rand::rngs::OsRng; -use std::{sync::Arc, thread::sleep, time::Duration}; use tari_comms::{ peer_manager::PeerFeatures, protocol::rpc::{mock::MockRpcServer, NamedProtocolService, RpcStatus}, @@ -48,7 +43,7 @@ use tari_core::{ transactions::{ helpers::schema_to_transaction, tari_amount::{uT, MicroTari, T}, - types::CryptoFactories, + CryptoFactories, }, txn_schema, }; @@ -80,7 +75,13 @@ use tari_wallet::{ types::ValidationRetryStrategy, }; use tempfile::{tempdir, TempDir}; -use tokio::{sync::broadcast, task, time::delay_for}; +use tokio::{sync::broadcast, task, time::sleep}; + +use crate::support::{ + rpc::{BaseNodeWalletRpcMockService, BaseNodeWalletRpcMockState}, + utils::make_input, +}; +use std::{sync::Arc, time::Duration}; // Just in case other options become apparent in later testing #[derive(PartialEq)] @@ -230,7 +231,7 @@ pub async fn oms_reply_channel_task( } /// A happy path test by submitting a transaction into the mempool, have it mined but unconfirmed and then confirmed. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_i() { let ( @@ -245,7 +246,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _temp_dir, mut transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); let protocol = TransactionBroadcastProtocol::new( @@ -353,7 +354,8 @@ async fn tx_broadcast_protocol_submit_success_i() { .unwrap(); // lets wait for the transaction service event to notify us of a confirmed tx // We need to do this to ensure that the wallet db has been updated to "Mined" - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMined(_) => { @@ -392,13 +394,14 @@ async fn tx_broadcast_protocol_submit_success_i() { ); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(5)).fuse(); + let delay = sleep(Duration::from_secs(5)); + tokio::pin!(delay); let mut broadcast = false; let mut unconfirmed = false; let mut confirmed = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionMinedUnconfirmed(_, confirmations) => if *confirmations == 1 { unconfirmed = true; @@ -412,7 +415,7 @@ async fn tx_broadcast_protocol_submit_success_i() { _ => (), } }, - () = delay => { + () = &mut delay => { break; }, } @@ -426,7 +429,7 @@ async fn tx_broadcast_protocol_submit_success_i() { } /// Test submitting a transaction that is immediately rejected -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_rejection() { let ( @@ -441,7 +444,7 @@ async fn tx_broadcast_protocol_submit_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -478,16 +481,17 @@ async fn tx_broadcast_protocol_submit_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -498,7 +502,7 @@ async fn tx_broadcast_protocol_submit_rejection() { /// Test restarting a protocol which means the first step is a query not a submission, detecting the Tx is not in the /// mempool, resubmit the tx and then have it mined -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_restart_protocol_as_query() { let ( @@ -585,7 +589,7 @@ async fn tx_broadcast_protocol_restart_protocol_as_query() { /// This test will submit a Tx which will be accepted and then dropped from the mempool, resulting in a resubmit which /// will be rejected and result in a cancelled transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { let ( @@ -600,7 +604,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { _temp_dir, _transaction_event_receiver, ) = setup(TxProtocolTestConfig::WithConnection).await; - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); let (base_node_update_publisher, _) = broadcast::channel(20); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -666,16 +670,17 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { assert!(db_completed_tx.is_err()); // Check that the appropriate events were emitted - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut cancelled = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { cancelled = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -686,7 +691,7 @@ async fn tx_broadcast_protocol_submit_success_followed_by_rejection() { /// This test will submit a tx which is accepted and mined but unconfirmed, then the next query it will not exist /// resulting in a resubmission which we will let run to being mined with success -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { let ( @@ -751,14 +756,14 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { // Wait for the "TransactionMinedUnconfirmed" tx event to ensure that the wallet db state is "MinedUnconfirmed" let mut count = 0u16; - while let Some(v) = transaction_event_receiver.next().await { + loop { + let v = transaction_event_receiver.recv().await; let event = v.unwrap(); match (*event).clone() { TransactionEvent::TransactionMinedUnconfirmed(_, _) => { break; }, _ => { - sleep(Duration::from_millis(1000)); count += 1; if count >= 10 { break; @@ -806,7 +811,7 @@ async fn tx_broadcast_protocol_submit_mined_then_not_mined_resubmit_success() { } /// Test being unable to connect and then connection becoming available. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_connection_problem() { let ( @@ -823,7 +828,7 @@ async fn tx_broadcast_protocol_connection_problem() { ) = setup(TxProtocolTestConfig::WithoutConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database(1, 1 * T, true, None, resources.db.clone()).await; @@ -839,11 +844,12 @@ async fn tx_broadcast_protocol_connection_problem() { let join_handle = task::spawn(protocol.execute()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut connection_issues = 0; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionBaseNodeConnectionProblem(_) = &*event.unwrap() { connection_issues+=1; } @@ -851,7 +857,7 @@ async fn tx_broadcast_protocol_connection_problem() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -878,7 +884,7 @@ async fn tx_broadcast_protocol_connection_problem() { } /// Submit a transaction that is Already Mined for the submission, the subsequent query should confirm the transaction -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_already_mined() { let ( @@ -948,7 +954,7 @@ async fn tx_broadcast_protocol_submit_already_mined() { } /// A test to see that the broadcast protocol can handle a change to the base node address while it runs. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { let ( @@ -1050,7 +1056,7 @@ async fn tx_broadcast_protocol_submit_and_base_node_gets_changed() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_valid() { let ( @@ -1148,7 +1154,7 @@ async fn tx_validation_protocol_tx_becomes_valid() { } /// Validate completed transaction, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_invalid() { let ( @@ -1213,7 +1219,7 @@ async fn tx_validation_protocol_tx_becomes_invalid() { } /// Validate completed transactions, the transaction should become invalid -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_becomes_unconfirmed() { let ( @@ -1285,7 +1291,7 @@ async fn tx_validation_protocol_tx_becomes_unconfirmed() { /// Test the validation protocol reacts correctly to a change in base node and redoes the full validation based on the /// new base node -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_tx_ends_on_base_node_end() { let ( @@ -1302,7 +1308,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1398,16 +1404,17 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { let result = join_handle.await.unwrap(); assert!(result.is_ok()); - let mut delay = delay_for(Duration::from_secs(1)).fuse(); + let delay = sleep(Duration::from_secs(1)); + tokio::pin!(delay); let mut aborted = false; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { if let TransactionEvent::TransactionValidationAborted(_) = &*event.unwrap() { aborted = true; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1416,7 +1423,7 @@ async fn tx_validation_protocol_tx_ends_on_base_node_end() { } /// Test the validation protocol reacts correctly when the RPC client returns an error between calls. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_between_calls() { let ( @@ -1540,7 +1547,7 @@ async fn tx_validation_protocol_rpc_client_broken_between_calls() { /// Test the validation protocol reacts correctly when the RPC client returns an error between calls and only retry /// finite amount of times -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_rpc_client_broken_finite_retries() { let ( @@ -1557,7 +1564,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, 1 * T, @@ -1610,12 +1617,13 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { assert!(result.is_err()); // Check that the connection problem event was emitted at least twice - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut timeouts = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { log::error!("EVENT: {:?}", event); match &*event.unwrap() { TransactionEvent::TransactionValidationTimedOut(_) => { @@ -1630,7 +1638,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -1641,7 +1649,7 @@ async fn tx_validation_protocol_rpc_client_broken_finite_retries() { /// Validate completed transactions, will check that valid ones stay valid and incorrectly marked invalid tx become /// valid. -#[tokio_macros::test] +#[tokio::test] #[allow(clippy::identity_op)] async fn tx_validation_protocol_base_node_not_synced() { let ( @@ -1658,7 +1666,7 @@ async fn tx_validation_protocol_base_node_not_synced() { ) = setup(TxProtocolTestConfig::WithConnection).await; let (base_node_update_publisher, _) = broadcast::channel(20); let (_timeout_update_publisher, _) = broadcast::channel(20); - let mut event_stream = resources.event_publisher.subscribe().fuse(); + let mut event_stream = resources.event_publisher.subscribe(); add_transaction_to_database( 1, @@ -1711,12 +1719,13 @@ async fn tx_validation_protocol_base_node_not_synced() { let result = join_handle.await.unwrap(); assert!(result.is_err()); - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let delay = sleep(Duration::from_secs(10)); + tokio::pin!(delay); let mut delayed = 0i32; let mut failures = 0i32; loop { - futures::select! { - event = event_stream.select_next_some() => { + tokio::select! { + event = event_stream.recv() => { match &*event.unwrap() { TransactionEvent::TransactionValidationDelayed(_) => { delayed +=1 ; @@ -1728,7 +1737,7 @@ async fn tx_validation_protocol_base_node_not_synced() { } }, - () = delay => { + () = &mut delay => { break; }, } diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index 6aa231b09c..c1320e5409 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -20,18 +20,27 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::support::{comms_and_services::get_next_memory_address, utils::make_input}; -use tari_core::transactions::transaction::OutputFeatures; +use std::{panic, path::Path, sync::Arc, time::Duration}; use aes_gcm::{ aead::{generic_array::GenericArray, NewAead}, Aes256Gcm, }; use digest::Digest; -use futures::{FutureExt, StreamExt}; use rand::rngs::OsRng; -use std::{panic, path::Path, sync::Arc, time::Duration}; -use tari_common_types::chain_metadata::ChainMetadata; +use tari_crypto::{ + common::Blake256, + inputs, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + script, +}; +use tempfile::tempdir; +use tokio::runtime::Runtime; + +use tari_common_types::{ + chain_metadata::ChainMetadata, + types::{PrivateKey, PublicKey}, +}; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, @@ -41,13 +50,8 @@ use tari_comms_dht::DhtConfig; use tari_core::transactions::{ helpers::{create_unblinded_output, TestParams}, tari_amount::{uT, MicroTari}, - types::{CryptoFactories, PrivateKey, PublicKey}, -}; -use tari_crypto::{ - common::Blake256, - inputs, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - script, + transaction::OutputFeatures, + CryptoFactories, }; use tari_p2p::{initialization::CommsConfig, transport::TransportType, Network, DEFAULT_DNS_NAME_SERVER}; use tari_shutdown::{Shutdown, ShutdownSignal}; @@ -70,8 +74,9 @@ use tari_wallet::{ WalletConfig, WalletSqlite, }; -use tempfile::tempdir; -use tokio::{runtime::Runtime, time::delay_for}; +use tokio::time::sleep; + +use crate::support::{comms_and_services::get_next_memory_address, utils::make_input}; fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { Peer::new( @@ -163,7 +168,7 @@ async fn create_wallet( .await } -#[tokio_macros::test] +#[tokio::test] async fn test_wallet() { let mut shutdown_a = Shutdown::new(); let mut shutdown_b = Shutdown::new(); @@ -227,7 +232,7 @@ async fn test_wallet() { .await .unwrap(); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream(); let value = MicroTari::from(1000); let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); @@ -245,15 +250,16 @@ async fn test_wallet() { .await .unwrap(); - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut reply_count = false; loop { - futures::select! { - event = alice_event_stream.select_next_some() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { - reply_count = true; - break; - }, - () = delay => { + tokio::select! { + event = alice_event_stream.recv() => if let TransactionEvent::ReceivedTransactionReply(_) = &*event.unwrap() { + reply_count = true; + break; + }, + () = &mut delay => { break; }, } @@ -298,7 +304,7 @@ async fn test_wallet() { } drop(alice_event_stream); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -343,7 +349,7 @@ async fn test_wallet() { alice_wallet.remove_encryption().await.unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; let connection = @@ -379,7 +385,7 @@ async fn test_wallet() { .await .unwrap(); - shutdown_a.trigger().unwrap(); + shutdown_a.trigger(); alice_wallet.wait_until_shutdown().await; partial_wallet_backup(current_wallet_path.clone(), backup_wallet_path.clone()) @@ -400,12 +406,12 @@ async fn test_wallet() { let master_secret_key = backup_wallet_db.get_master_secret_key().await.unwrap(); assert!(master_secret_key.is_none()); - shutdown_b.trigger().unwrap(); + shutdown_b.trigger(); bob_wallet.wait_until_shutdown().await; } -#[tokio_macros::test] +#[tokio::test] async fn test_do_not_overwrite_master_key() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -423,7 +429,7 @@ async fn test_do_not_overwrite_master_key() { ) .await .unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); wallet.wait_until_shutdown().await; // try to use a new master key to create a wallet using the existing wallet database @@ -457,7 +463,7 @@ async fn test_do_not_overwrite_master_key() { .unwrap(); } -#[tokio_macros::test] +#[tokio::test] async fn test_sign_message() { let factories = CryptoFactories::default(); let dir = tempdir().unwrap(); @@ -516,9 +522,9 @@ fn test_store_and_forward_send_tx() { let bob_db_tempdir = tempdir().unwrap(); let carol_db_tempdir = tempdir().unwrap(); - let mut alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); - let mut carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let alice_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let bob_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); + let carol_runtime = Runtime::new().expect("Failed to initialize tokio runtime"); let mut alice_wallet = alice_runtime .block_on(create_wallet( @@ -554,7 +560,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); let carol_identity = (*carol_wallet.comms.node_identity()).clone(); - shutdown_c.trigger().unwrap(); + shutdown_c.trigger(); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); alice_runtime @@ -591,13 +597,13 @@ fn test_store_and_forward_send_tx() { .unwrap(); // Waiting here for a while to make sure the discovery retry is over - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); alice_runtime .block_on(alice_wallet.transaction_service.cancel_transaction(tx_id)) .unwrap(); - alice_runtime.block_on(async { delay_for(Duration::from_secs(60)).await }); + alice_runtime.block_on(async { sleep(Duration::from_secs(60)).await }); let carol_wallet = carol_runtime .block_on(create_wallet( @@ -610,7 +616,7 @@ fn test_store_and_forward_send_tx() { )) .unwrap(); - let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream_fused(); + let mut carol_event_stream = carol_wallet.transaction_service.get_event_stream(); carol_runtime .block_on(carol_wallet.comms.peer_manager().add_peer(create_peer( @@ -623,13 +629,14 @@ fn test_store_and_forward_send_tx() { .unwrap(); carol_runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let delay = sleep(Duration::from_secs(60)); + tokio::pin!(delay); let mut tx_recv = false; let mut tx_cancelled = false; loop { - futures::select! { - event = carol_event_stream.select_next_some() => { + tokio::select! { + event = carol_event_stream.recv() => { match &*event.unwrap() { TransactionEvent::ReceivedTransaction(_) => tx_recv = true, TransactionEvent::TransactionCancelled(_) => tx_cancelled = true, @@ -639,7 +646,7 @@ fn test_store_and_forward_send_tx() { break; } }, - () = delay => { + () = &mut delay => { break; }, } @@ -647,15 +654,15 @@ fn test_store_and_forward_send_tx() { assert!(tx_recv, "Must have received a tx from alice"); assert!(tx_cancelled, "Must have received a cancel tx from alice"); }); - shutdown_a.trigger().unwrap(); - shutdown_b.trigger().unwrap(); - shutdown_c2.trigger().unwrap(); + shutdown_a.trigger(); + shutdown_b.trigger(); + shutdown_c2.trigger(); alice_runtime.block_on(alice_wallet.wait_until_shutdown()); bob_runtime.block_on(bob_wallet.wait_until_shutdown()); carol_runtime.block_on(carol_wallet.wait_until_shutdown()); } -#[tokio_macros::test] +#[tokio::test] async fn test_import_utxo() { let shutdown = Shutdown::new(); let factories = CryptoFactories::default(); diff --git a/base_layer/wallet_ffi/Cargo.toml b/base_layer/wallet_ffi/Cargo.toml index d7381c0d8e..b5d719981a 100644 --- a/base_layer/wallet_ffi/Cargo.toml +++ b/base_layer/wallet_ffi/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" [dependencies] tari_comms = { version = "^0.9", path = "../../comms", default-features = false} tari_comms_dht = { version = "^0.9", path = "../../comms/dht", default-features = false } +tari_common_types = {path="../common_types"} tari_crypto = "0.11.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_p2p = { version = "^0.9", path = "../p2p" } @@ -17,11 +18,11 @@ tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } tari_utilities = "^0.3" futures = { version = "^0.3.1", features =["compat", "std"]} -tokio = "0.2.10" +tokio = "1.10.1" libc = "0.2.65" rand = "0.8" chrono = { version = "0.4.6", features = ["serde"]} -thiserror = "1.0.20" +thiserror = "1.0.26" log = "0.4.6" log4rs = {version = "1.0.0", features = ["console_appender", "file_appender", "yaml_format"]} @@ -41,4 +42,3 @@ env_logger = "0.7.1" tari_key_manager = { version = "^0.9", path = "../key_manager" } tari_common_types = { version = "^0.9", path = "../../base_layer/common_types"} tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} -tokio = { version="0.2.10" } diff --git a/base_layer/wallet_ffi/src/callback_handler.rs b/base_layer/wallet_ffi/src/callback_handler.rs index 6d00799f23..c5af01c7b6 100644 --- a/base_layer/wallet_ffi/src/callback_handler.rs +++ b/base_layer/wallet_ffi/src/callback_handler.rs @@ -48,7 +48,6 @@ //! request_key is used to identify which request this callback references and a result of true means it was successful //! and false that the process timed out and new one will be started -use futures::{stream::Fuse, StreamExt}; use log::*; use tari_comms::types::CommsPublicKey; use tari_comms_dht::event::{DhtEvent, DhtEventReceiver}; @@ -96,9 +95,9 @@ where TBackend: TransactionBackend + 'static callback_transaction_validation_complete: unsafe extern "C" fn(u64, u8), callback_saf_messages_received: unsafe extern "C" fn(), db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: Option, comms_public_key: CommsPublicKey, } @@ -109,9 +108,9 @@ where TBackend: TransactionBackend + 'static { pub fn new( db: TransactionDatabase, - transaction_service_event_stream: Fuse, - output_manager_service_event_stream: Fuse, - dht_event_stream: Fuse, + transaction_service_event_stream: TransactionEventReceiver, + output_manager_service_event_stream: OutputManagerEventReceiver, + dht_event_stream: DhtEventReceiver, shutdown_signal: ShutdownSignal, comms_public_key: CommsPublicKey, callback_received_transaction: unsafe extern "C" fn(*mut InboundTransaction), @@ -219,8 +218,8 @@ where TBackend: TransactionBackend + 'static info!(target: LOG_TARGET, "Transaction Service Callback Handler starting"); loop { - futures::select! { - result = self.transaction_service_event_stream.select_next_some() => { + tokio::select! { + result = self.transaction_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Transaction Service Callback Handler event {:?}", msg); @@ -271,7 +270,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Transaction Service event broadcast channel"), } }, - result = self.output_manager_service_event_stream.select_next_some() => { + result = self.output_manager_service_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "Output Manager Service Callback Handler event {:?}", msg); @@ -295,7 +294,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from Output Manager Service event broadcast channel"), } }, - result = self.dht_event_stream.select_next_some() => { + result = self.dht_event_stream.recv() => { match result { Ok(msg) => { trace!(target: LOG_TARGET, "DHT Callback Handler event {:?}", msg); @@ -306,11 +305,7 @@ where TBackend: TransactionBackend + 'static Err(_e) => error!(target: LOG_TARGET, "Error reading from DHT event broadcast channel"), } } - complete => { - info!(target: LOG_TARGET, "Callback Handler is exiting because all tasks have completed"); - break; - }, - _ = shutdown_signal => { + _ = shutdown_signal.wait() => { info!(target: LOG_TARGET, "Transaction Callback Handler shutting down because the shutdown signal was received"); break; }, @@ -585,18 +580,17 @@ where TBackend: TransactionBackend + 'static mod test { use crate::callback_handler::CallbackHandler; use chrono::Utc; - use futures::StreamExt; use rand::rngs::OsRng; use std::{ sync::{Arc, Mutex}, thread, time::Duration, }; + use tari_common_types::types::{BlindingFactor, PrivateKey, PublicKey}; use tari_comms_dht::event::DhtEvent; use tari_core::transactions::{ tari_amount::{uT, MicroTari}, transaction::Transaction, - types::{BlindingFactor, PrivateKey, PublicKey}, ReceiverTransactionProtocol, SenderTransactionProtocol, }; @@ -774,7 +768,7 @@ mod test { #[test] fn test_callback_handler() { - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let (_wallet_backend, backend, _oms_backend, _, _tempdir) = make_wallet_databases(None); let db = TransactionDatabase::new(backend); @@ -854,9 +848,9 @@ mod test { let shutdown_signal = Shutdown::new(); let callback_handler = CallbackHandler::new( db, - tx_receiver.fuse(), - oms_receiver.fuse(), - dht_receiver.fuse(), + tx_receiver, + oms_receiver, + dht_receiver, shutdown_signal.to_signal(), PublicKey::from_secret_key(&PrivateKey::random(&mut OsRng)), received_tx_callback, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index d8b01d9b69..1f85e33d06 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -107,20 +107,18 @@ #[cfg(test)] #[macro_use] extern crate lazy_static; -mod callback_handler; -mod enums; -mod error; -mod tasks; -use crate::{ - callback_handler::CallbackHandler, - enums::SeedWordPushResult, - error::{InterfaceError, TransactionError}, - tasks::recovery_event_monitoring, -}; use core::ptr; -use error::LibWalletError; -use futures::StreamExt; +use std::{ + boxed::Box, + ffi::{CStr, CString}, + path::PathBuf, + slice, + str::FromStr, + sync::Arc, + time::Duration, +}; + use libc::{c_char, c_int, c_longlong, c_uchar, c_uint, c_ulonglong, c_ushort}; use log::{LevelFilter, *}; use log4rs::{ @@ -136,14 +134,19 @@ use log4rs::{ encode::pattern::PatternEncoder, }; use rand::rngs::OsRng; -use std::{ - boxed::Box, - ffi::{CStr, CString}, - path::PathBuf, - slice, - str::FromStr, - sync::Arc, - time::Duration, +use tari_crypto::{ + inputs, + keys::{PublicKey as PublicKeyTrait, SecretKey}, + script, + tari_utilities::ByteArray, +}; +use tari_utilities::{hex, hex::Hex}; +use tokio::runtime::Runtime; + +use error::LibWalletError; +use tari_common_types::{ + emoji::{emoji_set, EmojiId, EmojiIdError}, + types::{ComSignature, PublicKey}, }; use tari_comms::{ multiaddr::Multiaddr, @@ -154,23 +157,12 @@ use tari_comms::{ types::CommsSecretKey, }; use tari_comms_dht::{DbConnectionUrl, DhtConfig}; -use tari_core::transactions::{ - tari_amount::MicroTari, - transaction::OutputFeatures, - types::{ComSignature, CryptoFactories, PublicKey}, -}; -use tari_crypto::{ - inputs, - keys::{PublicKey as PublicKeyTrait, SecretKey}, - script, - tari_utilities::ByteArray, -}; +use tari_core::transactions::{tari_amount::MicroTari, transaction::OutputFeatures, CryptoFactories}; use tari_p2p::{ transport::{TorConfig, TransportType, TransportType::Tor}, Network, }; use tari_shutdown::Shutdown; -use tari_utilities::{hex, hex::Hex}; use tari_wallet::{ contacts_service::storage::database::Contact, error::{WalletError, WalletStorageError}, @@ -195,13 +187,23 @@ use tari_wallet::{ }, }, types::ValidationRetryStrategy, - util::emoji::{emoji_set, EmojiId, EmojiIdError}, utxo_scanner_service::utxo_scanning::{UtxoScannerService, RECOVERY_KEY}, Wallet, WalletConfig, WalletSqlite, }; -use tokio::runtime::Runtime; + +use crate::{ + callback_handler::CallbackHandler, + enums::SeedWordPushResult, + error::{InterfaceError, TransactionError}, + tasks::recovery_event_monitoring, +}; + +mod callback_handler; +mod enums; +mod error; +mod tasks; const LOG_TARGET: &str = "wallet_ffi"; @@ -209,7 +211,7 @@ pub type TariTransportType = tari_p2p::transport::TransportType; pub type TariPublicKey = tari_comms::types::CommsPublicKey; pub type TariPrivateKey = tari_comms::types::CommsSecretKey; pub type TariCommsConfig = tari_p2p::initialization::CommsConfig; -pub type TariExcess = tari_core::transactions::types::Commitment; +pub type TariExcess = tari_common_types::types::Commitment; pub type TariExcessPublicNonce = tari_crypto::ristretto::RistrettoPublicKey; pub type TariExcessSignature = tari_crypto::ristretto::RistrettoSecretKey; @@ -917,14 +919,14 @@ pub unsafe extern "C" fn seed_words_push_word( (*seed_words).0.push(word_string); if (*seed_words).0.len() >= 24 { - if let Err(e) = TariPrivateKey::from_mnemonic(&(*seed_words).0) { + return if let Err(e) = TariPrivateKey::from_mnemonic(&(*seed_words).0) { log::error!(target: LOG_TARGET, "Problem building private key from seed phrase"); error = LibWalletError::from(e).code; ptr::swap(error_out, &mut error as *mut c_int); - return SeedWordPushResult::InvalidSeedPhrase as u8; + SeedWordPushResult::InvalidSeedPhrase as u8 } else { - return SeedWordPushResult::SeedPhraseComplete as u8; - } + SeedWordPushResult::SeedPhraseComplete as u8 + }; } SeedWordPushResult::SuccessfulPush as u8 @@ -2858,7 +2860,7 @@ pub unsafe extern "C" fn wallet_create( match TariPrivateKey::from_mnemonic(&(*seed_words).0) { Ok(private_key) => Some(private_key), Err(e) => { - error!(target: LOG_TARGET, "Mnemonic Error for given seed words: {}", e); + error!(target: LOG_TARGET, "Mnemonic Error for given seed words: {:?}", e); error = LibWalletError::from(e).code; ptr::swap(error_out, &mut error as *mut c_int); return ptr::null_mut(); @@ -2866,7 +2868,7 @@ pub unsafe extern "C" fn wallet_create( } }; - let mut runtime = match Runtime::new() { + let runtime = match Runtime::new() { Ok(r) => r, Err(e) => { error = LibWalletError::from(InterfaceError::TokioError(e.to_string())).code; @@ -2947,15 +2949,15 @@ pub unsafe extern "C" fn wallet_create( // lets ensure the wallet tor_id is saved, this could have been changed during wallet startup if let Some(hs) = w.comms.hidden_service() { if let Err(e) = runtime.block_on(w.db.set_tor_identity(hs.tor_identity().clone())) { - warn!(target: LOG_TARGET, "Could not save tor identity to db: {}", e); + warn!(target: LOG_TARGET, "Could not save tor identity to db: {:?}", e); } } // Start Callback Handler let callback_handler = CallbackHandler::new( TransactionDatabase::new(transaction_backend), - w.transaction_service.get_event_stream_fused(), - w.output_manager_service.get_event_stream_fused(), - w.dht_service.subscribe_dht_events().fuse(), + w.transaction_service.get_event_stream(), + w.output_manager_service.get_event_stream(), + w.dht_service.subscribe_dht_events(), w.comms.shutdown_signal(), w.comms.node_identity().public_key().clone(), callback_received_transaction, @@ -5154,7 +5156,7 @@ pub unsafe extern "C" fn file_partial_backup( let runtime = Runtime::new(); match runtime { - Ok(mut runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { + Ok(runtime) => match runtime.block_on(partial_wallet_backup(original_path, backup_path)) { Ok(_) => (), Err(e) => { error = LibWalletError::from(WalletError::WalletStorageError(e)).code; @@ -5281,10 +5283,8 @@ pub unsafe extern "C" fn emoji_set_destroy(emoji_set: *mut EmojiSet) { pub unsafe extern "C" fn wallet_destroy(wallet: *mut TariWallet) { if !wallet.is_null() { let mut w = Box::from_raw(wallet); - match w.shutdown.trigger() { - Err(_) => error!(target: LOG_TARGET, "No listeners for the shutdown signal!"), - Ok(()) => w.runtime.block_on(w.wallet.wait_until_shutdown()), - } + w.shutdown.trigger(); + w.runtime.block_on(w.wallet.wait_until_shutdown()); } } @@ -5306,21 +5306,24 @@ pub unsafe extern "C" fn log_debug_message(msg: *const c_char) { #[cfg(test)] mod test { - use crate::*; - use libc::{c_char, c_uchar, c_uint}; use std::{ ffi::CString, path::Path, str::{from_utf8, FromStr}, sync::Mutex, }; + + use libc::{c_char, c_uchar, c_uint}; + use tempfile::tempdir; + + use tari_common_types::emoji; use tari_test_utils::random; use tari_wallet::{ storage::sqlite_utilities::run_migration_and_create_sqlite_connection, transaction_service::storage::models::TransactionStatus, - util::emoji, }; - use tempfile::tempdir; + + use crate::*; fn type_of(_: T) -> String { std::any::type_name::().to_string() @@ -5781,7 +5784,7 @@ mod test { error_ptr, ); - let mut runtime = Runtime::new().unwrap(); + let runtime = Runtime::new().unwrap(); let connection = run_migration_and_create_sqlite_connection(&sql_database_path).expect("Could not open Sqlite db"); diff --git a/base_layer/wallet_ffi/src/tasks.rs b/base_layer/wallet_ffi/src/tasks.rs index 9c44c94106..9e67eaa091 100644 --- a/base_layer/wallet_ffi/src/tasks.rs +++ b/base_layer/wallet_ffi/src/tasks.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::StreamExt; use log::*; use tari_crypto::tari_utilities::hex::Hex; use tari_wallet::{error::WalletError, utxo_scanner_service::handle::UtxoScannerEvent}; @@ -44,8 +43,8 @@ pub async fn recovery_event_monitoring( recovery_join_handle: JoinHandle>, recovery_progress_callback: unsafe extern "C" fn(u8, u64, u64), ) { - while let Some(event) = event_stream.next().await { - match event { + loop { + match event_stream.recv().await { Ok(UtxoScannerEvent::ConnectingToBaseNode(peer)) => { unsafe { (recovery_progress_callback)(RecoveryEvent::ConnectingToBaseNode as u8, 0u64, 0u64); @@ -139,6 +138,9 @@ pub async fn recovery_event_monitoring( } warn!(target: LOG_TARGET, "UTXO Scanner failed and exited",); }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, Err(e) => { // Event lagging warn!(target: LOG_TARGET, "{}", e); diff --git a/common/Cargo.toml b/common/Cargo.toml index aa9d646654..5998d0bdfd 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,7 +22,7 @@ dirs-next = "1.0.2" get_if_addrs = "0.5.3" log = "0.4.8" log4rs = { version = "1.0.0", default_features= false, features = ["config_parsing", "threshold_filter"]} -multiaddr={package="parity-multiaddr", version = "0.11.0"} +multiaddr={version = "0.13.0"} sha2 = "0.9.5" path-clean = "0.1.0" tari_storage = { version = "^0.9", path = "../infrastructure/storage"} @@ -36,7 +36,7 @@ opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} anyhow = { version = "1.0", optional = true } git2 = { version = "0.8", optional = true } -prost-build = { version = "0.6.1", optional = true } +prost-build = { version = "0.8.0", optional = true } toml = { version = "0.5", optional = true } [dev-dependencies] diff --git a/common/config/presets/tari_igor_config.toml b/common/config/presets/tari_igor_config.toml new file mode 100644 index 0000000000..efb6eedf47 --- /dev/null +++ b/common/config/presets/tari_igor_config.toml @@ -0,0 +1,535 @@ +######################################################################################################################## +# # +# The Tari Network Configuration File # +# # +######################################################################################################################## + +# This file carries all the configuration options for running Tari-related nodes and infrastructure in one single +# file. As you'll notice, almost all configuraton options are commented out. This is because they are either not +# needed, are for advanced users that know what they want to tweak, or are already set at their default values. If +# things are working fine, then there's no need to change anything here. +# +# Each major section is clearly marked so that you can quickly find the section you're looking for. This first +# section holds configuration options that are common to all sections. + +# A note about Logging - The logger is initialised before the configuration file is loaded. For this reason, logging +# is not configured here, but in `~/.tari/log4rs.yml` (*nix / OsX) or `%USERPROFILE%\.tari\log4rs.yml` (Windows) by +# default, or the location specified in the TARI_LOGFILE environment variable. + +[common] +# Select the network to connect to. Valid options are: +# mainnet - the "real" Tari network (default) +# igor - the Second Tari test net +network = "igor" + +# Tari is a 100% peer-to-peer network, so there are no servers to hold messages for you while you're offline. +# Instead, we rely on our peers to hold messages for us while we're offline. This settings sets maximum size of the +# message cache that for holding our peers' messages, in MB. +#message_cache_size = 10 + +# When storing messages for peers, hold onto them for at most this long before discarding them. The default is 1440 +# minutes = or 24 hrs. +#message_cache_ttl = 1440 + +# If peer nodes spam you with messages, or are otherwise badly behaved, they will be added to your denylist and banned +# You can set a time limit to release that ban (in minutes), or otherwise ban them for life (-1). The default is to +# ban them for 10 days. +#denylist_ban_period = 1440 + +# The number of liveness sessions to allow. Liveness sessions can be established by liveness monitors over TCP by +# sending a 0x50 (P) as the first byte. Any messages sent must be followed by newline message no longer than +# 50 characters. That message will be echoed back. +#liveness_max_sessions = 0 +#liveness_allowlist_cidrs = ["127.0.0.1/32"] + +# The buffer size constants for the publish/subscribe connector channel, connecting comms messages to the domain layer: +# - Buffer size for the base node (min value = 30, default value = 1500). +#buffer_size_base_node = 1500 +# - Buffer size for the console wallet (min value = 300, default value = 50000). +#buffer_size_console_wallet = 50000 +# The rate limit constants for the publish/subscribe connector channel, i.e. maximum amount of inbound messages to +# accept - any rate attemting to exceed this limit will be throttled. +# - Rate limit for the base node (min value = 5, default value = 1000). +#buffer_rate_limit_base_node = 1000 +# - Rate limit for the console wallet (min value = 5, default value = 1000). +buffer_rate_limit_console_wallet = 1000 +# The message deduplication persistent cache size - messages with these hashes in the cache will only be processed once. +# The cache will also be trimmed down to size periodically (min value = 0, default value = 2500). +dedup_cache_capacity = 25000 + +# The timeout (s) for requesting blocks from a peer during blockchain sync (min value = 10 s, default value = 150 s). +#fetch_blocks_timeout = 150 + +# The timeout (s) for requesting UTXOs from a base node (min value = 10 s, default value = 600 s). +#fetch_utxos_timeout = 600 + +# The timeout (s) for requesting other base node services (min value = 10 s, default value = 180 s). +#service_request_timeout = 180 + +# The maximum simultaneous comms RPC sessions allowed (default value = 1000). Setting this to -1 will allow unlimited +# sessions. +rpc_max_simultaneous_sessions = 10000 + +# Auto Update +# +# This interval in seconds to check for software updates. Setting this to 0 disables checking. +# auto_update.check_interval = 300 +# Customize the hosts that are used to check for updates. These hosts must contain update information in DNS TXT records. +# auto_update.dns_hosts = ["updates.tari.com"] +# Customize the location of the update SHA hashes and maintainer-signed signature. +# auto_update.hashes_url = "https://.../hashes.txt" +# auto_update.hashes_sig_url = "https://.../hashes.txt.sig" + +######################################################################################################################## +# # +# Wallet Configuration Options # +# # +######################################################################################################################## + +# If you are not running a wallet from this configuration, you can simply leave everything in this section commented out + +[wallet] +# Override common.network for wallet +# network = "igor" + +# The relative folder to store your local key data and transaction history. DO NOT EVER DELETE THIS FILE unless you +# a) have backed up your seed phrase and +# b) know what you are doing! +wallet_db_file = "wallet/wallet.dat" +console_wallet_db_file = "wallet/console-wallet.dat" + +# Console wallet password +# Should you wish to start your console wallet without typing in your password, the following options are available: +# 1. Start the console wallet with the --password=secret argument, or +# 2. Set the environment variable TARI_WALLET_PASSWORD=secret before starting the console wallet, or +# 3. Set the "password" key in this [wallet] section of the config +# password = "secret" + +# WalletNotify +# Allows you to execute a script or program when these transaction events are received by the console wallet: +# - transaction received +# - transaction sent +# - transaction cancelled +# - transaction mined but unconfirmed +# - transaction mined and confirmed +# An example script is available here: applications/tari_console_wallet/src/notifier/notify_example.sh +# notify = "/path/to/script" + +# This is the timeout period that will be used to monitor TXO queries to the base node (default = 60). Larger values +# are needed for wallets with many (>1000) TXOs to be validated. +base_node_query_timeout = 180 +# The amount of seconds added to the current time (Utc) which will then be used to check if the message has +# expired or not when processing the message (default = 10800). +#saf_expiry_duration = 10800 +# This is the number of block confirmations required for a transaction to be considered completely mined and +# confirmed. (default = 3) +#transaction_num_confirmations_required = 3 +# This is the timeout period that will be used for base node broadcast monitoring tasks (default = 60) +transaction_broadcast_monitoring_timeout = 180 +# This is the timeout period that will be used for chain monitoring tasks (default = 60) +#transaction_chain_monitoring_timeout = 60 +# This is the timeout period that will be used for sending transactions directly (default = 20) +transaction_direct_send_timeout = 180 +# This is the timeout period that will be used for sending transactions via broadcast mode (default = 60) +transaction_broadcast_send_timeout = 180 +# This is the size of the event channel used to communicate transaction status events to the wallet's UI. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>10000) (default = 1000). +transaction_event_channel_size = 25000 +# This is the size of the event channel used to communicate base node events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>3000) (default = 250). +base_node_event_channel_size = 3500 +# This is the size of the event channel used to communicate output manager events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>3000) (default = 250). +output_manager_event_channel_size = 3500 +# This is the size of the event channel used to communicate base node update events to the wallet. A busy console +# wallet doing thousands of bulk payments or used for stress testing needs a fairly big size (>300) (default = 50). +base_node_update_publisher_channel_size = 500 +# If a large amount of tiny valued uT UTXOs are used as inputs to a transaction, the fee may be larger than +# the transaction amount. Set this value to `false` to allow spending of "dust" UTXOs for small valued +# transactions (default = true). +#prevent_fee_gt_amount = false +# This option specifies the transaction routing mechanism as being directly between wallets, making +# use of store and forward or using any combination of these. +# (options: "DirectOnly", "StoreAndForwardOnly", DirectAndStoreAndForward". default: "DirectAndStoreAndForward"). +#transaction_routing_mechanism = "DirectAndStoreAndForward" + +# UTXO scanning service interval (default = 12 hours, i.e. 60 * 60 * 12 seconds) +scan_for_utxo_interval = 180 + +# When running the console wallet in command mode, use these values to determine what "stage" and timeout to wait +# for sent transactions. +# The stages are: +# - "DirectSendOrSaf" - The transaction was initiated and was accepted via Direct Send or Store And Forward. +# - "Negotiated" - The recipient replied and the transaction was negotiated. +# - "Broadcast" - The transaction was broadcast to the base node mempool. +# - "MinedUnconfirmed" - The transaction was successfully detected as mined but unconfirmed on the blockchain. +# - "Mined" - The transaction was successfully detected as mined and confirmed on the blockchain. + +# The default values are: "Broadcast", 300 +#command_send_wait_stage = "Broadcast" +#command_send_wait_timeout = 300 + +# The base nodes that the wallet should use for service requests and tracking chain state. +# base_node_service_peers = ["public_key::net_address", ...] +# base_node_service_peers = ["e856839057aac496b9e25f10821116d02b58f20129e9b9ba681b830568e47c4d::/onion3/exe2zgehnw3tvrbef3ep6taiacr6sdyeb54be2s25fpru357r4skhtad:18141"] + +# Configuration for the wallet's base node service +# The refresh interval, defaults to 10 seconds +base_node_service_refresh_interval = 30 +# The maximum age of service requests in seconds, requests older than this are discarded +base_node_service_request_max_age = 180 + +#[base_node.transport.tor] +#control_address = "/ip4/127.0.0.1/tcp/9051" +#control_auth_type = "none" # or "password" +# Required for control_auth_type = "password" +#control_auth_password = "super-secure-password" + +# Wallet configuration options for testnet +[wallet.igor] +# -------------- Transport configuration -------------- +# Use TCP to connect to the Tari network. This transport can only communicate with TCP/IP addresses, so peers with +# e.g. tor onion addresses will not be contactable. +#transport = "tcp" +# The address and port to listen for peer connections over TCP. +#tcp_listener_address = "/ip4/0.0.0.0/tcp/18188" +# Configures a tor proxy used to connect to onion addresses. All other traffic uses direct TCP connections. +# This setting is optional however, if it is not specified, this node will not be able to connect to nodes that +# only advertise an onion address. +#tcp_tor_socks_address = "/ip4/127.0.0.1/tcp/36050" +#tcp_tor_socks_auth = "none" + +# Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, +# onion v2, onion v3 and dns addresses. +transport = "tor" +# Address of the tor control server +tor_control_address = "/ip4/127.0.0.1/tcp/9051" +# Authentication to use for the tor control server +tor_control_auth = "none" # or "password=xxxxxx" +# The onion port to use. +#tor_onion_port = 18141 +# The address to which traffic on the node's onion address will be forwarded +# tor_forward_address = "/ip4/127.0.0.1/tcp/0" +# Instead of attemping to get the SOCKS5 address from the tor control port, use this one. The default is to +# use the first address returned by the tor control port (GETINFO /net/listeners/socks). +#tor_socks_address_override= + +# Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. +#transport = "socks5" +# The address of the SOCKS5 proxy +#socks5_proxy_address = "/ip4/127.0.0.1/tcp/9050" +# The address to which traffic will be forwarded +#socks5_listener_address = "/ip4/127.0.0.1/tcp/18188" +#socks5_auth = "none" # or "username_password=username:xxxxxxx" + +# Optionally bind an additional TCP socket for inbound Tari P2P protocol commms. +# Use cases include: +# - allowing wallets to locally connect to their base node, rather than through tor, when used in conjunction with `tor_proxy_bypass_addresses` +# - multiple P2P addresses, one public over DNS and one private over TOR +# - a "bridge" between TOR and TCP-only nodes +# auxilary_tcp_listener_address = "/ip4/127.0.0.1/tcp/9998" + +# When these addresses are encountered when dialing another peer, the tor proxy is bypassed and the connection is made +# direcly over TCP. /ip4, /ip6, /dns, /dns4 and /dns6 are supported. +# tor_proxy_bypass_addresses = ["/dns4/my-foo-base-node/tcp/9998"] + +######################################################################################################################## +# # +# Base Node Configuration Options # +# # +######################################################################################################################## + +# If you are not running a Tari Base node, you can simply leave everything in this section commented out. Base nodes +# help maintain the security of the Tari token and are the surest way to preserve your privacy and be 100% sure that +# no-one is cheating you out of your money. + +[base_node] +# Override common.network for base node +# network = "igor" + +# Configuration options for testnet +[base_node.igor] +# The type of database backend to use. Currently supported options are "memory" and "lmdb". LMDB is recommnded for +# almost all use cases. +db_type = "lmdb" + +# db config defaults +# db_init_size_mb = 1000 +# db_grow_size_mb = 500 +# db_resize_threshold_mb = 100 + +# The maximum number of orphans that can be stored in the Orphan block pool. Default value is "720". +# orphan_storage_capacity = 720 +# The size that the orphan pool will be allowed to grow before it is cleaned out, with threshold being tested every +# time before fetch and add blocks. Default value is "0", which indicates the orphan pool will not be cleaned out. +#orphan_db_clean_out_threshold = 0 +# The pruning horizon that indicates how many full blocks without pruning must be kept by the base node. Default value +# is "0", which indicates an archival node without any pruning. +#pruning_horizon = 0 + +# The amount of messages that will be permitted in the flood ban timespan of 100s (Default igor = 1000, +# default mainnet = 10000) +flood_ban_max_msg_count = 10000 + +# The relative path to store persistent data +data_dir = "igor" + +# When first logging onto the Tari network, you need to find a few peers to bootstrap the process. In the absence of +# any servers, this is a little more challenging than usual. Our best strategy is just to try and connect to the peers +# you knew about last time you ran the software. But what about when you run the software for the first time? That's +# where this allowlist comes in. It's a list of known Tari nodes that are likely to be around for a long time and that +# new nodes can use to introduce themselves to the network. +# peer_seeds = ["public_key1::address1", "public_key2::address2",... ] +peer_seeds = [ + "8e7eb81e512f3d6347bf9b1ca9cd67d2c8e29f2836fc5bd608206505cc72af34::/onion3/l4wouomx42nezhzexjdzfh7pcou5l7df24ggmwgekuih7tkv2rsaokqd:18141", + "00b35047a341401bcd336b2a3d564280a72f6dc72ec4c739d30c502acce4e803::/onion3/ojhxd7z6ga7qrvjlr3px66u7eiwasmffnuklscbh5o7g6wrbysj45vid:18141", + "40a9d8573745072534bce7d0ecafe882b1c79570375a69841c08a98dee9ecb5f::/onion3/io37fylc2pupg4cte4siqlsmuszkeythgjsxs2i3prm6jyz2dtophaad:18141", + "126c7ee64f71aca36398b977dd31fbbe9f9dad615df96473fb655bef5709c540::/onion3/6ilmgndocop7ybgmcvivbdsetzr5ggj4hhsivievoa2dx2b43wqlrlid:18141", +] + +# This allowlist provides a method to force syncing from any known nodes you may choose, for example if you have a +# couple of nodes that you always want to have in sync. +# force_sync_peers = ["public_key1::address1", "public_key2::address2",... ] +force_sync_peers = [ + #my known peer 1 + #"public_key1::address1", + #my known peer 2 + #"public_key1::address1", +] + +# DNS seeds +# The DNS records in these hostnames should provide TXT records as per https://github.com/tari-project/tari/pull/2319 +# Enter a domain name for the TXT records: seeds.tari.com +dns_seeds =["seeds.igor.tari.com"] +# The name server used to resolve DNS seeds (Default: "1.1.1.1:53") +# dns_seeds_name_server = "1.1.1.1:53" +# Set to true to only accept DNS records that pass DNSSEC validation (Default: true) +dns_seeds_use_dnssec = false + +# Determines the method of syncing blocks when the node is lagging. If you are not struggling with syncing, then +# it is recommended to leave this setting as it. Available values are ViaBestChainMetadata and ViaRandomPeer. +#block_sync_strategy="ViaBestChainMetadata" + +# Configure the maximum number of threads available for base node operation. These threads are spawned lazily, so a higher +# number is recommended. +# max_threads = 512 + +# The number of threads to spawn and keep active at all times. The default is the number of cores available on this node. +# core_threads = + +# The node's publicly-accessible hostname. This is the host name that is advertised on the network so that +# peers can find you. +# _NOTE_: If using the `tor` transport type, public_address will be ignored and an onion address will be +# automatically configured +#public_address = "/ip4/172.2.3.4/tcp/18189" + +# do we allow test addresses to be accpted like 127.0.0.1 +allow_test_addresses = false + +# Enable the gRPC server for the base node. Set this to true if you want to enable third-party wallet software +grpc_enabled = true +# The socket to expose for the gRPC base node server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +grpc_base_node_address = "127.0.0.1:18142" +# The socket to expose for the gRPC wallet server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +grpc_console_wallet_address = "127.0.0.1:18143" + +# A path to the file that stores your node identity and secret key +base_node_identity_file = "config/base_node_id.json" + +# A path to the file that stores your console wallet's node identity and secret key +console_wallet_identity_file = "config/console_wallet_id.json" + +# -------------- Transport configuration -------------- +# Use TCP to connect to the Tari network. This transport can only communicate with TCP/IP addresses, so peers with +# e.g. tor onion addresses will not be contactable. +#transport = "tcp" +# The address and port to listen for peer connections over TCP. +#tcp_listener_address = "/ip4/0.0.0.0/tcp/18189" +# Configures a tor proxy used to connect to onion addresses. All other traffic uses direct TCP connections. +# This setting is optional however, if it is not specified, this node will not be able to connect to nodes that +# only advertise an onion address. +#tcp_tor_socks_address = "/ip4/127.0.0.1/tcp/36050" +#tcp_tor_socks_auth = "none" + +# Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, +# onion v2, onion v3 and dns addresses. +transport = "tor" +# Address of the tor control server +tor_control_address = "/ip4/127.0.0.1/tcp/9051" +# Authentication to use for the tor control server +tor_control_auth = "none" # or "password=xxxxxx" +# The onion port to use. +#tor_onion_port = 18141 +# The address to which traffic on the node's onion address will be forwarded +# tor_forward_address = "/ip4/127.0.0.1/tcp/0" +# Instead of attemping to get the SOCKS5 address from the tor control port, use this one. The default is to +# use the first address returned by the tor control port (GETINFO /net/listeners/socks). +#tor_socks_address_override= + +# Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. +#transport = "socks5" +# The address of the SOCKS5 proxy +#socks5_proxy_address = "/ip4/127.0.0.1/tcp/9050" +# The address to which traffic will be forwarded +#socks5_listener_address = "/ip4/127.0.0.1/tcp/18189" +#socks5_auth = "none" # or "username_password=username:xxxxxxx" + +# A path to the file that stores the tor hidden service private key, if using the tor transport. +base_node_tor_identity_file = "config/base_node_tor.json" + +# A path to the file that stores the console wallet's tor hidden service private key, if using the tor transport. +console_wallet_tor_identity_file = "config/console_wallet_tor.json" + +# Optionally bind an additional TCP socket for inbound Tari P2P protocol commms. +# Use cases include: +# - allowing wallets to locally connect to their base node, rather than through tor, when used in conjunction with `tor_proxy_bypass_addresses` +# - multiple P2P addresses, one public over DNS and one private over TOR +# - a "bridge" between TOR and TCP-only nodes +# auxilary_tcp_listener_address = "/ip4/127.0.0.1/tcp/9998" + +# When these addresses are encountered when dialing another peer, the tor proxy is bypassed and the connection is made +# direcly over TCP. /ip4, /ip6, /dns, /dns4 and /dns6 are supported. +# tor_proxy_bypass_addresses = ["/dns4/my-foo-base-node/tcp/9998"] + +######################################################################################################################## +# # +# Mempool Configuration Options # +# # +######################################################################################################################## +[mempool.igor] + +# The maximum number of transactions that can be stored in the Unconfirmed Transaction pool. This is the main waiting +# area in the mempool and almost all transactions will end up in this pool before being mined. It's for this reason +# that this parameter will have the greatest impact on actual memory usage by your mempool. If you are not mining, +# you can reduce this parameter to reduce memory consumption by your node, at the expense of network bandwith. For +# reference, a single block can hold about 4,000 transactions +# Default = 40,000 transactions +# unconfirmed_pool_storage_capacity = 40000 + +# The maximum number of transactions that can be stored in the Orphan Transaction pool. This pool keep transactions +# that are 'orphans', i.e. transactions with inputs that don't exist in the UTXO set. If you're not mining, and +# memory usage is a concern, this can safely be set to zero. Even so, orphan transactions do not appear that often +# (it's usually a short chain of spends that are broadcast in quick succession). The other potential source of orphan +# transactions are from DOS attacks and setting the `tx_ttl` parameter to a low value is an effective countermeasure +# in this case. Default: 250 transactions +# orphan_pool_storage_capacity = 250 + +# The maximum amount of time an orphan transaction will be permitted to stay in the mempool before being rejected. +# This should be set to a fairly long enough to allow the parent transaction to arrive; but low enough also to thwart +# DOS attacks. Default: 300 seconds +#orphan_tx_ttl = 300 + +# The maximum number of transactions that can be stored in the Pending Transaction pool. This pool holds transactions +# that are valid, but cannot be included in a block yet becuase there is a consensus rule holding it back, usually a +# time lock. Once the conditions holding the transaction in the pending pool are resolved, the transaction will move +# into the unconfirmed pool. Default: 5,000 transactions +# pending_pool_storage_capacity = 5000 + +# The ReorgPool consists of all transactions that have recently been added to blocks. +# When a potential blockchain reorganization occurs the transactions can be recovered from the ReorgPool and can be +# added back into the UnconfirmedPool. Transactions in the ReOrg pool have a limited Time-to-live and will be removed +# from the pool when the Time-to-live thresholds is reached. Also, when the capacity of the pool has been reached, the +# oldest transactions will be removed to make space for incoming transactions. The pool capacity and TTL parameters +# have the same meaning as those for the pending pool, but applied to the reorg pool; obviously. +# Defaults: 10,000 transactions and 300 seconds +#reorg_pool_storage_capacity = 10_000 +#reorg_tx_ttl = 300 + +# The maximum number of transactions that can be skipped when compiling a set of highest priority transactions, +# skipping over large transactions are performed in an attempt to fit more transactions into the remaining space. +# This parameter only affects mining nodes. You can ignore it if you are only running a base node. Even so, changing +# this parameter should not affect profitabilty in any meaningful way, since the transaction weights are selected to +# closely mirror how much block space they take up +#weight_tx_skip_count = 20 + +######################################################################################################################## +# # +# Validator Node Configuration Options # +# # +######################################################################################################################## + +# If you are not , you can simply leave everything in this section commented out. Base nodes +# help maintain the security of the Tari token and are the surest way to preserve your privacy and be 100% sure that +# no-one is cheating you out of your money. + +[validator_node] + +# Enable the gRPC server for the base node. Set this to true if you want to enable third-party wallet software +#grpc_enabled = false + +# The socket to expose for the gRPC base node server. This value is ignored if grpc_enabled is false. +# Valid values here are IPv4 and IPv6 TCP sockets, local unix sockets (e.g. "ipc://base-node-gprc.sock.100") +#grpc_address = "127.0.0.1:18042" + +######################################################################################################################## +# # +# Merge Mining Configuration Options # +# # +######################################################################################################################## + +[merge_mining_proxy.igor] + +# URL to monerod +monerod_url = "http://monero-stagenet.exan.tech:38081" # stagenet +#monerod_url = "http://18.133.59.45:28081" # testnet +#monerod_url = "http://18.132.124.81:18081" # mainnet +#monerod_url = "http://monero.exan.tech:18081" # mainnet alternative + +# Address of the tari_merge_mining_proxy application +proxy_host_address = "127.0.0.1:7878" + +# In sole merged mining, the block solution is usually submitted to the Monero blockchain +# (monerod) as well as to the Tari blockchain, then this setting should be "true". With pool +# merged mining, there is no sense in submitting the solution to the Monero blockchain as the +# pool does that, then this setting should be "false". (default = true). +proxy_submit_to_origin = true + +# If authentication is being used for curl +monerod_use_auth = false + +# Username for curl +monerod_username = "" + +# Password for curl +monerod_password = "" + +# The merge mining proxy can either wait for the base node to achieve initial sync at startup before it enables mining, +# or not. If merge mining starts before the base node has achieved initial sync, those Tari mined blocks will not be +# accepted. (Default value = true; will wait for base node initial sync). +#wait_for_initial_sync_at_startup = true + +[stratum_transcoder] + +# Address of the tari_stratum_transcoder application +transcoder_host_address = "127.0.0.1:7879" + +[mining_node] +# Number of mining threads +# Default: number of logical CPU cores +#num_mining_threads=8 + +# GRPC address of base node +# Default: value from `base_node.grpc_base_node_address` +#base_node_grpc_address = "127.0.0.1:18142" + +# GRPC address of console wallet +# Default: value from `base_node.grpc_console_wallet_address` +#wallet_grpc_address = "127.0.0.1:18143" + +# Start mining only when base node is bootstrapped +# and current block height is on the tip of network +# Default: true +#mine_on_tip_only=true + +# Will check tip with node every N seconds and restart mining +# if height already taken and option `mine_on_tip_only` is set +# to true +# Default: 30 seconds +#validate_tip_timeout_sec=30 + +# Stratum Mode configuration +# mining_pool_address = "miningcore.tarilabs.com:3052" +# mining_wallet_address = "YOUR_WALLET_PUBLIC_KEY" +# mining_worker_name = "worker1" diff --git a/common/logging/log4rs_sample_mining_node.yml b/common/logging/log4rs_sample_mining_node.yml index f0c8a965b8..16c4c43739 100644 --- a/common/logging/log4rs_sample_mining_node.yml +++ b/common/logging/log4rs_sample_mining_node.yml @@ -14,10 +14,6 @@ appenders: kind: console encoder: pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] {h({l}):5} {m}{n}" - filters: - - - kind: threshold - level: warn # An appender named "base_layer" that writes to a file with a custom pattern encoder mining_node: kind: rolling_file @@ -37,9 +33,22 @@ appenders: # Set the default logging level to "warn" and attach the "stdout" appender to the root root: - level: info + level: warn appenders: - stdout - - mining_node + +loggers: + # mining_node + tari::application: + level: debug + appenders: + - mining_node + additive: false + tari_mining_node: + level: debug + appenders: + - mining_node + - stdout + additive: false diff --git a/common/src/configuration/bootstrap.rs b/common/src/configuration/bootstrap.rs index 6e0a668541..79c98f8d97 100644 --- a/common/src/configuration/bootstrap.rs +++ b/common/src/configuration/bootstrap.rs @@ -148,6 +148,9 @@ pub struct ConfigBootstrap { pub miner_max_diff: Option, #[structopt(long, alias = "tracing")] pub tracing_enabled: bool, + /// Supply a network (overrides existing configuration) + #[structopt(long, alias = "network")] + pub network: Option, } fn normalize_path(path: PathBuf) -> PathBuf { @@ -183,6 +186,7 @@ impl Default for ConfigBootstrap { miner_min_diff: None, miner_max_diff: None, tracing_enabled: false, + network: None, } } } diff --git a/common/src/configuration/global.rs b/common/src/configuration/global.rs index 880e111cd1..be8cb38038 100644 --- a/common/src/configuration/global.rs +++ b/common/src/configuration/global.rs @@ -71,7 +71,6 @@ pub struct GlobalConfig { pub pruning_horizon: u64, pub pruned_mode_cleanup_interval: u64, pub core_threads: Option, - pub max_threads: Option, pub base_node_identity_file: PathBuf, pub public_address: Multiaddr, pub grpc_enabled: bool, @@ -137,6 +136,7 @@ pub struct GlobalConfig { pub mining_pool_address: String, pub mining_wallet_address: String, pub mining_worker_name: String, + pub base_node_bypass_range_proof_verification: bool, } impl GlobalConfig { @@ -270,10 +270,6 @@ fn convert_node_config( let core_threads = optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - let key = config_string("base_node", &net_str, "max_threads"); - let max_threads = - optional(cfg.get_int(&key).map(|n| n as usize)).map_err(|e| ConfigurationError::new(&key, &e.to_string()))?; - // Max RandomX VMs let key = config_string("base_node", &net_str, "max_randomx_vms"); let max_randomx_vms = optional(cfg.get_int(&key).map(|n| n as usize)) @@ -376,6 +372,8 @@ fn convert_node_config( s.parse::() .map_err(|e| ConfigurationError::new(&key, &e.to_string())) })?; + let key = config_string("base_node", &net_str, "bypass_range_proof_verification"); + let base_node_bypass_range_proof_verification = cfg.get_bool(&key).unwrap_or(false); let key = config_string("base_node", &net_str, "dns_seeds_use_dnssec"); let dns_seeds_use_dnssec = cfg @@ -712,7 +710,6 @@ fn convert_node_config( pruning_horizon, pruned_mode_cleanup_interval, core_threads, - max_threads, base_node_identity_file, public_address, grpc_enabled, @@ -778,6 +775,7 @@ fn convert_node_config( mining_pool_address, mining_wallet_address, mining_worker_name, + base_node_bypass_range_proof_verification, }) } diff --git a/common/src/configuration/network.rs b/common/src/configuration/network.rs index c8a0d3fe4a..1498b1e623 100644 --- a/common/src/configuration/network.rs +++ b/common/src/configuration/network.rs @@ -37,6 +37,7 @@ pub enum Network { Ridcully = 0x21, Stibbons = 0x22, Weatherwax = 0x23, + Igor = 0x24, } impl Network { @@ -51,6 +52,7 @@ impl Network { Ridcully => "ridcully", Stibbons => "stibbons", Weatherwax => "weatherwax", + Igor => "igor", LocalNet => "localnet", } } @@ -73,6 +75,7 @@ impl FromStr for Network { "weatherwax" => Ok(Weatherwax), "mainnet" => Ok(MainNet), "localnet" => Ok(LocalNet), + "igor" => Ok(Igor), invalid => Err(ConfigurationError::new( "network", &format!("Invalid network option: {}", invalid), diff --git a/common/src/configuration/utils.rs b/common/src/configuration/utils.rs index 1f291cf0ce..3814deb8d9 100644 --- a/common/src/configuration/utils.rs +++ b/common/src/configuration/utils.rs @@ -192,7 +192,7 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { .unwrap(); cfg.set_default( "base_node.weatherwax.data_dir", - default_subdir("stibbons/", Some(&bootstrap.base_path)), + default_subdir("weatherwax/", Some(&bootstrap.base_path)), ) .unwrap(); cfg.set_default( @@ -228,7 +228,6 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { .unwrap(); cfg.set_default("base_node.weatherwax.grpc_console_wallet_address", "127.0.0.1:18143") .unwrap(); - cfg.set_default("base_node.weatherwax.dns_seeds_name_server", "1.1.1.1:53") .unwrap(); cfg.set_default("base_node.weatherwax.dns_seeds_use_dnssec", true) @@ -238,6 +237,28 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { cfg.set_default("wallet.base_node_service_peers", Vec::::new()) .unwrap(); + //---------------------------------- Igor Defaults --------------------------------------------// + + cfg.set_default("base_node.igor.db_type", "lmdb").unwrap(); + cfg.set_default("base_node.igor.orphan_storage_capacity", 720).unwrap(); + cfg.set_default("base_node.igor.orphan_db_clean_out_threshold", 0) + .unwrap(); + cfg.set_default("base_node.igor.pruning_horizon", 0).unwrap(); + cfg.set_default("base_node.igor.pruned_mode_cleanup_interval", 50) + .unwrap(); + cfg.set_default("base_node.igor.flood_ban_max_msg_count", 1000).unwrap(); + cfg.set_default("base_node.igor.public_address", format!("{}/tcp/18141", local_ip_addr)) + .unwrap(); + cfg.set_default("base_node.igor.grpc_enabled", false).unwrap(); + cfg.set_default("base_node.igor.grpc_base_node_address", "127.0.0.1:18142") + .unwrap(); + cfg.set_default("base_node.igor.grpc_console_wallet_address", "127.0.0.1:18143") + .unwrap(); + cfg.set_default("base_node.igor.dns_seeds_name_server", "1.1.1.1:53") + .unwrap(); + cfg.set_default("base_node.igor.dns_seeds_use_dnssec", true).unwrap(); + cfg.set_default("base_node.igor.auto_ping_interval", 30).unwrap(); + set_transport_defaults(&mut cfg).unwrap(); set_merge_mining_defaults(&mut cfg); set_mining_node_defaults(&mut cfg); @@ -254,6 +275,8 @@ fn set_stratum_transcoder_defaults(cfg: &mut Config) { "127.0.0.1:7879", ) .unwrap(); + cfg.set_default("stratum_transcoder.igor.transcoder_host_address", "127.0.0.1:7879") + .unwrap(); } fn set_merge_mining_defaults(cfg: &mut Config) { @@ -289,6 +312,16 @@ fn set_merge_mining_defaults(cfg: &mut Config) { .unwrap(); cfg.set_default("merge_mining_proxy.weatherwax.wait_for_initial_sync_at_startup", true) .unwrap(); + cfg.set_default("merge_mining_proxy.igor.proxy_host_address", "127.0.0.1:7878") + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.proxy_submit_to_origin", true) + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_use_auth", "false") + .unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_username", "").unwrap(); + cfg.set_default("merge_mining_proxy.igor.monerod_password", "").unwrap(); + cfg.set_default("merge_mining_proxy.igor.wait_for_initial_sync_at_startup", true) + .unwrap(); } fn set_mining_node_defaults(cfg: &mut Config) { @@ -372,6 +405,18 @@ fn set_transport_defaults(cfg: &mut Config) -> Result<(), config::ConfigError> { )?; cfg.set_default(&format!("{}.weatherwax.socks5_auth", app), "none")?; + + // igor + cfg.set_default(&format!("{}.igor.transport", app), "tor")?; + + cfg.set_default(&format!("{}.igor.tor_control_address", app), "/ip4/127.0.0.1/tcp/9051")?; + cfg.set_default(&format!("{}.igor.tor_control_auth", app), "none")?; + cfg.set_default(&format!("{}.igor.tor_forward_address", app), "/ip4/127.0.0.1/tcp/0")?; + cfg.set_default(&format!("{}.igor.tor_onion_port", app), "18141")?; + + cfg.set_default(&format!("{}.igor.socks5_proxy_address", app), "/ip4/0.0.0.0/tcp/9150")?; + + cfg.set_default(&format!("{}.igor.socks5_auth", app), "none")?; } Ok(()) } diff --git a/common/src/dns/tests.rs b/common/src/dns/tests.rs index 955f22cf97..b7dc087517 100644 --- a/common/src/dns/tests.rs +++ b/common/src/dns/tests.rs @@ -48,7 +48,7 @@ use trust_dns_client::rr::{rdata, RData, Record, RecordType}; // Ignore as this test requires network IO #[ignore] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_empty_vec_if_all_seeds_are_invalid() { let mut resolver = PeerSeedResolver::connect("1.1.1.1:53".parse().unwrap()).await.unwrap(); let seeds = resolver.resolve("tari.com").await.unwrap(); @@ -64,7 +64,7 @@ fn create_txt_record(contents: Vec) -> Record { } #[allow(clippy::vec_init_then_push)] -#[tokio_macros::test] +#[tokio::test] async fn it_returns_peer_seeds() { let mut records = Vec::new(); // Multiple addresses(works) diff --git a/common/src/lib.rs b/common/src/lib.rs index 6f5c98a2e4..cb2c99d0c1 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -72,7 +72,7 @@ //! let config = args.load_configuration().unwrap(); //! let global = GlobalConfig::convert_from(ApplicationType::BaseNode, config).unwrap(); //! assert_eq!(global.network, Network::Weatherwax); -//! assert!(global.max_threads.is_none()); +//! assert!(global.core_threads.is_none()); //! # std::fs::remove_dir_all(temp_dir).unwrap(); //! ``` diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 71b8eec12f..b781fa12bf 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -12,34 +12,36 @@ edition = "2018" [dependencies] tari_crypto = "0.11.1" tari_storage = { version = "^0.9", path = "../infrastructure/storage" } -tari_shutdown = { version="^0.9", path = "../infrastructure/shutdown" } +tari_shutdown = { version = "^0.9", path = "../infrastructure/shutdown" } +anyhow = "1.0.32" async-trait = "0.1.36" bitflags = "1.0.4" blake2 = "0.9.0" -bytes = { version = "0.5.x", features=["serde"] } +bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4.6", features = ["serde"] } cidr = "0.1.0" clear_on_drop = "=0.2.4" data-encoding = "2.2.0" digest = "0.9.0" -futures = { version = "^0.3", features = ["async-await"]} +futures = { version = "^0.3", features = ["async-await"] } lazy_static = "1.3.0" lmdb-zero = "0.4.4" log = { version = "0.4.0", features = ["std"] } -multiaddr = {version = "=0.11.0", package = "parity-multiaddr"} -nom = {version = "5.1.0", features=["std"], default-features=false} +multiaddr = { version = "0.13.0" } +nom = { version = "5.1.0", features = ["std"], default-features = false } openssl = { version = "0.10", features = ["vendored"] } -pin-project = "0.4.17" -prost = "=0.6.1" +pin-project = "1.0.8" +prost = "=0.8.0" rand = "0.8" serde = "1.0.119" serde_derive = "1.0.119" -snow = {version="=0.8.0", features=["default-resolver"]} -thiserror = "1.0.20" -tokio = {version="~0.2.19", features=["blocking", "time", "tcp", "dns", "sync", "stream", "signal"]} -tokio-util = {version="0.3.1", features=["codec"]} -tower= "0.3.1" +snow = { version = "=0.8.0", features = ["default-resolver"] } +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt-multi-thread", "time", "sync", "signal", "net", "macros", "io-util"] } +tokio-stream = { version = "0.1.7", features = ["sync"] } +tokio-util = { version = "0.6.7", features = ["codec", "compat"] } +tower = "0.3.1" tracing = "0.1.26" tracing-futures = "0.2.5" yamux = "=0.9.0" @@ -49,20 +51,19 @@ opentelemetry = { version = "0.16", default-features = false, features = ["trace opentelemetry-jaeger = { version="0.15", features=["rt-tokio"]} # RPC dependencies -tower-make = {version="0.3.0", optional=true} -anyhow = "1.0.32" +tower-make = { version = "0.3.0", optional = true } [dev-dependencies] -tari_test_utils = {version="^0.9", path="../infrastructure/test_utils"} -tari_comms_rpc_macros = {version="*", path="./rpc_macros"} +tari_test_utils = { version = "^0.9", path = "../infrastructure/test_utils" } +tari_comms_rpc_macros = { version = "*", path = "./rpc_macros" } env_logger = "0.7.0" serde_json = "1.0.39" -tokio-macros = "0.2.3" +#tokio = {version="1.8", features=["macros"]} tempfile = "3.1.0" [build-dependencies] -tari_common = { version = "^0.9", path="../common", features = ["build"]} +tari_common = { version = "^0.9", path = "../common", features = ["build"] } [features] avx2 = ["tari_crypto/avx2"] diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index c75e423543..77cf0d9e9d 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -10,58 +10,57 @@ license = "BSD-3-Clause" edition = "2018" [dependencies] -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} -tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros"} +tari_comms = { version = "^0.9", path = "../", features = ["rpc"] } +tari_comms_rpc_macros = { version = "^0.9", path = "../rpc_macros" } tari_crypto = "0.11.1" -tari_utilities = { version = "^0.3" } -tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown"} -tari_storage = { version = "^0.9", path = "../../infrastructure/storage"} +tari_utilities = { version = "^0.3" } +tari_shutdown = { version = "^0.9", path = "../../infrastructure/shutdown" } +tari_storage = { version = "^0.9", path = "../../infrastructure/storage" } anyhow = "1.0.32" bitflags = "1.2.0" -bytes = "0.4.12" +bytes = "0.5" chacha20 = "0.7.1" chrono = "0.4.9" -diesel = {version="1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"]} +diesel = { version = "1.4.7", features = ["sqlite", "serde_json", "chrono", "numeric"] } diesel_migrations = "1.4.0" -libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional=true } +libsqlite3-sys = { version = ">=0.8.0, <0.13.0", features = ["bundled"], optional = true } digest = "0.9.0" -futures= {version= "^0.3.1"} +futures = { version = "^0.3.1" } log = "0.4.8" -prost = "=0.6.1" -prost-types = "=0.6.1" +prost = "=0.8.0" +prost-types = "=0.8.0" rand = "0.8" serde = "1.0.90" serde_derive = "1.0.90" serde_repr = "0.1.5" -thiserror = "1.0.20" -tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} -tower= "0.3.1" +thiserror = "1.0.26" +tokio = { version = "1.10", features = ["rt", "macros"] } +tower = "0.3.1" ttl_cache = "0.5.1" # tower-filter dependencies pin-project = "0.4" [dev-dependencies] -tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils"} +tari_test_utils = { version = "^0.9", path = "../../infrastructure/test_utils" } env_logger = "0.7.0" -futures-test = { version = "0.3.0-alpha.19", package = "futures-test-preview" } +futures-test = { version = "0.3.5" } lmdb-zero = "0.4.4" tempfile = "3.1.0" -tokio-macros = "0.2.3" +tokio-stream = { version = "0.1.7", features = ["sync"] } petgraph = "0.5.1" clap = "2.33.0" # tower-filter dependencies tower-test = { version = "^0.3" } -tokio-test = "^0.2" -tokio = "^0.2" +tokio-test = "^0.4.2" futures-util = "^0.3.1" lazy_static = "1.4.0" [build-dependencies] -tari_common = { version = "^0.9", path="../../common"} +tari_common = { version = "^0.9", path = "../../common" } [features] test-mocks = [] diff --git a/comms/dht/examples/graphing_utilities/utilities.rs b/comms/dht/examples/graphing_utilities/utilities.rs index bfb081eb21..3a4fea7ae9 100644 --- a/comms/dht/examples/graphing_utilities/utilities.rs +++ b/comms/dht/examples/graphing_utilities/utilities.rs @@ -32,6 +32,7 @@ use petgraph::{ }; use std::{collections::HashMap, convert::TryFrom, fs, fs::File, io::Write, path::Path, process::Command, sync::Mutex}; use tari_comms::{connectivity::ConnectivitySelection, peer_manager::NodeId}; +use tari_test_utils::streams::convert_unbounded_mpsc_to_stream; const TEMP_GRAPH_OUTPUT_DIR: &str = "/tmp/memorynet_temp"; @@ -277,7 +278,9 @@ pub enum PythonRenderType { /// This function will drain the message event queue and then build a message propagation tree assuming the first sender /// is the starting node pub async fn track_join_message_drain_messaging_events(messaging_rx: &mut NodeEventRx) -> StableGraph { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); let messages = drain_fut.await; let num_messages = messages.len(); diff --git a/comms/dht/examples/memory_net/drain_burst.rs b/comms/dht/examples/memory_net/drain_burst.rs index d2f5bce2be..59ed8723f4 100644 --- a/comms/dht/examples/memory_net/drain_burst.rs +++ b/comms/dht/examples/memory_net/drain_burst.rs @@ -42,7 +42,7 @@ where St: ?Sized + Stream + Unpin let (lower_bound, upper_bound) = stream.size_hint(); Self { inner: stream, - collection: Vec::with_capacity(upper_bound.or(Some(lower_bound)).unwrap()), + collection: Vec::with_capacity(upper_bound.unwrap_or(lower_bound)), } } } @@ -70,15 +70,16 @@ where St: ?Sized + Stream + Unpin mod test { use super::*; use futures::stream; + use tari_comms::runtime; - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_terminating_stream() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; assert_eq!(burst, (1..10u8).into_iter().collect::>()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn drain_stream_with_pending() { let mut stream = stream::iter(1..10u8); let burst = DrainBurst::new(&mut stream).await; diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index bf1bf03ae4..0d4a675e8e 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -22,7 +22,7 @@ #![allow(clippy::mutex_atomic)] use crate::memory_net::DrainBurst; -use futures::{channel::mpsc, future, StreamExt}; +use futures::future; use lazy_static::lazy_static; use rand::{rngs::OsRng, Rng}; use std::{ @@ -62,8 +62,13 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tari_test_utils::{paths::create_temporary_data_path, random}; -use tokio::{runtime, sync::broadcast, task, time}; +use tari_test_utils::{paths::create_temporary_data_path, random, streams::convert_unbounded_mpsc_to_stream}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, + task, + time, +}; use tower::ServiceBuilder; pub type NodeEventRx = mpsc::UnboundedReceiver<(NodeId, NodeId)>; @@ -154,7 +159,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent start.elapsed() ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, Err(err) => { @@ -166,7 +171,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent err ); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; total_messages += drain_messaging_events(messaging_events_rx, false).await; }, } @@ -298,7 +303,7 @@ pub async fn do_network_wide_propagation(nodes: &mut [TestNode], origin_node_ind let node_name = node.name.clone(); task::spawn(async move { - let result = time::timeout(Duration::from_secs(30), ims_rx.next()).await; + let result = time::timeout(Duration::from_secs(30), ims_rx.recv()).await; let mut is_success = false; match result { Ok(Some(msg)) => { @@ -450,21 +455,23 @@ pub async fn do_store_and_forward_message_propagation( for (idx, mut s) in neighbour_subs.into_iter().enumerate() { let neighbour = neighbours[idx].name.clone(); task::spawn(async move { - let msg = time::timeout(Duration::from_secs(2), s.next()).await; + let msg = time::timeout(Duration::from_secs(2), s.recv()).await; match msg { - Ok(Some(Ok(evt))) => { + Ok(Ok(evt)) => { if let MessagingEvent::MessageReceived(_, tag) = &*evt { println!("{} received propagated SAF message ({})", neighbour, tag); } }, - Ok(_) => {}, + Ok(Err(err)) => { + println!("{}", err); + }, Err(_) => println!("{} did not receive the SAF message", neighbour), } }); } banner!("⏰ Waiting a few seconds for messages to propagate around the network..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; let mut total_messages = drain_messaging_events(messaging_rx, false).await; @@ -515,7 +522,7 @@ pub async fn do_store_and_forward_message_propagation( let mut num_msgs = 0; let mut succeeded = 0; loop { - let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().next()).await; + let result = time::timeout(Duration::from_secs(10), wallet.ims_rx.as_mut().unwrap().recv()).await; num_msgs += 1; match result { Ok(msg) => { @@ -554,7 +561,9 @@ pub async fn do_store_and_forward_message_propagation( } pub async fn drain_messaging_events(messaging_rx: &mut NodeEventRx, show_logs: bool) -> usize { - let drain_fut = DrainBurst::new(messaging_rx); + let stream = convert_unbounded_mpsc_to_stream(messaging_rx); + tokio::pin!(stream); + let drain_fut = DrainBurst::new(&mut stream); if show_logs { let messages = drain_fut.await; let num_messages = messages.len(); @@ -694,42 +703,46 @@ impl TestNode { fn spawn_event_monitor( comms: &CommsNode, - messaging_events: MessagingEventReceiver, + mut messaging_events: MessagingEventReceiver, events_tx: mpsc::Sender>, messaging_events_tx: NodeEventTx, quiet_mode: bool, ) { - let conn_man_event_sub = comms.subscribe_connection_manager_events(); + let mut conn_man_event_sub = comms.subscribe_connection_manager_events(); let executor = runtime::Handle::current(); - executor.spawn( - conn_man_event_sub - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .map(connection_manager_logger( - comms.node_identity().node_id().clone(), - quiet_mode, - )) - .map(Ok) - .forward(events_tx), - ); - let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + let mut logger = connection_manager_logger(node_id, quiet_mode); + loop { + match conn_man_event_sub.recv().await { + Ok(event) => { + events_tx.send(logger(event)).await.unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => break, + Err(err) => log::error!("{}", err), + } + } + }); - executor.spawn( - messaging_events - .filter(|r| future::ready(r.is_ok())) - .map(Result::unwrap) - .filter_map(move |event| { - use MessagingEvent::*; - future::ready(match &*event { - MessageReceived(peer_node_id, _) => Some((Clone::clone(&*peer_node_id), node_id.clone())), - _ => None, - }) - }) - .map(Ok) - .forward(messaging_events_tx), - ); + let node_id = comms.node_identity().node_id().clone(); + executor.spawn(async move { + loop { + let event = messaging_events.recv().await; + use MessagingEvent::*; + match event.as_deref() { + Ok(MessageReceived(peer_node_id, _)) => { + messaging_events_tx + .send((Clone::clone(&*peer_node_id), node_id.clone())) + .unwrap(); + }, + Err(broadcast::error::RecvError::Closed) => { + break; + }, + _ => {}, + } + } + }); } #[inline] @@ -749,7 +762,7 @@ impl TestNode { } use ConnectionManagerEvent::*; loop { - let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.next()) + let event = time::timeout(Duration::from_secs(30), self.conn_man_events_rx.recv()) .await .ok()??; @@ -763,7 +776,7 @@ impl TestNode { } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -946,5 +959,5 @@ async fn setup_comms_dht( pub async fn take_a_break(num_nodes: usize) { banner!("Taking a break for a few seconds to let things settle..."); - time::delay_for(Duration::from_millis(num_nodes as u64 * 100)).await; + time::sleep(Duration::from_millis(num_nodes as u64 * 100)).await; } diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 9cc28551bc..3ed35c05b7 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -49,15 +49,16 @@ use crate::memory_net::utilities::{ shutdown_all, take_a_break, }; -use futures::{channel::mpsc, future}; +use futures::future; use rand::{rngs::OsRng, Rng}; use std::{iter::repeat_with, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; // Size of network -const NUM_NODES: usize = 6; +const NUM_NODES: usize = 40; // Must be at least 2 -const NUM_WALLETS: usize = 50; +const NUM_WALLETS: usize = 5; const QUIET_MODE: bool = true; /// Number of neighbouring nodes each node should include in the connection pool const NUM_NEIGHBOURING_NODES: usize = 8; @@ -66,7 +67,7 @@ const NUM_RANDOM_NODES: usize = 4; /// The number of messages that should be propagated out const PROPAGATION_FACTOR: usize = 4; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -77,7 +78,7 @@ async fn main() { NUM_WALLETS ); - let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (node_message_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let seed_node = vec![ make_node( diff --git a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs index a0a1bddc1e..5d2759fb54 100644 --- a/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs +++ b/comms/dht/examples/memorynet_graph_network_join_multiple_seeds.rs @@ -65,11 +65,11 @@ use crate::{ }, }; use clap::{App, Arg}; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { env_logger::init(); @@ -96,7 +96,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, _messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_join.rs b/comms/dht/examples/memorynet_graph_network_track_join.rs index 259358e1cb..0a0324b683 100644 --- a/comms/dht/examples/memorynet_graph_network_track_join.rs +++ b/comms/dht/examples/memorynet_graph_network_track_join.rs @@ -73,11 +73,11 @@ use crate::{ }; use clap::{App, Arg}; use env_logger::Env; -use futures::channel::mpsc; use std::{path::Path, time::Duration}; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -106,7 +106,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/examples/memorynet_graph_network_track_propagation.rs b/comms/dht/examples/memorynet_graph_network_track_propagation.rs index fcc5debff3..d560b9f537 100644 --- a/comms/dht/examples/memorynet_graph_network_track_propagation.rs +++ b/comms/dht/examples/memorynet_graph_network_track_propagation.rs @@ -73,10 +73,10 @@ use crate::{ }, }; use env_logger::Env; -use futures::channel::mpsc; use tari_comms::peer_manager::PeerFeatures; +use tokio::sync::mpsc; -#[tokio_macros::main] +#[tokio::main] #[allow(clippy::same_item_push)] async fn main() { let _ = env_logger::from_env(Env::default()) @@ -105,7 +105,7 @@ async fn main() { NUM_WALLETS ); - let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded(); + let (messaging_events_tx, mut messaging_events_rx) = mpsc::unbounded_channel(); let mut seed_identities = Vec::new(); for _ in 0..NUM_SEED_NODES { diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index c2c2d4e52a..2e453291ac 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -37,13 +37,7 @@ use crate::{ DhtConfig, }; use chrono::{DateTime, Utc}; -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot}, - future::BoxFuture, - stream::{Fuse, FuturesUnordered}, - SinkExt, - StreamExt, -}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use log::*; use std::{cmp, fmt, fmt::Display, sync::Arc}; use tari_comms::{ @@ -51,10 +45,15 @@ use tari_comms::{ peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, }; +use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::ShutdownSignal; use tari_utilities::message_format::{MessageFormat, MessageFormatError}; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::actor"; @@ -62,8 +61,6 @@ const LOG_TARGET: &str = "comms::dht::actor"; pub enum DhtActorError { #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("Reply sender canceled the request")] ReplyCanceled, #[error("PeerManagerError: {0}")] @@ -84,15 +81,9 @@ pub enum DhtActorError { ConnectivityEventStreamClosed, } -impl From for DhtActorError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtActorError::ChannelDisconnected - } else if err.is_full() { - DhtActorError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtActorError { + fn from(_: mpsc::error::SendError) -> Self { + DhtActorError::ChannelDisconnected } } @@ -101,9 +92,14 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Inserts a message signature to the msg hash cache. This operation replies with a boolean - /// which is true if the signature already exists in the cache, otherwise false - MsgHashCacheInsert(Vec, CommsPublicKey, oneshot::Sender), + /// Inserts a message signature to the msg hash cache. This operation replies with the number of times this message + /// has previously been seen (hit count) + MsgHashCacheInsert { + message_hash: Vec, + received_from: CommsPublicKey, + reply_tx: oneshot::Sender, + }, + GetMsgHashHitCount(Vec, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), @@ -114,12 +110,22 @@ impl Display for DhtRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use DhtRequest::*; match self { - SendJoin => f.write_str("SendJoin"), - MsgHashCacheInsert(_, _, _) => f.write_str("MsgHashCacheInsert"), - SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), - GetMetadata(key, _) => f.write_str(&format!("GetMetadata (key={})", key)), + SendJoin => write!(f, "SendJoin"), + MsgHashCacheInsert { + message_hash, + received_from, + .. + } => write!( + f, + "MsgHashCacheInsert(message hash: {}, received from: {})", + message_hash.to_hex(), + received_from.to_hex(), + ), + GetMsgHashHitCount(hash, _) => write!(f, "GetMsgHashHitCount({})", hash.to_hex()), + SelectPeers(s, _) => write!(f, "SelectPeers (Strategy={})", s), + GetMetadata(key, _) => write!(f, "GetMetadata (key={})", key), SetMetadata(key, value, _) => { - f.write_str(&format!("SetMetadata (key={}, value={} bytes)", key, value.len())) + write!(f, "SetMetadata (key={}, value={} bytes)", key, value.len()) }, } } @@ -147,14 +153,27 @@ impl DhtRequester { reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) } - pub async fn insert_message_hash( + pub async fn add_message_to_dedup_cache( &mut self, message_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { + received_from: CommsPublicKey, + ) -> Result { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(DhtRequest::MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + }) + .await?; + + reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) + } + + pub async fn get_message_cache_hit_count(&mut self, message_hash: Vec) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtRequest::MsgHashCacheInsert(message_hash, public_key, reply_tx)) + .send(DhtRequest::GetMsgHashHitCount(message_hash, reply_tx)) .await?; reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) @@ -186,8 +205,8 @@ pub struct DhtActor { outbound_requester: OutboundMessageRequester, connectivity: ConnectivityRequester, config: DhtConfig, - shutdown_signal: Option, - request_rx: Fuse>, + shutdown_signal: ShutdownSignal, + request_rx: mpsc::Receiver, msg_hash_dedup_cache: DedupCacheDatabase, } @@ -217,8 +236,8 @@ impl DhtActor { peer_manager, connectivity, node_identity, - shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + shutdown_signal, + request_rx, } } @@ -247,33 +266,28 @@ impl DhtActor { let mut pending_jobs = FuturesUnordered::new(); - let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval).fuse(); - - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtActor initialized without shutdown_signal"); + let mut dedup_cache_trim_ticker = time::interval(self.config.dedup_cache_trim_interval); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { trace!(target: LOG_TARGET, "DhtActor received request: {}", request); pending_jobs.push(self.request_handler(request)); }, - result = pending_jobs.select_next_some() => { + Some(result) = pending_jobs.next() => { if let Err(err) = result { debug!(target: LOG_TARGET, "Error when handling DHT request message. {}", err); } }, - _ = dedup_cache_trim_ticker.select_next_some() => { - if let Err(err) = self.msg_hash_dedup_cache.truncate().await { + _ = dedup_cache_trim_ticker.tick() => { + if let Err(err) = self.msg_hash_dedup_cache.trim_entries().await { error!(target: LOG_TARGET, "Error when trimming message dedup cache: {:?}", err); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); self.mark_shutdown_time().await; break Ok(()); @@ -300,24 +314,36 @@ impl DhtActor { let outbound_requester = self.outbound_requester.clone(); Box::pin(Self::broadcast_join(node_identity, outbound_requester)) }, - MsgHashCacheInsert(hash, public_key, reply_tx) => { + MsgHashCacheInsert { + message_hash, + received_from, + reply_tx, + } => { let msg_hash_cache = self.msg_hash_dedup_cache.clone(); Box::pin(async move { - match msg_hash_cache.insert_body_hash_if_unique(hash, public_key).await { - Ok(already_exists) => { - let _ = reply_tx.send(already_exists).map_err(|_| DhtActorError::ReplyCanceled); + match msg_hash_cache.add_body_hash(message_hash, received_from).await { + Ok(hit_count) => { + let _ = reply_tx.send(hit_count); }, Err(err) => { warn!( target: LOG_TARGET, "Unable to update message dedup cache because {:?}", err ); - let _ = reply_tx.send(false).map_err(|_| DhtActorError::ReplyCanceled); + let _ = reply_tx.send(0); }, } Ok(()) }) }, + GetMsgHashHitCount(hash, reply_tx) => { + let msg_hash_cache = self.msg_hash_dedup_cache.clone(); + Box::pin(async move { + let hit_count = msg_hash_cache.get_hit_count(hash).await?; + let _ = reply_tx.send(hit_count); + Ok(()) + }) + }, SelectPeers(broadcast_strategy, reply_tx) => { let peer_manager = Arc::clone(&self.peer_manager); let node_identity = Arc::clone(&self.node_identity); @@ -690,11 +716,12 @@ mod test { test_utils::{build_peer_manager, make_client_identity, make_node_identity}, }; use chrono::{DateTime, Utc}; - use std::time::Duration; - use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}; + use tari_comms::{ + runtime, + test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}, + }; use tari_shutdown::Shutdown; use tari_test_utils::random; - use tokio::time::delay_for; async fn db_connection() -> DbConnection { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); @@ -702,7 +729,7 @@ mod test { conn } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_join_request() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -727,11 +754,11 @@ mod test { actor.spawn(); requester.send_join().await.unwrap(); - let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap()); + let (params, _) = unwrap_oms_send_msg!(out_rx.recv().await.unwrap()); assert_eq!(params.dht_message_type, DhtMessageType::Join); } - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_message_signature() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -756,24 +783,24 @@ mod test { actor.spawn(); let signature = vec![1u8, 2, 3]; - let is_dup = requester - .insert_message_hash(signature.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(signature.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); - let is_dup = requester - .insert_message_hash(signature, CommsPublicKey::default()) + assert_eq!(num_hits, 1); + let num_hits = requester + .add_message_to_dedup_cache(signature, CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); - let is_dup = requester - .insert_message_hash(Vec::new(), CommsPublicKey::default()) + assert_eq!(num_hits, 2); + let num_hits = requester + .add_message_to_dedup_cache(Vec::new(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn dedup_cache_cleanup() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -783,14 +810,12 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); - let mut shutdown = Shutdown::new(); - let trim_interval_ms = 500; + let shutdown = Shutdown::new(); // Note: This must be equal or larger than the minimum dedup cache capacity for DedupCacheDatabase - let capacity = 120; + let capacity = 10; let actor = DhtActor::new( DhtConfig { dedup_cache_capacity: capacity, - dedup_cache_trim_interval: Duration::from_millis(trim_interval_ms), ..Default::default() }, db_connection().await, @@ -803,66 +828,64 @@ mod test { ); // Create signatures for double the dedup cache capacity - let mut signatures: Vec> = Vec::new(); - for i in 0..(capacity * 2) { - signatures.push(vec![1u8, 2, i as u8]) - } + let signatures = (0..(capacity * 2)).map(|i| vec![1u8, 2, i as u8]).collect::>(); - // Pre-populate the dedup cache; everything should be accepted due to cleanup ticker not active yet + // Pre-populate the dedup cache; everything should be accepted because the cleanup ticker has not run yet for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Try to re-insert all; everything should be marked as duplicates due to cleanup ticker not active yet + // Try to re-insert all; all hashes should have incremented their hit count for key in &signatures { - let is_dup = actor + let num_hits = actor .msg_hash_dedup_cache - .insert_body_hash_if_unique(key.clone(), CommsPublicKey::default()) + .add_body_hash(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 2); } - // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire immediately + let dedup_cache_db = actor.msg_hash_dedup_cache.clone(); + // The cleanup ticker starts when the actor is spawned; the first cleanup event will fire fairly soon after the + // task is running on a thread. To remove this race condition, we trim the cache in the test. + dedup_cache_db.trim_entries().await.unwrap(); actor.spawn(); // Verify that the last half of the signatures are still present in the cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(is_dup); + assert_eq!(num_hits, 3); } // Verify that the first half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - // Let the trim period expire; this will trim the dedup cache to capacity - delay_for(Duration::from_millis(trim_interval_ms * 2)).await; + // Trim the database of excess entries + dedup_cache_db.trim_entries().await.unwrap(); // Verify that the last half of the signatures have been removed and can be re-inserted into cache for key in signatures.iter().take(capacity * 2).skip(capacity) { - let is_dup = requester - .insert_message_hash(key.clone(), CommsPublicKey::default()) + let num_hits = requester + .add_message_to_dedup_cache(key.clone(), CommsPublicKey::default()) .await .unwrap(); - assert!(!is_dup); + assert_eq!(num_hits, 1); } - - shutdown.trigger().unwrap(); } - #[tokio_macros::test_basic] + #[runtime::test] async fn select_peers() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -973,7 +996,7 @@ mod test { assert_eq!(peers.len(), 1); } - #[tokio_macros::test_basic] + #[runtime::test] async fn get_and_set_metadata() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -1029,6 +1052,6 @@ mod test { .unwrap(); assert_eq!(got_ts, ts); - shutdown.trigger().unwrap(); + shutdown.trigger(); } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index 249ed3d369..bd7fb2521c 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; -use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityRequester, peer_manager::{NodeIdentity, PeerManager}, }; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; pub struct DhtBuilder { node_identity: Arc, @@ -99,6 +99,11 @@ impl DhtBuilder { self } + pub fn with_dedup_discard_hit_count(mut self, max_hit_count: usize) -> Self { + self.config.dedup_allowed_message_occurrences = max_hit_count; + self + } + pub fn with_num_random_nodes(mut self, n: usize) -> Self { self.config.num_random_nodes = n; self diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0612445dca..a1b553ebb6 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -72,6 +72,10 @@ pub struct DhtConfig { /// The periodic trim interval for items in the message hash cache /// Default: 300s (5 mins) pub dedup_cache_trim_interval: Duration, + /// The number of occurrences of a message is allowed to pass through the DHT pipeline before being + /// deduped/discarded + /// Default: 1 + pub dedup_allowed_message_occurrences: usize, /// The duration to wait for a peer discovery to complete before giving up. /// Default: 2 minutes pub discovery_request_timeout: Duration, @@ -136,6 +140,7 @@ impl DhtConfig { impl Default for DhtConfig { fn default() -> Self { + // NB: please remember to update field comments to reflect these defaults Self { num_neighbouring_nodes: 8, num_random_nodes: 4, @@ -151,6 +156,7 @@ impl Default for DhtConfig { saf_max_message_size: 512 * 1024, dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), + dedup_allowed_message_occurrences: 1, database_url: DbConnectionUrl::Memory, discovery_request_timeout: Duration::from_secs(2 * 60), connectivity_update_interval: Duration::from_secs(2 * 60), diff --git a/comms/dht/src/connectivity/metrics.rs b/comms/dht/src/connectivity/metrics.rs index b7ee546d4e..ba456d9aa9 100644 --- a/comms/dht/src/connectivity/metrics.rs +++ b/comms/dht/src/connectivity/metrics.rs @@ -20,20 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::{mpsc, mpsc::SendError, oneshot, oneshot::Canceled}, - future, - SinkExt, - StreamExt, -}; use log::*; use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, - future::Future, time::{Duration, Instant}, }; use tari_comms::peer_manager::NodeId; -use tokio::task; +use tokio::{ + sync::{mpsc, oneshot}, + task, +}; const LOG_TARGET: &str = "comms::dht::metrics"; @@ -124,7 +120,7 @@ impl MetricsState { } pub struct MetricsCollector { - stream: Option>, + stream: mpsc::Receiver, state: MetricsState, } @@ -133,18 +129,17 @@ impl MetricsCollector { let (metrics_tx, metrics_rx) = mpsc::channel(500); let metrics_collector = MetricsCollectorHandle::new(metrics_tx); let collector = Self { - stream: Some(metrics_rx), + stream: metrics_rx, state: Default::default(), }; task::spawn(collector.run()); metrics_collector } - fn run(mut self) -> impl Future { - self.stream.take().unwrap().for_each(move |op| { + async fn run(mut self) { + while let Some(op) = self.stream.recv().await { self.handle(op); - future::ready(()) - }) + } } fn handle(&mut self, op: MetricOp) { @@ -286,7 +281,7 @@ impl MetricsCollectorHandle { match self.inner.try_send(MetricOp::Write(write)) { Ok(_) => true, Err(err) => { - warn!(target: LOG_TARGET, "Failed to write metric: {}", err.into_send_error()); + warn!(target: LOG_TARGET, "Failed to write metric: {:?}", err); false }, } @@ -338,14 +333,14 @@ pub enum MetricsError { ReplyCancelled, } -impl From for MetricsError { - fn from(_: SendError) -> Self { +impl From> for MetricsError { + fn from(_: mpsc::error::SendError) -> Self { MetricsError::ChannelClosedUnexpectedly } } -impl From for MetricsError { - fn from(_: Canceled) -> Self { +impl From for MetricsError { + fn from(_: oneshot::error::RecvError) -> Self { MetricsError::ReplyCancelled } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index c9c855b6c1..a3c4471451 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -27,7 +27,6 @@ mod metrics; pub use metrics::{MetricsCollector, MetricsCollectorHandle}; use crate::{connectivity::metrics::MetricsError, event::DhtEvent, DhtActorError, DhtConfig, DhtRequester}; -use futures::{stream::Fuse, StreamExt}; use log::*; use std::{sync::Arc, time::Instant}; use tari_comms::{ @@ -78,11 +77,11 @@ pub struct DhtConnectivity { /// Used to track when the random peer pool was last refreshed random_pool_last_refresh: Option, stats: Stats, - dht_events: Fuse>>, + dht_events: broadcast::Receiver>, metrics_collector: MetricsCollectorHandle, - shutdown_signal: Option, + shutdown_signal: ShutdownSignal, } impl DhtConnectivity { @@ -108,8 +107,8 @@ impl DhtConnectivity { metrics_collector, random_pool_last_refresh: None, stats: Stats::new(), - dht_events: dht_events.fuse(), - shutdown_signal: Some(shutdown_signal), + dht_events, + shutdown_signal, } } @@ -131,21 +130,15 @@ impl DhtConnectivity { }) } - pub async fn run(mut self, connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { - let mut connectivity_events = connectivity_events.fuse(); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DhtConnectivity initialized without a shutdown_signal"); - + pub async fn run(mut self, mut connectivity_events: ConnectivityEventRx) -> Result<(), DhtConnectivityError> { debug!(target: LOG_TARGET, "DHT connectivity starting"); self.refresh_neighbour_pool().await?; - let mut ticker = time::interval(self.config.connectivity_update_interval).fuse(); + let mut ticker = time::interval(self.config.connectivity_update_interval); loop { - futures::select! { - event = connectivity_events.select_next_some() => { + tokio::select! { + event = connectivity_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { debug!(target: LOG_TARGET, "Error handling connectivity event: {:?}", err); @@ -153,15 +146,13 @@ impl DhtConnectivity { } }, - event = self.dht_events.select_next_some() => { - if let Ok(event) = event { - if let Err(err) = self.handle_dht_event(&event).await { - debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); - } - } + Ok(event) = self.dht_events.recv() => { + if let Err(err) = self.handle_dht_event(&event).await { + debug!(target: LOG_TARGET, "Error handling DHT event: {:?}", err); + } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_random_pool_if_required().await { debug!(target: LOG_TARGET, "Error refreshing random peer pool: {:?}", err); } @@ -170,7 +161,7 @@ impl DhtConnectivity { } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "DhtConnectivity shutting down because the shutdown signal was received"); break; } diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 65dba8c235..d0e83c0aa5 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -30,6 +30,7 @@ use std::{iter::repeat_with, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{Peer, PeerFeatures}, + runtime, test_utils::{ count_string_occurrences, mocks::{create_connectivity_mock, create_dummy_peer_connection, ConnectivityManagerMockState}, @@ -89,7 +90,7 @@ async fn setup( ) } -#[tokio_macros::test_basic] +#[runtime::test] async fn initialize() { let config = DhtConfig { num_neighbouring_nodes: 4, @@ -127,7 +128,7 @@ async fn initialize() { assert!(managed.iter().all(|n| !neighbours.contains(n))); } -#[tokio_macros::test_basic] +#[runtime::test] async fn added_neighbours() { let node_identity = make_node_identity(); let mut node_identities = @@ -173,7 +174,7 @@ async fn added_neighbours() { assert!(managed.contains(closer_peer.node_id())); } -#[tokio_macros::test_basic] +#[runtime::test] #[allow(clippy::redundant_closure)] async fn reinitialize_pools_when_offline() { let node_identity = make_node_identity(); @@ -215,7 +216,7 @@ async fn reinitialize_pools_when_offline() { assert_eq!(managed.len(), 5); } -#[tokio_macros::test_basic] +#[runtime::test] async fn insert_neighbour() { let node_identity = make_node_identity(); let node_identities = @@ -254,11 +255,13 @@ async fn insert_neighbour() { } mod metrics { + use super::*; mod collector { + use super::*; use crate::connectivity::MetricsCollector; use tari_comms::peer_manager::NodeId; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_adds_message_received() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); @@ -273,7 +276,7 @@ mod metrics { assert_eq!(ts.count(), 100); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_clears_the_metrics() { let mut metric_collector = MetricsCollector::spawn(); let node_id = NodeId::default(); diff --git a/comms/dht/src/dedup/dedup_cache.rs b/comms/dht/src/dedup/dedup_cache.rs index f8f5f6fcbf..8364f020a0 100644 --- a/comms/dht/src/dedup/dedup_cache.rs +++ b/comms/dht/src/dedup/dedup_cache.rs @@ -24,15 +24,23 @@ use crate::{ schema::dedup_cache, storage::{DbConnection, StorageError}, }; -use chrono::Utc; -use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, QueryDsl, RunQueryDsl}; +use chrono::{NaiveDateTime, Utc}; +use diesel::{dsl, result::DatabaseErrorKind, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl}; use log::*; use tari_comms::types::CommsPublicKey; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; -use tari_utilities::hex; +use tari_crypto::tari_utilities::hex::Hex; const LOG_TARGET: &str = "comms::dht::dedup_cache"; +#[derive(Queryable, PartialEq, Debug)] +struct DedupCacheEntry { + body_hash: String, + sender_public_ke: String, + number_of_hit: i32, + stored_at: NaiveDateTime, + last_hit_at: NaiveDateTime, +} + #[derive(Clone)] pub struct DedupCacheDatabase { connection: DbConnection, @@ -48,36 +56,40 @@ impl DedupCacheDatabase { Self { connection, capacity } } - /// Inserts and returns Ok(true) if the item already existed and Ok(false) if it didn't, also updating hit stats - pub async fn insert_body_hash_if_unique( - &self, - body_hash: Vec, - public_key: CommsPublicKey, - ) -> Result { - let body_hash = hex::to_hex(&body_hash.as_bytes()); - let public_key = public_key.to_hex(); - match self - .insert_body_hash_or_update_stats(body_hash.clone(), public_key.clone()) - .await - { - Ok(val) => { - if val == 0 { - warn!( - target: LOG_TARGET, - "Unable to insert new entry into message dedup cache" - ); - } - Ok(false) - }, - Err(e) => match e { - StorageError::UniqueViolation(_) => Ok(true), - _ => Err(e), - }, + /// Adds the body hash to the cache, returning the number of hits (inclusive) that have been recorded for this body + /// hash + pub async fn add_body_hash(&self, body_hash: Vec, public_key: CommsPublicKey) -> Result { + let hit_count = self + .insert_body_hash_or_update_stats(body_hash.to_hex(), public_key.to_hex()) + .await?; + + if hit_count == 0 { + warn!( + target: LOG_TARGET, + "Unable to insert new entry into message dedup cache" + ); } + Ok(hit_count) + } + + pub async fn get_hit_count(&self, body_hash: Vec) -> Result { + let hit_count = self + .connection + .with_connection_async(move |conn| { + dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash.to_hex())) + .get_result::(conn) + .optional() + .map_err(Into::into) + }) + .await?; + + Ok(hit_count.unwrap_or(0) as u32) } /// Trims the dedup cache to the configured limit by removing the oldest entries - pub async fn truncate(&self) -> Result { + pub async fn trim_entries(&self) -> Result { let capacity = self.capacity; self.connection .with_connection_async(move |conn| { @@ -109,40 +121,46 @@ impl DedupCacheDatabase { .await } - // Insert new row into the table or update existing row in an atomic fashion; more than one thread can access this - // table at the same time. + /// Insert new row into the table or updates an existing row. Returns the number of hits for this body hash. async fn insert_body_hash_or_update_stats( &self, body_hash: String, public_key: String, - ) -> Result { + ) -> Result { self.connection .with_connection_async(move |conn| { let insert_result = diesel::insert_into(dedup_cache::table) .values(( - dedup_cache::body_hash.eq(body_hash.clone()), - dedup_cache::sender_public_key.eq(public_key.clone()), + dedup_cache::body_hash.eq(&body_hash), + dedup_cache::sender_public_key.eq(&public_key), dedup_cache::number_of_hits.eq(1), dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), )) .execute(conn); match insert_result { - Ok(val) => Ok(val), + Ok(1) => Ok(1), + Ok(n) => Err(StorageError::UnexpectedResult(format!( + "Expected exactly one row to be inserted. Got {}", + n + ))), Err(diesel::result::Error::DatabaseError(kind, e_info)) => match kind { DatabaseErrorKind::UniqueViolation => { // Update hit stats for the message - let result = - diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) - .set(( - dedup_cache::sender_public_key.eq(public_key), - dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), - dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), - )) - .execute(conn); - match result { - Ok(_) => Err(StorageError::UniqueViolation(body_hash)), - Err(e) => Err(e.into()), - } + diesel::update(dedup_cache::table.filter(dedup_cache::body_hash.eq(&body_hash))) + .set(( + dedup_cache::sender_public_key.eq(&public_key), + dedup_cache::number_of_hits.eq(dedup_cache::number_of_hits + 1), + dedup_cache::last_hit_at.eq(Utc::now().naive_utc()), + )) + .execute(conn)?; + // TODO: Diesel support for RETURNING statements would remove this query, but is not + // available for Diesel + SQLite yet + let hits = dedup_cache::table + .select(dedup_cache::number_of_hits) + .filter(dedup_cache::body_hash.eq(&body_hash)) + .get_result::(conn)?; + + Ok(hits as u32) }, _ => Err(diesel::result::Error::DatabaseError(kind, e_info).into()), }, diff --git a/comms/dht/src/dedup/mod.rs b/comms/dht/src/dedup/mod.rs index 5428277af0..8bea19f39b 100644 --- a/comms/dht/src/dedup/mod.rs +++ b/comms/dht/src/dedup/mod.rs @@ -47,13 +47,15 @@ fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { pub struct DedupMiddleware { next_service: S, dht_requester: DhtRequester, + allowed_message_occurrences: usize, } impl DedupMiddleware { - pub fn new(service: S, dht_requester: DhtRequester) -> Self { + pub fn new(service: S, dht_requester: DhtRequester, allowed_message_occurrences: usize) -> Self { Self { next_service: service, dht_requester, + allowed_message_occurrences, } } } @@ -71,9 +73,10 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); + let allowed_message_occurrences = self.allowed_message_occurrences; Box::pin(async move { let hash = hash_inbound_message(&message); trace!( @@ -83,14 +86,17 @@ where message.tag, message.dht_header.message_tag ); - if dht_requester - .insert_message_hash(hash, message.source_peer.public_key.clone()) - .await? - { + + message.dedup_hit_count = dht_requester + .add_message_to_dedup_cache(hash, message.source_peer.public_key.clone()) + .await?; + + if message.dedup_hit_count as usize > allowed_message_occurrences { trace!( target: LOG_TARGET, - "Received duplicate message {} from peer '{}' (Trace: {}). Message discarded.", + "Received duplicate message {} (hit_count = {}) from peer '{}' (Trace: {}). Message discarded.", message.tag, + message.dedup_hit_count, message.source_peer.node_id.short_str(), message.dht_header.message_tag, ); @@ -99,8 +105,9 @@ where trace!( target: LOG_TARGET, - "Passing message {} onto next service (Trace: {})", + "Passing message {} (hit_count = {}) onto next service (Trace: {})", message.tag, + message.dedup_hit_count, message.dht_header.message_tag ); next_service.oneshot(message).await @@ -110,11 +117,15 @@ where pub struct DedupLayer { dht_requester: DhtRequester, + allowed_message_occurrences: usize, } impl DedupLayer { - pub fn new(dht_requester: DhtRequester) -> Self { - Self { dht_requester } + pub fn new(dht_requester: DhtRequester, allowed_message_occurrences: usize) -> Self { + Self { + dht_requester, + allowed_message_occurrences, + } } } @@ -122,7 +133,7 @@ impl Layer for DedupLayer { type Service = DedupMiddleware; fn layer(&self, service: S) -> Self::Service { - DedupMiddleware::new(service, self.dht_requester.clone()) + DedupMiddleware::new(service, self.dht_requester.clone(), self.allowed_message_occurrences) } } @@ -138,15 +149,15 @@ mod test { #[test] fn process_message() { - let mut rt = Runtime::new().unwrap(); + let rt = Runtime::new().unwrap(); let spy = service_spy(); let (dht_requester, mock) = create_dht_actor_mock(1); let mock_state = mock.get_shared_state(); - mock_state.set_signature_cache_insert(false); + mock_state.set_number_of_message_hits(1); rt.spawn(mock.run()); - let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); + let mut dedup = DedupLayer::new(dht_requester, 3).layer(spy.to_service::()); panic_context!(cx); @@ -157,7 +168,7 @@ mod test { rt.block_on(dedup.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - mock_state.set_signature_cache_insert(true); + mock_state.set_number_of_message_hits(4); rt.block_on(dedup.call(msg)).unwrap(); assert_eq!(spy.call_count(), 1); // Drop dedup so that the DhtMock will stop running diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index dcdeea5730..9d29a70d79 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -26,6 +26,7 @@ use crate::{ connectivity::{DhtConnectivity, MetricsCollector, MetricsCollectorHandle}, discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester, DhtDiscoveryService}, event::{DhtEventReceiver, DhtEventSender}, + filter, inbound, inbound::{DecryptedDhtMessage, DhtInboundMessage, MetricsLayer}, logging_middleware::MessageLoggingLayer, @@ -37,12 +38,11 @@ use crate::{ storage::{DbConnection, StorageError}, store_forward, store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, - tower_filter, DedupLayer, DhtActorError, DhtConfig, }; -use futures::{channel::mpsc, future, Future}; +use futures::Future; use log::*; use std::sync::Arc; use tari_comms::{ @@ -53,7 +53,7 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; use tower::{layer::Layer, Service, ServiceBuilder}; const LOG_TARGET: &str = "comms::dht"; @@ -285,13 +285,14 @@ impl Dht { S: Service + Clone + Send + Sync + 'static, S::Future: Send, { - // FIXME: There is an unresolved stack overflow issue on windows in debug mode during runtime, but not in - // release mode, related to the amount of layers. (issue #1416) ServiceBuilder::new() .layer(MetricsLayer::new(self.metrics_collector.clone())) .layer(inbound::DeserializeLayer::new(self.peer_manager.clone())) - .layer(DedupLayer::new(self.dht_requester())) - .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) + .layer(DedupLayer::new( + self.dht_requester(), + self.config.dedup_allowed_message_occurrences, + )) + .layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter())) .layer(MessageLoggingLayer::new(format!( "Inbound [{}]", self.node_identity.node_id().short_str() @@ -301,6 +302,7 @@ impl Dht { self.node_identity.clone(), self.connectivity.clone(), )) + .layer(filter::FilterLayer::new(filter_messages_to_rebroadcast)) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), @@ -363,34 +365,60 @@ impl Dht { /// Produces a filter predicate which disallows store and forward messages if that feature is not /// supported by the node. - fn unsupported_saf_messages_filter( - &self, - ) -> impl tower_filter::Predicate>> + Clone + Send - { + fn unsupported_saf_messages_filter(&self) -> impl filter::Predicate + Clone + Send { let node_identity = Arc::clone(&self.node_identity); move |msg: &DhtInboundMessage| { if node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) { - return future::ready(Ok(())); + return true; } match msg.dht_header.message_type { DhtMessageType::SafRequestMessages => { // TODO: #banheuristic This is an indication of node misbehaviour - debug!( + warn!( "Received store and forward message from PublicKey={}. Store and forward feature is not \ supported by this node. Discarding message.", msg.source_peer.public_key ); - future::ready(Err(anyhow::anyhow!( - "Message filtered out because store and forward is not supported by this node", - ))) + false }, - _ => future::ready(Ok(())), + _ => true, } } } } +/// Provides the gossip filtering rules for an inbound message +fn filter_messages_to_rebroadcast(msg: &DecryptedDhtMessage) -> bool { + // Let the message through if: + // it isn't a duplicate (normal message), or + let should_continue = !msg.is_duplicate() || + ( + // it is a duplicate domain message (i.e. not DHT or SAF protocol message), and + msg.dht_header.message_type.is_domain_message() && + // it has an unknown destination (e.g complete transactions, blocks, misc. encrypted + // messages) we allow it to proceed, which in turn, re-propagates it for another round. + msg.dht_header.destination.is_unknown() + ); + + if should_continue { + // The message has been forwarded, but downstream middleware may be interested + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Passing message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + true + } else { + debug!( + target: LOG_TARGET, + "[filter_messages_to_rebroadcast] Discarding duplicate message {}", msg + ); + false + } +} + #[cfg(test)] mod test { use crate::{ @@ -404,22 +432,23 @@ mod test { make_comms_inbound_message, make_dht_envelope, make_node_identity, + service_spy, }, DhtBuilder, }; - use futures::{channel::mpsc, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_comms::{ message::{MessageExt, MessageTag}, pipeline::SinkService, + runtime, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body, }; use tari_shutdown::Shutdown; - use tokio::{task, time}; + use tokio::{sync::mpsc, task, time}; use tower::{layer::Layer, Service}; - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_unencrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -459,7 +488,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -469,7 +498,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_encrypted() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -509,7 +538,7 @@ mod test { let msg = { service.call(inbound_message).await.unwrap(); - let msg = time::timeout(Duration::from_secs(10), out_rx.next()) + let msg = time::timeout(Duration::from_secs(10), out_rx.recv()) .await .unwrap() .unwrap(); @@ -519,7 +548,7 @@ mod test { assert_eq!(msg, b"secret"); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_forward() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); @@ -528,7 +557,6 @@ mod test { peer_manager.add_peer(node_identity.to_peer()).await.unwrap(); let (connectivity, _) = create_connectivity_mock(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); let (oms_requester, oms_mock) = create_outbound_service_mock(1); // Send all outbound requests to the mock @@ -545,7 +573,8 @@ mod test { let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); @@ -574,10 +603,10 @@ mod test { assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called - assert!(next_service_rx.try_next().is_err()); + assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn stack_filter_saf_message() { let node_identity = make_client_identity(); let peer_manager = build_peer_manager(); @@ -600,9 +629,8 @@ mod test { .await .unwrap(); - let (next_service_tx, mut next_service_rx) = mpsc::channel(10); - - let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); + let spy = service_spy(); + let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); let msg = wrap_in_envelope_body!(b"secret".to_vec()); let mut dht_envelope = make_dht_envelope( @@ -619,10 +647,6 @@ mod test { let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); - // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely - assert_eq!( - format!("{}", next_service_rx.try_next().unwrap_err()), - "receiver channel is empty" - ); + assert_eq!(spy.call_count(), 0); } } diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index cf98e42d9d..c2a77f0c9a 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::outbound::{message::SendFailure, DhtOutboundError}; -use futures::channel::mpsc::SendError; use tari_comms::peer_manager::PeerManagerError; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtDiscoveryError { @@ -37,8 +37,6 @@ pub enum DhtDiscoveryError { InvalidNodeId, #[error("MPSC channel is disconnected")] ChannelDisconnected, - #[error("MPSC sender was unable to send because the channel buffer is full")] - SendBufferFull, #[error("The discovery request timed out")] DiscoveryTimeout, #[error("Failed to send discovery message: {0}")] @@ -56,14 +54,8 @@ impl DhtDiscoveryError { } } -impl From for DhtDiscoveryError { - fn from(err: SendError) -> Self { - if err.is_disconnected() { - DhtDiscoveryError::ChannelDisconnected - } else if err.is_full() { - DhtDiscoveryError::SendBufferFull - } else { - unreachable!(); - } +impl From> for DhtDiscoveryError { + fn from(_: SendError) -> Self { + DhtDiscoveryError::ChannelDisconnected } } diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index a286bcc6a5..a7317f79a2 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -21,16 +21,15 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; #[derive(Debug)] pub enum DhtDiscoveryRequest { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 258eb67cea..2cebf571b4 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -27,11 +27,6 @@ use crate::{ proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, DhtConfig, }; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - StreamExt, -}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::{ @@ -47,7 +42,11 @@ use tari_comms::{ }; use tari_shutdown::ShutdownSignal; use tari_utilities::{hex::Hex, ByteArray}; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -72,8 +71,8 @@ pub struct DhtDiscoveryService { node_identity: Arc, outbound_requester: OutboundMessageRequester, peer_manager: Arc, - request_rx: Option>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, inflight_discoveries: HashMap, } @@ -91,8 +90,8 @@ impl DhtDiscoveryService { outbound_requester, node_identity, peer_manager, - shutdown_signal: Some(shutdown_signal), - request_rx: Some(request_rx), + shutdown_signal, + request_rx, inflight_discoveries: HashMap::new(), } } @@ -106,29 +105,19 @@ impl DhtDiscoveryService { pub async fn run(mut self) { info!(target: LOG_TARGET, "Dht discovery service started"); - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("DiscoveryService initialized without shutdown_signal") - .fuse(); - - let mut request_rx = self - .request_rx - .take() - .expect("DiscoveryService initialized without request_rx") - .fuse(); - loop { - futures::select! { - request = request_rx.select_next_some() => { - trace!(target: LOG_TARGET, "Received request '{}'", request); - self.handle_request(request).await; - }, + tokio::select! { + biased; - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "Discovery service is shutting down because the shutdown signal was received"); break; } + + Some(request) = self.request_rx.recv() => { + trace!(target: LOG_TARGET, "Received request '{}'", request); + self.handle_request(request).await; + }, } } } @@ -153,7 +142,7 @@ impl DhtDiscoveryService { let mut remaining_requests = HashMap::new(); for (nonce, request) in self.inflight_discoveries.drain() { // Exclude canceled requests - if request.reply_tx.is_canceled() { + if request.reply_tx.is_closed() { continue; } @@ -199,7 +188,7 @@ impl DhtDiscoveryService { ); for request in self.collect_all_discovery_requests(&public_key) { - if !reply_tx.is_canceled() { + if !reply_tx.is_closed() { let _ = request.reply_tx.send(Ok(peer.clone())); } } @@ -299,7 +288,7 @@ impl DhtDiscoveryService { self.inflight_discoveries = self .inflight_discoveries .drain() - .filter(|(_, state)| !state.reply_tx.is_canceled()) + .filter(|(_, state)| !state.reply_tx.is_closed()) .collect(); trace!( @@ -393,9 +382,10 @@ mod test { test_utils::{build_peer_manager, make_node_identity}, }; use std::time::Duration; + use tari_comms::runtime; use tari_shutdown::Shutdown; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_discovery() { let node_identity = make_node_identity(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/domain_message.rs b/comms/dht/src/domain_message.rs index 2fe7af16fe..f565882725 100644 --- a/comms/dht/src/domain_message.rs +++ b/comms/dht/src/domain_message.rs @@ -33,7 +33,7 @@ impl ToProtoEnum for i32 { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OutboundDomainMessage { inner: T, message_type: i32, diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 0b93546dbb..c8b18d9e9b 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -22,6 +22,8 @@ use bitflags::bitflags; use bytes::Bytes; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost_types::Timestamp; use serde::{Deserialize, Serialize}; use std::{ cmp, @@ -30,14 +32,11 @@ use std::{ fmt::Display, }; use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey, NodeIdentity}; -use tari_utilities::{ByteArray, ByteArrayError}; +use tari_utilities::{epoch_time::EpochTime, ByteArray, ByteArrayError}; use thiserror::Error; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType}; -use chrono::{DateTime, NaiveDateTime, Utc}; -use prost_types::Timestamp; -use tari_utilities::epoch_time::EpochTime; /// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { @@ -106,8 +105,12 @@ impl DhtMessageFlags { } impl DhtMessageType { + pub fn is_domain_message(self) -> bool { + matches!(self, DhtMessageType::None) + } + pub fn is_dht_message(self) -> bool { - self.is_dht_discovery() || self.is_dht_join() + self.is_dht_discovery() || matches!(self, DhtMessageType::DiscoveryResponse) || self.is_dht_join() } pub fn is_dht_discovery(self) -> bool { diff --git a/comms/dht/src/tower_filter/future.rs b/comms/dht/src/filter/future.rs similarity index 66% rename from comms/dht/src/tower_filter/future.rs rename to comms/dht/src/filter/future.rs index 78b2c613e6..4559aeaadf 100644 --- a/comms/dht/src/tower_filter/future.rs +++ b/comms/dht/src/filter/future.rs @@ -13,16 +13,15 @@ use tower::Service; /// Filtered response future #[pin_project] #[derive(Debug)] -pub struct ResponseFuture +pub struct ResponseFuture where S: Service { #[pin] /// Response future state state: State, - #[pin] - /// Predicate future - check: T, + /// Predicate result + check: bool, /// Inner service service: S, @@ -35,12 +34,10 @@ enum State { WaitResponse(#[pin] U), } -impl ResponseFuture -where - F: Future>, - S: Service, +impl ResponseFuture +where S: Service { - pub(crate) fn new(request: Request, check: F, service: S) -> Self { + pub(crate) fn new(request: Request, check: bool, service: S) -> Self { ResponseFuture { state: State::Check(Some(request)), check, @@ -49,10 +46,8 @@ where } } -impl Future for ResponseFuture -where - F: Future>, - S: Service, +impl Future for ResponseFuture +where S: Service { type Output = Result; @@ -66,15 +61,13 @@ where .take() .expect("we either give it back or leave State::Check once we take"); - // Poll predicate - match this.check.as_mut().poll(cx)? { - Poll::Ready(_) => { + match this.check { + true => { let response = this.service.call(request); this.state.set(State::WaitResponse(response)); }, - Poll::Pending => { - this.state.set(State::Check(Some(request))); - return Poll::Pending; + false => { + return Poll::Ready(Ok(())); }, } }, diff --git a/comms/dht/src/tower_filter/layer.rs b/comms/dht/src/filter/layer.rs similarity index 100% rename from comms/dht/src/tower_filter/layer.rs rename to comms/dht/src/filter/layer.rs diff --git a/comms/dht/src/tower_filter/mod.rs b/comms/dht/src/filter/mod.rs similarity index 92% rename from comms/dht/src/tower_filter/mod.rs rename to comms/dht/src/filter/mod.rs index d1df2f27a7..e7f168161b 100644 --- a/comms/dht/src/tower_filter/mod.rs +++ b/comms/dht/src/filter/mod.rs @@ -33,11 +33,11 @@ impl Filter { impl Service for Filter where - T: Service + Clone, + T: Service + Clone, U: Predicate, { type Error = PipelineError; - type Future = ResponseFuture; + type Future = ResponseFuture; type Response = T::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/comms/dht/src/filter/predicate.rs b/comms/dht/src/filter/predicate.rs new file mode 100644 index 0000000000..024dee826d --- /dev/null +++ b/comms/dht/src/filter/predicate.rs @@ -0,0 +1,13 @@ +/// Checks a request +pub trait Predicate { + /// Check whether the given request should be forwarded. + fn check(&mut self, request: &Request) -> bool; +} + +impl Predicate for F +where F: Fn(&T) -> bool +{ + fn check(&mut self, request: &T) -> bool { + self(request) + } +} diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 9e3a6bbd3e..46fc3daa47 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -397,7 +397,12 @@ mod test { }; use futures::{executor::block_on, future}; use std::sync::Mutex; - use tari_comms::{message::MessageExt, test_utils::mocks::create_connectivity_mock, wrap_in_envelope_body}; + use tari_comms::{ + message::MessageExt, + runtime, + test_utils::mocks::create_connectivity_mock, + wrap_in_envelope_body, + }; use tari_test_utils::{counter_context, unpack_enum}; use tower::service_fn; @@ -469,7 +474,7 @@ mod test { assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index b28a057cb5..a73c3b3cfb 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -137,9 +137,12 @@ mod test { service_spy, }, }; - use tari_comms::message::{MessageExt, MessageTag}; + use tari_comms::{ + message::{MessageExt, MessageTag}, + runtime, + }; - #[tokio_macros::test_basic] + #[runtime::test] async fn deserialize() { let spy = service_spy(); let peer_manager = build_peer_manager(); diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index f45507a905..ec42bbd4fe 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -88,6 +88,20 @@ where S: Service return Ok(()); } + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + to next service (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + trace!( target: LOG_TARGET, "Received DHT message type `{}` (Source peer: {}, Tag: {}, Trace: {})", diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index a49ae4b073..c9cdd103fd 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -43,6 +43,7 @@ pub struct DhtInboundMessage { pub dht_header: DhtMessageHeader, /// True if forwarded via store and forward, otherwise false pub is_saf_message: bool, + pub dedup_hit_count: u32, pub body: Vec, } impl DhtInboundMessage { @@ -53,6 +54,7 @@ impl DhtInboundMessage { dht_header, source_peer, is_saf_message: false, + dedup_hit_count: 0, body, } } @@ -62,11 +64,12 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHit Count: {}\nHeader: {}\n{}\n----", self.body.len(), self.dht_header.message_type, self.source_peer, self.dht_header, + self.dedup_hit_count, self.tag, ) } @@ -86,6 +89,14 @@ pub struct DecryptedDhtMessage { pub is_saf_stored: Option, pub is_already_forwarded: bool, pub decryption_result: Result>, + pub dedup_hit_count: u32, +} + +impl DecryptedDhtMessage { + /// Returns true if this message has been received before, otherwise false if this is the first time + pub fn is_duplicate(&self) -> bool { + self.dedup_hit_count > 1 + } } impl DecryptedDhtMessage { @@ -104,6 +115,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Ok(message_body), + dedup_hit_count: message.dedup_hit_count, } } @@ -118,6 +130,7 @@ impl DecryptedDhtMessage { is_saf_stored: None, is_already_forwarded: false, decryption_result: Err(message.body), + dedup_hit_count: message.dedup_hit_count, } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index cab2f8ab6f..710b354f7b 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -71,7 +71,7 @@ //! #use std::sync::Arc; //! #use tari_comms::CommsBuilder; //! #use tokio::runtime::Runtime; -//! #use futures::channel::mpsc; +//! #use tokio::sync::mpsc; //! //! let runtime = Runtime::new().unwrap(); //! // Channel from comms to inbound dht @@ -153,11 +153,11 @@ pub use storage::DbConnectionUrl; mod dedup; pub use dedup::DedupLayer; +mod filter; mod logging_middleware; mod proto; mod rpc; mod schema; -mod tower_filter; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/network_discovery/on_connect.rs b/comms/dht/src/network_discovery/on_connect.rs index b93657f061..cd162b3903 100644 --- a/comms/dht/src/network_discovery/on_connect.rs +++ b/comms/dht/src/network_discovery/on_connect.rs @@ -33,7 +33,7 @@ use crate::{ }; use futures::StreamExt; use log::*; -use std::{convert::TryInto, ops::Deref}; +use std::convert::TryInto; use tari_comms::{ connectivity::ConnectivityEvent, peer_manager::{NodeId, Peer}, @@ -62,8 +62,9 @@ impl OnConnect { pub async fn next_event(&mut self) -> StateEvent { let mut connectivity_events = self.context.connectivity.get_event_subscription(); - while let Some(event) = connectivity_events.next().await { - match event.as_ref().map(|e| e.deref()) { + loop { + let event = connectivity_events.recv().await; + match event { Ok(ConnectivityEvent::PeerConnected(conn)) => { if conn.peer_features().is_client() { continue; @@ -96,10 +97,10 @@ impl OnConnect { self.prev_synced.push(conn.peer_node_id().clone()); }, Ok(_) => { /* Nothing to do */ }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(broadcast::error::RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagged behind on {} connectivity event(s)", n) }, - Err(broadcast::RecvError::Closed) => { + Err(broadcast::error::RecvError::Closed) => { break; }, } diff --git a/comms/dht/src/network_discovery/test.rs b/comms/dht/src/network_discovery/test.rs index 54f596ee26..2f854627f1 100644 --- a/comms/dht/src/network_discovery/test.rs +++ b/comms/dht/src/network_discovery/test.rs @@ -28,12 +28,12 @@ use crate::{ test_utils::{build_peer_manager, make_node_identity}, DhtConfig, }; -use futures::StreamExt; use std::{iter, sync::Arc, time::Duration}; use tari_comms::{ connectivity::ConnectivityStatus, peer_manager::{Peer, PeerFeatures}, protocol::rpc::{mock::MockRpcServer, NamedProtocolService}, + runtime, test_utils::{ mocks::{create_connectivity_mock, ConnectivityManagerMockState}, node_identity::build_node_identity, @@ -97,7 +97,7 @@ mod state_machine { ) } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn it_fetches_peers() { const NUM_PEERS: usize = 3; @@ -139,7 +139,7 @@ mod state_machine { mock.get_peers.set_response(Ok(peers)).await; discovery_actor.spawn(); - let event = event_rx.next().await.unwrap().unwrap(); + let event = event_rx.recv().await.unwrap(); unpack_enum!(DhtEvent::NetworkDiscoveryPeersAdded(info) = &*event); assert!(info.has_new_neighbours()); assert_eq!(info.num_new_neighbours, NUM_PEERS); @@ -149,11 +149,11 @@ mod state_machine { assert_eq!(info.sync_peers, vec![peer_node_identity.node_id().clone()]); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_shuts_down() { let (discovery, _, _, _, _, mut shutdown) = setup(Default::default(), make_node_identity(), vec![]).await; - shutdown.trigger().unwrap(); + shutdown.trigger(); tokio::time::timeout(Duration::from_secs(5), discovery.run()) .await .unwrap(); @@ -200,7 +200,7 @@ mod discovery_ready { (node_identity, peer_manager, connectivity_mock, ready, context) } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_begins_aggressive_discovery() { let (_, pm, _, mut ready, _) = setup(Default::default()); let peers = build_many_node_identities(1, PeerFeatures::COMMUNICATION_NODE); @@ -212,14 +212,14 @@ mod discovery_ready { assert!(params.num_peers_to_request.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_no_sync_peers() { let (_, _, _, mut ready, _) = setup(Default::default()); let state_event = ready.next_event().await; unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_idles_if_num_rounds_reached() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, @@ -240,7 +240,7 @@ mod discovery_ready { unpack_enum!(StateEvent::Idle = state_event); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_transitions_to_on_connect() { let config = NetworkDiscoveryConfig { min_desired_peers: 0, diff --git a/comms/dht/src/network_discovery/waiting.rs b/comms/dht/src/network_discovery/waiting.rs index 73e8929ac5..f61dfc6b24 100644 --- a/comms/dht/src/network_discovery/waiting.rs +++ b/comms/dht/src/network_discovery/waiting.rs @@ -46,7 +46,7 @@ impl Waiting { target: LOG_TARGET, "Network discovery is IDLING for {:.0?}", self.duration ); - time::delay_for(self.duration).await; + time::sleep(self.duration).await; debug!(target: LOG_TARGET, "Network discovery resuming"); StateEvent::Ready } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 0aa9fab611..57c423df55 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -39,7 +39,6 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use digest::Digest; use futures::{ - channel::oneshot, future, future::BoxFuture, stream::{self, StreamExt}, @@ -60,6 +59,7 @@ use tari_crypto::{ tari_utilities::{message_format::MessageFormat, ByteArray}, }; use tari_utilities::hex::Hex; +use tokio::sync::oneshot; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::outbound::broadcast_middleware"; @@ -251,11 +251,12 @@ where S: Service is_discovery_enabled, force_origin, dht_header, + tag, } = params; match self.select_peers(broadcast_strategy.clone()).await { Ok(mut peers) => { - if reply_tx.is_canceled() { + if reply_tx.is_closed() { return Err(DhtOutboundError::ReplyChannelCanceled); } @@ -320,6 +321,7 @@ where S: Service is_broadcast, body, Some(expires), + tag, ) .await { @@ -411,6 +413,7 @@ where S: Service is_broadcast: bool, body: Bytes, expires: Option>, + tag: Option, ) -> Result<(Vec, Vec), DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; @@ -424,7 +427,7 @@ where S: Service // Construct a DhtOutboundMessage for each recipient let messages = selected_peers.into_iter().map(|node_id| { let (reply_tx, reply_rx) = oneshot::channel(); - let tag = MessageTag::new(); + let tag = tag.unwrap_or_else(MessageTag::new); let send_state = MessageSendState::new(tag, reply_rx); ( DhtOutboundMessage { @@ -448,7 +451,7 @@ where S: Service Ok(messages.unzip()) } - async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result { + async fn add_to_dedup_cache(&mut self, body: &[u8], public_key: CommsPublicKey) -> Result<(), DhtOutboundError> { let hash = Challenge::new().chain(&body).finalize().to_vec(); trace!( target: LOG_TARGET, @@ -456,10 +459,19 @@ where S: Service hash.to_hex(), ); - self.dht_requester - .insert_message_hash(hash, public_key) + // Do not count messages we've broadcast towards the total hit count + let hit_count = self + .dht_requester + .get_message_cache_hit_count(hash.clone()) .await - .map_err(|_| DhtOutboundError::FailedToInsertMessageHash) + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + if hit_count == 0 { + self.dht_requester + .add_message_to_dedup_cache(hash, public_key) + .await + .map_err(|err| DhtOutboundError::FailedToInsertMessageHash(err.to_string()))?; + } + Ok(()) } fn process_encryption( @@ -525,19 +537,19 @@ mod test { DhtDiscoveryMockState, }, }; - use futures::channel::oneshot; use rand::rngs::OsRng; use std::time::Duration; use tari_comms::{ multiaddr::Multiaddr, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, + runtime, types::CommsPublicKey, }; use tari_crypto::keys::PublicKey; use tari_test_utils::unpack_enum; - use tokio::task; + use tokio::{sync::oneshot, task}; - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_flood() { let pk = CommsPublicKey::default(); let example_peer = Peer::new( @@ -601,7 +613,7 @@ mod test { assert!(requests.iter().any(|msg| msg.destination_node_id == other_peer.node_id)); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_not_found() { // Test for issue https://github.com/tari-project/tari/issues/959 @@ -645,7 +657,7 @@ mod test { assert_eq!(spy.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn send_message_direct_dht_discovery() { let node_identity = NodeIdentity::random( &mut OsRng, diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 3f93dab043..b919e45134 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -20,16 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::outbound::message::SendFailure; -use futures::channel::mpsc::SendError; +use crate::outbound::{message::SendFailure, DhtOutboundRequest}; use tari_comms::message::MessageError; use tari_crypto::{signatures::SchnorrSignatureError, tari_utilities::message_format::MessageFormatError}; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; #[derive(Debug, Error)] pub enum DhtOutboundError { - #[error("SendError: {0}")] - SendError(#[from] SendError), + #[error("`Failed to send: {0}")] + SendError(#[from] SendError), #[error("MessageSerializationError: {0}")] MessageSerializationError(#[from] MessageError), #[error("MessageFormatError: {0}")] @@ -48,8 +48,8 @@ pub enum DhtOutboundError { SendToOurselves, #[error("Discovery process failed")] DiscoveryFailed, - #[error("Failed to insert message hash")] - FailedToInsertMessageHash, + #[error("Failed to insert message hash: {0}")] + FailedToInsertMessageHash(String), #[error("Failed to send message: {0}")] SendMessageFailed(SendFailure), #[error("No messages were queued for sending")] diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index 52356ec364..bb782dc2e5 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -25,7 +25,6 @@ use crate::{ outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, }; use bytes::Bytes; -use futures::channel::oneshot; use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ message::{MessageTag, MessagingReplyTx}, @@ -34,6 +33,7 @@ use tari_comms::{ }; use tari_utilities::hex::Hex; use thiserror::Error; +use tokio::sync::oneshot; /// Determines if an outbound message should be Encrypted and, if so, for which public key #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 0ad00bbc4e..3b38272c38 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -27,7 +27,7 @@ use crate::{ proto::envelope::DhtMessageType, }; use std::{fmt, fmt::Display}; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKey}; /// Configuration for outbound messages. /// @@ -66,6 +66,7 @@ pub struct FinalSendMessageParams { pub dht_message_type: DhtMessageType, pub dht_message_flags: DhtMessageFlags, pub dht_header: Option, + pub tag: Option, } impl Default for FinalSendMessageParams { @@ -79,6 +80,7 @@ impl Default for FinalSendMessageParams { force_origin: false, is_discovery_enabled: false, dht_header: None, + tag: None, } } } @@ -171,6 +173,12 @@ impl SendMessageParams { self } + /// Set the message trace tag + pub fn with_tag(&mut self, tag: MessageTag) -> &mut Self { + self.params_mut().tag = Some(tag); + self + } + /// Set destination field in message header. pub fn with_destination(&mut self, destination: NodeDestination) -> &mut Self { self.params_mut().destination = destination; diff --git a/comms/dht/src/outbound/message_send_state.rs b/comms/dht/src/outbound/message_send_state.rs index 1576e87c70..b3ca43fcb2 100644 --- a/comms/dht/src/outbound/message_send_state.rs +++ b/comms/dht/src/outbound/message_send_state.rs @@ -250,9 +250,9 @@ impl Index for MessageSendStates { #[cfg(test)] mod test { use super::*; - use futures::channel::oneshot; use std::iter::repeat_with; - use tari_comms::message::MessagingReplyTx; + use tari_comms::{message::MessagingReplyTx, runtime}; + use tokio::sync::oneshot; fn create_send_state() -> (MessageSendState, MessagingReplyTx) { let (reply_tx, reply_rx) = oneshot::channel(); @@ -269,7 +269,7 @@ mod test { assert!(!states.is_empty()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn wait_single() { let (state, mut reply_tx) = create_send_state(); let states = MessageSendStates::from(vec![state]); @@ -284,7 +284,7 @@ mod test { assert!(!states.wait_single().await); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_percentage_success() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -300,7 +300,7 @@ mod test { assert_eq!(failed.len(), 4); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_n_timeout() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); @@ -329,7 +329,7 @@ mod test { assert_eq!(failed.len(), 6); } - #[tokio_macros::test_basic] + #[runtime::test] #[allow(clippy::redundant_closure)] async fn wait_all() { let states = repeat_with(|| create_send_state()).take(10).collect::>(); diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index f5c3f30665..66e2b7258e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -31,11 +31,6 @@ use crate::{ }, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - StreamExt, -}; use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, @@ -45,7 +40,10 @@ use tari_comms::{ message::{MessageTag, MessagingReplyTx}, protocol::messaging::SendFailReason, }; -use tokio::time::delay_for; +use tokio::{ + sync::{mpsc, oneshot}, + time::sleep, +}; const LOG_TARGET: &str = "mock::outbound_requester"; @@ -54,7 +52,7 @@ const LOG_TARGET: &str = "mock::outbound_requester"; /// Each time a request is expected, handle_next should be called. pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, OutboundServiceMock) { let (tx, rx) = mpsc::channel(size); - (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx.fuse())) + (OutboundMessageRequester::new(tx), OutboundServiceMock::new(rx)) } #[derive(Clone, Default)] @@ -149,12 +147,12 @@ impl OutboundServiceMockState { } pub struct OutboundServiceMock { - receiver: Fuse>, + receiver: mpsc::Receiver, mock_state: OutboundServiceMockState, } impl OutboundServiceMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, mock_state: OutboundServiceMockState::new(), @@ -166,7 +164,7 @@ impl OutboundServiceMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { let behaviour = self.mock_state.get_behaviour(); @@ -192,7 +190,7 @@ impl OutboundServiceMock { ResponseType::QueuedSuccessDelay(delay) => { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); - delay_for(delay).await; + sleep(delay).await; inner_reply_tx.reply_success(); }, resp => { diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index c1957a98b0..f8536c9d9c 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -32,12 +32,9 @@ use crate::{ MessageSendStates, }, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use log::*; use tari_comms::{message::MessageExt, peer_manager::NodeId, types::CommsPublicKey, wrap_in_envelope_body}; +use tokio::sync::{mpsc, oneshot}; const LOG_TARGET: &str = "comms::dht::requests::outbound"; diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 97c7c0df58..195d2d3d39 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -137,9 +137,9 @@ mod test { use super::*; use crate::test_utils::{assert_send_static_service, create_outbound_message, service_spy}; use prost::Message; - use tari_comms::peer_manager::NodeId; + use tari_comms::{peer_manager::NodeId, runtime}; - #[tokio_macros::test_basic] + #[runtime::test] async fn serialize() { let spy = service_spy(); let mut serialize = SerializeLayer.layer(spy.to_service::()); diff --git a/comms/dht/src/rpc/service.rs b/comms/dht/src/rpc/service.rs index e762ed4d79..84aef6e5ff 100644 --- a/comms/dht/src/rpc/service.rs +++ b/comms/dht/src/rpc/service.rs @@ -24,16 +24,16 @@ use crate::{ proto::rpc::{GetCloserPeersRequest, GetPeersRequest, GetPeersResponse}, rpc::DhtRpcService, }; -use futures::{channel::mpsc, stream, SinkExt}; use log::*; use std::{cmp, sync::Arc}; use tari_comms::{ peer_manager::{NodeId, Peer, PeerFeatures, PeerQuery}, protocol::rpc::{Request, RpcError, RpcStatus, Streaming}, + utils, PeerManager, }; use tari_utilities::ByteArray; -use tokio::task; +use tokio::{sync::mpsc, task}; const LOG_TARGET: &str = "comms::dht::rpc"; @@ -56,17 +56,15 @@ impl DhtRpcServiceImpl { // A maximum buffer size of 10 is selected arbitrarily and is to allow the producer/consumer some room to // buffer. - let (mut tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); + let (tx, rx) = mpsc::channel(cmp::min(10, peers.len() as usize)); task::spawn(async move { let iter = peers .into_iter() .map(|peer| GetPeersResponse { peer: Some(peer.into()), }) - .map(Ok) .map(Ok); - let mut stream = stream::iter(iter); - let _ = tx.send_all(&mut stream).await; + let _ = utils::mpsc::send_all(&tx, iter).await; }); Streaming::new(rx) diff --git a/comms/dht/src/rpc/test.rs b/comms/dht/src/rpc/test.rs index 764d49ba7f..cd70d65a0f 100644 --- a/comms/dht/src/rpc/test.rs +++ b/comms/dht/src/rpc/test.rs @@ -26,13 +26,16 @@ use crate::{ test_utils::build_peer_manager, }; use futures::StreamExt; -use std::{convert::TryInto, sync::Arc}; +use std::{convert::TryInto, sync::Arc, time::Duration}; use tari_comms::{ - peer_manager::{node_id::NodeDistance, PeerFeatures}, + peer_manager::{node_id::NodeDistance, NodeId, Peer, PeerFeatures}, protocol::rpc::{mock::RpcRequestMock, RpcStatusCode}, + runtime, test_utils::node_identity::{build_node_identity, ordered_node_identities_by_distance}, PeerManager, }; +use tari_test_utils::collect_recv; +use tari_utilities::ByteArray; fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { let peer_manager = build_peer_manager(); @@ -45,10 +48,8 @@ fn setup() -> (DhtRpcServiceImpl, RpcRequestMock, Arc) { // Unit tests for get_closer_peers request mod get_closer_peers { use super::*; - use tari_comms::peer_manager::{NodeId, Peer}; - use tari_utilities::ByteArray; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -66,7 +67,7 @@ mod get_closer_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_closest_peers() { let (service, mock, peer_manager) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -83,7 +84,7 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 10); let peers = results @@ -101,7 +102,7 @@ mod get_closer_peers { } } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); @@ -123,7 +124,7 @@ mod get_closer_peers { assert_eq!(results.len(), 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_skips_excluded_peers() { let (service, mock, peer_manager) = setup(); @@ -142,12 +143,12 @@ mod get_closer_peers { let req = mock.request_with_context(node_identity.node_id().clone(), req); let peers_stream = service.get_closer_peers(req).await.unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); let mut peers = results.into_iter().map(Result::unwrap).map(|r| r.peer.unwrap()); assert!(peers.all(|p| p.public_key != excluded_peer.public_key().as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_errors_if_maximum_n_exceeded() { let (service, mock, _) = setup(); let req = GetCloserPeersRequest { @@ -165,9 +166,10 @@ mod get_closer_peers { mod get_peers { use super::*; use crate::proto::rpc::GetPeersRequest; + use std::time::Duration; use tari_comms::{peer_manager::Peer, test_utils::node_identity::build_many_node_identities}; - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_empty_peer_stream() { let (service, mock, _) = setup(); let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -183,7 +185,7 @@ mod get_peers { assert!(next.is_none()); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_all_peers() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -200,7 +202,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 5); let peers = results @@ -214,7 +216,7 @@ mod get_peers { assert_eq!(peers.iter().filter(|p| p.features.is_node()).count(), 3); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_excludes_clients() { let (service, mock, peer_manager) = setup(); let nodes = build_many_node_identities(3, PeerFeatures::COMMUNICATION_NODE); @@ -231,7 +233,7 @@ mod get_peers { .get_peers(mock.request_with_context(Default::default(), req)) .await .unwrap(); - let results = peers_stream.into_inner().collect::>().await; + let results = collect_recv!(peers_stream.into_inner(), timeout = Duration::from_secs(10)); assert_eq!(results.len(), 3); let peers = results @@ -244,7 +246,7 @@ mod get_peers { assert!(peers.iter().all(|p| p.features.is_node())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn it_returns_n_peers() { let (service, mock, peer_manager) = setup(); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs index 856a94315a..ee99f8b560 100644 --- a/comms/dht/src/storage/connection.rs +++ b/comms/dht/src/storage/connection.rs @@ -123,16 +123,17 @@ impl DbConnection { mod test { use super::*; use diesel::{expression::sql_literal::sql, sql_types::Integer, RunQueryDsl}; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn connect_and_migrate() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); let output = conn.migrate().await.unwrap(); assert!(output.starts_with("Running migration")); } - #[tokio_macros::test_basic] + #[runtime::test] async fn memory_connections() { let id = random::string(8); let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs index ab9f52f78d..f5bf4f0596 100644 --- a/comms/dht/src/storage/error.rs +++ b/comms/dht/src/storage/error.rs @@ -40,4 +40,6 @@ pub enum StorageError { ResultError(#[from] diesel::result::Error), #[error("MessageFormatError: {0}")] MessageFormatError(#[from] MessageFormatError), + #[error("Unexpected result: {0}")] + UnexpectedResult(String), } diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index 173d00e0ef..58ee06eb9c 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -255,9 +255,10 @@ impl StoreAndForwardDatabase { #[cfg(test)] mod test { use super::*; + use tari_comms::runtime; use tari_test_utils::random; - #[tokio_macros::test_basic] + #[runtime::test] async fn insert_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -277,7 +278,7 @@ mod test { assert_eq!(messages[1].body_hash, msg2.body_hash); } - #[tokio_macros::test_basic] + #[runtime::test] async fn remove_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); @@ -304,7 +305,7 @@ mod test { assert_eq!(messages[0].id, msg2_id); } - #[tokio_macros::test_basic] + #[runtime::test] async fn truncate_messages() { let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); conn.migrate().await.unwrap(); diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 95ce5e2500..d8de4fe048 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -153,7 +153,7 @@ where S: Service self.forward(&message).await?; } - // The message has been forwarded, but other middleware may be interested (i.e. StoreMiddleware) + // The message has been forwarded, but downstream middleware may be interested trace!( target: LOG_TARGET, "Passing message {} to next service (Trace: {})", @@ -205,8 +205,9 @@ where S: Service } let body = decryption_result - .clone() + .as_ref() .err() + .cloned() .expect("previous check that decryption failed"); let excluded_peers = vec![source_peer.node_id.clone()]; @@ -262,14 +263,13 @@ mod test { outbound::mock::create_outbound_service_mock, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; - use futures::{channel::mpsc, executor::block_on}; - use tari_comms::wrap_in_envelope_body; - use tokio::runtime::Runtime; + use tari_comms::{runtime, runtime::task, wrap_in_envelope_body}; + use tokio::sync::mpsc; - #[test] - fn decryption_succeeded() { + #[runtime::test] + async fn decryption_succeeded() { let spy = service_spy(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (oms_tx, _) = mpsc::channel(1); let oms = OutboundMessageRequester::new(oms_tx); let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::()); @@ -280,18 +280,16 @@ mod test { Some(node_identity.public_key().clone()), inbound_msg, ); - block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); - assert!(oms_rx.try_next().is_err()); } - #[test] - fn decryption_failed() { - let mut rt = Runtime::new().unwrap(); + #[runtime::test] + async fn decryption_failed() { let spy = service_spy(); let (oms_requester, oms_mock) = create_outbound_service_mock(1); let oms_mock_state = oms_mock.get_state(); - rt.spawn(oms_mock.run()); + task::spawn(oms_mock.run()); let mut service = ForwardLayer::new(oms_requester, true).layer(spy.to_service::()); @@ -304,7 +302,7 @@ mod test { ); let header = inbound_msg.dht_header.clone(); let msg = DecryptedDhtMessage::failed(inbound_msg); - rt.block_on(service.call(msg)).unwrap(); + service.call(msg).await.unwrap(); assert!(spy.is_called()); assert_eq!(oms_mock_state.call_count(), 1); diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 50b6ab7839..16e2760a1e 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -27,9 +27,9 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::channel::mpsc; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; +use tokio::sync::mpsc; use tower::layer::Layer; pub struct MessageHandlerLayer { diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 578fc1dcbc..641950e4f1 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -28,12 +28,13 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::{channel::mpsc, future::BoxFuture, task::Context}; +use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, pipeline::PipelineError, }; +use tokio::sync::mpsc; use tower::Service; #[derive(Clone)] diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index f3ba852118..c6224a7af7 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -41,7 +41,7 @@ use crate::{ }; use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; -use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; +use futures::{future, stream, StreamExt}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; @@ -53,6 +53,7 @@ use tari_comms::{ utils::signature, }; use tari_utilities::{convert::try_convert_all, ByteArray}; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::storeforward::handler"; @@ -103,6 +104,20 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); + if message.is_duplicate() { + debug!( + target: LOG_TARGET, + "Received message ({}) that has already been received {} time(s). Last sent by peer '{}', passing on \ + (Trace: {})", + message.tag, + message.dedup_hit_count, + message.source_peer.node_id.short_str(), + message.dht_header.message_tag, + ); + self.next_service.oneshot(message).await?; + return Ok(()); + } + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, @@ -460,7 +475,8 @@ where S: Service public_key: CommsPublicKey, ) -> Result<(), StoreAndForwardError> { let msg_hash = Challenge::new().chain(body).finalize().to_vec(); - if dht_requester.insert_message_hash(msg_hash, public_key).await? { + let hit_count = dht_requester.add_message_to_dedup_cache(msg_hash, public_key).await?; + if hit_count > 1 { Err(StoreAndForwardError::DuplicateMessage) } else { Ok(()) @@ -567,14 +583,13 @@ mod test { }, }; use chrono::{Duration as OldDuration, Utc}; - use futures::channel::mpsc; use prost::Message; use std::time::Duration; - use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex; - use tari_test_utils::collect_stream; + use tari_test_utils::collect_recv; use tari_utilities::hex::Hex; - use tokio::{runtime::Handle, task, time::delay_for}; + use tokio::{runtime::Handle, sync::mpsc, task, time::sleep}; // TODO: unit tests for static functions (check_signature, etc) @@ -602,7 +617,7 @@ mod test { } } - #[tokio_macros::test] + #[tokio::test] async fn request_stored_messages() { let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); @@ -662,7 +677,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); @@ -724,7 +739,7 @@ mod test { if oms_mock_state.call_count() >= 1 { break; } - delay_for(Duration::from_secs(5)).await; + sleep(Duration::from_secs(5)).await; } assert_eq!(oms_mock_state.call_count(), 1); let call = oms_mock_state.pop_call().unwrap(); @@ -750,7 +765,7 @@ mod test { assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } - #[tokio_macros::test_basic] + #[runtime::test] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); @@ -845,7 +860,7 @@ mod test { assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - let signals = collect_stream!( + let signals = collect_recv!( saf_response_signal_receiver, take = 1, timeout = Duration::from_secs(20) diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 5d06d85d56..7c8af0b731 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -36,12 +36,6 @@ use crate::{ DhtRequester, }; use chrono::{DateTime, NaiveDateTime, Utc}; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ @@ -51,7 +45,11 @@ use tari_comms::{ PeerManager, }; use tari_shutdown::ShutdownSignal; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc, oneshot}, + task, + time, +}; const LOG_TARGET: &str = "comms::dht::storeforward::actor"; /// The interval to initiate a database cleanup. @@ -167,13 +165,13 @@ pub struct StoreAndForwardService { dht_requester: DhtRequester, database: StoreAndForwardDatabase, peer_manager: Arc, - connection_events: Fuse, + connection_events: ConnectivityEventRx, outbound_requester: OutboundMessageRequester, - request_rx: Fuse>, - shutdown_signal: Option, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, num_received_saf_responses: Option, num_online_peers: Option, - saf_response_signal_rx: Fuse>, + saf_response_signal_rx: mpsc::Receiver<()>, event_publisher: DhtEventSender, } @@ -196,13 +194,13 @@ impl StoreAndForwardService { database: StoreAndForwardDatabase::new(conn), peer_manager, dht_requester, - request_rx: request_rx.fuse(), - connection_events: connectivity.get_event_subscription().fuse(), + request_rx, + connection_events: connectivity.get_event_subscription(), outbound_requester, - shutdown_signal: Some(shutdown_signal), + shutdown_signal, num_received_saf_responses: Some(0), num_online_peers: None, - saf_response_signal_rx: saf_response_signal_rx.fuse(), + saf_response_signal_rx, event_publisher, } } @@ -213,20 +211,15 @@ impl StoreAndForwardService { } async fn run(mut self) { - let mut shutdown_signal = self - .shutdown_signal - .take() - .expect("StoreAndForwardActor initialized without shutdown_signal"); - - let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL); loop { - futures::select! { - request = self.request_rx.select_next_some() => { + tokio::select! { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - event = self.connection_events.select_next_some() => { + event = self.connection_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connectivity_event(&event).await { error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -234,20 +227,20 @@ impl StoreAndForwardService { } }, - _ = cleanup_ticker.select_next_some() => { + _ = cleanup_ticker.tick() => { if let Err(err) = self.cleanup().await { error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); } }, - _ = self.saf_response_signal_rx.select_next_some() => { + Some(_) = self.saf_response_signal_rx.recv() => { if let Some(n) = self.num_received_saf_responses { self.num_received_saf_responses = Some(n + 1); self.check_saf_response_threshold(); } }, - _ = shutdown_signal => { + _ = self.shutdown_signal.wait() => { info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); break; } diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 4393f36518..e9b88a37ad 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -122,16 +122,31 @@ where } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - Box::pin( - StoreTask::new( - self.next_service.clone(), - self.config.clone(), - Arc::clone(&self.peer_manager), - Arc::clone(&self.node_identity), - self.saf_requester.clone(), + if msg.is_duplicate() { + trace!( + target: LOG_TARGET, + "Passing duplicate message {} to next service (Trace: {})", + msg.tag, + msg.dht_header.message_tag + ); + + let service = self.next_service.clone(); + Box::pin(async move { + let service = service.ready_oneshot().await?; + service.oneshot(msg).await + }) + } else { + Box::pin( + StoreTask::new( + self.next_service.clone(), + self.config.clone(), + Arc::clone(&self.peer_manager), + Arc::clone(&self.node_identity), + self.saf_requester.clone(), + ) + .handle(msg), ) - .handle(msg), - ) + } } } @@ -447,11 +462,11 @@ mod test { }; use chrono::Utc; use std::time::Duration; - use tari_comms::wrap_in_envelope_body; + use tari_comms::{runtime, wrap_in_envelope_body}; use tari_test_utils::async_assert_eventually; use tari_utilities::hex::Hex; - #[tokio_macros::test_basic] + #[runtime::test] async fn cleartext_message_no_origin() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -471,7 +486,7 @@ mod test { assert_eq!(messages.len(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_succeeded_no_store() { let (requester, mock_state) = create_store_and_forward_mock(); @@ -499,7 +514,7 @@ mod test { assert_eq!(mock_state.call_count(), 0); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_should_store() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); @@ -538,7 +553,7 @@ mod test { assert!(duration.num_seconds() <= 5); } - #[tokio_macros::test_basic] + #[runtime::test] async fn decryption_failed_banned_peer() { let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index ccc53c5a1e..4cfd99f209 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -25,26 +25,25 @@ use crate::{ actor::{DhtRequest, DhtRequester}, storage::DhtMetadataKey, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, RwLock, }, }; use tari_comms::peer_manager::Peer; -use tokio::task; +use tokio::{sync::mpsc, task}; pub fn create_dht_actor_mock(buf_size: usize) -> (DhtRequester, DhtActorMock) { let (tx, rx) = mpsc::channel(buf_size); - (DhtRequester::new(tx), DhtActorMock::new(rx.fuse())) + (DhtRequester::new(tx), DhtActorMock::new(rx)) } #[derive(Default, Debug, Clone)] pub struct DhtMockState { - signature_cache_insert: Arc, + signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, settings: Arc>>>, @@ -52,16 +51,11 @@ pub struct DhtMockState { impl DhtMockState { pub fn new() -> Self { - Self { - signature_cache_insert: Arc::new(AtomicBool::new(false)), - call_count: Arc::new(AtomicUsize::new(0)), - select_peers: Arc::new(RwLock::new(Vec::new())), - settings: Arc::new(RwLock::new(HashMap::new())), - } + Default::default() } - pub fn set_signature_cache_insert(&self, v: bool) -> &Self { - self.signature_cache_insert.store(v, Ordering::SeqCst); + pub fn set_number_of_message_hits(&self, v: u32) -> &Self { + self.signature_cache_insert.store(v as usize, Ordering::SeqCst); self } @@ -80,12 +74,12 @@ impl DhtMockState { } pub struct DhtActorMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtMockState, } impl DhtActorMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtMockState::default(), @@ -101,7 +95,7 @@ impl DhtActorMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } @@ -111,9 +105,13 @@ impl DhtActorMock { self.state.inc_call_count(); match req { SendJoin => {}, - MsgHashCacheInsert(_, _, reply_tx) => { + MsgHashCacheInsert { reply_tx, .. } => { + let v = self.state.signature_cache_insert.load(Ordering::SeqCst); + reply_tx.send(v as u32).unwrap(); + }, + GetMsgHashHitCount(_, reply_tx) => { let v = self.state.signature_cache_insert.load(Ordering::SeqCst); - reply_tx.send(v).unwrap(); + reply_tx.send(v as u32).unwrap(); }, SelectPeers(_, reply_tx) => { let lock = self.state.select_peers.read().unwrap(); diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index 70575e2ae0..fbdf8e8284 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -24,7 +24,6 @@ use crate::{ discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester}, test_utils::make_peer, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use std::{ sync::{ @@ -35,15 +34,13 @@ use std::{ time::Duration, }; use tari_comms::peer_manager::Peer; +use tokio::sync::mpsc; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_dht_discovery_mock(buf_size: usize, timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { let (tx, rx) = mpsc::channel(buf_size); - ( - DhtDiscoveryRequester::new(tx, timeout), - DhtDiscoveryMock::new(rx.fuse()), - ) + (DhtDiscoveryRequester::new(tx, timeout), DhtDiscoveryMock::new(rx)) } #[derive(Debug, Clone)] @@ -75,12 +72,12 @@ impl DhtDiscoveryMockState { } pub struct DhtDiscoveryMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: DhtDiscoveryMockState, } impl DhtDiscoveryMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: DhtDiscoveryMockState::new(), @@ -92,7 +89,7 @@ impl DhtDiscoveryMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 0dd464c43a..72c0861a6d 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -23,7 +23,6 @@ use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; use chrono::Utc; use digest::Digest; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; use std::sync::{ @@ -32,14 +31,17 @@ use std::sync::{ }; use tari_comms::types::Challenge; use tari_utilities::hex; -use tokio::{runtime, sync::RwLock}; +use tokio::{ + runtime, + sync::{mpsc, RwLock}, +}; const LOG_TARGET: &str = "comms::dht::discovery_mock"; pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { let (tx, rx) = mpsc::channel(10); - let mock = StoreAndForwardMock::new(rx.fuse()); + let mock = StoreAndForwardMock::new(rx); let state = mock.get_shared_state(); runtime::Handle::current().spawn(mock.run()); (StoreAndForwardRequester::new(tx), state) @@ -90,12 +92,12 @@ impl StoreAndForwardMockState { } pub struct StoreAndForwardMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: StoreAndForwardMockState, } impl StoreAndForwardMock { - pub fn new(receiver: Fuse>) -> Self { + pub fn new(receiver: mpsc::Receiver) -> Self { Self { receiver, state: StoreAndForwardMockState::new(), @@ -107,7 +109,7 @@ impl StoreAndForwardMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/dht/src/tower_filter/predicate.rs b/comms/dht/src/tower_filter/predicate.rs deleted file mode 100644 index f86b9cc406..0000000000 --- a/comms/dht/src/tower_filter/predicate.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::future::Future; -use tari_comms::pipeline::PipelineError; - -/// Checks a request -pub trait Predicate { - /// The future returned by `check`. - type Future: Future>; - - /// Check whether the given request should be forwarded. - /// - /// If the future resolves with `Ok`, the request is forwarded to the inner service. - fn check(&mut self, request: &Request) -> Self::Future; -} - -impl Predicate for F -where - F: Fn(&T) -> U, - U: Future>, -{ - type Future = U; - - fn check(&mut self, request: &T) -> Self::Future { - self(request) - } -} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index a5aed09970..761bb1badd 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -46,6 +45,7 @@ use tari_comms_dht::{ DbConnectionUrl, Dht, DhtBuilder, + DhtConfig, }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::{ @@ -54,16 +54,20 @@ use tari_storage::{ }; use tari_test_utils::{ async_assert_eventually, - collect_stream, + collect_try_recv, paths::create_temporary_data_path, random, streams, unpack_enum, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc}, + time, +}; use tower::ServiceBuilder; struct TestNode { + name: String, comms: CommsNode, dht: Dht, inbound_messages: mpsc::Receiver, @@ -80,12 +84,16 @@ impl TestNode { self.comms.node_identity().to_peer() } + pub fn name(&self) -> &str { + &self.name + } + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { - time::timeout(timeout, self.inbound_messages.next()).await.ok()? + time::timeout(timeout, self.inbound_messages.recv()).await.ok()? } pub async fn shutdown(mut self) { - self.shutdown.trigger().unwrap(); + self.shutdown.trigger(); self.comms.wait_until_shutdown().await; } } @@ -113,24 +121,36 @@ fn create_peer_storage() -> CommsDatabase { LMDBWrapper::new(Arc::new(peer_database)) } -async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { +async fn make_node>( + name: &str, + features: PeerFeatures, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let node_identity = make_node_identity(features); - make_node_with_node_identity(node_identity, seed_peer).await + make_node_with_node_identity(name, node_identity, dht_config, known_peers).await } -async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { +async fn make_node_with_node_identity>( + name: &str, + node_identity: Arc, + dht_config: DhtConfig, + known_peers: I, +) -> TestNode { let (tx, inbound_messages) = mpsc::channel(10); let shutdown = Shutdown::new(); let (comms, dht, messaging_events) = setup_comms_dht( node_identity, create_peer_storage(), tx, - seed_peer.into_iter().collect(), + known_peers.into_iter().collect(), + dht_config, shutdown.to_signal(), ) .await; TestNode { + name: name.to_string(), comms, dht, inbound_messages, @@ -145,6 +165,7 @@ async fn setup_comms_dht( storage: CommsDatabase, inbound_tx: mpsc::Sender, peers: Vec, + dht_config: DhtConfig, shutdown_signal: ShutdownSignal, ) -> (CommsNode, Dht, MessagingEventSender) { // Create inbound and outbound channels @@ -168,11 +189,8 @@ async fn setup_comms_dht( comms.connectivity(), comms.shutdown_signal(), ) - .local_test() - .set_auto_store_and_forward_requests(false) + .with_config(dht_config) .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) - .with_discovery_timeout(Duration::from_secs(60)) - .with_num_neighbouring_nodes(8) .build() .await .unwrap(); @@ -205,17 +223,38 @@ async fn setup_comms_dht( (comms, dht, event_tx) } -#[tokio_macros::test] +fn dht_config() -> DhtConfig { + let mut config = DhtConfig::default_local_test(); + config.allow_test_addresses = true; + config.saf_auto_request = false; + config.discovery_request_timeout = Duration::from_secs(60); + config.num_neighbouring_nodes = 8; + config +} + +#[tokio::test] #[allow(non_snake_case)] async fn dht_join_propagation() { // Create 3 nodes where only Node B knows A and C, but A and C want to talk to each other // Node C knows no one - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A .comms @@ -262,19 +301,37 @@ async fn dht_join_propagation() { node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_discover_propagation() { // Create 4 nodes where A knows B, B knows A and C, C knows B and D, and D knows C // Node D knows no one - let node_D = make_node(PeerFeatures::COMMUNICATION_CLIENT, None).await; + let node_D = make_node("node_D", PeerFeatures::COMMUNICATION_CLIENT, dht_config(), None).await; // Node C knows about Node D - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", node_A.node_identity().node_id().short_str(), @@ -318,14 +375,20 @@ async fn dht_discover_propagation() { assert!(node_D_peer_manager.exists(node_A.node_identity().public_key()).await); } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_store_forward() { let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); // Node B knows about Node C - let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; log::info!( "NodeA = {}, NodeB = {}, Node C = {}", node_A.node_identity().node_id().short_str(), @@ -370,9 +433,10 @@ async fn dht_store_forward() { .unwrap(); // Wait for node B to receive 2 propagation messages - collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); + collect_try_recv!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); - let mut node_C = make_node_with_node_identity(node_C_node_identity, Some(node_B.to_peer())).await; + let mut node_C = + make_node_with_node_identity("node_C", node_C_node_identity, dht_config(), Some(node_B.to_peer())).await; let mut node_C_dht_events = node_C.dht.subscribe_dht_events(); let mut node_C_msg_events = node_C.messaging_events.subscribe(); // Ask node B for messages @@ -389,8 +453,8 @@ async fn dht_store_forward() { .await .unwrap(); // Wait for node C to and receive a response from the SAF request - let event = collect_stream!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_msg_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(MessagingEvent::MessageReceived(_node_id, _msg) = &*event.get(0).unwrap().as_ref()); let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); assert_eq!( @@ -418,25 +482,47 @@ async fn dht_store_forward() { assert!(msgs.is_empty()); // Check that Node C emitted the StoreAndForwardMessagesReceived event when it went Online - let event = collect_stream!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); - unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &**event.get(0).unwrap().as_ref().unwrap()); + let event = collect_try_recv!(node_C_dht_events, take = 1, timeout = Duration::from_secs(20)); + unpack_enum!(DhtEvent::StoreAndForwardMessagesReceived = &*event.get(0).unwrap().as_ref()); node_A.shutdown().await; node_B.shutdown().await; node_C.shutdown().await; } -#[tokio_macros::test] +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_dedup() { + let mut config = dht_config(); + // For this test we want to exactly measure the path of a message, so we disable repropagation of messages (i.e + // allow 1 occurrence) + config.dedup_allowed_message_occurrences = 1; // Node D knows no one - let mut node_D = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let mut node_D = make_node("node_D", PeerFeatures::COMMUNICATION_NODE, config.clone(), None).await; // Node C knows about Node D - let mut node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + let mut node_C = make_node( + "node_C", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_D.to_peer()), + ) + .await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B and C - let mut node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let mut node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + config.clone(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", @@ -482,8 +568,7 @@ async fn dht_propagate_dedup() { .dht .outbound_requester() .propagate( - // Node D is a client node, so an destination is required for domain messages - NodeDestination::Unknown, // NodeId(Box::new(node_D.node_identity().node_id().clone())), + NodeDestination::Unknown, OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), vec![], out_msg, @@ -496,6 +581,7 @@ async fn dht_propagate_dedup() { .await .expect("Node D expected an inbound message but it never arrived"); assert!(msg.decryption_succeeded()); + log::info!("Received message {}", msg.tag); let person = msg .decryption_result .unwrap() @@ -515,35 +601,150 @@ async fn dht_propagate_dedup() { node_D.shutdown().await; // Check the message flow BEFORE deduping - let received = filter_received(collect_stream!(node_A_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_A_messaging, timeout = Duration::from_secs(20))); // Expected race condition: If A->(B|C)->(C|B) before A->(C|B) then (C|B)->A if !received.is_empty() { assert_eq!(count_messages_received(&received, &[&node_B_id, &node_C_id]), 1); } - let received = filter_received(collect_stream!(node_B_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_B_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); // Expected race condition: If A->B->C before A->C then C->B does not happen assert!((1..=2).contains(&recv_count)); - let received = filter_received(collect_stream!(node_C_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); assert_eq!(recv_count, 2); assert_eq!(count_messages_received(&received, &[&node_D_id]), 0); - let received = filter_received(collect_stream!(node_D_messaging, timeout = Duration::from_secs(20))); + let received = filter_received(collect_try_recv!(node_D_messaging, timeout = Duration::from_secs(20))); assert_eq!(received.len(), 1); assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); } -#[tokio_macros::test] +#[tokio::test] +#[allow(non_snake_case)] +async fn dht_repropagate() { + let mut config = dht_config(); + config.dedup_allowed_message_occurrences = 3; + let mut node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, config.clone(), []).await; + let mut node_B = make_node("node_B", PeerFeatures::COMMUNICATION_NODE, config.clone(), [ + node_C.to_peer() + ]) + .await; + let mut node_A = make_node("node_A", PeerFeatures::COMMUNICATION_NODE, config, [ + node_B.to_peer(), + node_C.to_peer(), + ]) + .await; + node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_B.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_A.to_peer()).await.unwrap(); + node_C.comms.peer_manager().add_peer(node_B.to_peer()).await.unwrap(); + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + async fn connect_nodes(node1: &mut TestNode, node2: &mut TestNode) { + node1 + .comms + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + } + // Pre-connect nodes, this helps message passing be more deterministic + connect_nodes(&mut node_A, &mut node_B).await; + connect_nodes(&mut node_A, &mut node_C).await; + connect_nodes(&mut node_B, &mut node_C).await; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + let out_msg = OutboundDomainMessage::new(123, Person { + name: "Alan Turing".into(), + age: 41, + }); + node_A + .dht + .outbound_requester() + .propagate( + NodeDestination::Unknown, + OutboundEncryption::ClearText, + vec![], + out_msg.clone(), + ) + .await + .unwrap(); + + async fn receive_and_repropagate(node: &mut TestNode, out_msg: &OutboundDomainMessage) { + let msg = node + .next_inbound_message(Duration::from_secs(10)) + .await + .unwrap_or_else(|| panic!("{} expected an inbound message but it never arrived", node.name())); + log::info!("Received message {}", msg.tag); + + node.dht + .outbound_requester() + .send_message( + SendMessageParams::new() + .propagate(NodeDestination::Unknown, vec![]) + .with_destination(NodeDestination::Unknown) + .with_tag(msg.tag) + .finish(), + out_msg.clone(), + ) + .await + .unwrap() + .resolve() + .await + .unwrap(); + } + + // This relies on the DHT being set with .with_dedup_discard_hit_count(3) + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + receive_and_repropagate(&mut node_A, &out_msg).await; + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + receive_and_repropagate(&mut node_A, &out_msg).await; + receive_and_repropagate(&mut node_B, &out_msg).await; + receive_and_repropagate(&mut node_C, &out_msg).await; + + node_A.shutdown().await; + node_B.shutdown().await; + node_C.shutdown().await; +} + +#[tokio::test] #[allow(non_snake_case)] async fn dht_propagate_message_contents_not_malleable_ban() { - let node_C = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + let node_C = make_node("node_C", PeerFeatures::COMMUNICATION_NODE, dht_config(), None).await; // Node B knows about Node C - let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + let mut node_B = make_node( + "node_B", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_C.to_peer()), + ) + .await; // Node A knows about Node B - let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + let node_A = make_node( + "node_A", + PeerFeatures::COMMUNICATION_NODE, + dht_config(), + Some(node_B.to_peer()), + ) + .await; node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); log::info!( "NodeA = {}, NodeB = {}", @@ -613,10 +814,10 @@ async fn dht_propagate_message_contents_not_malleable_ban() { let node_B_node_id = node_B.node_identity().node_id().clone(); // Node C should ban node B - let banned_node_id = streams::assert_in_stream( + let banned_node_id = streams::assert_in_broadcast( &mut connectivity_events, - |r| match &*r.unwrap() { - ConnectivityEvent::PeerBanned(node_id) => Some(node_id.clone()), + |r| match r { + ConnectivityEvent::PeerBanned(node_id) => Some(node_id), _ => None, }, Duration::from_secs(10), @@ -629,12 +830,9 @@ async fn dht_propagate_message_contents_not_malleable_ban() { node_C.shutdown().await; } -fn filter_received( - events: Vec, tokio::sync::broadcast::RecvError>>, -) -> Vec> { +fn filter_received(events: Vec>) -> Vec> { events .into_iter() - .map(Result::unwrap) .filter(|e| match &**e { MessagingEvent::MessageReceived(_, _) => true, _ => unreachable!(), diff --git a/comms/examples/stress/error.rs b/comms/examples/stress/error.rs index 5cb9be1cb3..e87aae514e 100644 --- a/comms/examples/stress/error.rs +++ b/comms/examples/stress/error.rs @@ -19,10 +19,10 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -use futures::channel::{mpsc::SendError, oneshot}; use std::io; use tari_comms::{ connectivity::ConnectivityError, + message::OutboundMessage, peer_manager::PeerManagerError, tor, CommsBuilderError, @@ -30,7 +30,11 @@ use tari_comms::{ }; use tari_crypto::tari_utilities::message_format::MessageFormatError; use thiserror::Error; -use tokio::{task, time}; +use tokio::{ + sync::{mpsc::error::SendError, oneshot}, + task, + time, +}; #[derive(Debug, Error)] pub enum Error { @@ -48,12 +52,12 @@ pub enum Error { ConnectivityError(#[from] ConnectivityError), #[error("Message format error: {0}")] MessageFormatError(#[from] MessageFormatError), - #[error("Failed to send message")] - SendError(#[from] SendError), + #[error("Failed to send message: {0}")] + SendError(#[from] SendError), #[error("JoinError: {0}")] JoinError(#[from] task::JoinError), #[error("Example did not exit cleanly: `{0}`")] - WaitTimeout(#[from] time::Elapsed), + WaitTimeout(#[from] time::error::Elapsed), #[error("IO error: {0}")] Io(#[from] io::Error), #[error("User quit")] @@ -63,5 +67,5 @@ pub enum Error { #[error("Unexpected EoF")] UnexpectedEof, #[error("Internal reply canceled")] - ReplyCanceled(#[from] oneshot::Canceled), + ReplyCanceled(#[from] oneshot::error::RecvError), } diff --git a/comms/examples/stress/node.rs b/comms/examples/stress/node.rs index 45ad0919ab..d060d18071 100644 --- a/comms/examples/stress/node.rs +++ b/comms/examples/stress/node.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{error::Error, STRESS_PROTOCOL_NAME, TOR_CONTROL_PORT_ADDR, TOR_SOCKS_ADDR}; -use futures::channel::mpsc; use rand::rngs::OsRng; use std::{convert, net::Ipv4Addr, path::Path, sync::Arc, time::Duration}; use tari_comms::{ @@ -43,7 +42,7 @@ use tari_storage::{ lmdb_store::{LMDBBuilder, LMDBConfig}, LMDBWrapper, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; pub async fn create( node_identity: Option>, diff --git a/comms/examples/stress/service.rs b/comms/examples/stress/service.rs index 3e262cc38b..45e2bc0fd3 100644 --- a/comms/examples/stress/service.rs +++ b/comms/examples/stress/service.rs @@ -23,15 +23,7 @@ use super::error::Error; use crate::stress::{MAX_FRAME_SIZE, STRESS_PROTOCOL_NAME}; use bytes::{Buf, Bytes, BytesMut}; -use futures::{ - channel::{mpsc, oneshot}, - stream, - stream::Fuse, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::{stream, SinkExt, StreamExt}; use rand::{rngs::OsRng, RngCore}; use std::{ iter::repeat_with, @@ -43,12 +35,19 @@ use tari_comms::{ message::{InboundMessage, OutboundMessage}, peer_manager::{NodeId, Peer}, protocol::{ProtocolEvent, ProtocolNotification}, + utils, CommsNode, PeerConnection, Substream, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, task::JoinHandle, time}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot, RwLock}, + task, + task::JoinHandle, + time, +}; pub fn start_service( comms_node: CommsNode, @@ -70,6 +69,7 @@ pub fn start_service( (task::spawn(service.start()), request_tx) } +#[derive(Debug)] pub enum StressTestServiceRequest { BeginProtocol(Peer, StressProtocol, oneshot::Sender>), Shutdown, @@ -135,9 +135,9 @@ impl StressProtocol { } struct StressTestService { - request_rx: Fuse>, + request_rx: mpsc::Receiver, comms_node: CommsNode, - protocol_notif: Fuse>>, + protocol_notif: mpsc::Receiver>, shutdown: bool, inbound_rx: Arc>>, @@ -153,9 +153,9 @@ impl StressTestService { outbound_tx: mpsc::Sender, ) -> Self { Self { - request_rx: request_rx.fuse(), + request_rx, comms_node, - protocol_notif: protocol_notif.fuse(), + protocol_notif, shutdown: false, inbound_rx: Arc::new(RwLock::new(inbound_rx)), outbound_tx, @@ -163,23 +163,21 @@ impl StressTestService { } async fn start(mut self) -> Result<(), Error> { - let mut events = self.comms_node.subscribe_connectivity_events().fuse(); + let mut events = self.comms_node.subscribe_connectivity_events(); loop { - futures::select! { - event = events.select_next_some() => { - if let Ok(event) = event { - println!("{}", event); - } + tokio::select! { + Ok(event) = events.recv() => { + println!("{}", event); }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { if let Err(err) = self.handle_request(request).await { println!("Error: {}", err); } }, - notif = self.protocol_notif.select_next_some() => { + Some(notif) = self.protocol_notif.recv() => { self.handle_protocol_notification(notif).await; }, } @@ -431,7 +429,7 @@ async fn messaging_flood( peer: NodeId, protocol: StressProtocol, inbound_rx: Arc>>, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result<(), Error> { let start = Instant::now(); let mut counter = 1u32; @@ -441,18 +439,15 @@ async fn messaging_flood( protocol.num_messages * protocol.message_size / 1024 / 1024 ); let outbound_task = task::spawn(async move { - let mut iter = stream::iter( - repeat_with(|| { - counter += 1; - - println!("Send MSG {}", counter); - OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) - }) - .take(protocol.num_messages as usize) - .map(Ok), - ); - outbound_tx.send_all(&mut iter).await?; - time::delay_for(Duration::from_secs(5)).await; + let iter = repeat_with(|| { + counter += 1; + + println!("Send MSG {}", counter); + OutboundMessage::new(peer.clone(), generate_message(counter, protocol.message_size as usize)) + }) + .take(protocol.num_messages as usize); + utils::mpsc::send_all(&outbound_tx, iter).await?; + time::sleep(Duration::from_secs(5)).await; outbound_tx .send(OutboundMessage::new(peer.clone(), Bytes::from_static(&[0u8; 4]))) .await?; @@ -462,7 +457,7 @@ async fn messaging_flood( let inbound_task = task::spawn(async move { let mut inbound_rx = inbound_rx.write().await; let mut msgs = vec![]; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { let msg_id = decode_msg(msg.body); println!("GOT MSG {}", msg_id); if msgs.len() == protocol.num_messages as usize { @@ -497,6 +492,6 @@ fn generate_message(n: u32, size: usize) -> Bytes { fn decode_msg(msg: T) -> u32 { let mut buf = [0u8; 4]; - msg.bytes().copy_to_slice(&mut buf); + msg.chunk().copy_to_slice(&mut buf); u32::from_be_bytes(buf) } diff --git a/comms/examples/stress_test.rs b/comms/examples/stress_test.rs index 71c185aa3e..3a0c04f020 100644 --- a/comms/examples/stress_test.rs +++ b/comms/examples/stress_test.rs @@ -24,13 +24,13 @@ mod stress; use stress::{error::Error, prompt::user_prompt}; use crate::stress::{node, prompt::parse_from_short_str, service, service::StressTestServiceRequest}; -use futures::{channel::oneshot, future, future::Either, SinkExt}; +use futures::{future, future::Either}; use std::{env, net::Ipv4Addr, path::Path, process, sync::Arc, time::Duration}; use tari_crypto::tari_utilities::message_format::MessageFormat; use tempfile::Builder; -use tokio::time; +use tokio::{sync::oneshot, time}; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); match run().await { @@ -99,7 +99,7 @@ async fn run() -> Result<(), Error> { } println!("Stress test service started!"); - let (handle, mut requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); + let (handle, requester) = service::start_service(comms_node, protocol_notif, inbound_rx, outbound_tx); let mut last_peer = peer.as_ref().and_then(parse_from_short_str); diff --git a/comms/examples/tor.rs b/comms/examples/tor.rs index 734ef7718a..9186c69d43 100644 --- a/comms/examples/tor.rs +++ b/comms/examples/tor.rs @@ -1,7 +1,6 @@ use anyhow::anyhow; use bytes::Bytes; use chrono::Utc; -use futures::{channel::mpsc, SinkExt, StreamExt}; use rand::{rngs::OsRng, thread_rng, RngCore}; use std::{collections::HashMap, convert::identity, env, net::SocketAddr, path::Path, process, sync::Arc}; use tari_comms::{ @@ -21,7 +20,10 @@ use tari_storage::{ LMDBWrapper, }; use tempfile::Builder; -use tokio::{runtime, sync::broadcast}; +use tokio::{ + runtime, + sync::{broadcast, mpsc}, +}; // Tor example for tari_comms. // @@ -29,7 +31,7 @@ use tokio::{runtime, sync::broadcast}; type Error = anyhow::Error; -#[tokio_macros::main] +#[tokio::main] async fn main() { env_logger::init(); if let Err(err) = run().await { @@ -56,7 +58,7 @@ async fn run() -> Result<(), Error> { println!("Starting comms nodes...",); let temp_dir1 = Builder::new().prefix("tor-example1").tempdir().unwrap(); - let (comms_node1, inbound_rx1, mut outbound_tx1) = setup_node_with_tor( + let (comms_node1, inbound_rx1, outbound_tx1) = setup_node_with_tor( control_port_addr.clone(), temp_dir1.as_ref(), (9098u16, "127.0.0.1:0".parse::().unwrap()), @@ -208,11 +210,11 @@ async fn setup_node_with_tor>( async fn start_ping_ponger( dest_node_id: NodeId, mut inbound_rx: mpsc::Receiver, - mut outbound_tx: mpsc::Sender, + outbound_tx: mpsc::Sender, ) -> Result { let mut inflight_pings = HashMap::new(); let mut counter = 0; - while let Some(msg) = inbound_rx.next().await { + while let Some(msg) = inbound_rx.recv().await { counter += 1; let msg_str = String::from_utf8_lossy(&msg.body); diff --git a/comms/rpc_macros/Cargo.toml b/comms/rpc_macros/Cargo.toml index 3680ed81f9..5bc185bc90 100644 --- a/comms/rpc_macros/Cargo.toml +++ b/comms/rpc_macros/Cargo.toml @@ -13,16 +13,16 @@ edition = "2018" proc-macro = true [dependencies] +tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} + proc-macro2 = "1.0.24" quote = "1.0.7" syn = {version = "1.0.38", features = ["fold"]} -tari_comms = { version = "^0.9", path = "../", features = ["rpc"]} [dev-dependencies] tari_test_utils = {version="^0.9", path="../../infrastructure/test_utils"} futures = "0.3.5" -prost = "0.6.1" -tokio = "0.2.22" -tokio-macros = "0.2.5" +prost = "0.8.0" +tokio = {version = "1", features = ["macros"]} tower-service = "0.3.0" diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index f3e8cbffd1..5f44066f19 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -215,8 +215,8 @@ impl RpcCodeGenerator { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } }; diff --git a/comms/rpc_macros/tests/macro.rs b/comms/rpc_macros/tests/macro.rs index f3f05b4481..b41f3f9914 100644 --- a/comms/rpc_macros/tests/macro.rs +++ b/comms/rpc_macros/tests/macro.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::StreamExt; use prost::Message; use std::{collections::HashMap, ops::AddAssign, sync::Arc}; use tari_comms::{ @@ -34,7 +34,10 @@ use tari_comms::{ }; use tari_comms_rpc_macros::tari_rpc; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; use tower_service::Service; #[tari_rpc(protocol_name = b"/test/protocol/123", server_struct = TestServer, client_struct = TestClient)] @@ -80,7 +83,7 @@ impl Test for TestService { async fn server_streaming(&self, _: Request) -> Result, RpcStatus> { self.add_call("server_streaming").await; - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); tx.send(Ok(1)).await.unwrap(); Ok(Streaming::new(rx)) } @@ -101,7 +104,7 @@ fn it_sets_the_protocol_name() { assert_eq!(TestClient::PROTOCOL_NAME, b"/test/protocol/123"); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_the_correct_type() { let mut server = TestServer::new(TestService::default()); let resp = server @@ -112,7 +115,7 @@ async fn it_returns_the_correct_type() { assert_eq!(u32::decode(v).unwrap(), 12); } -#[tokio_macros::test] +#[tokio::test] async fn it_correctly_maps_the_method_nums() { let service = TestService::default(); let spy = service.state.clone(); @@ -135,7 +138,7 @@ async fn it_correctly_maps_the_method_nums() { assert_eq!(*spy.read().await.get("unit").unwrap(), 1); } -#[tokio_macros::test] +#[tokio::test] async fn it_returns_an_error_for_invalid_method_nums() { let service = TestService::default(); let mut server = TestServer::new(service); @@ -147,7 +150,7 @@ async fn it_returns_an_error_for_invalid_method_nums() { unpack_enum!(RpcStatusCode::UnsupportedMethod = err.status_code()); } -#[tokio_macros::test] +#[tokio::test] async fn it_generates_client_calls() { let (sock_client, sock_server) = MemorySocket::new_pair(); let client = task::spawn(TestClient::connect(framing::canonical(sock_client, 1024))); diff --git a/comms/src/bounded_executor.rs b/comms/src/bounded_executor.rs index ee65e68476..3e36edb50b 100644 --- a/comms/src/bounded_executor.rs +++ b/comms/src/bounded_executor.rs @@ -145,7 +145,15 @@ impl BoundedExecutor { F::Output: Send + 'static, { let span = span!(Level::TRACE, "bounded_executor::waiting_time"); - let permit = self.semaphore.clone().acquire_owned().instrument(span).await; + // SAFETY: acquire_owned only fails if the semaphore is closed (i.e self.semaphore.close() is called) - this + // never happens in this implementation + let permit = self + .semaphore + .clone() + .acquire_owned() + .instrument(span) + .await + .expect("semaphore closed"); self.do_spawn(permit, future) } @@ -230,9 +238,9 @@ mod test { }, time::Duration, }; - use tokio::time::delay_for; + use tokio::time::sleep; - #[runtime::test_basic] + #[runtime::test] async fn spawn() { let flag = Arc::new(AtomicBool::new(false)); let flag_cloned = flag.clone(); @@ -241,7 +249,7 @@ mod test { // Spawn 1 let task1_fut = executor .spawn(async move { - delay_for(Duration::from_millis(1)).await; + sleep(Duration::from_millis(1)).await; flag_cloned.store(true, Ordering::SeqCst); }) .await; diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index 24c856f5ba..abd71e8952 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -46,11 +46,13 @@ use crate::{ CommsBuilder, Substream, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use log::*; use std::{iter, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::node"; diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index f5acc151a9..5da0d51793 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -51,10 +51,9 @@ use crate::{ tor, types::CommsDatabase, }; -use futures::channel::mpsc; use std::{fs::File, sync::Arc}; use tari_shutdown::ShutdownSignal; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; /// The `CommsBuilder` provides a simple builder API for getting Tari comms p2p messaging up and running. pub struct CommsBuilder { diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index d1ae9a0f9a..d4fe8cca97 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -42,19 +42,16 @@ use crate::{ CommsNode, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; +use futures::stream::FuturesUnordered; use std::{collections::HashSet, convert::identity, hash::Hash, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_storage::HashmapDatabase; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, task}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{broadcast, mpsc, oneshot}, + task, +}; async fn spawn_node( protocols: Protocols, @@ -109,7 +106,7 @@ async fn spawn_node( (comms_node, inbound_rx, outbound_tx, messaging_events_sender) } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_custom_protocols() { static TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test"); static ANOTHER_TEST_PROTOCOL: Bytes = Bytes::from_static(b"/tari/test-again"); @@ -161,9 +158,9 @@ async fn peer_to_peer_custom_protocols() { // Check that both nodes get the PeerConnected event. We subscribe after the nodes are initialized // so we miss those events. - let next_event = conn_man_events2.next().await.unwrap().unwrap(); + let next_event = conn_man_events2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn2) = &*next_event); - let next_event = conn_man_events1.next().await.unwrap().unwrap(); + let next_event = conn_man_events1.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn) = &*next_event); // Let's speak both our test protocols @@ -176,7 +173,7 @@ async fn peer_to_peer_custom_protocols() { negotiated_substream2.stream.write_all(ANOTHER_TEST_MSG).await.unwrap(); // Read TEST_PROTOCOL message to node 2 from node 1 - let negotiation = test_protocol_rx2.next().await.unwrap(); + let negotiation = test_protocol_rx2.recv().await.unwrap(); assert_eq!(negotiation.protocol, TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -185,7 +182,7 @@ async fn peer_to_peer_custom_protocols() { assert_eq!(buf, TEST_MSG); // Read ANOTHER_TEST_PROTOCOL message to node 1 from node 2 - let negotiation = another_test_protocol_rx1.next().await.unwrap(); + let negotiation = another_test_protocol_rx1.recv().await.unwrap(); assert_eq!(negotiation.protocol, ANOTHER_TEST_PROTOCOL); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream) = negotiation.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -193,18 +190,18 @@ async fn peer_to_peer_custom_protocols() { substream.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, ANOTHER_TEST_MSG); - shutdown.trigger().unwrap(); + shutdown.trigger(); comms_node1.wait_until_shutdown().await; comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging() { const NUM_MSGS: usize = 100; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, messaging_events2) = + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, messaging_events2) = spawn_node(Protocols::new(), shutdown.to_signal()).await; let mut messaging_events2 = messaging_events2.subscribe(); @@ -238,14 +235,14 @@ async fn peer_to_peer_messaging() { outbound_tx1.send(outbound_msg).await.unwrap(); } - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); let send_results = collect_stream!(replies, take = NUM_MSGS, timeout = Duration::from_secs(10)); send_results.into_iter().for_each(|r| { r.unwrap().unwrap(); }); - let events = collect_stream!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - events.into_iter().map(Result::unwrap).for_each(|m| { + let events = collect_recv!(messaging_events2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + events.into_iter().for_each(|m| { unpack_enum!(MessagingEvent::MessageReceived(_n, _t) = &*m); }); @@ -258,7 +255,7 @@ async fn peer_to_peer_messaging() { outbound_tx2.send(outbound_msg).await.unwrap(); } - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); // Check that we got all the messages let check_messages = |msgs: Vec| { @@ -279,13 +276,13 @@ async fn peer_to_peer_messaging() { comms_node2.wait_until_shutdown().await; } -#[runtime::test_basic] +#[runtime::test] async fn peer_to_peer_messaging_simultaneous() { const NUM_MSGS: usize = 10; let shutdown = Shutdown::new(); - let (comms_node1, mut inbound_rx1, mut outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; - let (comms_node2, mut inbound_rx2, mut outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; + let (comms_node2, mut inbound_rx2, outbound_tx2, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; log::info!( "Peer1 = `{}`, Peer2 = `{}`", @@ -350,8 +347,8 @@ async fn peer_to_peer_messaging_simultaneous() { handle2.await.unwrap(); // Tasks are finished, let's see if all the messages made it though - let messages1_to_2 = collect_stream!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); - let messages2_to_1 = collect_stream!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages1_to_2 = collect_recv!(inbound_rx2, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let messages2_to_1 = collect_recv!(inbound_rx1, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert!(has_unique_elements(messages1_to_2.into_iter().map(|m| m.body))); assert!(has_unique_elements(messages2_to_1.into_iter().map(|m| m.body))); diff --git a/comms/src/compat.rs b/comms/src/compat.rs index 254876f14d..67b53c7c91 100644 --- a/comms/src/compat.rs +++ b/comms/src/compat.rs @@ -27,8 +27,9 @@ use std::{ io, pin::Pin, - task::{self, Poll}, + task::{self, Context, Poll}, }; +use tokio::io::ReadBuf; /// `IoCompat` provides a compatibility shim between the `AsyncRead`/`AsyncWrite` traits provided by /// the `futures` library and those provided by the `tokio` library since they are different and @@ -47,16 +48,16 @@ impl IoCompat { impl tokio::io::AsyncRead for IoCompat where T: futures::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + futures::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf.filled_mut()) } } impl futures::io::AsyncRead for IoCompat where T: tokio::io::AsyncRead + Unpin { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { - tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context, buf: &mut [u8]) -> Poll> { + tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, &mut ReadBuf::new(buf)) } } diff --git a/comms/src/connection_manager/dial_state.rs b/comms/src/connection_manager/dial_state.rs index 0b07378747..07bbd1e631 100644 --- a/comms/src/connection_manager/dial_state.rs +++ b/comms/src/connection_manager/dial_state.rs @@ -24,8 +24,8 @@ use crate::{ connection_manager::{error::ConnectionManagerError, peer_connection::PeerConnection}, peer_manager::Peer, }; -use futures::channel::oneshot; use tari_shutdown::ShutdownSignal; +use tokio::sync::oneshot; /// The state of the dial request pub struct DialState { diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 538ffe89bd..8264286db3 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -39,22 +39,22 @@ use crate::{ types::CommsPublicKey, }; use futures::{ - channel::{mpsc, oneshot}, future, future::{BoxFuture, Either, FusedFuture}, pin_mut, - stream::{Fuse, FuturesUnordered}, - AsyncRead, - AsyncWrite, - AsyncWriteExt, + stream::FuturesUnordered, FutureExt, - SinkExt, - StreamExt, }; use log::*; use std::{collections::HashMap, sync::Arc, time::Duration}; use tari_shutdown::{Shutdown, ShutdownSignal}; -use tokio::{task::JoinHandle, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + sync::{mpsc, oneshot}, + task::JoinHandle, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::dialer"; @@ -79,7 +79,7 @@ pub struct Dialer { transport: TTransport, noise_config: NoiseConfig, backoff: Arc, - request_rx: Fuse>, + request_rx: mpsc::Receiver, cancel_signals: HashMap, conn_man_notifier: mpsc::Sender, shutdown: Option, @@ -112,7 +112,7 @@ where transport, noise_config, backoff: Arc::new(backoff), - request_rx: request_rx.fuse(), + request_rx, cancel_signals: Default::default(), conn_man_notifier, shutdown: Some(shutdown), @@ -139,16 +139,20 @@ where .expect("Establisher initialized without a shutdown"); debug!(target: LOG_TARGET, "Connection dialer started"); loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(&mut pending_dials, request), - (dial_state, dial_result) = pending_dials.select_next_some() => { - self.handle_dial_result(dial_state, dial_result).await; - } - _ = shutdown => { + tokio::select! { + // Biased ordering is used because we already have the futures polled here in a fair order, and so wish to + // forgo the minor cost of the random ordering + biased; + + _ = &mut shutdown => { info!(target: LOG_TARGET, "Connection dialer shutting down because the shutdown signal was received"); self.cancel_all_dials(); break; } + Some((dial_state, dial_result)) = pending_dials.next() => { + self.handle_dial_result(dial_state, dial_result).await; + } + Some(request) = self.request_rx.recv() => self.handle_request(&mut pending_dials, request), } } } @@ -179,12 +183,7 @@ where self.cancel_signals.len() ); self.cancel_signals.drain().for_each(|(_, mut signal)| { - log_if_error_fmt!( - level: warn, - target: LOG_TARGET, - signal.trigger(), - "Shutdown trigger failed", - ); + signal.trigger(); }) } @@ -339,7 +338,7 @@ where } #[allow(clippy::too_many_arguments)] - #[tracing::instrument(skip(peer_manager, socket, conn_man_notifier, config, cancel_signal), err)] + #[tracing::instrument(skip(peer_manager, socket, conn_man_notifier, config, cancel_signal))] async fn perform_socket_upgrade_procedure( peer_manager: Arc, node_identity: Arc, @@ -352,7 +351,6 @@ where cancel_signal: ShutdownSignal, ) -> Result { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; - let mut muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) .await .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; @@ -448,9 +446,9 @@ where current_state.peer.node_id.short_str(), backoff_duration.as_secs() ); - let mut delay = time::delay_for(backoff_duration).fuse(); - let mut cancel_signal = current_state.get_cancel_signal(); - futures::select! { + let delay = time::sleep(backoff_duration).fuse(); + let cancel_signal = current_state.get_cancel_signal(); + tokio::select! { _ = delay => { debug!(target: LOG_TARGET, "[Attempt {}] Connecting to peer '{}'", current_state.num_attempts(), current_state.peer.node_id.short_str()); match Self::dial_peer(current_state, &noise_config, ¤t_transport, config.network_info.network_byte).await { @@ -544,18 +542,13 @@ where // Try the next address continue; }, - Either::Right((cancel_result, _)) => { + // Canceled + Either::Right(_) => { debug!( target: LOG_TARGET, "Dial for peer '{}' cancelled", dial_state.peer.node_id.short_str() ); - log_if_error!( - level: warn, - target: LOG_TARGET, - cancel_result, - "Cancel channel error during dial: {}", - ); Err(ConnectionManagerError::DialCancelled) }, } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index f7bbbaa564..5645a57e62 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -21,13 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ + connection_manager::PeerConnectionRequest, noise, peer_manager::PeerManagerError, protocol::{IdentityProtocolError, ProtocolError}, }; -use futures::channel::mpsc; use thiserror::Error; -use tokio::{time, time::Elapsed}; +use tokio::{sync::mpsc, time::error::Elapsed}; #[derive(Debug, Error, Clone)] pub enum ConnectionManagerError { @@ -108,14 +108,14 @@ pub enum PeerConnectionError { #[error("Internal oneshot reply channel was unexpectedly cancelled")] InternalReplyCancelled, #[error("Failed to send internal request: {0}")] - InternalRequestSendFailed(#[from] mpsc::SendError), + InternalRequestSendFailed(#[from] mpsc::error::SendError), #[error("Protocol error: {0}")] ProtocolError(#[from] ProtocolError), #[error("Protocol negotiation timeout")] ProtocolNegotiationTimeout, } -impl From for PeerConnectionError { +impl From for PeerConnectionError { fn from(_: Elapsed) -> Self { PeerConnectionError::ProtocolNegotiationTimeout } diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index 60ae3c2d12..21c3771610 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -32,7 +32,6 @@ use crate::{ bounded_executor::BoundedExecutor, connection_manager::{ liveness::LivenessSession, - types::OneshotTrigger, wire_mode::{WireMode, LIVENESS_WIRE_MODE}, }, multiaddr::Multiaddr, @@ -46,17 +45,7 @@ use crate::{ utils::multiaddr::multiaddr_to_socketaddr, PeerManager, }; -use futures::{ - channel::mpsc, - future, - AsyncRead, - AsyncReadExt, - AsyncWrite, - AsyncWriteExt, - FutureExt, - SinkExt, - StreamExt, -}; +use futures::{future, FutureExt}; use log::*; use std::{ convert::TryInto, @@ -69,8 +58,13 @@ use std::{ time::Duration, }; use tari_crypto::tari_utilities::hex::Hex; -use tari_shutdown::ShutdownSignal; -use tokio::time; +use tari_shutdown::{oneshot_trigger, oneshot_trigger::OneshotTrigger, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::listener"; @@ -118,7 +112,7 @@ where bounded_executor: BoundedExecutor::from_current(config.max_simultaneous_inbound_connects), liveness_session_count: Arc::new(AtomicUsize::new(config.liveness_max_sessions)), config, - on_listening: OneshotTrigger::new(), + on_listening: oneshot_trigger::channel(), } } @@ -128,7 +122,7 @@ where // 'static lifetime as well as to flatten the oneshot result for ergonomics pub fn on_listening(&self) -> impl Future> + 'static { let signal = self.on_listening.to_signal(); - signal.map(|r| r.map_err(|_| ConnectionManagerError::ListenerOneshotCancelled)?) + signal.map(|r| r.ok_or(ConnectionManagerError::ListenerOneshotCancelled)?) } /// Set the supported protocols of this node to send to peers during the peer identity exchange @@ -147,31 +141,30 @@ where let mut shutdown_signal = self.shutdown_signal.clone(); match self.bind().await { - Ok((inbound, address)) => { + Ok((mut inbound, address)) => { info!(target: LOG_TARGET, "Listening for peer connections on '{}'", address); - self.on_listening.trigger(Ok(address)); - - let inbound = inbound.fuse(); - futures::pin_mut!(inbound); + self.on_listening.broadcast(Ok(address)); loop { - futures::select! { - inbound_result = inbound.select_next_some() => { + tokio::select! { + biased; + + _ = &mut shutdown_signal => { + info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); + break; + }, + Some(inbound_result) = inbound.next() => { if let Some((socket, peer_addr)) = log_if_error!(target: LOG_TARGET, inbound_result, "Inbound connection failed because '{error}'",) { self.spawn_listen_task(socket, peer_addr).await; } }, - _ = shutdown_signal => { - info!(target: LOG_TARGET, "PeerListener is shutting down because the shutdown signal was triggered"); - break; - }, } } }, Err(err) => { warn!(target: LOG_TARGET, "PeerListener was unable to start because '{}'", err); - self.on_listening.trigger(Err(err)); + self.on_listening.broadcast(Err(err)); }, } } @@ -238,7 +231,7 @@ where async fn spawn_listen_task(&self, mut socket: TTransport::Output, peer_addr: Multiaddr) { let node_identity = self.node_identity.clone(); let peer_manager = self.peer_manager.clone(); - let mut conn_man_notifier = self.conn_man_notifier.clone(); + let conn_man_notifier = self.conn_man_notifier.clone(); let noise_config = self.noise_config.clone(); let config = self.config.clone(); let our_supported_protocols = self.our_supported_protocols.clone(); @@ -318,7 +311,7 @@ where "No liveness sessions available or permitted for peer address '{}'", peer_addr ); - let _ = socket.close().await; + let _ = socket.shutdown().await; } }, Err(err) => { diff --git a/comms/src/connection_manager/liveness.rs b/comms/src/connection_manager/liveness.rs index 3c06889307..75ee2db13f 100644 --- a/comms/src/connection_manager/liveness.rs +++ b/comms/src/connection_manager/liveness.rs @@ -20,15 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite, Future, StreamExt}; +use futures::StreamExt; +use std::future::Future; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; /// Max line length accepted by the liveness session. const MAX_LINE_LENGTH: usize = 50; pub struct LivenessSession { - framed: Framed, LinesCodec>, + framed: Framed, } impl LivenessSession @@ -36,7 +37,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin { pub fn new(socket: TSocket) -> Self { Self { - framed: Framed::new(IoCompat::new(socket), LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), + framed: Framed::new(socket, LinesCodec::new_with_max_length(MAX_LINE_LENGTH)), } } @@ -52,13 +53,14 @@ mod test { use crate::{memsocket::MemorySocket, runtime}; use futures::SinkExt; use tokio::{time, time::Duration}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn echos() { let (inbound, outbound) = MemorySocket::new_pair(); let liveness = LivenessSession::new(inbound); let join_handle = runtime::current().spawn(liveness.run()); - let mut outbound = Framed::new(IoCompat::new(outbound), LinesCodec::new()); + let mut outbound = Framed::new(outbound, LinesCodec::new()); for _ in 0..10usize { outbound.send("ECHO".to_string()).await.unwrap() } diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 928b53611f..0c6de18f59 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -36,20 +36,17 @@ use crate::{ transports::{TcpTransport, Transport}, PeerManager, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - AsyncRead, - AsyncWrite, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{fmt, sync::Arc}; use tari_shutdown::{Shutdown, ShutdownSignal}; use time::Duration; -use tokio::{sync::broadcast, task, time}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc, oneshot}, + task, + time, +}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connection_manager::manager"; @@ -156,8 +153,8 @@ impl ListenerInfo { } pub struct ConnectionManager { - request_rx: Fuse>, - internal_event_rx: Fuse>, + request_rx: mpsc::Receiver, + internal_event_rx: mpsc::Receiver, dialer_tx: mpsc::Sender, dialer: Option>, listener: Option>, @@ -230,10 +227,10 @@ where Self { shutdown_signal: Some(shutdown_signal), - request_rx: request_rx.fuse(), + request_rx, peer_manager, protocols: Protocols::new(), - internal_event_rx: internal_event_rx.fuse(), + internal_event_rx, dialer_tx, dialer: Some(dialer), listener: Some(listener), @@ -266,7 +263,7 @@ where .take() .expect("ConnectionManager initialized without a shutdown"); - // Runs the listeners, waiting for a + // Runs the listeners. Sockets are bound and ready once this resolves match self.run_listeners().await { Ok(info) => { self.listener_info = Some(info); @@ -293,16 +290,16 @@ where .join(", ") ); loop { - futures::select! { - event = self.internal_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_event_rx.recv() => { self.handle_event(event).await; }, - request = self.request_rx.select_next_some() => { + Some(request) = self.request_rx.recv() => { self.handle_request(request).await; }, - _ = shutdown => { + _ = &mut shutdown => { info!(target: LOG_TARGET, "ConnectionManager is shutting down because it received the shutdown signal"); break; } diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 6f0d90da5d..cceb6ae9bd 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -44,20 +44,21 @@ use crate::{ protocol::{ProtocolId, ProtocolNegotiation}, runtime, }; -use futures::{ - channel::{mpsc, oneshot}, - stream::Fuse, - SinkExt, - StreamExt, -}; use log::*; use multiaddr::Multiaddr; use std::{ fmt, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + sync::{mpsc, oneshot}, + time, +}; +use tokio_stream::StreamExt; use tracing::{self, span, Instrument, Level, Span}; const LOG_TARGET: &str = "comms::connection_manager::peer_connection"; @@ -130,7 +131,7 @@ pub struct PeerConnection { peer_node_id: NodeId, peer_features: PeerFeatures, request_tx: mpsc::Sender, - address: Multiaddr, + address: Arc, direction: ConnectionDirection, started_at: Instant, substream_counter: SubstreamCounter, @@ -151,7 +152,7 @@ impl PeerConnection { request_tx, peer_node_id, peer_features, - address, + address: Arc::new(address), direction, started_at: Instant::now(), substream_counter, @@ -190,7 +191,7 @@ impl PeerConnection { self.substream_counter.get() } - #[tracing::instrument("peer_connection::open_substream", skip(self), err)] + #[tracing::instrument("peer_connection::open_substream", skip(self))] pub async fn open_substream( &mut self, protocol_id: &ProtocolId, @@ -208,7 +209,7 @@ impl PeerConnection { .map_err(|_| PeerConnectionError::InternalReplyCancelled)? } - #[tracing::instrument("peer_connection::open_framed_substream", skip(self), err)] + #[tracing::instrument("peer_connection::open_framed_substream", skip(self))] pub async fn open_framed_substream( &mut self, protocol_id: &ProtocolId, @@ -219,14 +220,14 @@ impl PeerConnection { } #[cfg(feature = "rpc")] - #[tracing::instrument("peer_connection::connect_rpc", skip(self), fields(peer_node_id = self.peer_node_id.to_string().as_str()), err)] + #[tracing::instrument("peer_connection::connect_rpc", skip(self), fields(peer_node_id = self.peer_node_id.to_string().as_str()))] pub async fn connect_rpc(&mut self) -> Result where T: From + NamedProtocolService { self.connect_rpc_using_builder(Default::default()).await } #[cfg(feature = "rpc")] - #[tracing::instrument("peer_connection::connect_rpc_with_builder", skip(self, builder), err)] + #[tracing::instrument("peer_connection::connect_rpc_with_builder", skip(self, builder))] pub async fn connect_rpc_using_builder(&mut self, builder: RpcClientBuilder) -> Result where T: From + NamedProtocolService { let protocol = ProtocolId::from_static(T::PROTOCOL_NAME); @@ -301,9 +302,9 @@ impl PartialEq for PeerConnection { struct PeerConnectionActor { id: ConnectionId, peer_node_id: NodeId, - request_rx: Fuse>, + request_rx: mpsc::Receiver, direction: ConnectionDirection, - incoming_substreams: Fuse, + incoming_substreams: IncomingSubstreams, control: Control, event_notifier: mpsc::Sender, our_supported_protocols: Vec, @@ -327,8 +328,8 @@ impl PeerConnectionActor { peer_node_id, direction, control: connection.get_yamux_control(), - incoming_substreams: connection.incoming().fuse(), - request_rx: request_rx.fuse(), + incoming_substreams: connection.incoming(), + request_rx, event_notifier, our_supported_protocols, their_supported_protocols, @@ -337,8 +338,8 @@ impl PeerConnectionActor { pub async fn run(mut self) { loop { - futures::select! { - request = self.request_rx.select_next_some() => self.handle_request(request).await, + tokio::select! { + Some(request) = self.request_rx.recv() => self.handle_request(request).await, maybe_substream = self.incoming_substreams.next() => { match maybe_substream { @@ -362,7 +363,7 @@ impl PeerConnectionActor { } } } - self.request_rx.get_mut().close(); + self.request_rx.close(); } async fn handle_request(&mut self, request: PeerConnectionRequest) { @@ -396,7 +397,7 @@ impl PeerConnectionActor { } } - #[tracing::instrument(skip(self, stream), err, fields(comms.direction="inbound"))] + #[tracing::instrument(skip(self, stream),fields(comms.direction="inbound"))] async fn handle_incoming_substream(&mut self, mut stream: Substream) -> Result<(), PeerConnectionError> { let selected_protocol = ProtocolNegotiation::new(&mut stream) .negotiate_protocol_inbound(&self.our_supported_protocols) @@ -412,7 +413,7 @@ impl PeerConnectionActor { Ok(()) } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] async fn open_negotiated_protocol_stream( &mut self, protocol: ProtocolId, diff --git a/comms/src/connection_manager/requester.rs b/comms/src/connection_manager/requester.rs index 1f3f5cc887..0007e59228 100644 --- a/comms/src/connection_manager/requester.rs +++ b/comms/src/connection_manager/requester.rs @@ -25,12 +25,8 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, oneshot}; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] @@ -78,7 +74,7 @@ impl ConnectionManagerRequester { } /// Attempt to connect to a remote peer - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub async fn dial_peer(&mut self, node_id: NodeId) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.send_dial_peer(node_id, Some(reply_tx)).await?; @@ -97,7 +93,7 @@ impl ConnectionManagerRequester { } /// Send instruction to ConnectionManager to dial a peer and return the result on the given oneshot - #[tracing::instrument(skip(self, reply_tx), err)] + #[tracing::instrument(skip(self, reply_tx))] pub(crate) async fn send_dial_peer( &mut self, node_id: NodeId, @@ -124,7 +120,7 @@ impl ConnectionManagerRequester { } /// Send instruction to ConnectionManager to dial a peer without waiting for a result. - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub(crate) async fn send_dial_peer_no_reply(&mut self, node_id: NodeId) -> Result<(), ConnectionManagerError> { self.send_dial_peer(node_id, None).await?; Ok(()) diff --git a/comms/src/connection_manager/tests/listener_dialer.rs b/comms/src/connection_manager/tests/listener_dialer.rs index 586c0d5ec4..25e715bf50 100644 --- a/comms/src/connection_manager/tests/listener_dialer.rs +++ b/comms/src/connection_manager/tests/listener_dialer.rs @@ -36,20 +36,17 @@ use crate::{ test_utils::{node_identity::build_node_identity, test_node::build_peer_manager}, transports::MemoryTransport, }; -use futures::{ - channel::{mpsc, oneshot}, - AsyncReadExt, - AsyncWriteExt, - SinkExt, - StreamExt, -}; use multiaddr::Protocol; use std::{error::Error, time::Duration}; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::time::timeout; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot}, + time::timeout, +}; -#[runtime::test_basic] +#[runtime::test] async fn listen() -> Result<(), Box> { let (event_tx, _) = mpsc::channel(1); let mut shutdown = Shutdown::new(); @@ -61,7 +58,7 @@ async fn listen() -> Result<(), Box> { "/memory/0".parse()?, MemoryTransport, noise_config.clone(), - event_tx.clone(), + event_tx, peer_manager, node_identity, shutdown.to_signal(), @@ -72,12 +69,12 @@ async fn listen() -> Result<(), Box> { unpack_enum!(Protocol::Memory(port) = bind_addr.pop().unwrap()); assert!(port > 0); - shutdown.trigger().unwrap(); + shutdown.trigger(); Ok(()) } -#[runtime::test_basic] +#[runtime::test] async fn smoke() { let rt_handle = runtime::current(); // This test sets up Dialer and Listener components, uses the Dialer to dial the Listener, @@ -108,7 +105,7 @@ async fn smoke() { let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -148,11 +145,11 @@ async fn smoke() { } // Read PeerConnected events - we don't know which connection is which - unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.next().await.unwrap()); - unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(conn1) = event_rx.recv().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerConnected(_conn2) = event_rx.recv().await.unwrap()); // Next event should be a NewInboundSubstream has been received - let listen_event = event_rx.next().await.unwrap(); + let listen_event = event_rx.recv().await.unwrap(); { unpack_enum!(ConnectionManagerEvent::NewInboundSubstream(node_id, proto, in_stream) = listen_event); assert_eq!(&*node_id, node_identity2.node_id()); @@ -165,7 +162,7 @@ async fn smoke() { conn1.disconnect().await.unwrap(); - shutdown.trigger().unwrap(); + shutdown.trigger(); let peer2 = peer_manager1.find_by_node_id(node_identity2.node_id()).await.unwrap(); let peer1 = peer_manager2.find_by_node_id(node_identity1.node_id()).await.unwrap(); @@ -176,7 +173,7 @@ async fn smoke() { timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn banned() { let rt_handle = runtime::current(); let (event_tx, mut event_rx) = mpsc::channel(10); @@ -209,7 +206,7 @@ async fn banned() { peer_manager1.add_peer(peer).await.unwrap(); let noise_config2 = NoiseConfig::new(node_identity2.clone()); - let (mut request_tx, request_rx) = mpsc::channel(1); + let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let mut dialer = Dialer::new( ConnectionManagerConfig::default(), @@ -241,10 +238,10 @@ async fn banned() { let err = reply_rx.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::IdentityProtocolError(_err) = err); - unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.next().await.unwrap()); + unpack_enum!(ConnectionManagerEvent::PeerInboundConnectFailed(err) = event_rx.recv().await.unwrap()); unpack_enum!(ConnectionManagerError::PeerBanned = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); timeout(Duration::from_secs(5), dialer_fut).await.unwrap().unwrap(); } diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index bca876eff4..910280b4cf 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -41,19 +41,17 @@ use crate::{ }, transports::{MemoryTransport, TcpTransport}, }; -use futures::{ - channel::{mpsc, oneshot}, - future, - AsyncReadExt, - AsyncWriteExt, - StreamExt, -}; +use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{runtime::Handle, sync::broadcast}; +use tari_test_utils::{collect_try_recv, unpack_enum}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + runtime::Handle, + sync::{broadcast, mpsc, oneshot}, +}; -#[runtime::test_basic] +#[runtime::test] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); @@ -83,10 +81,10 @@ async fn connect_to_nonexistent_peer() { unpack_enum!(ConnectionManagerError::PeerManagerError(err) = err); unpack_enum!(PeerManagerError::PeerNotFoundError = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -159,7 +157,7 @@ async fn dial_success() { assert_eq!(peer2.supported_protocols, [&IDENTITY_PROTOCOL, &TEST_PROTO]); assert_eq!(peer2.user_agent, "node2"); - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); unpack_enum!(ConnectionManagerEvent::PeerConnected(conn_in) = &*event); assert_eq!(conn_in.peer_node_id(), node_identity1.node_id()); @@ -179,7 +177,7 @@ async fn dial_success() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx2.next().await.unwrap(); + let protocol_in = proto_rx2.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity1.node_id()); @@ -189,7 +187,7 @@ async fn dial_success() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn dial_success_aux_tcp_listener() { static TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -271,7 +269,7 @@ async fn dial_success_aux_tcp_listener() { const MSG: &[u8] = b"Welease Woger!"; substream_out.stream.write_all(MSG).await.unwrap(); - let protocol_in = proto_rx1.next().await.unwrap(); + let protocol_in = proto_rx1.recv().await.unwrap(); assert_eq!(protocol_in.protocol, &TEST_PROTO); unpack_enum!(ProtocolEvent::NewInboundSubstream(node_id, substream_in) = protocol_in.event); assert_eq!(&node_id, node_identity2.node_id()); @@ -281,7 +279,7 @@ async fn dial_success_aux_tcp_listener() { assert_eq!(buf, MSG); } -#[runtime::test_basic] +#[runtime::test] async fn simultaneous_dial_events() { let mut shutdown = Shutdown::new(); @@ -360,29 +358,22 @@ async fn simultaneous_dial_events() { _ => panic!("unexpected simultaneous dial result"), } - let event = subscription2.next().await.unwrap().unwrap(); + let event = subscription2.recv().await.unwrap(); assert!(count_string_occurrences(&[event], &["PeerConnected", "PeerInboundConnectFailed"]) >= 1); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); drop(conn_man2); - let _events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); - - let _events2 = collect_stream!(subscription2, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let _events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); + let _events2 = collect_try_recv!(subscription2, timeout = Duration::from_secs(5)); // TODO: Investigate why two PeerDisconnected events are sometimes received // assert!(count_string_occurrences(&events1, &["PeerDisconnected"]) >= 1); // assert!(count_string_occurrences(&events2, &["PeerDisconnected"]) >= 1); } -#[tokio_macros::test_basic] +#[runtime::test] async fn dial_cancelled() { let mut shutdown = Shutdown::new(); @@ -429,13 +420,10 @@ async fn dial_cancelled() { let err = dial_result.await.unwrap().unwrap_err(); unpack_enum!(ConnectionManagerError::DialCancelled = err); - shutdown.trigger().unwrap(); + shutdown.trigger(); drop(conn_man1); - let events1 = collect_stream!(subscription1, timeout = Duration::from_secs(5)) - .into_iter() - .map(Result::unwrap) - .collect::>(); + let events1 = collect_try_recv!(subscription1, timeout = Duration::from_secs(5)); assert_eq!(events1.len(), 1); unpack_enum!(ConnectionManagerEvent::PeerConnectFailed(node_id, err) = &*events1[0]); diff --git a/comms/src/connection_manager/types.rs b/comms/src/connection_manager/types.rs index ddb8f6de8f..c92b2b717a 100644 --- a/comms/src/connection_manager/types.rs +++ b/comms/src/connection_manager/types.rs @@ -20,11 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{ - channel::oneshot, - future::{Fuse, Shared}, - FutureExt, -}; use std::fmt; /// Direction of the connection relative to this node @@ -47,29 +42,3 @@ impl fmt::Display for ConnectionDirection { write!(f, "{:?}", self) } } - -pub type OneshotSignal = Shared>>; -pub struct OneshotTrigger(Option>, OneshotSignal); - -impl OneshotTrigger { - pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self(Some(tx), rx.fuse().shared()) - } - - pub fn to_signal(&self) -> OneshotSignal { - self.1.clone() - } - - pub fn trigger(&mut self, item: T) { - if let Some(tx) = self.0.take() { - let _ = tx.send(item); - } - } -} - -impl Default for OneshotTrigger { - fn default() -> Self { - Self::new() - } -} diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index 35f37627c4..821253e84b 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -34,6 +34,7 @@ use crate::{ ConnectionManagerEvent, ConnectionManagerRequester, }, + connectivity::ConnectivityEventTx, peer_manager::NodeId, runtime::task, utils::datetime::format_duration, @@ -41,7 +42,6 @@ use crate::{ PeerConnection, PeerManager, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; use log::*; use nom::lib::std::collections::hash_map::Entry; use std::{ @@ -52,7 +52,7 @@ use std::{ time::{Duration, Instant}, }; use tari_shutdown::ShutdownSignal; -use tokio::{sync::broadcast, task::JoinHandle, time}; +use tokio::{sync::mpsc, task::JoinHandle, time}; use tracing::{span, Instrument, Level}; const LOG_TARGET: &str = "comms::connectivity::manager"; @@ -72,7 +72,7 @@ const LOG_TARGET: &str = "comms::connectivity::manager"; pub struct ConnectivityManager { pub config: ConnectivityConfig, pub request_rx: mpsc::Receiver, - pub event_tx: broadcast::Sender>, + pub event_tx: ConnectivityEventTx, pub connection_manager: ConnectionManagerRequester, pub peer_manager: Arc, pub node_identity: Arc, @@ -84,7 +84,7 @@ impl ConnectivityManager { ConnectivityManagerActor { config: self.config, status: ConnectivityStatus::Initializing, - request_rx: self.request_rx.fuse(), + request_rx: self.request_rx, connection_manager: self.connection_manager, peer_manager: self.peer_manager.clone(), event_tx: self.event_tx, @@ -140,12 +140,12 @@ impl fmt::Display for ConnectivityStatus { pub struct ConnectivityManagerActor { config: ConnectivityConfig, status: ConnectivityStatus, - request_rx: Fuse>, + request_rx: mpsc::Receiver, connection_manager: ConnectionManagerRequester, node_identity: Arc, shutdown_signal: Option, peer_manager: Arc, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, connection_stats: HashMap, managed_peers: Vec, @@ -165,7 +165,7 @@ impl ConnectivityManagerActor { .take() .expect("ConnectivityManager initialized without a shutdown_signal"); - let mut connection_manager_events = self.connection_manager.get_event_subscription().fuse(); + let mut connection_manager_events = self.connection_manager.get_event_subscription(); let interval = self.config.connection_pool_refresh_interval; let mut ticker = time::interval_at( @@ -174,18 +174,17 @@ impl ConnectivityManagerActor { .expect("connection_pool_refresh_interval cause overflow") .into(), interval, - ) - .fuse(); + ); self.publish_event(ConnectivityEvent::ConnectivityStateInitialized); loop { - futures::select! { - req = self.request_rx.select_next_some() => { + tokio::select! { + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; }, - event = connection_manager_events.select_next_some() => { + event = connection_manager_events.recv() => { if let Ok(event) = event { if let Err(err) = self.handle_connection_manager_event(&event).await { error!(target:LOG_TARGET, "Error handling connection manager event: {:?}", err); @@ -193,13 +192,13 @@ impl ConnectivityManagerActor { } }, - _ = ticker.next() => { + _ = ticker.tick() => { if let Err(err) = self.refresh_connection_pool().await { error!(target: LOG_TARGET, "Error when refreshing connection pools: {:?}", err); } }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "ConnectivityManager is shutting down because it received the shutdown signal"); self.disconnect_all().await; break; @@ -823,7 +822,7 @@ impl ConnectivityManagerActor { fn publish_event(&mut self, event: ConnectivityEvent) { // A send operation can only fail if there are no subscribers, so it is safe to ignore the error - let _ = self.event_tx.send(Arc::new(event)); + let _ = self.event_tx.send(event); } async fn ban_peer( @@ -863,7 +862,7 @@ impl ConnectivityManagerActor { fn delayed_close(conn: PeerConnection, delay: Duration) { task::spawn(async move { - time::delay_for(delay).await; + time::sleep(delay).await; debug!( target: LOG_TARGET, "Closing connection from peer `{}` after delay", diff --git a/comms/src/connectivity/requester.rs b/comms/src/connectivity/requester.rs index 740c8a6c81..8fb86320b4 100644 --- a/comms/src/connectivity/requester.rs +++ b/comms/src/connectivity/requester.rs @@ -31,23 +31,21 @@ use crate::{ peer_manager::NodeId, PeerConnection, }; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, - StreamExt, -}; use log::*; use std::{ fmt, - sync::Arc, time::{Duration, Instant}, }; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, broadcast::error::RecvError, mpsc, oneshot}, + time, +}; + const LOG_TARGET: &str = "comms::connectivity::requester"; use tracing; -pub type ConnectivityEventRx = broadcast::Receiver>; -pub type ConnectivityEventTx = broadcast::Sender>; +pub type ConnectivityEventRx = broadcast::Receiver; +pub type ConnectivityEventTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ConnectivityEvent { @@ -128,7 +126,7 @@ impl ConnectivityRequester { self.event_tx.clone() } - #[tracing::instrument(skip(self), err)] + #[tracing::instrument(skip(self))] pub async fn dial_peer(&mut self, peer: NodeId) -> Result { let mut num_cancels = 0; loop { @@ -264,24 +262,23 @@ impl ConnectivityRequester { let mut last_known_peer_count = status.num_connected_nodes(); loop { debug!(target: LOG_TARGET, "Waiting for connectivity event"); - let recv_result = time::timeout(remaining, connectivity_events.next()) + let recv_result = time::timeout(remaining, connectivity_events.recv()) .await - .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))? - .ok_or(ConnectivityError::ConnectivityEventStreamClosed)?; + .map_err(|_| ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; remaining = timeout .checked_sub(start.elapsed()) .ok_or(ConnectivityError::OnlineWaitTimeout(last_known_peer_count))?; match recv_result { - Ok(event) => match &*event { + Ok(event) => match event { ConnectivityEvent::ConnectivityStateOnline(_) => { info!(target: LOG_TARGET, "Connectivity is ONLINE."); break Ok(()); }, ConnectivityEvent::ConnectivityStateDegraded(n) => { warn!(target: LOG_TARGET, "Connectivity is DEGRADED ({} peer(s))", n); - last_known_peer_count = *n; + last_known_peer_count = n; }, ConnectivityEvent::ConnectivityStateOffline => { warn!( @@ -297,14 +294,14 @@ impl ConnectivityRequester { ); }, }, - Err(broadcast::RecvError::Closed) => { + Err(RecvError::Closed) => { error!( target: LOG_TARGET, "Connectivity event stream closed unexpectedly. System may be shutting down." ); break Err(ConnectivityError::ConnectivityEventStreamClosed); }, - Err(broadcast::RecvError::Lagged(n)) => { + Err(RecvError::Lagged(n)) => { warn!(target: LOG_TARGET, "Lagging behind on {} connectivity event(s)", n); // We lagged, so could have missed the state change. Check it explicitly. let status = self.get_connectivity_status().await?; diff --git a/comms/src/connectivity/selection.rs b/comms/src/connectivity/selection.rs index 931b78163d..c300d2d353 100644 --- a/comms/src/connectivity/selection.rs +++ b/comms/src/connectivity/selection.rs @@ -135,8 +135,8 @@ mod test { peer_manager::node_id::NodeDistance, test_utils::{mocks::create_dummy_peer_connection, node_id, node_identity::build_node_identity}, }; - use futures::channel::mpsc; use std::iter::repeat_with; + use tokio::sync::mpsc; fn create_pool_with_connections(n: usize) -> (ConnectionPool, Vec>) { let mut pool = ConnectionPool::new(); diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index a4fec1e896..948d083e94 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -28,6 +28,7 @@ use super::{ }; use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerEvent}, + connectivity::ConnectivityEventRx, peer_manager::{Peer, PeerFeatures}, runtime, runtime::task, @@ -39,18 +40,18 @@ use crate::{ NodeIdentity, PeerManager, }; -use futures::{channel::mpsc, future}; +use futures::future; use std::{sync::Arc, time::Duration}; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, streams, unpack_enum}; -use tokio::sync::broadcast; +use tari_test_utils::{collect_try_recv, streams, unpack_enum}; +use tokio::sync::{broadcast, mpsc}; #[allow(clippy::type_complexity)] fn setup_connectivity_manager( config: ConnectivityConfig, ) -> ( ConnectivityRequester, - broadcast::Receiver>, + ConnectivityEventRx, Arc, Arc, ConnectionManagerMockState, @@ -100,7 +101,7 @@ async fn add_test_peers(peer_manager: &PeerManager, n: usize) -> Vec { peers } -#[runtime::test_basic] +#[runtime::test] async fn connecting_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -117,15 +118,15 @@ async fn connecting_peers() { .map(|(_, _, conn, _)| conn) .collect::>(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // All connections succeeded for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } - let _events = collect_stream!(event_stream, take = 11, timeout = Duration::from_secs(10)); + let _events = collect_try_recv!(event_stream, take = 11, timeout = Duration::from_secs(10)); let connection_states = connectivity.get_all_connection_states().await.unwrap(); assert_eq!(connection_states.len(), 10); @@ -135,7 +136,7 @@ async fn connecting_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn add_many_managed_peers() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -156,8 +157,8 @@ async fn add_many_managed_peers() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // First 5 succeeded for conn in &connections { @@ -172,10 +173,10 @@ async fn add_many_managed_peers() { )); } - let events = collect_stream!(event_stream, take = 9, timeout = Duration::from_secs(10)); + let events = collect_try_recv!(event_stream, take = 9, timeout = Duration::from_secs(10)); let n = events .iter() - .find_map(|event| match &**event.as_ref().unwrap() { + .find_map(|event| match event { ConnectivityEvent::ConnectivityStateOnline(n) => Some(n), ConnectivityEvent::ConnectivityStateDegraded(_) => None, ConnectivityEvent::PeerConnected(_) => None, @@ -205,7 +206,7 @@ async fn add_many_managed_peers() { } } -#[runtime::test_basic] +#[runtime::test] async fn online_then_offline() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); @@ -244,8 +245,8 @@ async fn online_then_offline() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); for conn in connections.iter().skip(1) { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); @@ -269,9 +270,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateDegraded(2) => Some(()), _ => None, }, @@ -289,9 +290,9 @@ async fn online_then_offline() { )); } - streams::assert_in_stream( + streams::assert_in_broadcast( &mut event_stream, - |item| match &*item.unwrap() { + |item| match item { ConnectivityEvent::ConnectivityStateOffline => Some(()), _ => None, }, @@ -303,20 +304,20 @@ async fn online_then_offline() { assert!(is_offline); } -#[runtime::test_basic] +#[runtime::test] async fn ban_peer() { let (mut connectivity, mut event_stream, node_identity, peer_manager, cm_mock_state, _shutdown) = setup_connectivity_manager(Default::default()); let peer = add_test_peers(&peer_manager, 1).await.pop().unwrap(); let (conn, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); - let mut events = collect_stream!(event_stream, take = 2, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = &*events.remove(0).unwrap()); - unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 2, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::PeerConnected(_conn) = events.remove(0)); + unpack_enum!(ConnectivityEvent::ConnectivityStateOnline(_n) = events.remove(0)); let conn = connectivity.get_connection(peer.node_id.clone()).await.unwrap(); assert!(conn.is_some()); @@ -329,13 +330,12 @@ async fn ban_peer() { // We can always expect a single PeerBanned because we do not publish a disconnected event from the connection // manager In a real system, peer disconnect and peer banned events may happen in any order and should always be // completely fine. - let event = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)) + let event = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)) .pop() - .unwrap() .unwrap(); - unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = &*event); - assert_eq!(node_id, &peer.node_id); + unpack_enum!(ConnectivityEvent::PeerBanned(node_id) = event); + assert_eq!(node_id, peer.node_id); let peer = peer_manager.find_by_node_id(&peer.node_id).await.unwrap(); assert!(peer.is_banned()); @@ -344,7 +344,7 @@ async fn ban_peer() { assert!(conn.is_none()); } -#[runtime::test_basic] +#[runtime::test] async fn peer_selection() { let config = ConnectivityConfig { min_connectivity: 1.0, @@ -370,15 +370,15 @@ async fn peer_selection() { .await .unwrap(); - let mut events = collect_stream!(event_stream, take = 1, timeout = Duration::from_secs(10)); - unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = &*events.remove(0).unwrap()); + let mut events = collect_try_recv!(event_stream, take = 1, timeout = Duration::from_secs(10)); + unpack_enum!(ConnectivityEvent::ConnectivityStateInitialized = events.remove(0)); // 10 connections for conn in &connections { cm_mock_state.publish_event(ConnectionManagerEvent::PeerConnected(conn.clone())); } // Wait for all peers to be connected (i.e. for the connection manager events to be received) - let mut _events = collect_stream!(event_stream, take = 12, timeout = Duration::from_secs(10)); + let mut _events = collect_try_recv!(event_stream, take = 12, timeout = Duration::from_secs(10)); let conns = connectivity .select_connections(ConnectivitySelection::random_nodes(10, vec![connections[0] diff --git a/comms/src/framing.rs b/comms/src/framing.rs index 1e6b67691e..06ccc00c30 100644 --- a/comms/src/framing.rs +++ b/comms/src/framing.rs @@ -20,17 +20,16 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::compat::IoCompat; -use futures::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; /// Tari comms canonical framing -pub type CanonicalFraming = Framed, LengthDelimitedCodec>; +pub type CanonicalFraming = Framed; pub fn canonical(stream: T, max_frame_len: usize) -> CanonicalFraming where T: AsyncRead + AsyncWrite + Unpin { Framed::new( - IoCompat::new(stream), + stream, LengthDelimitedCodec::builder() .max_frame_length(max_frame_len) .new_codec(), diff --git a/comms/src/lib.rs b/comms/src/lib.rs index 4f3bb178d2..de48e316e5 100644 --- a/comms/src/lib.rs +++ b/comms/src/lib.rs @@ -32,21 +32,19 @@ pub use peer_manager::{NodeIdentity, PeerManager}; pub mod framing; -mod common; -pub use common::rate_limit; +pub mod rate_limit; mod multiplexing; pub use multiplexing::Substream; mod noise; mod proto; -mod runtime; pub mod backoff; pub mod bounded_executor; -pub mod compat; pub mod memsocket; pub mod protocol; +pub mod runtime; #[macro_use] pub mod message; pub mod net_address; diff --git a/comms/src/memsocket/mod.rs b/comms/src/memsocket/mod.rs index 3d102bf8dc..ed77fc6146 100644 --- a/comms/src/memsocket/mod.rs +++ b/comms/src/memsocket/mod.rs @@ -26,17 +26,21 @@ use bytes::{Buf, Bytes}; use futures::{ channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, - io::{AsyncRead, AsyncWrite, Error, ErrorKind, Result}, ready, stream::{FusedStream, Stream}, task::{Context, Poll}, }; use std::{ + cmp, collections::{hash_map::Entry, HashMap}, num::NonZeroU16, pin::Pin, sync::Mutex, }; +use tokio::{ + io, + io::{AsyncRead, AsyncWrite, ErrorKind, ReadBuf}, +}; lazy_static! { static ref SWITCHBOARD: Mutex = Mutex::new(SwitchBoard(HashMap::default(), 1)); @@ -114,6 +118,7 @@ pub fn release_memsocket_port(port: NonZeroU16) { /// use std::io::Result; /// /// use tari_comms::memsocket::{MemoryListener, MemorySocket}; +/// use tokio::io::*; /// use futures::prelude::*; /// /// async fn write_stormlight(mut stream: MemorySocket) -> Result<()> { @@ -170,7 +175,7 @@ impl MemoryListener { /// ``` /// /// [`local_addr`]: #method.local_addr - pub fn bind(port: u16) -> Result { + pub fn bind(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Get the port we should bind to. If 0 was given, use a random port @@ -262,11 +267,11 @@ impl MemoryListener { Incoming { inner: self } } - fn poll_accept(&mut self, context: &mut Context) -> Poll> { + fn poll_accept(&mut self, context: &mut Context) -> Poll> { match Pin::new(&mut self.incoming).poll_next(context) { Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)), Poll::Ready(None) => { - let err = Error::new(ErrorKind::Other, "MemoryListener unknown error"); + let err = io::Error::new(ErrorKind::Other, "MemoryListener unknown error"); Poll::Ready(Err(err)) }, Poll::Pending => Poll::Pending, @@ -283,7 +288,7 @@ pub struct Incoming<'a> { } impl<'a> Stream for Incoming<'a> { - type Item = Result; + type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { let socket = ready!(self.inner.poll_accept(context)?); @@ -302,6 +307,7 @@ impl<'a> Stream for Incoming<'a> { /// /// ```rust, no_run /// use futures::prelude::*; +/// use tokio::io::*; /// use tari_comms::memsocket::MemorySocket; /// /// # async fn run() -> ::std::io::Result<()> { @@ -371,7 +377,7 @@ impl MemorySocket { /// let socket = MemorySocket::connect(16)?; /// # Ok(())} /// ``` - pub fn connect(port: u16) -> Result { + pub fn connect(port: u16) -> io::Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Find port to connect to @@ -399,13 +405,13 @@ impl MemorySocket { impl AsyncRead for MemorySocket { /// Attempt to read from the `AsyncRead` into `buf`. - fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(mut self: Pin<&mut Self>, mut context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { if self.incoming.is_terminated() { if self.seen_eof { return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); } else { self.seen_eof = true; - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } } @@ -413,22 +419,23 @@ impl AsyncRead for MemorySocket { loop { // If we're already filled up the buffer then we can return - if bytes_read == buf.len() { - return Poll::Ready(Ok(bytes_read)); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } match self.current_buffer { // We have data to copy to buf Some(ref mut current_buffer) if !current_buffer.is_empty() => { - let bytes_to_read = ::std::cmp::min(buf.len() - bytes_read, current_buffer.len()); - debug_assert!(bytes_to_read > 0); - - buf[bytes_read..(bytes_read + bytes_to_read)] - .copy_from_slice(current_buffer.slice(0..bytes_to_read).as_ref()); + let bytes_to_read = cmp::min(buf.remaining(), current_buffer.len()); + if bytes_to_read > 0 { + buf.initialize_unfilled_to(bytes_to_read) + .copy_from_slice(¤t_buffer.slice(..bytes_to_read)); + buf.advance(bytes_to_read); - current_buffer.advance(bytes_to_read); + current_buffer.advance(bytes_to_read); - bytes_read += bytes_to_read; + bytes_read += bytes_to_read; + } }, // Either we've exhausted our current buffer or don't have one @@ -438,13 +445,13 @@ impl AsyncRead for MemorySocket { Poll::Pending => { // If we've read anything up to this point return the bytes read if bytes_read > 0 { - return Poll::Ready(Ok(bytes_read)); + return Poll::Ready(Ok(())); } else { return Poll::Pending; } }, Poll::Ready(Some(buf)) => Some(buf), - Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)), + Poll::Ready(None) => return Poll::Ready(Ok(())), } }; }, @@ -455,14 +462,14 @@ impl AsyncRead for MemorySocket { impl AsyncWrite for MemorySocket { /// Attempt to write bytes from `buf` into the outgoing channel. - fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(mut self: Pin<&mut Self>, context: &mut Context, buf: &[u8]) -> Poll> { let len = buf.len(); match self.outgoing.poll_ready(context) { Poll::Ready(Ok(())) => { if let Err(e) = self.outgoing.start_send(Bytes::copy_from_slice(buf)) { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -471,7 +478,7 @@ impl AsyncWrite for MemorySocket { }, Poll::Ready(Err(e)) => { if e.is_disconnected() { - return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + return Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e))); } // Unbounded channels should only ever have "Disconnected" errors @@ -484,12 +491,12 @@ impl AsyncWrite for MemorySocket { } /// Attempt to flush the channel. Cannot Fail. - fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { Poll::Ready(Ok(())) } /// Attempt to close the channel. Cannot Fail. - fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { self.outgoing.close_channel(); Poll::Ready(Ok(())) @@ -499,15 +506,12 @@ impl AsyncWrite for MemorySocket { #[cfg(test)] mod test { use super::*; - use futures::{ - executor::block_on, - io::{AsyncReadExt, AsyncWriteExt}, - stream::StreamExt, - }; - use std::io::Result; + use crate::runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; #[test] - fn listener_bind() -> Result<()> { + fn listener_bind() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let listener = MemoryListener::bind(port)?; assert_eq!(listener.local_addr(), port); @@ -515,172 +519,187 @@ mod test { Ok(()) } - #[test] - fn simple_connect() -> Result<()> { + #[runtime::test] + async fn simple_connect() -> io::Result<()> { let port = acquire_next_memsocket_port().into(); let mut listener = MemoryListener::bind(port)?; let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); Ok(()) } - #[test] - fn listen_on_port_zero() -> Result<()> { + #[runtime::test] + async fn listen_on_port_zero() -> io::Result<()> { let mut listener = MemoryListener::bind(0)?; let listener_addr = listener.local_addr(); let mut dialer = MemorySocket::connect(listener_addr)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + let mut listener_socket = listener.incoming().next().await.unwrap()?; - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await?; + dialer.flush().await?; let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + listener_socket.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(listener_socket.write_all(b"bar"))?; - block_on(listener_socket.flush())?; + listener_socket.write_all(b"bar").await?; + listener_socket.flush().await?; let mut buf = [0; 3]; - block_on(dialer.read_exact(&mut buf))?; + dialer.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let mut listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; - let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + #[runtime::test] + async fn listener_correctly_frees_port_on_drop() { + async fn connect_on_port(port: u16) { + let mut listener = MemoryListener::bind(port).unwrap(); + let mut dialer = MemorySocket::connect(port).unwrap(); + let mut listener_socket = listener.incoming().next().await.unwrap().unwrap(); - block_on(dialer.write_all(b"foo"))?; - block_on(dialer.flush())?; + dialer.write_all(b"foo").await.unwrap(); + dialer.flush().await.unwrap(); let mut buf = [0; 3]; - block_on(listener_socket.read_exact(&mut buf))?; + let n = listener_socket.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, 3); assert_eq!(&buf, b"foo"); - - Ok(()) } let port = acquire_next_memsocket_port().into(); - connect_on_port(port)?; - connect_on_port(port)?; - - Ok(()) + connect_on_port(port).await; + connect_on_port(port).await; } - #[test] - fn simple_write_read() -> Result<()> { + #[runtime::test] + async fn simple_write_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"hello world"))?; - block_on(a.flush())?; + a.write_all(b"hello world").await?; + a.flush().await?; drop(a); let mut v = Vec::new(); - block_on(b.read_to_end(&mut v))?; + b.read_to_end(&mut v).await?; assert_eq!(v, b"hello world"); Ok(()) } - #[test] - fn partial_read() -> Result<()> { + #[runtime::test] + async fn partial_read() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; let mut buf = [0; 3]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"foo"); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"bar"); Ok(()) } - #[test] - fn partial_read_write_both_sides() -> Result<()> { + #[runtime::test] + async fn partial_read_write_both_sides() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"foobar"))?; - block_on(a.flush())?; - block_on(b.write_all(b"stormlight"))?; - block_on(b.flush())?; + a.write_all(b"foobar").await?; + a.flush().await?; + b.write_all(b"stormlight").await?; + b.flush().await?; let mut buf_a = [0; 5]; let mut buf_b = [0; 3]; - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"storm"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"foo"); - block_on(a.read_exact(&mut buf_a))?; + a.read_exact(&mut buf_a).await?; assert_eq!(&buf_a, b"light"); - block_on(b.read_exact(&mut buf_b))?; + b.read_exact(&mut buf_b).await?; assert_eq!(&buf_b, b"bar"); Ok(()) } - #[test] - fn many_small_writes() -> Result<()> { + #[runtime::test] + async fn many_small_writes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"words"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"of"))?; - block_on(a.write_all(b" "))?; - block_on(a.write_all(b"radiance"))?; - block_on(a.flush())?; + a.write_all(b"words").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"of").await?; + a.write_all(b" ").await?; + a.flush().await?; + a.write_all(b"radiance").await?; + a.flush().await?; drop(a); let mut buf = [0; 17]; - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"words of radiance"); Ok(()) } - #[test] - fn read_zero_bytes() -> Result<()> { + #[runtime::test] + async fn large_writes() -> io::Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + let large_data = vec![123u8; 1024]; + a.write_all(&large_data).await?; + a.flush().await?; + drop(a); + + let mut buf = Vec::new(); + b.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 1024); + + Ok(()) + } + + #[runtime::test] + async fn read_zero_bytes() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 12]; - block_on(b.read_exact(&mut buf[0..0]))?; + b.read_exact(&mut buf[0..0]).await?; assert_eq!(buf, [0; 12]); - block_on(b.read_exact(&mut buf))?; + b.read_exact(&mut buf).await?; assert_eq!(&buf, b"way of kings"); Ok(()) } - #[test] - fn read_bytes_with_large_buffer() -> Result<()> { + #[runtime::test] + async fn read_bytes_with_large_buffer() -> io::Result<()> { let (mut a, mut b) = MemorySocket::new_pair(); - block_on(a.write_all(b"way of kings"))?; - block_on(a.flush())?; + a.write_all(b"way of kings").await?; + a.flush().await?; let mut buf = [0; 20]; - let bytes_read = block_on(b.read(&mut buf))?; + let bytes_read = b.read(&mut buf).await?; assert_eq!(bytes_read, 12); assert_eq!(&buf[0..12], b"way of kings"); diff --git a/comms/src/message/outbound.rs b/comms/src/message/outbound.rs index 25f0899143..a08c9604ce 100644 --- a/comms/src/message/outbound.rs +++ b/comms/src/message/outbound.rs @@ -22,11 +22,11 @@ use crate::{message::MessageTag, peer_manager::NodeId, protocol::messaging::SendFailReason}; use bytes::Bytes; -use futures::channel::oneshot; use std::{ fmt, fmt::{Error, Formatter}, }; +use tokio::sync::oneshot; pub type MessagingReplyResult = Result<(), SendFailReason>; pub type MessagingReplyRx = oneshot::Receiver; diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 1ba104ce04..17558133f2 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -21,17 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{connection_manager::ConnectionDirection, runtime}; -use futures::{ - channel::mpsc, - io::{AsyncRead, AsyncWrite}, - stream::FusedStream, - task::Context, - SinkExt, - Stream, - StreamExt, -}; +use futures::{task::Context, Stream}; use std::{future::Future, io, pin::Pin, sync::Arc, task::Poll}; use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::{self, debug, error, event, Level}; use yamux::Mode; @@ -70,7 +67,7 @@ impl Yamux { config.set_receive_window(RECEIVE_WINDOW); let substream_counter = SubstreamCounter::new(); - let connection = yamux::Connection::new(socket, config, mode); + let connection = yamux::Connection::new(socket.compat(), config, mode); let control = Control::new(connection.control(), substream_counter.clone()); let incoming = Self::spawn_incoming_stream_worker(connection, substream_counter.clone()); @@ -88,12 +85,11 @@ impl Yamux { counter: SubstreamCounter, ) -> IncomingSubstreams where - TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, + TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static, { let shutdown = Shutdown::new(); let (incoming_tx, incoming_rx) = mpsc::channel(10); - let stream = yamux::into_stream(connection).boxed(); - let incoming = IncomingWorker::new(stream, incoming_tx, shutdown.to_signal()); + let incoming = IncomingWorker::new(connection, incoming_tx, shutdown.to_signal()); runtime::task::spawn(incoming.run()); IncomingSubstreams::new(incoming_rx, counter, shutdown) } @@ -122,10 +118,6 @@ impl Yamux { pub(crate) fn substream_counter(&self) -> SubstreamCounter { self.substream_counter.clone() } - - pub fn is_terminated(&self) -> bool { - self.incoming.is_terminated() - } } #[derive(Clone)] @@ -146,7 +138,7 @@ impl Control { pub async fn open_stream(&mut self) -> Result { let stream = self.inner.open_stream().await?; Ok(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), }) } @@ -185,19 +177,13 @@ impl IncomingSubstreams { } } -impl FusedStream for IncomingSubstreams { - fn is_terminated(&self) -> bool { - self.inner.is_terminated() - } -} - impl Stream for IncomingSubstreams { type Item = Substream; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match futures::ready!(Pin::new(&mut self.inner).poll_next(cx)) { + match futures::ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(stream) => Poll::Ready(Some(Substream { - stream, + stream: stream.compat(), counter_guard: self.substream_counter.new_guard(), })), None => Poll::Ready(None), @@ -213,17 +199,17 @@ impl Drop for IncomingSubstreams { #[derive(Debug)] pub struct Substream { - stream: yamux::Stream, + stream: Compat, counter_guard: CounterGuard, } -impl AsyncRead for Substream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { +impl tokio::io::AsyncRead for Substream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } } -impl AsyncWrite for Substream { +impl tokio::io::AsyncWrite for Substream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { Pin::new(&mut self.stream).poll_write(cx, buf) } @@ -232,23 +218,23 @@ impl AsyncWrite for Substream { Pin::new(&mut self.stream).poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) } } -struct IncomingWorker { - inner: S, +struct IncomingWorker { + connection: yamux::Connection, sender: mpsc::Sender, shutdown_signal: ShutdownSignal, } -impl IncomingWorker -where S: Stream> + Unpin +impl IncomingWorker +where TSocket: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static /* */ { - pub fn new(stream: S, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { + pub fn new(connection: yamux::Connection, sender: IncomingTx, shutdown_signal: ShutdownSignal) -> Self { Self { - inner: stream, + connection, sender, shutdown_signal, } @@ -256,37 +242,55 @@ where S: Stream> + Unpin #[tracing::instrument(name = "yamux::incoming_worker::run", skip(self))] pub async fn run(mut self) { - let mut mux_stream = self.inner.take_until(&mut self.shutdown_signal); - while let Some(result) = mux_stream.next().await { - match result { - Ok(stream) => { - event!(Level::TRACE, "yamux::stream received {}", stream); - if self.sender.send(stream).await.is_err() { - debug!( - target: LOG_TARGET, - "Incoming peer substream task is shutting down because the internal stream sender channel \ - was closed" - ); - break; + loop { + tokio::select! { + biased; + + _ = &mut self.shutdown_signal => { + let mut control = self.connection.control(); + if let Err(err) = control.close().await { + error!(target: LOG_TARGET, "Failed to close yamux connection: {}", err); } - }, - Err(err) => { - event!( + break + } + + result = self.connection.next_stream() => { + match result { + Ok(Some(stream)) => { + event!(Level::TRACE, "yamux::stream received {}", stream);if self.sender.send(stream).await.is_err() { + debug!( + target: LOG_TARGET, + "Incoming peer substream task is shutting down because the internal stream sender channel \ + was closed" + ); + break; + } + }, + Ok(None) =>{ + debug!( + target: LOG_TARGET, + "Incoming peer substream completed. IncomingWorker exiting" + ); + break; + } + Err(err) => { + event!( Level::ERROR, "Incoming peer substream task received an error because '{}'", err ); error!( - target: LOG_TARGET, - "Incoming peer substream task received an error because '{}'", err - ); - break; - }, + target: LOG_TARGET, + "Incoming peer substream task received an error because '{}'", err + ); + break; + }, + } + } } } debug!(target: LOG_TARGET, "Incoming peer substream task is shutting down"); - self.sender.close_channel(); } } @@ -321,15 +325,12 @@ mod test { runtime, runtime::task, }; - use futures::{ - future, - io::{AsyncReadExt, AsyncWriteExt}, - StreamExt, - }; use std::{io, time::Duration}; use tari_test_utils::collect_stream; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::StreamExt; - #[runtime::test_basic] + #[runtime::test] async fn open_substream() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; @@ -344,7 +345,7 @@ mod test { substream.write_all(msg).await.unwrap(); substream.flush().await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); }); let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) @@ -356,13 +357,16 @@ mod test { .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))?; let mut buf = Vec::new(); - let _ = future::select(substream.read_to_end(&mut buf), listener.next()).await; + tokio::select! { + _ = substream.read_to_end(&mut buf) => {}, + _ = listener.next() => {}, + }; assert_eq!(buf, msg); Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn substream_count() { const NUM_SUBSTREAMS: usize = 10; let (dialer, listener) = MemorySocket::new_pair(); @@ -396,7 +400,7 @@ mod test { assert_eq!(listener.substream_count(), 0); } - #[runtime::test_basic] + #[runtime::test] async fn close() -> io::Result<()> { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; @@ -425,7 +429,7 @@ mod test { assert_eq!(buf, msg); // Close the substream and then try to write to it - substream.close().await?; + substream.shutdown().await?; let result = substream.write_all(b"ignored message").await; match result { @@ -436,7 +440,7 @@ mod test { Ok(()) } - #[runtime::test_basic] + #[runtime::test] async fn send_big_message() -> io::Result<()> { #[allow(non_upper_case_globals)] static MiB: usize = 1 << 20; @@ -457,7 +461,7 @@ mod test { let mut buf = vec![0u8; MSG_LEN]; substream.read_exact(&mut buf).await.unwrap(); - substream.close().await.unwrap(); + substream.shutdown().await.unwrap(); assert_eq!(buf.len(), MSG_LEN); assert_eq!(buf, vec![0xAAu8; MSG_LEN]); @@ -476,7 +480,7 @@ mod test { let msg = vec![0xAAu8; MSG_LEN]; substream.write_all(msg.as_slice()).await?; - substream.close().await?; + substream.shutdown().await?; drop(substream); assert_eq!(incoming.substream_count(), 0); diff --git a/comms/src/noise/config.rs b/comms/src/noise/config.rs index 30cab07c48..23e1d34ec3 100644 --- a/comms/src/noise/config.rs +++ b/comms/src/noise/config.rs @@ -31,11 +31,11 @@ use crate::{ }, peer_manager::NodeIdentity, }; -use futures::{AsyncRead, AsyncWrite}; use log::*; use snow::{self, params::NoiseParams}; use std::sync::Arc; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncWrite}; const LOG_TARGET: &str = "comms::noise"; pub(super) const NOISE_IX_PARAMETER: &str = "Noise_IX_25519_ChaChaPoly_BLAKE2b"; @@ -60,7 +60,7 @@ impl NoiseConfig { /// Upgrades the given socket to using the noise protocol. The upgraded socket and the peer's static key /// is returned. - #[tracing::instrument(name = "noise::upgrade_socket", skip(self, socket), err)] + #[tracing::instrument(name = "noise::upgrade_socket", skip(self, socket))] pub async fn upgrade_socket( &self, socket: TSocket, @@ -96,10 +96,15 @@ impl NoiseConfig { #[cfg(test)] mod test { use super::*; - use crate::{memsocket::MemorySocket, peer_manager::PeerFeatures, test_utils::node_identity::build_node_identity}; - use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt}; + use crate::{ + memsocket::MemorySocket, + peer_manager::PeerFeatures, + runtime, + test_utils::node_identity::build_node_identity, + }; + use futures::{future, FutureExt}; use snow::params::{BaseChoice, CipherChoice, DHChoice, HandshakePattern, HashChoice}; - use tokio::runtime::Runtime; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; fn check_noise_params(config: &NoiseConfig) { assert_eq!(config.parameters.hash, HashChoice::Blake2b); @@ -118,39 +123,35 @@ mod test { assert_eq!(config.node_identity.public_key(), node_identity.public_key()); } - #[test] - fn upgrade_socket() { - let mut rt = Runtime::new().unwrap(); - + #[runtime::test] + async fn upgrade_socket() { let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config1 = NoiseConfig::new(node_identity1.clone()); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let config2 = NoiseConfig::new(node_identity2.clone()); - rt.block_on(async move { - let (in_socket, out_socket) = MemorySocket::new_pair(); - let (mut socket_in, mut socket_out) = future::join( - config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), - config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), - ) - .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) - .await; - - let in_pubkey = socket_in.get_remote_public_key().unwrap(); - let out_pubkey = socket_out.get_remote_public_key().unwrap(); - - assert_eq!(&in_pubkey, node_identity2.public_key()); - assert_eq!(&out_pubkey, node_identity1.public_key()); - - let sample = b"Children of time"; - socket_in.write_all(sample).await.unwrap(); - socket_in.flush().await.unwrap(); - socket_in.close().await.unwrap(); - - let mut read_buf = Vec::with_capacity(16); - socket_out.read_to_end(&mut read_buf).await.unwrap(); - assert_eq!(read_buf, sample); - }); + let (in_socket, out_socket) = MemorySocket::new_pair(); + let (mut socket_in, mut socket_out) = future::join( + config1.upgrade_socket(in_socket, ConnectionDirection::Inbound), + config2.upgrade_socket(out_socket, ConnectionDirection::Outbound), + ) + .map(|(s1, s2)| (s1.unwrap(), s2.unwrap())) + .await; + + let in_pubkey = socket_in.get_remote_public_key().unwrap(); + let out_pubkey = socket_out.get_remote_public_key().unwrap(); + + assert_eq!(&in_pubkey, node_identity2.public_key()); + assert_eq!(&out_pubkey, node_identity1.public_key()); + + let sample = b"Children of time"; + socket_in.write_all(sample).await.unwrap(); + socket_in.flush().await.unwrap(); + socket_in.shutdown().await.unwrap(); + + let mut read_buf = Vec::with_capacity(16); + socket_out.read_to_end(&mut read_buf).await.unwrap(); + assert_eq!(read_buf, sample); } } diff --git a/comms/src/noise/socket.rs b/comms/src/noise/socket.rs index eaf02a60b0..33d35f89ca 100644 --- a/comms/src/noise/socket.rs +++ b/comms/src/noise/socket.rs @@ -26,27 +26,27 @@ //! Noise Socket +use crate::types::CommsPublicKey; use futures::ready; use log::*; use snow::{error::StateProblem, HandshakeState, TransportState}; use std::{ + cmp, convert::TryInto, io, pin::Pin, task::{Context, Poll}, }; -// use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::types::CommsPublicKey; -use futures::{io::Error, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tari_crypto::tari_utilities::ByteArray; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; const LOG_TARGET: &str = "comms::noise::socket"; -const MAX_PAYLOAD_LENGTH: usize = u16::max_value() as usize; // 65535 +const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize; // 65535 // The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because // encrypted messages include a tag along with the payload. -const MAX_WRITE_BUFFER_LENGTH: usize = u16::max_value() as usize - 16; // 65519 +const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16; // 65519 /// Collection of buffers used for buffering data during the various read/write states of a /// NoiseSocket @@ -223,7 +223,12 @@ where TSocket: AsyncRead, { loop { - let n = ready!(socket.as_mut().poll_read(&mut context, &mut buf[*offset..]))?; + let mut read_buf = ReadBuf::new(&mut buf[*offset..]); + let prev_rem = read_buf.remaining(); + ready!(socket.as_mut().poll_read(&mut context, &mut read_buf))?; + let n = prev_rem + .checked_sub(read_buf.remaining()) + .expect("buffer underflow: prev_rem < read_buf.remaining()"); trace!( target: LOG_TARGET, "poll_read_exact: read {}/{} bytes", @@ -320,7 +325,7 @@ where TSocket: AsyncRead + Unpin decrypted_len, ref mut offset, } => { - let bytes_to_copy = ::std::cmp::min(decrypted_len as usize - *offset, buf.len()); + let bytes_to_copy = cmp::min(decrypted_len as usize - *offset, buf.len()); buf[..bytes_to_copy] .copy_from_slice(&self.buffers.read_decrypted[*offset..(*offset + bytes_to_copy)]); trace!( @@ -351,8 +356,11 @@ where TSocket: AsyncRead + Unpin impl AsyncRead for NoiseSocket where TSocket: AsyncRead + Unpin { - fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut [u8]) -> Poll> { - self.get_mut().poll_read(context, buf) + fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll> { + let slice = buf.initialize_unfilled(); + let n = futures::ready!(self.get_mut().poll_read(context, slice))?; + buf.advance(n); + Poll::Ready(Ok(())) } } @@ -501,8 +509,8 @@ where TSocket: AsyncWrite + Unpin self.get_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.socket).poll_close(cx) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.socket).poll_shutdown(cx) } } @@ -531,7 +539,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin target: LOG_TARGET, "Noise handshake failed because '{:?}'. Closing socket.", err ); - self.socket.close().await?; + self.socket.shutdown().await?; Err(err) }, } @@ -644,7 +652,6 @@ mod test { use futures::future::join; use snow::{params::NoiseParams, Builder, Error, Keypair}; use std::io; - use tokio::runtime::Runtime; async fn build_test_connection( ) -> Result<((Keypair, Handshake), (Keypair, Handshake)), Error> { @@ -707,7 +714,7 @@ mod test { dialer_socket.write_all(b" ").await?; dialer_socket.write_all(b"archive").await?; dialer_socket.flush().await?; - dialer_socket.close().await?; + dialer_socket.shutdown().await?; let mut buf = Vec::new(); listener_socket.read_to_end(&mut buf).await?; @@ -745,51 +752,60 @@ mod test { Ok(()) } - #[test] - fn u16_max_writes() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn u16_max_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH + 1]; + a.write_all(&buf_send).await?; + a.flush().await?; - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await?; - assert_eq!(&buf_receive[..], &buf_send[..]); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH + 1]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - Ok(()) - }) + Ok(()) } - #[test] - fn unexpected_eof() -> io::Result<()> { - // Current thread runtime stack overflows, so the full tokio runtime is used here - let mut rt = Runtime::new().unwrap(); - rt.block_on(async move { - let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); + #[runtime::test] + async fn larger_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let (mut a, mut b) = perform_handshake(dialer, listener).await?; + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - let buf_send = [1; MAX_PAYLOAD_LENGTH]; - a.write_all(&buf_send).await?; - a.flush().await?; + let buf_send = [1; MAX_PAYLOAD_LENGTH * 2 + 1024]; + a.write_all(&buf_send).await?; + a.flush().await?; - a.socket.close().await.unwrap(); - drop(a); + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH * 2 + 1024]; + b.read_exact(&mut buf_receive).await?; + assert_eq!(&buf_receive[..], &buf_send[..]); - let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; - b.read_exact(&mut buf_receive).await.unwrap(); - assert_eq!(&buf_receive[..], &buf_send[..]); + Ok(()) + } + + #[runtime::test] + async fn unexpected_eof() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap(); - let err = b.read_exact(&mut buf_receive).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + let (mut a, mut b) = perform_handshake(dialer, listener).await?; - Ok(()) - }) + let buf_send = [1; MAX_PAYLOAD_LENGTH]; + a.write_all(&buf_send).await?; + a.flush().await?; + + a.socket.shutdown().await.unwrap(); + drop(a); + + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; + b.read_exact(&mut buf_receive).await.unwrap(); + assert_eq!(&buf_receive[..], &buf_send[..]); + + let err = b.read_exact(&mut buf_receive).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) } } diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index 3828c82ab5..d7df51a2ea 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -337,7 +337,7 @@ mod test { peer } - #[runtime::test_basic] + #[runtime::test] async fn get_broadcast_identities() { // Create peer manager with random peers let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); @@ -446,7 +446,7 @@ mod test { assert_ne!(identities1, identities2); } - #[runtime::test_basic] + #[runtime::test] async fn calc_region_threshold() { let n = 5; // Create peer manager with random peers @@ -514,7 +514,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn closest_peers() { let n = 5; // Create peer manager with random peers @@ -548,7 +548,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn add_or_update_online_peer() { let peer_manager = PeerManager::new(HashmapDatabase::new(), None).unwrap(); let mut peer = create_test_peer(false, PeerFeatures::COMMUNICATION_NODE); diff --git a/comms/src/pipeline/builder.rs b/comms/src/pipeline/builder.rs index 40a38d10a3..9ae90fabec 100644 --- a/comms/src/pipeline/builder.rs +++ b/comms/src/pipeline/builder.rs @@ -24,8 +24,8 @@ use crate::{ message::{InboundMessage, OutboundMessage}, pipeline::SinkService, }; -use futures::channel::mpsc; use thiserror::Error; +use tokio::sync::mpsc; use tower::Service; const DEFAULT_MAX_CONCURRENT_TASKS: usize = 50; @@ -99,9 +99,7 @@ where TOutSvc: Service + Clone + Send + 'static, TInSvc: Service + Clone + Send + 'static, { - fn build_outbound( - &mut self, - ) -> Result, TOutSvc>, PipelineBuilderError> { + fn build_outbound(&mut self) -> Result, PipelineBuilderError> { let (out_sender, out_receiver) = mpsc::channel(self.outbound_buffer_size); let in_receiver = self @@ -137,9 +135,9 @@ where } } -pub struct OutboundPipelineConfig { +pub struct OutboundPipelineConfig { /// Messages read from this stream are passed to the pipeline - pub in_receiver: TInStream, + pub in_receiver: mpsc::Receiver, /// Receiver of `OutboundMessage`s coming from the pipeline pub out_receiver: mpsc::Receiver, /// The pipeline (`tower::Service`) to run for each in_stream message @@ -149,7 +147,7 @@ pub struct OutboundPipelineConfig { pub struct Config { pub max_concurrent_inbound_tasks: usize, pub inbound: TInSvc, - pub outbound: OutboundPipelineConfig, TOutSvc>, + pub outbound: OutboundPipelineConfig, } #[derive(Debug, Error)] diff --git a/comms/src/pipeline/inbound.rs b/comms/src/pipeline/inbound.rs index 0b2116bc37..1f135640a7 100644 --- a/comms/src/pipeline/inbound.rs +++ b/comms/src/pipeline/inbound.rs @@ -21,10 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::bounded_executor::BoundedExecutor; -use futures::{future::FusedFuture, stream::FusedStream, Stream, StreamExt}; +use futures::future::FusedFuture; use log::*; use std::fmt::Display; use tari_shutdown::ShutdownSignal; +use tokio::sync::mpsc; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::inbound"; @@ -33,22 +34,26 @@ const LOG_TARGET: &str = "comms::pipeline::inbound"; /// The difference between this can ServiceExt::call_all is /// that ServicePipeline doesn't keep the result of the service /// call and that it spawns a task for each incoming item. -pub struct Inbound { +pub struct Inbound { executor: BoundedExecutor, service: TSvc, - stream: TStream, + stream: mpsc::Receiver, shutdown_signal: ShutdownSignal, } -impl Inbound +impl Inbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TSvc: Service + Clone + Send + 'static, + TMsg: Send + 'static, + TSvc: Service + Clone + Send + 'static, TSvc::Error: Display + Send, TSvc::Future: Send, { - pub fn new(executor: BoundedExecutor, stream: TStream, service: TSvc, shutdown_signal: ShutdownSignal) -> Self { + pub fn new( + executor: BoundedExecutor, + stream: mpsc::Receiver, + service: TSvc, + shutdown_signal: ShutdownSignal, + ) -> Self { Self { executor, service, @@ -59,7 +64,7 @@ where } pub async fn run(mut self) { - while let Some(item) = self.stream.next().await { + while let Some(item) = self.stream.recv().await { // Check if the shutdown signal has been triggered. // If there are messages in the stream, drop them. Otherwise the stream is empty, // it will return None and the while loop will end. @@ -100,21 +105,25 @@ where mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, future, stream}; + use futures::future; use std::time::Duration; use tari_shutdown::Shutdown; - use tari_test_utils::collect_stream; - use tokio::{runtime::Handle, time}; + use tari_test_utils::collect_recv; + use tokio::{sync::mpsc, time}; use tower::service_fn; - #[runtime::test_basic] + #[runtime::test] async fn run() { let items = vec![1, 2, 3, 4, 5, 6]; - let stream = stream::iter(items.clone()).fuse(); + let (tx, mut stream) = mpsc::channel(items.len()); + for i in items.clone() { + tx.send(i).await.unwrap(); + } + stream.close(); - let (mut out_tx, mut out_rx) = mpsc::channel(items.len()); + let (out_tx, mut out_rx) = mpsc::channel(items.len()); - let executor = Handle::current(); + let executor = runtime::current(); let shutdown = Shutdown::new(); let pipeline = Inbound::new( BoundedExecutor::new(executor.clone(), 1), @@ -125,9 +134,10 @@ mod test { }), shutdown.to_signal(), ); + let spawned_task = executor.spawn(pipeline.run()); - let received = collect_stream!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); + let received = collect_recv!(out_rx, take = items.len(), timeout = Duration::from_secs(10)); assert!(received.iter().all(|i| items.contains(i))); // Check that this task ends because the stream has closed diff --git a/comms/src/pipeline/mod.rs b/comms/src/pipeline/mod.rs index 1039374c65..a2a057354b 100644 --- a/comms/src/pipeline/mod.rs +++ b/comms/src/pipeline/mod.rs @@ -44,7 +44,4 @@ pub(crate) use inbound::Inbound; mod outbound; pub(crate) use outbound::Outbound; -mod translate_sink; -pub use translate_sink::TranslateSink; - pub type PipelineError = anyhow::Error; diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index c860166ad0..54facdb92b 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -25,34 +25,33 @@ use crate::{ pipeline::builder::OutboundPipelineConfig, protocol::messaging::MessagingRequest, }; -use futures::{channel::mpsc, future, future::Either, stream::FusedStream, SinkExt, Stream, StreamExt}; +use futures::future::Either; use log::*; use std::fmt::Display; -use tokio::runtime; +use tokio::{runtime, sync::mpsc}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::pipeline::outbound"; -pub struct Outbound { +pub struct Outbound { /// Executor used to spawn a pipeline for each received item on the stream executor: runtime::Handle, /// Outbound pipeline configuration containing the pipeline and it's in and out streams - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, /// Request sender for Messaging messaging_request_tx: mpsc::Sender, } -impl Outbound +impl Outbound where - TStream: Stream + FusedStream + Unpin, - TStream::Item: Send + 'static, - TPipeline: Service + Clone + Send + 'static, + TItem: Send + 'static, + TPipeline: Service + Clone + Send + 'static, TPipeline::Error: Display + Send, TPipeline::Future: Send, { pub fn new( executor: runtime::Handle, - config: OutboundPipelineConfig, + config: OutboundPipelineConfig, messaging_request_tx: mpsc::Sender, ) -> Self { Self { @@ -64,10 +63,13 @@ where pub async fn run(mut self) { loop { - let either = future::select(self.config.in_receiver.next(), self.config.out_receiver.next()).await; + let either = tokio::select! { + next = self.config.in_receiver.recv() => Either::Left(next), + next = self.config.out_receiver.recv() => Either::Right(next) + }; match either { // Pipeline IN received a message. Spawn a new task for the pipeline - Either::Left((Some(msg), _)) => { + Either::Left(Some(msg)) => { let pipeline = self.config.pipeline.clone(); self.executor.spawn(async move { if let Err(err) = pipeline.oneshot(msg).await { @@ -76,7 +78,7 @@ where }); }, // Pipeline IN channel closed - Either::Left((None, _)) => { + Either::Left(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the in channel closed" @@ -84,7 +86,7 @@ where break; }, // Pipeline OUT received a message - Either::Right((Some(out_msg), _)) => { + Either::Right(Some(out_msg)) => { if self.messaging_request_tx.is_closed() { // MessagingRequest channel closed break; @@ -92,7 +94,7 @@ where self.send_messaging_request(out_msg).await; }, // Pipeline OUT channel closed - Either::Right((None, _)) => { + Either::Right(None) => { info!( target: LOG_TARGET, "Outbound pipeline is shutting down because the out channel closed" @@ -117,19 +119,22 @@ where #[cfg(test)] mod test { use super::*; - use crate::{pipeline::SinkService, runtime}; + use crate::{pipeline::SinkService, runtime, utils}; use bytes::Bytes; - use futures::stream; use std::time::Duration; - use tari_test_utils::{collect_stream, unpack_enum}; + use tari_test_utils::{collect_recv, unpack_enum}; use tokio::{runtime::Handle, time}; - #[runtime::test_basic] + #[runtime::test] async fn run() { const NUM_ITEMS: usize = 10; - let items = - (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))); - let stream = stream::iter(items).fuse(); + let (tx, in_receiver) = mpsc::channel(NUM_ITEMS); + utils::mpsc::send_all( + &tx, + (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))), + ) + .await + .unwrap(); let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); let (msg_tx, mut msg_rx) = mpsc::channel(NUM_ITEMS); let executor = Handle::current(); @@ -137,7 +142,7 @@ mod test { let pipeline = Outbound::new( executor.clone(), OutboundPipelineConfig { - in_receiver: stream, + in_receiver, out_receiver: out_rx, pipeline: SinkService::new(out_tx), }, @@ -146,7 +151,8 @@ mod test { let spawned_task = executor.spawn(pipeline.run()); - let requests = collect_stream!(msg_rx, take = NUM_ITEMS, timeout = Duration::from_millis(5)); + msg_rx.close(); + let requests = collect_recv!(msg_rx, timeout = Duration::from_millis(5)); for req in requests { unpack_enum!(MessagingRequest::SendMessage(_o) = req); } diff --git a/comms/src/pipeline/sink.rs b/comms/src/pipeline/sink.rs index a455aaf320..1b524e92a5 100644 --- a/comms/src/pipeline/sink.rs +++ b/comms/src/pipeline/sink.rs @@ -21,8 +21,8 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::PipelineError; -use futures::{future::BoxFuture, task::Context, FutureExt, Sink, SinkExt}; -use std::{pin::Pin, task::Poll}; +use futures::{future::BoxFuture, task::Context, FutureExt}; +use std::task::Poll; use tower::Service; /// A service which forwards and messages it gets to the given Sink @@ -35,22 +35,24 @@ impl SinkService { } } -impl Service for SinkService -where - T: Send + 'static, - TSink: Sink + Unpin + Clone + Send + 'static, - TSink::Error: Into + Send + 'static, +impl Service for SinkService> +where T: Send + 'static { type Error = PipelineError; type Future = BoxFuture<'static, Result>; type Response = (); - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_ready(cx).map_err(Into::into) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, item: T) -> Self::Future { - let mut sink = self.0.clone(); - async move { sink.send(item).await.map_err(Into::into) }.boxed() + let sink = self.0.clone(); + async move { + sink.send(item) + .await + .map_err(|_| anyhow::anyhow!("sink closed in sink service")) + } + .boxed() } } diff --git a/comms/src/pipeline/translate_sink.rs b/comms/src/pipeline/translate_sink.rs index 6a2bcad56a..606c038299 100644 --- a/comms/src/pipeline/translate_sink.rs +++ b/comms/src/pipeline/translate_sink.rs @@ -93,9 +93,10 @@ where F: FnMut(I) -> Option mod test { use super::*; use crate::runtime; - use futures::{channel::mpsc, SinkExt, StreamExt}; + use futures::{SinkExt, StreamExt}; + use tokio::sync::mpsc; - #[runtime::test_basic] + #[runtime::test] async fn check_translates() { let (tx, mut rx) = mpsc::channel(1); diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 2c4eba1db5..b468a946f5 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -20,26 +20,28 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - compat::IoCompat, connection_manager::ConnectionDirection, message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, protocol::{NodeNetworkInfo, ProtocolError, ProtocolId, ProtocolNegotiation}, }; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use log::*; use prost::Message; use std::{io, time::Duration}; use thiserror::Error; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing; pub static IDENTITY_PROTOCOL: ProtocolId = ProtocolId::from_static(b"t/identity/1.0"); const LOG_TARGET: &str = "comms::protocol::identity"; -#[tracing::instrument(skip(socket, our_supported_protocols), err)] +#[tracing::instrument(skip(socket, our_supported_protocols))] pub async fn identity_exchange<'p, TSocket, P>( node_identity: &NodeIdentity, direction: ConnectionDirection, @@ -79,7 +81,7 @@ where debug_assert_eq!(proto, IDENTITY_PROTOCOL); // Create length-delimited frame codec - let framed = Framed::new(IoCompat::new(socket), LengthDelimitedCodec::new()); + let framed = Framed::new(socket, LengthDelimitedCodec::new()); let (mut sink, mut stream) = framed.split(); let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect(); @@ -136,8 +138,8 @@ pub enum IdentityProtocolError { ProtocolVersionMismatch, } -impl From for IdentityProtocolError { - fn from(_: time::Elapsed) -> Self { +impl From for IdentityProtocolError { + fn from(_: time::error::Elapsed) -> Self { IdentityProtocolError::Timeout } } @@ -172,7 +174,7 @@ mod test { }; use futures::{future, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn identity_exchange() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); @@ -221,7 +223,7 @@ mod test { assert_eq!(identity2.addresses, vec![node_identity2.public_address().to_vec()]); } - #[runtime::test_basic] + #[runtime::test] async fn fail_cases() { let transport = MemoryTransport; let addr = "/memory/0".parse().unwrap(); diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index 6f078cec23..91d0e786ba 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -20,10 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{connection_manager::PeerConnectionError, peer_manager::PeerManagerError, protocol::ProtocolError}; -use futures::channel::mpsc; +use crate::{ + connection_manager::PeerConnectionError, + message::OutboundMessage, + peer_manager::PeerManagerError, + protocol::ProtocolError, +}; use std::io; use thiserror::Error; +use tokio::sync::mpsc; #[derive(Debug, Error)] pub enum InboundMessagingError { @@ -46,7 +51,7 @@ pub enum MessagingProtocolError { #[error("IO Error: {0}")] Io(#[from] io::Error), #[error("Sender error: {0}")] - SenderError(#[from] mpsc::SendError), + SenderError(#[from] mpsc::error::SendError), #[error("Stream closed due to inactivity")] Inactivity, } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 241a152a5b..f216ddd04e 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -34,8 +34,8 @@ use crate::{ runtime, runtime::task, }; -use futures::channel::mpsc; use std::fmt; +use tokio::sync::mpsc; use tower::Service; /// Buffer size for inbound messages from _all_ peers. This should be large enough to buffer quite a few incoming diff --git a/comms/src/protocol/messaging/forward.rs b/comms/src/protocol/messaging/forward.rs new file mode 100644 index 0000000000..ce035e8fb8 --- /dev/null +++ b/comms/src/protocol/messaging/forward.rs @@ -0,0 +1,110 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Copied from futures rs + +use futures::{ + future::{FusedFuture, Future}, + ready, + stream::{Fuse, StreamExt}, + task::{Context, Poll}, + Sink, + Stream, + TryStream, +}; +use pin_project::pin_project; +use std::pin::Pin; + +/// Future for the [`forward`](super::StreamExt::forward) method. +#[pin_project(project = ForwardProj)] +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Forward { + #[pin] + sink: Option, + #[pin] + stream: Fuse, + buffered_item: Option, +} + +impl Forward +where St: TryStream +{ + pub(crate) fn new(stream: St, sink: Si) -> Self { + Self { + sink: Some(sink), + stream: stream.fuse(), + buffered_item: None, + } + } +} + +impl FusedFuture for Forward +where + Si: Sink, + St: Stream>, +{ + fn is_terminated(&self) -> bool { + self.sink.is_none() + } +} + +impl Future for Forward +where + Si: Sink, + St: Stream>, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let ForwardProj { + mut sink, + mut stream, + buffered_item, + } = self.project(); + let mut si = sink.as_mut().as_pin_mut().expect("polled `Forward` after completion"); + + loop { + // If we've got an item buffered already, we need to write it to the + // sink before we can do anything else + if buffered_item.is_some() { + ready!(si.as_mut().poll_ready(cx))?; + si.as_mut().start_send(buffered_item.take().unwrap())?; + } + + match stream.as_mut().poll_next(cx)? { + Poll::Ready(Some(item)) => { + *buffered_item = Some(item); + }, + Poll::Ready(None) => { + ready!(si.poll_close(cx))?; + sink.set(None); + return Poll::Ready(Ok(())); + }, + Poll::Pending => { + ready!(si.poll_flush(cx))?; + return Poll::Pending; + }, + } + } + } +} diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs index 643b07fc45..e65f81f8ec 100644 --- a/comms/src/protocol/messaging/inbound.rs +++ b/comms/src/protocol/messaging/inbound.rs @@ -21,15 +21,18 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - common::rate_limit::RateLimit, message::InboundMessage, peer_manager::NodeId, protocol::messaging::{MessagingEvent, MessagingProtocol}, + rate_limit::RateLimit, }; -use futures::{channel::mpsc, future::Either, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{future::Either, StreamExt}; use log::*; use std::{sync::Arc, time::Duration}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; @@ -61,7 +64,7 @@ impl InboundMessaging { } } - pub async fn run(mut self, socket: S) + pub async fn run(self, socket: S) where S: AsyncRead + AsyncWrite + Unpin { let peer = &self.peer; debug!( @@ -70,48 +73,40 @@ impl InboundMessaging { peer.short_str() ); - let (mut sink, stream) = MessagingProtocol::framed(socket).split(); - - if let Err(err) = sink.close().await { - debug!( - target: LOG_TARGET, - "Error closing sink half for peer `{}`: {}", - peer.short_str(), - err - ); - } - let stream = stream.rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); + let stream = + MessagingProtocol::framed(socket).rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); - let mut stream = match self.inactivity_timeout { - Some(timeout) => Either::Left(tokio::stream::StreamExt::timeout(stream, timeout)), + let stream = match self.inactivity_timeout { + Some(timeout) => Either::Left(tokio_stream::StreamExt::timeout(stream, timeout)), None => Either::Right(stream.map(Ok)), }; + tokio::pin!(stream); while let Some(result) = stream.next().await { match result { Ok(Ok(raw_msg)) => { - let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.clone().freeze()); + let msg_len = raw_msg.len(); + let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); debug!( target: LOG_TARGET, "Received message {} from peer '{}' ({} bytes)", inbound_msg.tag, peer.short_str(), - raw_msg.len() + msg_len ); let event = MessagingEvent::MessageReceived(inbound_msg.source_peer.clone(), inbound_msg.tag); if let Err(err) = self.inbound_message_tx.send(inbound_msg).await { + let tag = err.0.tag; warn!( target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", + "Failed to send InboundMessage {} for peer '{}' because inbound message channel closed", + tag, peer.short_str(), - err ); - if err.is_disconnected() { - break; - } + break; } let _ = self.messaging_events_tx.send(Arc::new(event)); diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 88fca6af05..1fa347b3a2 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -27,9 +27,9 @@ mod extension; pub use extension::MessagingProtocolExtension; mod error; +mod forward; mod inbound; mod outbound; - mod protocol; pub use protocol::{ MessagingEvent, diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 9d47895338..5ba1d36ec0 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -29,14 +29,13 @@ use crate::{ peer_manager::NodeId, protocol::messaging::protocol::MESSAGING_PROTOCOL, }; -use futures::{channel::mpsc, future::Either, SinkExt, StreamExt}; -use log::*; +use futures::{future::Either, StreamExt, TryStreamExt}; use std::{ io, time::{Duration, Instant}, }; -use tokio::stream as tokio_stream; -use tracing::{event, span, Instrument, Level}; +use tokio::sync::mpsc as tokiompsc; +use tracing::{debug, error, event, span, Instrument, Level}; const LOG_TARGET: &str = "comms::protocol::messaging::outbound"; /// The number of times to retry sending a failed message before publishing a SendMessageFailed event. @@ -46,8 +45,8 @@ const MAX_SEND_RETRIES: usize = 1; pub struct OutboundMessaging { connectivity: ConnectivityRequester, - request_rx: mpsc::UnboundedReceiver, - messaging_events_tx: mpsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, peer_node_id: NodeId, inactivity_timeout: Option, } @@ -55,8 +54,8 @@ pub struct OutboundMessaging { impl OutboundMessaging { pub fn new( connectivity: ConnectivityRequester, - messaging_events_tx: mpsc::Sender, - request_rx: mpsc::UnboundedReceiver, + messaging_events_tx: tokiompsc::Sender, + request_rx: tokiompsc::UnboundedReceiver, peer_node_id: NodeId, inactivity_timeout: Option, ) -> Self { @@ -82,7 +81,7 @@ impl OutboundMessaging { self.peer_node_id.short_str() ); let peer_node_id = self.peer_node_id.clone(); - let mut messaging_events_tx = self.messaging_events_tx.clone(); + let messaging_events_tx = self.messaging_events_tx.clone(); match self.run_inner().await { Ok(_) => { event!( @@ -107,9 +106,18 @@ impl OutboundMessaging { peer_node_id.short_str() ); }, - Err(err) => { - event!(Level::ERROR, "Outbound messaging substream failed:{}", err); - debug!(target: LOG_TARGET, "Outbound messaging substream failed: {}", err); + Err(err) => match err { + MessagingProtocolError::PeerDialFailed => { + debug!( + target: LOG_TARGET, + "Outbound messaging substream failed due to a dial fail. Most likely the peer is offline \ + or doesn't exist: {}", + err + ); + }, + _ => { + error!(target: LOG_TARGET, "Outbound messaging substream failed:{}", err); + }, }, } @@ -131,7 +139,6 @@ impl OutboundMessaging { break substream; }, Err(err) => { - event!(Level::ERROR, "Error establishing messaging protocol"); if attempts >= MAX_SEND_RETRIES { debug!( target: LOG_TARGET, @@ -265,7 +272,7 @@ impl OutboundMessaging { ); let substream = substream.stream; - let (sink, _) = MessagingProtocol::framed(substream).split(); + let framed = MessagingProtocol::framed(substream); let Self { request_rx, @@ -273,30 +280,30 @@ impl OutboundMessaging { .. } = self; + // Convert unbounded channel to a stream + let stream = futures::stream::unfold(request_rx, |mut rx| async move { + let v = rx.recv().await; + v.map(|v| (v, rx)) + }); + let stream = match inactivity_timeout { Some(timeout) => { - let s = tokio_stream::StreamExt::timeout(request_rx, timeout).map(|r| match r { - Ok(s) => Ok(s), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - MessagingProtocolError::Inactivity, - )), - }); + let s = tokio_stream::StreamExt::timeout(stream, timeout) + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, MessagingProtocolError::Inactivity)); Either::Left(s) }, - None => Either::Right(request_rx.map(Ok)), + None => Either::Right(stream.map(Ok)), }; - stream - .map(|msg| { - msg.map(|mut out_msg| { - event!(Level::DEBUG, "Message buffered for sending {}", out_msg); - out_msg.reply_success(); - out_msg.body - }) + let stream = stream.map(|msg| { + msg.map(|mut out_msg| { + event!(Level::DEBUG, "Message buffered for sending {}", out_msg); + out_msg.reply_success(); + out_msg.body }) - .forward(sink) - .await?; + }); + + super::forward::Forward::new(stream, framed).await?; debug!( target: LOG_TARGET, @@ -310,7 +317,7 @@ impl OutboundMessaging { // Close the request channel so that we can read all the remaining messages and flush them // to a failed event self.request_rx.close(); - while let Some(mut out_msg) = self.request_rx.next().await { + while let Some(mut out_msg) = self.request_rx.recv().await { out_msg.reply_fail(reason); let _ = self .messaging_events_tx diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index 1f6fe029ab..988b4ada21 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -22,7 +22,6 @@ use super::error::MessagingProtocolError; use crate::{ - compat::IoCompat, connectivity::{ConnectivityEvent, ConnectivityRequester}, framing, message::{InboundMessage, MessageTag, OutboundMessage}, @@ -36,7 +35,6 @@ use crate::{ runtime::task, }; use bytes::Bytes; -use futures::{channel::mpsc, stream::Fuse, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{ collections::{hash_map::Entry, HashMap}, @@ -46,7 +44,10 @@ use std::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use thiserror::Error; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; const LOG_TARGET: &str = "comms::protocol::messaging"; @@ -106,13 +107,13 @@ impl fmt::Display for MessagingEvent { pub struct MessagingProtocol { config: MessagingConfig, connectivity: ConnectivityRequester, - proto_notification: Fuse>>, + proto_notification: mpsc::Receiver>, active_queues: HashMap>, - request_rx: Fuse>, + request_rx: mpsc::Receiver, messaging_events_tx: MessagingEventSender, inbound_message_tx: mpsc::Sender, internal_messaging_event_tx: mpsc::Sender, - internal_messaging_event_rx: Fuse>, + internal_messaging_event_rx: mpsc::Receiver, shutdown_signal: ShutdownSignal, complete_trigger: Shutdown, } @@ -133,11 +134,11 @@ impl MessagingProtocol { Self { config, connectivity, - proto_notification: proto_notification.fuse(), - request_rx: request_rx.fuse(), + proto_notification, + request_rx, active_queues: Default::default(), messaging_events_tx, - internal_messaging_event_rx: internal_messaging_event_rx.fuse(), + internal_messaging_event_rx, internal_messaging_event_tx, inbound_message_tx, shutdown_signal, @@ -151,15 +152,15 @@ impl MessagingProtocol { pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); - let mut connectivity_events = self.connectivity.get_event_subscription().fuse(); + let mut connectivity_events = self.connectivity.get_event_subscription(); loop { - futures::select! { - event = self.internal_messaging_event_rx.select_next_some() => { + tokio::select! { + Some(event) = self.internal_messaging_event_rx.recv() => { self.handle_internal_messaging_event(event).await; }, - req = self.request_rx.select_next_some() => { + Some(req) = self.request_rx.recv() => { if let Err(err) = self.handle_request(req).await { error!( target: LOG_TARGET, @@ -169,17 +170,17 @@ impl MessagingProtocol { } }, - event = connectivity_events.select_next_some() => { + event = connectivity_events.recv() => { if let Ok(event) = event { self.handle_connectivity_event(&event); } } - notification = self.proto_notification.select_next_some() => { + Some(notification) = self.proto_notification.recv() => { self.handle_protocol_notification(notification).await; }, - _ = shutdown_signal => { + _ = &mut shutdown_signal => { info!(target: LOG_TARGET, "MessagingProtocol is shutting down because the shutdown signal was triggered"); break; } @@ -188,7 +189,7 @@ impl MessagingProtocol { } #[inline] - pub fn framed(socket: TSubstream) -> Framed, LengthDelimitedCodec> + pub fn framed(socket: TSubstream) -> Framed where TSubstream: AsyncRead + AsyncWrite + Unpin { framing::canonical(socket, MAX_FRAME_LENGTH) } @@ -198,11 +199,9 @@ impl MessagingProtocol { #[allow(clippy::single_match)] match event { PeerConnectionWillClose(node_id, _) => { - // If the peer connection will close, cut off the pipe to send further messages. - // Any messages in the channel will be sent (hopefully) before the connection is disconnected. - if let Some(sender) = self.active_queues.remove(node_id) { - sender.close_channel(); - } + // If the peer connection will close, cut off the pipe to send further messages by dropping the sender. + // Any messages in the channel may be sent before the connection is disconnected. + let _ = self.active_queues.remove(node_id); }, _ => {}, } @@ -263,7 +262,7 @@ impl MessagingProtocol { let sender = Self::spawn_outbound_handler( self.connectivity.clone(), self.internal_messaging_event_tx.clone(), - peer_node_id.clone(), + peer_node_id, self.config.inactivity_timeout, ); break entry.insert(sender); @@ -273,7 +272,7 @@ impl MessagingProtocol { debug!(target: LOG_TARGET, "Sending message {}", out_msg); let tag = out_msg.tag; - match sender.send(out_msg).await { + match sender.send(out_msg) { Ok(_) => { debug!(target: LOG_TARGET, "Message ({}) dispatched to outbound handler", tag,); Ok(()) @@ -294,7 +293,7 @@ impl MessagingProtocol { peer_node_id: NodeId, inactivity_timeout: Option, ) -> mpsc::UnboundedSender { - let (msg_tx, msg_rx) = mpsc::unbounded(); + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); let outbound_messaging = OutboundMessaging::new(connectivity, events_tx, msg_rx, peer_node_id, inactivity_timeout); task::spawn(outbound_messaging.run()); diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index e954af31f2..5d066eb4c2 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -49,18 +49,16 @@ use crate::{ types::{CommsDatabase, CommsPublicKey}, }; use bytes::Bytes; -use futures::{ - channel::{mpsc, oneshot}, - stream::FuturesUnordered, - SinkExt, - StreamExt, -}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::rngs::OsRng; use std::{io, sync::Arc, time::Duration}; use tari_crypto::keys::PublicKey; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_stream, unpack_enum}; -use tokio::{sync::broadcast, time}; +use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; static TEST_MSG1: Bytes = Bytes::from_static(b"TEST_MSG1"); @@ -110,9 +108,9 @@ async fn spawn_messaging_protocol() -> ( ) } -#[runtime::test_basic] +#[runtime::test] async fn new_inbound_substream_handling() { - let (peer_manager, _, _, mut proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = + let (peer_manager, _, _, proto_tx, _, mut inbound_msg_rx, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let expected_node_id = node_id::random(); @@ -148,7 +146,7 @@ async fn new_inbound_substream_handling() { framed_theirs.send(TEST_MSG1.clone()).await.unwrap(); - let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.next()) + let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.recv()) .await .unwrap() .unwrap(); @@ -156,19 +154,18 @@ async fn new_inbound_substream_handling() { assert_eq!(in_msg.body, TEST_MSG1); let expected_tag = in_msg.tag; - let event = time::timeout(Duration::from_secs(5), events_rx.next()) + let event = time::timeout(Duration::from_secs(5), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::MessageReceived(node_id, tag) = &*event); assert_eq!(tag, &expected_tag); assert_eq!(*node_id, expected_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_request() { - let (_, node_identity, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, node_identity, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let peer_node_identity = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -192,9 +189,9 @@ async fn send_message_request() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_dial_failed() { - let (_, _, conn_manager_mock, _, mut request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); let out_msg = OutboundMessage::new(node_id, TEST_MSG1.clone()); @@ -202,7 +199,7 @@ async fn send_message_dial_failed() { // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); - let event = event_tx.next().await.unwrap().unwrap(); + let event = event_tx.recv().await.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); assert_eq!(out_msg.tag, expected_out_msg_tag); @@ -212,7 +209,7 @@ async fn send_message_dial_failed() { assert!(calls.iter().all(|evt| evt.starts_with("DialPeer"))); } -#[runtime::test_basic] +#[runtime::test] async fn send_message_substream_bulk_failure() { const NUM_MSGS: usize = 10; let (_, node_identity, conn_manager_mock, _, mut request_tx, _, mut events_rx, _shutdown) = @@ -258,19 +255,18 @@ async fn send_message_substream_bulk_failure() { } // Check that the outbound handler closed - let event = time::timeout(Duration::from_secs(10), events_rx.next()) + let event = time::timeout(Duration::from_secs(10), events_rx.recv()) .await .unwrap() - .unwrap() .unwrap(); unpack_enum!(MessagingEvent::OutboundProtocolExited(node_id) = &*event); assert_eq!(node_id, peer_node_id); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests() { const NUM_MSGS: usize = 100; - let (_, _, conn_man_mock, _, mut request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; + let (_, _, conn_man_mock, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); @@ -315,10 +311,10 @@ async fn many_concurrent_send_message_requests() { assert_eq!(peer_conn_mock1.call_count(), 1); } -#[runtime::test_basic] +#[runtime::test] async fn many_concurrent_send_message_requests_that_fail() { const NUM_MSGS: usize = 100; - let (_, _, _, _, mut request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, _, _, request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; let node_id2 = node_id::random(); @@ -339,10 +335,9 @@ async fn many_concurrent_send_message_requests_that_fail() { } // Check that we got message success events - let events = collect_stream!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); + let events = collect_recv!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); assert_eq!(events.len(), NUM_MSGS); for event in events { - let event = event.unwrap(); unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); unpack_enum!(SendFailReason::PeerDialFailed = reason); // Assert that each tag is emitted only once @@ -357,7 +352,7 @@ async fn many_concurrent_send_message_requests_that_fail() { assert_eq!(msg_tags.len(), 0); } -#[runtime::test_basic] +#[runtime::test] async fn inactivity_timeout() { let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); let (inbound_msg_tx, mut inbound_msg_rx) = mpsc::channel(5); @@ -381,13 +376,13 @@ async fn inactivity_timeout() { let mut framed = MessagingProtocol::framed(socket_out); for _ in 0..5u8 { framed.send(Bytes::from_static(b"some message")).await.unwrap(); - time::delay_for(Duration::from_millis(1)).await; + time::sleep(Duration::from_millis(1)).await; } - time::delay_for(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(10)).await; let err = framed.send(Bytes::from_static(b"another message")).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - let _ = collect_stream!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); + let _ = collect_recv!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); } diff --git a/comms/src/protocol/negotiation.rs b/comms/src/protocol/negotiation.rs index 5178d790e6..326910dc99 100644 --- a/comms/src/protocol/negotiation.rs +++ b/comms/src/protocol/negotiation.rs @@ -23,9 +23,9 @@ use super::{ProtocolError, ProtocolId}; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use log::*; use std::convert::TryInto; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const LOG_TARGET: &str = "comms::connection_manager::protocol"; @@ -204,7 +204,7 @@ mod test { use futures::future; use tari_test_utils::unpack_enum; - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -229,7 +229,7 @@ mod test { assert_eq!(out_proto.unwrap(), ProtocolId::from_static(b"A")); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -254,7 +254,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolOutboundNegotiationFailed(_s) = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_max_rounds() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -279,7 +279,7 @@ mod test { unpack_enum!(ProtocolError::ProtocolNegotiationTerminatedByPeer = out_proto.unwrap_err()); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_success_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); @@ -300,7 +300,7 @@ mod test { out_proto.unwrap(); } - #[runtime::test_basic] + #[runtime::test] async fn negotiate_fail_optimistic() { let (mut initiator, mut responder) = MemorySocket::new_pair(); let mut negotiate_out = ProtocolNegotiation::new(&mut initiator); diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 936ef15f34..14d253196e 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -32,8 +32,8 @@ use crate::{ }, Substream, }; -use futures::{channel::mpsc, SinkExt}; use std::collections::HashMap; +use tokio::sync::mpsc; pub type ProtocolNotificationTx = mpsc::Sender>; pub type ProtocolNotificationRx = mpsc::Receiver>; @@ -143,7 +143,6 @@ impl ProtocolExtension for Protocols { mod test { use super::*; use crate::runtime; - use futures::StreamExt; use tari_test_utils::unpack_enum; #[test] @@ -160,7 +159,7 @@ mod test { assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p))); } - #[runtime::test_basic] + #[runtime::test] async fn notify() { let (tx, mut rx) = mpsc::channel(1); let protos = [ProtocolId::from_static(b"/tari/test/1")]; @@ -172,12 +171,12 @@ mod test { .await .unwrap(); - let notification = rx.next().await.unwrap(); + let notification = rx.recv().await.unwrap(); unpack_enum!(ProtocolEvent::NewInboundSubstream(peer_id, _s) = notification.event); assert_eq!(peer_id, NodeId::new()); } - #[runtime::test_basic] + #[runtime::test] async fn notify_fail_not_registered() { let mut protocols = Protocols::<()>::new(); diff --git a/comms/src/protocol/rpc/body.rs b/comms/src/protocol/rpc/body.rs index 6079508729..e563d6483e 100644 --- a/comms/src/protocol/rpc/body.rs +++ b/comms/src/protocol/rpc/body.rs @@ -27,7 +27,6 @@ use crate::{ }; use bytes::BytesMut; use futures::{ - channel::mpsc, ready, stream::BoxStream, task::{Context, Poll}, @@ -37,6 +36,7 @@ use futures::{ use pin_project::pin_project; use prost::bytes::Buf; use std::{fmt, marker::PhantomData, pin::Pin}; +use tokio::sync::mpsc; pub trait IntoBody { fn into_body(self) -> Body; @@ -205,8 +205,8 @@ impl Buf for BodyBytes { self.0.as_ref().map(Buf::remaining).unwrap_or(0) } - fn bytes(&self) -> &[u8] { - self.0.as_ref().map(Buf::bytes).unwrap_or(&[]) + fn chunk(&self) -> &[u8] { + self.0.as_ref().map(Buf::chunk).unwrap_or(&[]) } fn advance(&mut self, cnt: usize) { @@ -227,7 +227,7 @@ impl Streaming { } pub fn empty() -> Self { - let (_, rx) = mpsc::channel(0); + let (_, rx) = mpsc::channel(1); Self { inner: rx } } @@ -240,7 +240,7 @@ impl Stream for Streaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(result) => { let result = result.map(|msg| msg.to_encoded_bytes().into()); Poll::Ready(Some(result)) @@ -275,7 +275,7 @@ impl Stream for ClientStreaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_next_unpin(cx)) { + match ready!(Pin::new(&mut self.inner).poll_recv(cx)) { Some(Ok(resp)) => { // The streaming protocol dictates that an empty finish flag MUST be sent to indicate a terminated // stream. This empty response need not be emitted to downsteam consumers. @@ -298,7 +298,7 @@ mod test { use futures::{stream, StreamExt}; use prost::Message; - #[runtime::test_basic] + #[runtime::test] async fn single_body() { let mut body = Body::single(123u32.to_encoded_bytes()); let bytes = body.next().await.unwrap().unwrap(); @@ -306,7 +306,7 @@ mod test { assert_eq!(u32::decode(bytes).unwrap(), 123u32); } - #[runtime::test_basic] + #[runtime::test] async fn streaming_body() { let body = Body::streaming(stream::repeat(Bytes::new()).map(Ok).take(10)); let body = body.collect::>().await; diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index 2806366f9f..6c994fcaf8 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -41,11 +41,8 @@ use crate::{ }; use bytes::Bytes; use futures::{ - channel::{mpsc, oneshot}, future::{BoxFuture, Either}, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, SinkExt, StreamExt, @@ -58,9 +55,15 @@ use std::{ fmt, future::Future, marker::PhantomData, + sync::Arc, time::{Duration, Instant}, }; -use tokio::time; +use tari_shutdown::{Shutdown, ShutdownSignal}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot, Mutex}, + time, +}; use tower::{Service, ServiceExt}; use tracing::{event, span, Instrument, Level}; @@ -82,14 +85,16 @@ impl RpcClient { TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (request_tx, request_rx) = mpsc::channel(1); - let connector = ClientConnector::new(request_tx); + let shutdown = Shutdown::new(); + let shutdown_signal = shutdown.to_signal(); + let connector = ClientConnector::new(request_tx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); let tracing_id = tracing::Span::current().id(); task::spawn({ let span = span!(Level::TRACE, "start_rpc_worker"); span.follows_from(tracing_id); - RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name) + RpcClientWorker::new(config, request_rx, framed, ready_tx, protocol_name, shutdown_signal) .run() .instrument(span) }); @@ -110,7 +115,7 @@ impl RpcClient { let request = BaseRequest::new(method.into(), req_bytes.into()); let mut resp = self.call_inner(request).await?; - let resp = resp.next().await.ok_or(RpcError::ServerClosedRequest)??; + let resp = resp.recv().await.ok_or(RpcError::ServerClosedRequest)??; let resp = R::decode(resp.into_message())?; Ok(resp) @@ -132,8 +137,8 @@ impl RpcClient { } /// Close the RPC session. Any subsequent calls will error. - pub fn close(&mut self) { - self.connector.close() + pub async fn close(&mut self) { + self.connector.close().await; } pub fn is_connected(&self) -> bool { @@ -269,15 +274,20 @@ impl Default for RpcClientConfig { #[derive(Clone)] pub struct ClientConnector { inner: mpsc::Sender, + shutdown: Arc>, } impl ClientConnector { - pub(self) fn new(sender: mpsc::Sender) -> Self { - Self { inner: sender } + pub(self) fn new(sender: mpsc::Sender, shutdown: Shutdown) -> Self { + Self { + inner: sender, + shutdown: Arc::new(Mutex::new(shutdown)), + } } - pub fn close(&mut self) { - self.inner.close_channel(); + pub async fn close(&mut self) { + let mut lock = self.shutdown.lock().await; + lock.trigger(); } pub async fn get_last_request_latency(&mut self) -> Result, RpcError> { @@ -317,13 +327,13 @@ impl Service> for ClientConnector { type Future = BoxFuture<'static, Result>; type Response = mpsc::Receiver, RpcStatus>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready_unpin(cx).map_err(|_| RpcError::ClientClosed) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, request: BaseRequest) -> Self::Future { let (reply, reply_rx) = oneshot::channel(); - let mut inner = self.inner.clone(); + let inner = self.inner.clone(); async move { inner .send(ClientRequest::SendRequest { request, reply }) @@ -346,6 +356,7 @@ pub struct RpcClientWorker { ready_tx: Option>>, last_request_latency: Option, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, } impl RpcClientWorker @@ -357,6 +368,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send framed: CanonicalFraming, ready_tx: oneshot::Sender>, protocol_id: ProtocolId, + shutdown_signal: ShutdownSignal, ) -> Self { Self { config, @@ -366,6 +378,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ready_tx: Some(ready_tx), last_request_latency: None, protocol_id, + shutdown_signal, } } @@ -405,26 +418,26 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send }, } - while let Some(req) = self.request_rx.next().await { - use ClientRequest::*; - match req { - SendRequest { request, reply } => { - if let Err(err) = self.do_request_response(request, reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; - } - }, - GetLastRequestLatency(reply) => { - let _ = reply.send(self.last_request_latency); - }, - SendPing(reply) => { - if let Err(err) = self.do_ping_pong(reply).await { - error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); - break; + loop { + tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + break; + } + req = self.request_rx.recv() => { + match req { + Some(req) => { + if let Err(err) = self.handle_request(req).await { + error!(target: LOG_TARGET, "Unexpected error: {}. Worker is terminating.", err); + break; + } + } + None => break, } - }, + } } } + if let Err(err) = self.framed.close().await { debug!(target: LOG_TARGET, "IO Error when closing substream: {}", err); } @@ -436,6 +449,22 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send ); } + async fn handle_request(&mut self, req: ClientRequest) -> Result<(), RpcError> { + use ClientRequest::*; + match req { + SendRequest { request, reply } => { + self.do_request_response(request, reply).await?; + }, + GetLastRequestLatency(reply) => { + let _ = reply.send(self.last_request_latency); + }, + SendPing(reply) => { + self.do_ping_pong(reply).await?; + }, + } + Ok(()) + } + async fn do_ping_pong(&mut self, reply: oneshot::Sender>) -> Result<(), RpcError> { let ack = proto::rpc::RpcRequest { flags: RpcMessageFlags::ACK.bits() as u32, @@ -482,7 +511,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send Ok(()) } - #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply), err)] + #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply))] async fn do_request_response( &mut self, request: BaseRequest, @@ -501,14 +530,30 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send debug!(target: LOG_TARGET, "Sending request: {}", req); let start = Instant::now(); + if reply.is_closed() { + event!(Level::WARN, "Client request was cancelled before request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled before request was sent" + ); + } self.framed.send(req.to_encoded_bytes().into()).await?; - let (mut response_tx, response_rx) = mpsc::channel(10); - if reply.send(response_rx).is_err() { - event!(Level::WARN, "Client request was cancelled"); - warn!(target: LOG_TARGET, "Client request was cancelled."); - response_tx.close_channel(); - // TODO: Should this not exit here? + let (response_tx, response_rx) = mpsc::channel(10); + if let Err(mut rx) = reply.send(response_rx) { + event!(Level::WARN, "Client request was cancelled after request was sent"); + warn!( + target: LOG_TARGET, + "Client request was cancelled after request was sent" + ); + rx.close(); + // RPC is strictly request/response + // If the client drops the RpcClient request at this point after the , we have two options: + // 1. Obey the protocol: receive the response + // 2. Close the RPC session and return an error (seems brittle and unexpected) + // Option 1 has the disadvantage when receiving large/many streamed responses. + // TODO: Detect if all handles to the client handles have been dropped. If so, + // immediately close the RPC session } loop { @@ -537,8 +582,20 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send start.elapsed() ); event!(Level::ERROR, "Response timed out"); - let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; - response_tx.close_channel(); + if !response_tx.is_closed() { + let _ = response_tx.send(Err(RpcStatus::timed_out("Response timed out"))).await; + } + break; + }, + Err(RpcError::ClientClosed) => { + debug!( + target: LOG_TARGET, + "Request {} (method={}) was closed after {:.0?} (read_reply)", + request_id, + method, + start.elapsed() + ); + self.request_rx.close(); break; }, Err(err) => { @@ -564,7 +621,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send let _ = response_tx.send(Ok(resp)).await; } if is_finished { - response_tx.close_channel(); break; } }, @@ -573,7 +629,6 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send if !response_tx.is_closed() { let _ = response_tx.send(Err(err)).await; } - response_tx.close_channel(); break; }, Err(err @ RpcError::ResponseIdDidNotMatchRequest { .. }) | @@ -598,7 +653,15 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send None => Either::Right(self.framed.next().map(Ok)), }; - match next_msg_fut.await { + let result = tokio::select! { + biased; + _ = &mut self.shutdown_signal => { + return Err(RpcError::ClientClosed); + } + result = next_msg_fut => result, + }; + + match result { Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), Ok(Some(Err(err))) => Err(err.into()), Ok(None) => Err(RpcError::ServerClosedRequest), diff --git a/comms/src/protocol/rpc/client_pool.rs b/comms/src/protocol/rpc/client_pool.rs index 6829b41265..7cf99ed419 100644 --- a/comms/src/protocol/rpc/client_pool.rs +++ b/comms/src/protocol/rpc/client_pool.rs @@ -61,6 +61,11 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone let mut pool = self.pool.lock().await; pool.get_least_used_or_connect().await } + + pub async fn is_connected(&self) -> bool { + let pool = self.pool.lock().await; + pool.is_connected() + } } #[derive(Clone)] @@ -111,6 +116,10 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone } } + pub fn is_connected(&self) -> bool { + self.connection.is_connected() + } + pub(super) fn refresh_num_active_connections(&mut self) -> usize { self.prune(); self.clients.len() diff --git a/comms/src/protocol/rpc/handshake.rs b/comms/src/protocol/rpc/handshake.rs index 3abd62cef6..b39c15e6d7 100644 --- a/comms/src/protocol/rpc/handshake.rs +++ b/comms/src/protocol/rpc/handshake.rs @@ -22,10 +22,13 @@ use crate::{framing::CanonicalFraming, message::MessageExt, proto, protocol::rpc::error::HandshakeRejectReason}; use bytes::BytesMut; -use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use prost::{DecodeError, Message}; use std::{io, time::Duration}; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + time, +}; use tracing::{debug, error, event, span, warn, Instrument, Level}; const LOG_TARGET: &str = "comms::rpc::handshake"; @@ -168,7 +171,7 @@ where T: AsyncRead + AsyncWrite + Unpin } #[tracing::instrument(name = "rpc::receive_handshake_reply", skip(self), err)] - async fn recv_next_frame(&mut self) -> Result>, time::Elapsed> { + async fn recv_next_frame(&mut self) -> Result>, time::error::Elapsed> { match self.timeout { Some(timeout) => time::timeout(timeout, self.framed.next()).await, None => Ok(self.framed.next().await), diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index d4e91fa8e4..2244979adf 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -80,6 +80,7 @@ pub mod __macro_reexports { }, Bytes, }; - pub use futures::{future, future::BoxFuture, AsyncRead, AsyncWrite}; + pub use futures::{future, future::BoxFuture}; + pub use tokio::io::{AsyncRead, AsyncWrite}; pub use tower::Service; } diff --git a/comms/src/protocol/rpc/server/error.rs b/comms/src/protocol/rpc/server/error.rs index 5078c6c588..6972cec60b 100644 --- a/comms/src/protocol/rpc/server/error.rs +++ b/comms/src/protocol/rpc/server/error.rs @@ -21,9 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::protocol::rpc::handshake::RpcHandshakeError; -use futures::channel::oneshot; use prost::DecodeError; use std::io; +use tokio::sync::oneshot; #[derive(Debug, thiserror::Error)] pub enum RpcServerError { @@ -41,8 +41,8 @@ pub enum RpcServerError { ProtocolServiceNotFound(String), } -impl From for RpcServerError { - fn from(_: oneshot::Canceled) -> Self { +impl From for RpcServerError { + fn from(_: oneshot::error::RecvError) -> Self { RpcServerError::RequestCanceled } } diff --git a/comms/src/protocol/rpc/server/handle.rs b/comms/src/protocol/rpc/server/handle.rs index 89bf8dd3b9..972d91429d 100644 --- a/comms/src/protocol/rpc/server/handle.rs +++ b/comms/src/protocol/rpc/server/handle.rs @@ -21,10 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::RpcServerError; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, -}; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug)] pub enum RpcServerRequest { diff --git a/comms/src/protocol/rpc/server/mock.rs b/comms/src/protocol/rpc/server/mock.rs index 19741a0a1a..69659ba03b 100644 --- a/comms/src/protocol/rpc/server/mock.rs +++ b/comms/src/protocol/rpc/server/mock.rs @@ -42,6 +42,7 @@ use crate::{ ProtocolNotificationTx, }, test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, + utils, NodeIdentity, PeerConnection, PeerManager, @@ -49,7 +50,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{channel::mpsc, future::BoxFuture, stream, SinkExt}; +use futures::future::BoxFuture; use std::{ collections::HashMap, future, @@ -57,7 +58,7 @@ use std::{ task::{Context, Poll}, }; use tokio::{ - sync::{Mutex, RwLock}, + sync::{mpsc, Mutex, RwLock}, task, }; use tower::Service; @@ -139,9 +140,13 @@ pub trait RpcMock { { method_state.requests.write().await.push(request.into_message()); let resp = method_state.response.read().await.clone()?; - let (mut tx, rx) = mpsc::channel(resp.len()); - let mut resp = stream::iter(resp.into_iter().map(Ok).map(Ok)); - tx.send_all(&mut resp).await.unwrap(); + let (tx, rx) = mpsc::channel(resp.len()); + match utils::mpsc::send_all(&tx, resp.into_iter().map(Ok)).await { + Ok(_) => {}, + // This is done because tokio mpsc channels give the item back to you in the error, and our item doesn't + // impl Debug, so we can't use unwrap, expect etc + Err(_) => panic!("send error"), + } Ok(Streaming::new(rx)) } } @@ -234,7 +239,7 @@ where let peer_node_id = peer.node_id.clone(); let (_, our_conn_mock, peer_conn, _) = create_peer_connection_mock_pair(peer, self.our_node.to_peer()).await; - let mut protocol_tx = self.protocol_tx.clone(); + let protocol_tx = self.protocol_tx.clone(); task::spawn(async move { while let Some(substream) = our_conn_mock.next_incoming_substream().await { let proto_notif = ProtocolNotification::new( diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index b2b7dbf76f..88fdb7ee61 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -53,14 +53,19 @@ use crate::{ protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::SinkExt; use prost::Message; use std::{ borrow::Cow, future::Future, time::{Duration, Instant}, }; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, + time, +}; +use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; @@ -198,7 +203,7 @@ pub(super) struct PeerRpcServer { service: TSvc, protocol_notifications: Option>, comms_provider: TCommsProvider, - request_rx: Option>, + request_rx: mpsc::Receiver, } impl PeerRpcServer @@ -233,7 +238,7 @@ where service, protocol_notifications: Some(protocol_notifications), comms_provider, - request_rx: Some(request_rx), + request_rx, } } @@ -243,24 +248,19 @@ where .take() .expect("PeerRpcServer initialized without protocol_notifications"); - let mut requests = self - .request_rx - .take() - .expect("PeerRpcServer initialized without request_rx"); - loop { - futures::select! { - maybe_notif = protocol_notifs.next() => { - match maybe_notif { - Some(notif) => self.handle_protocol_notification(notif).await?, - // No more protocol notifications to come, so we're done - None => break, - } - } - - req = requests.select_next_some() => { + tokio::select! { + maybe_notif = protocol_notifs.recv() => { + match maybe_notif { + Some(notif) => self.handle_protocol_notification(notif).await?, + // No more protocol notifications to come, so we're done + None => break, + } + } + + Some(req) = self.request_rx.recv() => { self.handle_request(req).await; - }, + }, } } diff --git a/comms/src/protocol/rpc/server/router.rs b/comms/src/protocol/rpc/server/router.rs index 9d03c6535d..1d40988075 100644 --- a/comms/src/protocol/rpc/server/router.rs +++ b/comms/src/protocol/rpc/server/router.rs @@ -44,14 +44,15 @@ use crate::{ Bytes, }; use futures::{ - channel::mpsc, future::BoxFuture, task::{Context, Poll}, - AsyncRead, - AsyncWrite, FutureExt, }; use std::sync::Arc; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; use tower::Service; use tower_make::MakeService; @@ -329,7 +330,7 @@ mod test { } } - #[runtime::test_basic] + #[runtime::test] async fn find_route() { let server = RpcServer::new(); let mut router = Router::new(server, HelloService).add_service(GoodbyeService); diff --git a/comms/src/protocol/rpc/test/client_pool.rs b/comms/src/protocol/rpc/test/client_pool.rs index e1eb957d5f..d95e1d22b6 100644 --- a/comms/src/protocol/rpc/test/client_pool.rs +++ b/comms/src/protocol/rpc/test/client_pool.rs @@ -39,13 +39,13 @@ use crate::{ runtime::task, test_utils::mocks::{new_peer_connection_mock_pair, PeerConnectionMockState}, }; -use futures::{channel::mpsc, SinkExt}; use tari_shutdown::Shutdown; use tari_test_utils::{async_assert_eventually, unpack_enum}; +use tokio::sync::mpsc; async fn setup(num_concurrent_sessions: usize) -> (PeerConnection, PeerConnectionMockState, Shutdown) { let (conn1, conn1_state, conn2, conn2_state) = new_peer_connection_mock_pair().await; - let (mut notif_tx, notif_rx) = mpsc::channel(1); + let (notif_tx, notif_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let (context, _) = create_mocked_rpc_context(); @@ -148,15 +148,15 @@ mod lazy_pool { async fn it_prunes_disconnected_sessions() { let (conn, mock_state, _shutdown) = setup(2).await; let mut pool = LazyPool::::new(conn, 2, Default::default()); - let mut conn1 = pool.get_least_used_or_connect().await.unwrap(); + let mut client1 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 1); - let _conn2 = pool.get_least_used_or_connect().await.unwrap(); + let _client2 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(mock_state.num_open_substreams(), 2); - conn1.close(); - drop(conn1); + client1.close().await; + drop(client1); async_assert_eventually!(mock_state.num_open_substreams(), expect = 1); assert_eq!(pool.refresh_num_active_connections(), 1); - let _conn3 = pool.get_least_used_or_connect().await.unwrap(); + let _client3 = pool.get_least_used_or_connect().await.unwrap(); assert_eq!(pool.refresh_num_active_connections(), 2); assert_eq!(mock_state.num_open_substreams(), 2); } diff --git a/comms/src/protocol/rpc/test/comms_integration.rs b/comms/src/protocol/rpc/test/comms_integration.rs index 9d23088f07..f43f921081 100644 --- a/comms/src/protocol/rpc/test/comms_integration.rs +++ b/comms/src/protocol/rpc/test/comms_integration.rs @@ -37,7 +37,7 @@ use crate::{ use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn run_service() { let node_identity1 = build_node_identity(Default::default()); let rpc_service = MockRpcService::new(); diff --git a/comms/src/protocol/rpc/test/greeting_service.rs b/comms/src/protocol/rpc/test/greeting_service.rs index 0e190473dd..445099e7c3 100644 --- a/comms/src/protocol/rpc/test/greeting_service.rs +++ b/comms/src/protocol/rpc/test/greeting_service.rs @@ -26,12 +26,16 @@ use crate::{ rpc::{NamedProtocolService, Request, Response, RpcError, RpcServerError, RpcStatus, Streaming}, ProtocolId, }, + utils, }; use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::RwLock, task, time}; +use tokio::{ + sync::{mpsc, RwLock}, + task, + time, +}; #[async_trait] // #[tari_rpc(protocol_name = "/tari/greeting/1.0", server_struct = GreetingServer, client_struct = GreetingClient)] @@ -91,20 +95,11 @@ impl GreetingRpc for GreetingService { } async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(1); let num = *request.message(); let greetings = self.greetings[..num as usize].to_vec(); task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - match tx.send_all(&mut stream).await { - Ok(_) => {}, - Err(_err) => { - // Log error - }, - } + let _ = utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await; }); Ok(Streaming::new(rx)) @@ -118,7 +113,7 @@ impl GreetingRpc for GreetingService { } async fn streaming_error2(&self, _: Request<()>) -> Result, RpcStatus> { - let (mut tx, rx) = mpsc::channel(2); + let (tx, rx) = mpsc::channel(2); tx.send(Ok("This is ok".to_string())).await.unwrap(); tx.send(Err(RpcStatus::bad_request("This is a problem"))).await.unwrap(); @@ -151,7 +146,7 @@ impl SlowGreetingService { impl GreetingRpc for SlowGreetingService { async fn say_hello(&self, _: Request) -> Result, RpcStatus> { let delay = *self.delay.read().await; - time::delay_for(delay).await; + time::sleep(delay).await; Ok(Response::new(SayHelloResponse { greeting: "took a while to load".to_string(), })) @@ -376,8 +371,8 @@ impl GreetingClient { self.inner.ping().await } - pub fn close(&mut self) { - self.inner.close(); + pub async fn close(&mut self) { + self.inner.close().await; } } diff --git a/comms/src/protocol/rpc/test/handshake.rs b/comms/src/protocol/rpc/test/handshake.rs index cdd79746f2..9a21628012 100644 --- a/comms/src/protocol/rpc/test/handshake.rs +++ b/comms/src/protocol/rpc/test/handshake.rs @@ -33,7 +33,7 @@ use crate::{ }; use tari_test_utils::unpack_enum; -#[runtime::test_basic] +#[runtime::test] async fn it_performs_the_handshake() { let (client, server) = MemorySocket::new_pair(); @@ -51,7 +51,7 @@ async fn it_performs_the_handshake() { assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); } -#[runtime::test_basic] +#[runtime::test] async fn it_rejects_the_handshake() { let (client, server) = MemorySocket::new_pair(); diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index a762ac4c9c..3149e794e3 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -52,12 +52,15 @@ use crate::{ test_utils::node_identity::build_node_identity, NodeIdentity, }; -use futures::{channel::mpsc, future, future::Either, SinkExt, StreamExt}; +use futures::{future, future::Either, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::Shutdown; use tari_test_utils::unpack_enum; -use tokio::{sync::RwLock, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; pub(super) async fn setup_service( service_impl: T, @@ -85,7 +88,7 @@ pub(super) async fn setup_service( futures::pin_mut!(fut); match future::select(shutdown_signal, fut).await { - Either::Left((r, _)) => r.unwrap(), + Either::Left(_) => {}, Either::Right((r, _)) => r.unwrap(), } } @@ -97,7 +100,7 @@ pub(super) async fn setup( service_impl: T, num_concurrent_sessions: usize, ) -> (MemorySocket, task::JoinHandle<()>, Arc, Shutdown) { - let (mut notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; + let (notif_tx, server_hnd, context, shutdown) = setup_service(service_impl, num_concurrent_sessions).await; let (inbound, outbound) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -114,7 +117,7 @@ pub(super) async fn setup( (outbound, server_hnd, node_identity, shutdown) } -#[runtime::test_basic] +#[runtime::test] async fn request_response_errors_and_streaming() { let (socket, server_hnd, node_identity, mut shutdown) = setup(GreetingService::default(), 1).await; @@ -171,7 +174,7 @@ async fn request_response_errors_and_streaming() { let pk_hex = client.get_public_key_hex().await.unwrap(); assert_eq!(pk_hex, node_identity.public_key().to_hex()); - client.close(); + client.close().await; let err = client .say_hello(SayHelloRequest { @@ -181,13 +184,20 @@ async fn request_response_errors_and_streaming() { .await .unwrap_err(); - unpack_enum!(RpcError::ClientClosed = err); + match err { + // Because of the race between closing the request stream and sending on that stream in the above call + // We can either get "this client was closed" or "the request you made was cancelled". + // If we delay some small time, we'll always get the former (but arbitrary delays cause flakiness and should be + // avoided) + RpcError::ClientClosed | RpcError::RequestCancelled => {}, + err => panic!("Unexpected error {:?}", err), + } - shutdown.trigger().unwrap(); + shutdown.trigger(); server_hnd.await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn concurrent_requests() { let (socket, _, _, _shutdown) = setup(GreetingService::default(), 1).await; @@ -227,7 +237,7 @@ async fn concurrent_requests() { assert_eq!(spawned2.await.unwrap(), GreetingService::DEFAULT_GREETINGS[..5]); } -#[runtime::test_basic] +#[runtime::test] async fn response_too_big() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -248,7 +258,7 @@ async fn response_too_big() { let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); } -#[runtime::test_basic] +#[runtime::test] async fn ping_latency() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; @@ -261,11 +271,11 @@ async fn ping_latency() { assert!(latency.as_secs() < 5); } -#[runtime::test_basic] +#[runtime::test] async fn server_shutdown_before_connect() { let (socket, _, _, mut shutdown) = setup(GreetingService::new(&[]), 1).await; let framed = framing::canonical(socket, 1024); - shutdown.trigger().unwrap(); + shutdown.trigger(); let err = GreetingClient::connect(framed).await.unwrap_err(); assert!(matches!( @@ -274,7 +284,7 @@ async fn server_shutdown_before_connect() { )); } -#[runtime::test_basic] +#[runtime::test] async fn timeout() { let delay = Arc::new(RwLock::new(Duration::from_secs(10))); let (socket, _, _, _shutdown) = setup(SlowGreetingService::new(delay.clone()), 1).await; @@ -298,9 +308,9 @@ async fn timeout() { assert_eq!(resp.greeting, "took a while to load"); } -#[runtime::test_basic] +#[runtime::test] async fn unknown_protocol() { - let (mut notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; + let (notif_tx, _, _, _shutdown) = setup_service(GreetingService::new(&[]), 1).await; let (inbound, socket) = MemorySocket::new_pair(); let node_identity = build_node_identity(Default::default()); @@ -324,7 +334,7 @@ async fn unknown_protocol() { )); } -#[runtime::test_basic] +#[runtime::test] async fn rejected_no_sessions_available() { let (socket, _, _, _shutdown) = setup(GreetingService::new(&[]), 0).await; let framed = framing::canonical(socket, 1024); diff --git a/comms/src/common/rate_limit.rs b/comms/src/rate_limit.rs similarity index 79% rename from comms/src/common/rate_limit.rs rename to comms/src/rate_limit.rs index 1705397d8e..3a06c70040 100644 --- a/comms/src/common/rate_limit.rs +++ b/comms/src/rate_limit.rs @@ -26,7 +26,7 @@ // This is slightly changed from the libra rate limiter implementation -use futures::{stream::Fuse, FutureExt, Stream, StreamExt}; +use futures::FutureExt; use pin_project::pin_project; use std::{ future::Future, @@ -36,10 +36,11 @@ use std::{ time::Duration, }; use tokio::{ - sync::{OwnedSemaphorePermit, Semaphore}, + sync::{AcquireError, OwnedSemaphorePermit, Semaphore}, time, time::Interval, }; +use tokio_stream::Stream; pub trait RateLimit: Stream { /// Consumes the stream and returns a rate-limited stream that only polls the underlying stream @@ -60,12 +61,13 @@ pub struct RateLimiter { stream: T, /// An interval stream that "restocks" the permits #[pin] - interval: Fuse, + interval: Interval, /// The maximum permits to issue capacity: usize, /// A semaphore that holds the permits permits: Arc, - permit_future: Option + Send>>>, + #[allow(clippy::type_complexity)] + permit_future: Option> + Send>>>, permit_acquired: bool, } @@ -75,7 +77,7 @@ impl RateLimiter { stream, capacity, - interval: time::interval(restock_interval).fuse(), + interval: time::interval(restock_interval), // `interval` starts immediately, so we can start with zero permits permits: Arc::new(Semaphore::new(0)), permit_future: None, @@ -89,7 +91,7 @@ impl Stream for RateLimiter { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // "Restock" permits once interval is ready - if let Poll::Ready(Some(_)) = self.as_mut().project().interval.poll_next(cx) { + if self.as_mut().project().interval.poll_tick(cx).is_ready() { self.permits .add_permits(self.capacity - self.permits.available_permits()); } @@ -103,6 +105,8 @@ impl Stream for RateLimiter { } // Wait until a permit is acquired + // `unwrap()` is safe because acquire_owned only panics if the semaphore has closed, but we never close it + // for the lifetime of this instance let permit = futures::ready!(self .as_mut() .project() @@ -110,7 +114,8 @@ impl Stream for RateLimiter { .as_mut() .unwrap() .as_mut() - .poll(cx)); + .poll(cx)) + .unwrap(); // Don't release the permit on drop, `interval` will restock permits permit.forget(); let this = self.as_mut().project(); @@ -130,45 +135,54 @@ impl Stream for RateLimiter { mod test { use super::*; use crate::runtime; - use futures::{future::Either, stream}; + use futures::{stream, StreamExt}; - #[runtime::test_basic] + #[runtime::test] async fn rate_limit() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_secs(100)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } assert_eq!(count, 10); } - #[runtime::test_basic] + #[runtime::test] async fn rate_limit_restock() { let repeater = stream::repeat(()); - let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)).fuse(); + let mut rate_limited = repeater.rate_limit(10, Duration::from_millis(10)); - let mut timeout = time::delay_for(Duration::from_millis(50)).fuse(); + let timeout = time::sleep(Duration::from_millis(50)); + tokio::pin!(timeout); let mut count = 0usize; loop { - let either = futures::future::select(rate_limited.select_next_some(), timeout).await; - match either { - Either::Left((_, to)) => { + let item = tokio::select! { + biased; + _ = &mut timeout => None, + item = rate_limited.next() => item, + }; + match item { + Some(_) => { count += 1; - timeout = to; }, - Either::Right(_) => break, + None => break, } } // Test that at least 1 restock happens. diff --git a/comms/src/runtime.rs b/comms/src/runtime.rs index a6f7e615f7..48752c05d0 100644 --- a/comms/src/runtime.rs +++ b/comms/src/runtime.rs @@ -25,10 +25,7 @@ use tokio::runtime; // Re-export pub use tokio::{runtime::Handle, task}; -#[cfg(test)] -pub use tokio_macros::test; -#[cfg(test)] -pub use tokio_macros::test_basic; +pub use tokio::test; /// Return the current tokio executor. Panics if the tokio runtime is not started. #[inline] diff --git a/comms/src/socks/client.rs b/comms/src/socks/client.rs index f155f2f420..7580c1dbcf 100644 --- a/comms/src/socks/client.rs +++ b/comms/src/socks/client.rs @@ -23,7 +23,6 @@ // Acknowledgement to @sticnarf for tokio-socks on which this code is based use super::error::SocksError; use data_encoding::BASE32; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use multiaddr::{Multiaddr, Protocol}; use std::{ borrow::Cow, @@ -31,6 +30,7 @@ use std::{ fmt::Formatter, net::{Ipv4Addr, Ipv6Addr}, }; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub type Result = std::result::Result; @@ -104,7 +104,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Connects to a address through a SOCKS5 proxy and returns the 'upgraded' socket. This consumes the /// `Socks5Client` as once connected, the socks protocol does not recognise any further commands. - #[tracing::instrument(name = "socks::connect", skip(self), err)] + #[tracing::instrument(name = "socks::connect", skip(self))] pub async fn connect(mut self, address: &Multiaddr) -> Result<(TSocket, Multiaddr)> { let address = self.execute_command(Command::Connect, address).await?; Ok((self.protocol.socket, address)) @@ -112,7 +112,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Requests the tor proxy to resolve a DNS address is resolved into an IP address. /// This operation only works with the tor SOCKS proxy. - #[tracing::instrument(name = "socks:tor_resolve", skip(self), err)] + #[tracing::instrument(name = "socks:tor_resolve", skip(self))] pub async fn tor_resolve(&mut self, address: &Multiaddr) -> Result { // Tor resolve does not return the port back let (dns, rest) = multiaddr_split_first(&address); @@ -126,7 +126,7 @@ where TSocket: AsyncRead + AsyncWrite + Unpin /// Requests the tor proxy to reverse resolve an IP address into a DNS address if it is able. /// This operation only works with the tor SOCKS proxy. - #[tracing::instrument(name = "socks::tor_resolve_ptr", skip(self), err)] + #[tracing::instrument(name = "socks::tor_resolve_ptr", skip(self))] pub async fn tor_resolve_ptr(&mut self, address: &Multiaddr) -> Result { self.execute_command(Command::TorResolvePtr, address).await } diff --git a/comms/src/test_utils/mocks/connection_manager.rs b/comms/src/test_utils/mocks/connection_manager.rs index cc489af60e..ece7224f44 100644 --- a/comms/src/test_utils/mocks/connection_manager.rs +++ b/comms/src/test_utils/mocks/connection_manager.rs @@ -31,7 +31,6 @@ use crate::{ peer_manager::NodeId, runtime::task, }; -use futures::{channel::mpsc, lock::Mutex, stream::Fuse, StreamExt}; use std::{ collections::HashMap, sync::{ @@ -39,14 +38,14 @@ use std::{ Arc, }, }; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc, Mutex}; pub fn create_connection_manager_mock() -> (ConnectionManagerRequester, ConnectionManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectionManagerRequester::new(tx, event_tx.clone()), - ConnectionManagerMock::new(rx.fuse(), event_tx), + ConnectionManagerMock::new(rx, event_tx), ) } @@ -97,13 +96,13 @@ impl ConnectionManagerMockState { } pub struct ConnectionManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectionManagerMockState, } impl ConnectionManagerMock { pub fn new( - receiver: Fuse>, + receiver: mpsc::Receiver, event_tx: broadcast::Sender>, ) -> Self { Self { @@ -121,7 +120,7 @@ impl ConnectionManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/connectivity_manager.rs b/comms/src/test_utils/mocks/connectivity_manager.rs index 122a60127b..78394f428b 100644 --- a/comms/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/src/test_utils/mocks/connectivity_manager.rs @@ -22,32 +22,36 @@ use crate::{ connection_manager::{ConnectionManagerError, PeerConnection}, - connectivity::{ConnectivityEvent, ConnectivityRequest, ConnectivityRequester, ConnectivityStatus}, + connectivity::{ + ConnectivityEvent, + ConnectivityEventTx, + ConnectivityRequest, + ConnectivityRequester, + ConnectivityStatus, + }, peer_manager::NodeId, runtime::task, }; -use futures::{ - channel::{mpsc, oneshot}, - lock::Mutex, - stream::Fuse, - StreamExt, -}; +use futures::lock::Mutex; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::broadcast, time}; +use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time, +}; pub fn create_connectivity_mock() -> (ConnectivityRequester, ConnectivityManagerMock) { let (tx, rx) = mpsc::channel(10); let (event_tx, _) = broadcast::channel(10); ( ConnectivityRequester::new(tx, event_tx.clone()), - ConnectivityManagerMock::new(rx.fuse(), event_tx), + ConnectivityManagerMock::new(rx, event_tx), ) } #[derive(Debug, Clone)] pub struct ConnectivityManagerMockState { inner: Arc>, - event_tx: broadcast::Sender>, + event_tx: ConnectivityEventTx, } #[derive(Debug, Default)] @@ -61,7 +65,7 @@ struct State { } impl ConnectivityManagerMockState { - pub fn new(event_tx: broadcast::Sender>) -> Self { + pub fn new(event_tx: ConnectivityEventTx) -> Self { Self { event_tx, inner: Default::default(), @@ -132,7 +136,7 @@ impl ConnectivityManagerMockState { count, self.call_count().await ); - time::delay_for(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(100)).await; } } @@ -156,9 +160,8 @@ impl ConnectivityManagerMockState { .await } - #[allow(dead_code)] pub fn publish_event(&self, event: ConnectivityEvent) { - self.event_tx.send(Arc::new(event)).unwrap(); + self.event_tx.send(event).unwrap(); } pub(self) async fn with_state(&self, f: F) -> R @@ -169,15 +172,12 @@ impl ConnectivityManagerMockState { } pub struct ConnectivityManagerMock { - receiver: Fuse>, + receiver: mpsc::Receiver, state: ConnectivityManagerMockState, } impl ConnectivityManagerMock { - pub fn new( - receiver: Fuse>, - event_tx: broadcast::Sender>, - ) -> Self { + pub fn new(receiver: mpsc::Receiver, event_tx: ConnectivityEventTx) -> Self { Self { receiver, state: ConnectivityManagerMockState::new(event_tx), @@ -195,7 +195,7 @@ impl ConnectivityManagerMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index 13b0c77bbf..bccdec9575 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -34,15 +34,18 @@ use crate::{ peer_manager::{NodeId, Peer, PeerFeatures}, test_utils::{node_identity::build_node_identity, transport}, }; -use futures::{channel::mpsc, lock::Mutex, StreamExt}; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use tokio::runtime::Handle; +use tokio::{ + runtime::Handle, + sync::{mpsc, Mutex}, +}; +use tokio_stream::StreamExt; pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::Receiver) { - let (tx, rx) = mpsc::channel(0); + let (tx, rx) = mpsc::channel(1); ( PeerConnection::new( 1, @@ -114,7 +117,7 @@ pub async fn new_peer_connection_mock_pair() -> ( create_peer_connection_mock_pair(peer1, peer2).await } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct PeerConnectionMockState { call_count: Arc, mux_control: Arc>, @@ -181,7 +184,7 @@ impl PeerConnectionMock { } pub async fn run(mut self) { - while let Some(req) = self.receiver.next().await { + while let Some(req) = self.receiver.recv().await { self.handle_request(req).await; } } diff --git a/comms/src/test_utils/test_node.rs b/comms/src/test_utils/test_node.rs index a150a765f8..3e5d7229a0 100644 --- a/comms/src/test_utils/test_node.rs +++ b/comms/src/test_utils/test_node.rs @@ -29,12 +29,14 @@ use crate::{ protocol::Protocols, transports::Transport, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_shutdown::ShutdownSignal; use tari_storage::HashmapDatabase; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; #[derive(Clone, Debug)] pub struct TestNodeConfig { diff --git a/comms/src/tor/control_client/client.rs b/comms/src/tor/control_client/client.rs index 5d14c770b7..573bce60c0 100644 --- a/comms/src/tor/control_client/client.rs +++ b/comms/src/tor/control_client/client.rs @@ -34,10 +34,13 @@ use crate::{ tor::control_client::{event::TorControlEvent, monitor::spawn_monitor}, transports::{TcpTransport, Transport}, }; -use futures::{channel::mpsc, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use log::*; use std::{borrow::Cow, fmt, fmt::Display, num::NonZeroU16}; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; +use tokio_stream::wrappers::BroadcastStream; /// Client for the Tor control port. /// @@ -80,8 +83,8 @@ impl TorControlPortClient { &self.event_tx } - pub fn get_event_stream(&self) -> broadcast::Receiver { - self.event_tx.subscribe() + pub fn get_event_stream(&self) -> BroadcastStream { + BroadcastStream::new(self.event_tx.subscribe()) } /// Authenticate with the tor control port @@ -232,8 +235,7 @@ impl TorControlPortClient { } async fn receive_line(&mut self) -> Result { - let line = self.output_stream.next().await.ok_or(TorClientError::UnexpectedEof)?; - + let line = self.output_stream.recv().await.ok_or(TorClientError::UnexpectedEof)?; Ok(line) } } @@ -273,9 +275,11 @@ mod test { runtime, tor::control_client::{test_server, test_server::canned_responses, types::PrivateKey}, }; - use futures::{future, AsyncWriteExt}; + use futures::future; use std::net::SocketAddr; use tari_test_utils::unpack_enum; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; async fn setup_test() -> (TorControlPortClient, test_server::State) { let (_, mock_state, socket) = test_server::spawn().await; @@ -298,7 +302,7 @@ mod test { let _out_sock = result_out.unwrap(); let (mut in_sock, _) = result_in.unwrap().unwrap(); in_sock.write(b"test123").await.unwrap(); - in_sock.close().await.unwrap(); + in_sock.shutdown().await.unwrap(); } #[runtime::test] diff --git a/comms/src/tor/control_client/monitor.rs b/comms/src/tor/control_client/monitor.rs index a5191b466d..72eb6b88ef 100644 --- a/comms/src/tor/control_client/monitor.rs +++ b/comms/src/tor/control_client/monitor.rs @@ -21,11 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::{event::TorControlEvent, parsers, response::ResponseLine, LOG_TARGET}; -use crate::{compat::IoCompat, runtime::task}; -use futures::{channel::mpsc, future, future::Either, AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use crate::runtime::task; +use futures::{future::Either, SinkExt, Stream, StreamExt}; use log::*; use std::fmt; -use tokio::sync::broadcast; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc}, +}; use tokio_util::codec::{Framed, LinesCodec}; pub fn spawn_monitor( @@ -36,16 +39,19 @@ pub fn spawn_monitor( where TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut responses_tx, responses_rx) = mpsc::channel(100); + let (responses_tx, responses_rx) = mpsc::channel(100); task::spawn(async move { - let framed = Framed::new(IoCompat::new(socket), LinesCodec::new()); + let framed = Framed::new(socket, LinesCodec::new()); let (mut sink, mut stream) = framed.split(); loop { - let either = future::select(cmd_rx.next(), stream.next()).await; + let either = tokio::select! { + next = cmd_rx.recv() => Either::Left(next), + next = stream.next() => Either::Right(next), + }; match either { // Received a command to send to the control server - Either::Left((Some(line), _)) => { + Either::Left(Some(line)) => { trace!(target: LOG_TARGET, "Writing command of length '{}'", line.len()); if let Err(err) = sink.send(line).await { error!( @@ -56,7 +62,7 @@ where } }, // Command stream ended - Either::Left((None, _)) => { + Either::Left(None) => { debug!( target: LOG_TARGET, "Tor control server command receiver closed. Monitor is exiting." @@ -65,7 +71,7 @@ where }, // Received a line from the control server - Either::Right((Some(Ok(line)), _)) => { + Either::Right(Some(Ok(line))) => { trace!(target: LOG_TARGET, "Read line of length '{}'", line.len()); match parsers::response_line(&line) { Ok(mut line) => { @@ -95,7 +101,7 @@ where }, // Error receiving a line from the control server - Either::Right((Some(Err(err)), _)) => { + Either::Right(Some(Err(err))) => { error!( target: LOG_TARGET, "Line framing error when reading from tor control server: '{:?}'. Monitor is exiting.", err @@ -103,7 +109,7 @@ where break; }, // The control server disconnected - Either::Right((None, _)) => { + Either::Right(None) => { cmd_rx.close(); debug!( target: LOG_TARGET, diff --git a/comms/src/tor/control_client/test_server.rs b/comms/src/tor/control_client/test_server.rs index 5a5e1b3b7c..1741cfc721 100644 --- a/comms/src/tor/control_client/test_server.rs +++ b/comms/src/tor/control_client/test_server.rs @@ -20,13 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - compat::IoCompat, - memsocket::MemorySocket, - multiaddr::Multiaddr, - runtime, - test_utils::transport::build_connected_sockets, -}; +use crate::{memsocket::MemorySocket, multiaddr::Multiaddr, runtime, test_utils::transport::build_connected_sockets}; use futures::{lock::Mutex, stream, SinkExt, StreamExt}; use std::sync::Arc; use tokio_util::codec::{Framed, LinesCodec}; @@ -82,7 +76,7 @@ impl TorControlPortTestServer { } pub async fn run(self) { - let mut framed = Framed::new(IoCompat::new(self.socket), LinesCodec::new()); + let mut framed = Framed::new(self.socket, LinesCodec::new()); let state = self.state; while let Some(msg) = framed.next().await { state.request_lines.lock().await.push(msg.unwrap()); diff --git a/comms/src/tor/hidden_service/controller.rs b/comms/src/tor/hidden_service/controller.rs index e19818df58..74b89808fb 100644 --- a/comms/src/tor/hidden_service/controller.rs +++ b/comms/src/tor/hidden_service/controller.rs @@ -214,7 +214,7 @@ impl HiddenServiceController { "Failed to reestablish connection with tor control server because '{:?}'", err ); warn!(target: LOG_TARGET, "Will attempt again in 5 seconds..."); - time::delay_for(Duration::from_secs(5)).await; + time::sleep(Duration::from_secs(5)).await; }, Either::Right(_) => { diff --git a/comms/src/transports/dns/tor.rs b/comms/src/transports/dns/tor.rs index 737971af9a..e198efdc0b 100644 --- a/comms/src/transports/dns/tor.rs +++ b/comms/src/transports/dns/tor.rs @@ -69,7 +69,7 @@ impl DnsResolver for TorDnsResolver { let resolved = match client.tor_resolve(&addr).await { Ok(a) => a, Err(err) => { - error!(target: LOG_TARGET, "{}", err); + error!(target: LOG_TARGET, "Error resolving address: {}", err); return Err(err.into()); }, }; diff --git a/comms/src/transports/memory.rs b/comms/src/transports/memory.rs index 074e728274..8e3a1d7a91 100644 --- a/comms/src/transports/memory.rs +++ b/comms/src/transports/memory.rs @@ -129,7 +129,8 @@ impl Stream for Listener { mod test { use super::*; use crate::runtime; - use futures::{future::join, stream::StreamExt, AsyncReadExt, AsyncWriteExt}; + use futures::{future::join, stream::StreamExt}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[runtime::test] async fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { diff --git a/comms/src/transports/mod.rs b/comms/src/transports/mod.rs index 0c04e50a72..75b16db975 100644 --- a/comms/src/transports/mod.rs +++ b/comms/src/transports/mod.rs @@ -24,8 +24,8 @@ // Copyright (c) The Libra Core Contributors // SPDX-License-Identifier: Apache-2.0 -use futures::Stream; use multiaddr::Multiaddr; +use tokio_stream::Stream; mod dns; mod helpers; @@ -37,7 +37,7 @@ mod socks; pub use socks::{SocksConfig, SocksTransport}; mod tcp; -pub use tcp::{TcpSocket, TcpTransport}; +pub use tcp::TcpTransport; mod tcp_with_tor; pub use tcp_with_tor::TcpWithTorTransport; diff --git a/comms/src/transports/socks.rs b/comms/src/transports/socks.rs index 2027f34561..7dc87ef0e4 100644 --- a/comms/src/transports/socks.rs +++ b/comms/src/transports/socks.rs @@ -24,12 +24,13 @@ use crate::{ multiaddr::Multiaddr, socks, socks::Socks5Client, - transports::{dns::SystemDnsResolver, tcp::TcpTransport, TcpSocket, Transport}, + transports::{dns::SystemDnsResolver, tcp::TcpTransport, Transport}, }; -use std::{io, time::Duration}; +use std::io; +use tokio::net::TcpStream; -/// SO_KEEPALIVE setting for the SOCKS TCP connection -const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); +// /// SO_KEEPALIVE setting for the SOCKS TCP connection +// const SOCKS_SO_KEEPALIVE: Duration = Duration::from_millis(1500); #[derive(Clone, Debug)] pub struct SocksConfig { @@ -57,7 +58,7 @@ impl SocksTransport { pub fn create_socks_tcp_transport() -> TcpTransport { let mut tcp_transport = TcpTransport::new(); tcp_transport.set_nodelay(true); - tcp_transport.set_keepalive(Some(SOCKS_SO_KEEPALIVE)); + // .set_keepalive(Some(SOCKS_SO_KEEPALIVE)) tcp_transport.set_dns_resolver(SystemDnsResolver); tcp_transport } @@ -66,7 +67,7 @@ impl SocksTransport { tcp: TcpTransport, socks_config: SocksConfig, dest_addr: Multiaddr, - ) -> io::Result { + ) -> io::Result { // Create a new connection to the SOCKS proxy let socks_conn = tcp.dial(socks_config.proxy_address).await?; let mut client = Socks5Client::new(socks_conn); diff --git a/comms/src/transports/tcp.rs b/comms/src/transports/tcp.rs index 8112cca203..6b47e7c357 100644 --- a/comms/src/transports/tcp.rs +++ b/comms/src/transports/tcp.rs @@ -25,44 +25,42 @@ use crate::{ transports::dns::{DnsResolverRef, SystemDnsResolver}, utils::multiaddr::socketaddr_to_multiaddr, }; -use futures::{io::Error, ready, AsyncRead, AsyncWrite, Future, FutureExt, Stream}; +use futures::{ready, FutureExt}; use multiaddr::Multiaddr; use std::{ + future::Future, io, pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}, - net::{TcpListener, TcpStream}, }; +use tokio::net::{TcpListener, TcpStream}; +use tokio_stream::Stream; /// Transport implementation for TCP #[derive(Clone)] pub struct TcpTransport { - recv_buffer_size: Option, - send_buffer_size: Option, + // recv_buffer_size: Option, + // send_buffer_size: Option, ttl: Option, - #[allow(clippy::option_option)] - keepalive: Option>, + // #[allow(clippy::option_option)] + // keepalive: Option>, nodelay: Option, dns_resolver: DnsResolverRef, } impl TcpTransport { - #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] - setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); - - #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] - setter_mut!(set_send_buffer_size, send_buffer_size, Option); + // #[doc("Sets `SO_RCVBUF` i.e the size of the receive buffer.")] + // setter_mut!(set_recv_buffer_size, recv_buffer_size, Option); + // + // #[doc("Sets `SO_SNDBUF` i.e. the size of the send buffer.")] + // setter_mut!(set_send_buffer_size, send_buffer_size, Option); #[doc("Sets `IP_TTL` i.e. the TTL of packets sent from this socket.")] setter_mut!(set_ttl, ttl, Option); - #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] - setter_mut!(set_keepalive, keepalive, Option>); + // #[doc("Sets `SO_KEEPALIVE` i.e. the interval to send keepalive probes, or None to disable.")] + // setter_mut!(set_keepalive, keepalive, Option>); #[doc("Sets `TCP_NODELAY` i.e disable Nagle's algorithm if set to true.")] setter_mut!(set_nodelay, nodelay, Option); @@ -81,9 +79,10 @@ impl TcpTransport { /// Apply socket options to `TcpStream`. fn configure(&self, socket: &TcpStream) -> io::Result<()> { - if let Some(keepalive) = self.keepalive { - socket.set_keepalive(keepalive)?; - } + // https://github.com/rust-lang/rust/issues/69774 + // if let Some(keepalive) = self.keepalive { + // socket.set_keepalive(keepalive)?; + // } if let Some(ttl) = self.ttl { socket.set_ttl(ttl)?; @@ -93,13 +92,13 @@ impl TcpTransport { socket.set_nodelay(nodelay)?; } - if let Some(recv_buffer_size) = self.recv_buffer_size { - socket.set_recv_buffer_size(recv_buffer_size)?; - } - - if let Some(send_buffer_size) = self.send_buffer_size { - socket.set_send_buffer_size(send_buffer_size)?; - } + // if let Some(recv_buffer_size) = self.recv_buffer_size { + // socket.set_recv_buffer_size(recv_buffer_size)?; + // } + // + // if let Some(send_buffer_size) = self.send_buffer_size { + // socket.set_send_buffer_size(send_buffer_size)?; + // } Ok(()) } @@ -108,10 +107,10 @@ impl TcpTransport { impl Default for TcpTransport { fn default() -> Self { Self { - recv_buffer_size: None, - send_buffer_size: None, + // recv_buffer_size: None, + // send_buffer_size: None, ttl: None, - keepalive: None, + // keepalive: None, nodelay: None, dns_resolver: Arc::new(SystemDnsResolver), } @@ -122,7 +121,7 @@ impl Default for TcpTransport { impl Transport for TcpTransport { type Error = io::Error; type Listener = TcpInbound; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { let socket_addr = self @@ -161,12 +160,12 @@ impl TcpOutbound { impl Future for TcpOutbound where F: Future> + Unpin { - type Output = io::Result; + type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let socket = ready!(Pin::new(&mut self.future).poll(cx))?; - self.config.configure(&socket)?; - Poll::Ready(Ok(TcpSocket::new(socket))) + let stream = ready!(Pin::new(&mut self.future).poll(cx))?; + self.config.configure(&stream)?; + Poll::Ready(Ok(stream)) } } @@ -184,52 +183,14 @@ impl TcpInbound { } impl Stream for TcpInbound { - type Item = io::Result<(TcpSocket, Multiaddr)>; + type Item = io::Result<(TcpStream, Multiaddr)>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let (socket, addr) = ready!(self.listener.poll_accept(cx))?; // Configure each socket self.config.configure(&socket)?; let peer_addr = socketaddr_to_multiaddr(&addr); - Poll::Ready(Some(Ok((TcpSocket::new(socket), peer_addr)))) - } -} - -/// TcpSocket is a wrapper struct for tokio `TcpStream` and implements -/// `futures-rs` AsyncRead/Write -pub struct TcpSocket { - inner: TcpStream, -} - -impl TcpSocket { - pub fn new(stream: TcpStream) -> Self { - Self { inner: stream } - } -} - -impl AsyncWrite for TcpSocket { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -impl AsyncRead for TcpSocket { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl From for TcpSocket { - fn from(stream: TcpStream) -> Self { - Self { inner: stream } + Poll::Ready(Some(Ok((socket, peer_addr)))) } } @@ -240,16 +201,15 @@ mod test { #[test] fn configure() { let mut tcp = TcpTransport::new(); - tcp.set_send_buffer_size(123) - .set_recv_buffer_size(456) - .set_nodelay(true) - .set_ttl(789) - .set_keepalive(Some(Duration::from_millis(100))); - - assert_eq!(tcp.send_buffer_size, Some(123)); - assert_eq!(tcp.recv_buffer_size, Some(456)); + // tcp.set_send_buffer_size(123) + // .set_recv_buffer_size(456) + tcp.set_nodelay(true).set_ttl(789); + // .set_keepalive(Some(Duration::from_millis(100))); + + // assert_eq!(tcp.send_buffer_size, Some(123)); + // assert_eq!(tcp.recv_buffer_size, Some(456)); assert_eq!(tcp.nodelay, Some(true)); assert_eq!(tcp.ttl, Some(789)); - assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); + // assert_eq!(tcp.keepalive, Some(Some(Duration::from_millis(100)))); } } diff --git a/comms/src/transports/tcp_with_tor.rs b/comms/src/transports/tcp_with_tor.rs index de54e17bb5..be800d9bfb 100644 --- a/comms/src/transports/tcp_with_tor.rs +++ b/comms/src/transports/tcp_with_tor.rs @@ -21,16 +21,10 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::Transport; -use crate::transports::{ - dns::TorDnsResolver, - helpers::is_onion_address, - SocksConfig, - SocksTransport, - TcpSocket, - TcpTransport, -}; +use crate::transports::{dns::TorDnsResolver, helpers::is_onion_address, SocksConfig, SocksTransport, TcpTransport}; use multiaddr::Multiaddr; use std::io; +use tokio::net::TcpStream; /// Transport implementation for TCP with Tor support #[derive(Clone, Default)] @@ -69,7 +63,7 @@ impl TcpWithTorTransport { impl Transport for TcpWithTorTransport { type Error = io::Error; type Listener = ::Listener; - type Output = TcpSocket; + type Output = TcpStream; async fn listen(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { self.tcp_transport.listen(addr).await diff --git a/comms/src/utils/mod.rs b/comms/src/utils/mod.rs index 697a758c1f..e543e644bb 100644 --- a/comms/src/utils/mod.rs +++ b/comms/src/utils/mod.rs @@ -22,5 +22,6 @@ pub mod cidr; pub mod datetime; +pub mod mpsc; pub mod multiaddr; pub mod signature; diff --git a/comms/src/common/mod.rs b/comms/src/utils/mpsc.rs similarity index 84% rename from comms/src/common/mod.rs rename to comms/src/utils/mpsc.rs index 2ac0cd0a6f..8ded39967f 100644 --- a/comms/src/common/mod.rs +++ b/comms/src/utils/mpsc.rs @@ -1,4 +1,4 @@ -// Copyright 2020, The Tari Project +// Copyright 2021, The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -20,4 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -pub mod rate_limit; +use tokio::sync::mpsc; + +pub async fn send_all>( + sender: &mpsc::Sender, + iter: I, +) -> Result<(), mpsc::error::SendError> { + for item in iter { + sender.send(item).await?; + } + Ok(()) +} diff --git a/comms/tests/greeting_service.rs b/comms/tests/greeting_service.rs index c13ae842e4..e70b6e16ff 100644 --- a/comms/tests/greeting_service.rs +++ b/comms/tests/greeting_service.rs @@ -23,14 +23,14 @@ #![cfg(feature = "rpc")] use core::iter; -use futures::{channel::mpsc, stream, SinkExt, StreamExt}; use std::{cmp, time::Duration}; use tari_comms::{ async_trait, protocol::rpc::{Request, Response, RpcStatus, Streaming}, + utils, }; use tari_comms_rpc_macros::tari_rpc; -use tokio::{task, time}; +use tokio::{sync::mpsc, task, time}; #[tari_rpc(protocol_name = b"t/greeting/1", server_struct = GreetingServer, client_struct = GreetingClient)] pub trait GreetingRpc: Send + Sync + 'static { @@ -85,15 +85,9 @@ impl GreetingRpc for GreetingService { async fn get_greetings(&self, request: Request) -> Result, RpcStatus> { let num = *request.message(); - let (mut tx, rx) = mpsc::channel(num as usize); + let (tx, rx) = mpsc::channel(num as usize); let greetings = self.greetings[..cmp::min(num as usize + 1, self.greetings.len())].to_vec(); - task::spawn(async move { - let iter = greetings.into_iter().map(Ok); - let mut stream = stream::iter(iter) - // "Extra" Result::Ok is to satisfy send_all - .map(Ok); - tx.send_all(&mut stream).await.unwrap(); - }); + task::spawn(async move { utils::mpsc::send_all(&tx, greetings.into_iter().map(Ok)).await }); Ok(Streaming::new(rx)) } @@ -113,7 +107,7 @@ impl GreetingRpc for GreetingService { item_size, num_items, } = request.into_message(); - let (mut tx, rx) = mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let t = std::time::Instant::now(); task::spawn(async move { let item = iter::repeat(0u8).take(item_size as usize).collect::>(); @@ -136,7 +130,7 @@ impl GreetingRpc for GreetingService { } async fn slow_response(&self, request: Request) -> Result, RpcStatus> { - time::delay_for(Duration::from_secs(request.into_message())).await; + time::sleep(Duration::from_secs(request.into_message())).await; Ok(Response::new(())) } } diff --git a/comms/tests/rpc_stress.rs b/comms/tests/rpc_stress.rs index 0376a7dddc..933e158398 100644 --- a/comms/tests/rpc_stress.rs +++ b/comms/tests/rpc_stress.rs @@ -155,6 +155,7 @@ async fn run_stress_test(test_params: Params) { } future::join_all(tasks).await.into_iter().for_each(Result::unwrap); + log::info!("Stress test took {:.2?}", time.elapsed()); } @@ -259,7 +260,7 @@ async fn high_contention_high_concurrency() { .await; } -#[tokio_macros::test] +#[tokio::test] async fn run() { // let _ = env_logger::try_init(); log_timing("quick", quick()).await; diff --git a/comms/tests/substream_stress.rs b/comms/tests/substream_stress.rs index cbae6e8e52..a72eff03f3 100644 --- a/comms/tests/substream_stress.rs +++ b/comms/tests/substream_stress.rs @@ -23,7 +23,7 @@ mod helpers; use helpers::create_comms; -use futures::{channel::mpsc, future, SinkExt, StreamExt}; +use futures::{future, SinkExt, StreamExt}; use std::time::Duration; use tari_comms::{ framing, @@ -35,7 +35,7 @@ use tari_comms::{ }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::unpack_enum; -use tokio::{task, time::Instant}; +use tokio::{sync::mpsc, task, time::Instant}; const PROTOCOL_NAME: &[u8] = b"test/dummy/protocol"; @@ -79,7 +79,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s task::spawn({ let sample = sample.clone(); async move { - while let Some(event) = notif_rx.next().await { + while let Some(event) = notif_rx.recv().await { unpack_enum!(ProtocolEvent::NewInboundSubstream(_n, remote_substream) = event.event); let mut remote_substream = framing::canonical(remote_substream, frame_size); @@ -150,7 +150,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s println!("avg t = {}ms", avg); } -#[tokio_macros::test] +#[tokio::test] async fn many_at_frame_limit() { const NUM_SUBSTREAMS: usize = 20; const NUM_ITERATIONS_PER_STREAM: usize = 100; diff --git a/infrastructure/shutdown/Cargo.toml b/infrastructure/shutdown/Cargo.toml index 1102ec037a..176bace6e2 100644 --- a/infrastructure/shutdown/Cargo.toml +++ b/infrastructure/shutdown/Cargo.toml @@ -12,7 +12,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "^0.3.1" +futures = "^0.3" [dev-dependencies] -tokio = {version="^0.2", features=["rt-core"]} +tokio = {version="1", default-features = false, features = ["rt", "macros"]} diff --git a/infrastructure/shutdown/src/lib.rs b/infrastructure/shutdown/src/lib.rs index f0054bd843..5afd56a239 100644 --- a/infrastructure/shutdown/src/lib.rs +++ b/infrastructure/shutdown/src/lib.rs @@ -26,14 +26,16 @@ #![deny(unused_must_use)] #![deny(unreachable_patterns)] #![deny(unknown_lints)] -use futures::{ - channel::{oneshot, oneshot::Canceled}, - future::{Fuse, FusedFuture, Shared}, + +pub mod oneshot_trigger; + +use crate::oneshot_trigger::OneshotSignal; +use futures::future::FusedFuture; +use std::{ + future::Future, + pin::Pin, task::{Context, Poll}, - Future, - FutureExt, }; -use std::pin::Pin; /// Trigger for shutdowns. /// @@ -42,71 +44,69 @@ use std::pin::Pin; /// /// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as /// long as required by the application. -pub struct Shutdown { - trigger: Option>, - signal: ShutdownSignal, - on_triggered: Option>, -} - +pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>); impl Shutdown { - /// Create a new Shutdown pub fn new() -> Self { - let (tx, rx) = oneshot::channel(); - Self { - trigger: Some(tx), - signal: rx.fuse().shared(), - on_triggered: None, - } + Self(oneshot_trigger::OneshotTrigger::new()) } - /// Set the on_triggered callback - pub fn on_triggered(&mut self, on_trigger: F) -> &mut Self - where F: FnOnce() + Send + Sync + 'static { - self.on_triggered = Some(Box::new(on_trigger)); - self + pub fn trigger(&mut self) { + self.0.broadcast(()); + } + + pub fn is_triggered(&self) -> bool { + self.0.is_used() } - /// Convert this into a ShutdownSignal without consuming the - /// struct. pub fn to_signal(&self) -> ShutdownSignal { - self.signal.clone() + self.0.to_signal().into() } +} - /// Trigger any listening signals - pub fn trigger(&mut self) -> Result<(), ShutdownError> { - match self.trigger.take() { - Some(trigger) => { - trigger.send(()).map_err(|_| ShutdownError)?; +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} - if let Some(on_triggered) = self.on_triggered.take() { - on_triggered(); - } +/// Receiver end of a shutdown signal. Once received the consumer should shut down. +#[derive(Debug, Clone)] +pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>); - Ok(()) - }, - None => Ok(()), - } +impl ShutdownSignal { + pub fn is_triggered(&self) -> bool { + self.0.is_terminated() } - pub fn is_triggered(&self) -> bool { - self.trigger.is_none() + /// Wait for the shutdown signal to trigger. + pub fn wait(&mut self) -> &mut Self { + self } } -impl Drop for Shutdown { - fn drop(&mut self) { - let _ = self.trigger(); +impl Future for ShutdownSignal { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.0).poll(cx) { + // Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } } } -impl Default for Shutdown { - fn default() -> Self { - Self::new() +impl FusedFuture for ShutdownSignal { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } -/// Receiver end of a shutdown signal. Once received the consumer should shut down. -pub type ShutdownSignal = Shared>>; +impl From> for ShutdownSignal { + fn from(inner: OneshotSignal<()>) -> Self { + Self(inner) + } +} #[derive(Debug, Clone, Default)] pub struct OptionalShutdownSignal(Option); @@ -137,11 +137,11 @@ impl OptionalShutdownSignal { } impl Future for OptionalShutdownSignal { - type Output = Result<(), Canceled>; + type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.as_mut() { - Some(inner) => inner.poll_unpin(cx), + Some(inner) => Pin::new(inner).poll(cx), None => Poll::Pending, } } @@ -165,73 +165,50 @@ impl FusedFuture for OptionalShutdownSignal { } } -#[derive(Debug)] -pub struct ShutdownError; - #[cfg(test)] mod test { use super::*; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; - use tokio::runtime::Runtime; - - #[test] - fn trigger() { - let rt = Runtime::new().unwrap(); + use tokio::task; + + #[tokio::test] + async fn trigger() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); assert!(!shutdown.is_triggered()); - rt.spawn(async move { - signal.await.unwrap(); + let fut = task::spawn(async move { + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + assert!(shutdown.is_triggered()); // Shutdown::trigger is idempotent - shutdown.trigger().unwrap(); + shutdown.trigger(); assert!(shutdown.is_triggered()); + fut.await.unwrap(); } - #[test] - fn signal_clone() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn signal_clone() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); - shutdown.trigger().unwrap(); + shutdown.trigger(); + fut.await.unwrap(); } - #[test] - fn drop_trigger() { - let rt = Runtime::new().unwrap(); + #[tokio::test] + async fn drop_trigger() { let shutdown = Shutdown::new(); let signal = shutdown.to_signal(); let signal_clone = signal.clone(); - rt.spawn(async move { - signal_clone.await.unwrap(); - signal.await.unwrap(); + let fut = task::spawn(async move { + signal_clone.await; + signal.await; }); drop(shutdown); - } - - #[test] - fn on_trigger() { - let rt = Runtime::new().unwrap(); - let spy = Arc::new(AtomicBool::new(false)); - let spy_clone = Arc::clone(&spy); - let mut shutdown = Shutdown::new(); - shutdown.on_triggered(move || { - spy_clone.store(true, Ordering::SeqCst); - }); - let signal = shutdown.to_signal(); - rt.spawn(async move { - signal.await.unwrap(); - }); - shutdown.trigger().unwrap(); - assert!(spy.load(Ordering::SeqCst)); + fut.await.unwrap(); } } diff --git a/infrastructure/shutdown/src/oneshot_trigger.rs b/infrastructure/shutdown/src/oneshot_trigger.rs new file mode 100644 index 0000000000..7d2ee5b46a --- /dev/null +++ b/infrastructure/shutdown/src/oneshot_trigger.rs @@ -0,0 +1,106 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{ + channel::{oneshot, oneshot::Receiver}, + future::{FusedFuture, Shared}, + FutureExt, +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +pub fn channel() -> OneshotTrigger { + OneshotTrigger::new() +} + +pub struct OneshotTrigger { + sender: Option>, + signal: OneshotSignal, +} + +impl OneshotTrigger { + pub fn new() -> Self { + let (tx, rx) = oneshot::channel(); + Self { + sender: Some(tx), + signal: rx.shared().into(), + } + } + + pub fn to_signal(&self) -> OneshotSignal { + self.signal.clone() + } + + pub fn broadcast(&mut self, item: T) { + if let Some(tx) = self.sender.take() { + let _ = tx.send(item); + } + } + + pub fn is_used(&self) -> bool { + self.sender.is_none() + } +} + +impl Default for OneshotTrigger { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OneshotSignal { + inner: Shared>, +} + +impl From>> for OneshotSignal { + fn from(inner: Shared>) -> Self { + Self { inner } + } +} + +impl Future for OneshotSignal { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.is_terminated() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(Some(v)), + // Channel canceled + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl FusedFuture for OneshotSignal { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} diff --git a/infrastructure/storage/Cargo.toml b/infrastructure/storage/Cargo.toml index 5db74c5097..3915bde787 100644 --- a/infrastructure/storage/Cargo.toml +++ b/infrastructure/storage/Cargo.toml @@ -13,13 +13,13 @@ edition = "2018" bincode = "1.1" log = "0.4.0" lmdb-zero = "0.4.4" -thiserror = "1.0.20" +thiserror = "1.0.26" rmp = "0.8.7" rmp-serde = "0.13.7" serde = "1.0.80" serde_derive = "1.0.80" tari_utilities = "^0.3" -bytes = "0.4.12" +bytes = "0.5" [dev-dependencies] rand = "0.8" diff --git a/infrastructure/test_utils/Cargo.toml b/infrastructure/test_utils/Cargo.toml index bee18392c9..c2ea538ffb 100644 --- a/infrastructure/test_utils/Cargo.toml +++ b/infrastructure/test_utils/Cargo.toml @@ -10,9 +10,10 @@ license = "BSD-3-Clause" [dependencies] tari_shutdown = {version="*", path="../shutdown"} + futures-test = { version = "^0.3.1" } futures = {version= "^0.3.1"} rand = "0.8" -tokio = {version= "0.2.10", features=["rt-threaded", "time", "io-driver"]} +tokio = {version= "1.10", features=["rt-multi-thread", "time"]} lazy_static = "1.3.0" tempfile = "3.1.0" diff --git a/infrastructure/test_utils/src/futures/async_assert_eventually.rs b/infrastructure/test_utils/src/futures/async_assert_eventually.rs index cddfb29b37..f9a1f3ee9d 100644 --- a/infrastructure/test_utils/src/futures/async_assert_eventually.rs +++ b/infrastructure/test_utils/src/futures/async_assert_eventually.rs @@ -46,7 +46,7 @@ macro_rules! async_assert_eventually { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; value = $check_expr; } }}; @@ -82,7 +82,7 @@ macro_rules! async_assert { $max_attempts ); } - tokio::time::delay_for($interval).await; + tokio::time::sleep($interval).await; } }}; ($check_expr:expr$(,)?) => {{ diff --git a/infrastructure/test_utils/src/runtime.rs b/infrastructure/test_utils/src/runtime.rs index 7b8c5faa57..22ee5962a6 100644 --- a/infrastructure/test_utils/src/runtime.rs +++ b/infrastructure/test_utils/src/runtime.rs @@ -26,12 +26,8 @@ use tari_shutdown::Shutdown; use tokio::{runtime, runtime::Runtime, task, task::JoinError}; pub fn create_runtime() -> Runtime { - tokio::runtime::Builder::new() - .threaded_scheduler() - .enable_io() - .enable_time() - .max_threads(8) - .core_threads(4) + tokio::runtime::Builder::new_multi_thread() + .enable_all() .build() .expect("Could not create runtime") } @@ -49,7 +45,7 @@ where F: Future + Send + 'static { /// Create a runtime and report if it panics. If there are tasks still running after the panic, this /// will carry on running forever. -// #[deprecated(note = "use tokio_macros::test instead")] +// #[deprecated(note = "use tokio::test instead")] pub fn test_async(f: F) where F: FnOnce(&mut TestRuntime) { let mut rt = TestRuntime::from(create_runtime()); diff --git a/infrastructure/test_utils/src/streams/mod.rs b/infrastructure/test_utils/src/streams/mod.rs index a70e588f7e..177bea14b1 100644 --- a/infrastructure/test_utils/src/streams/mod.rs +++ b/infrastructure/test_utils/src/streams/mod.rs @@ -20,8 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{Stream, StreamExt}; -use std::{collections::HashMap, hash::Hash, time::Duration}; +use futures::{stream, Stream, StreamExt}; +use std::{borrow::BorrowMut, collections::HashMap, hash::Hash, time::Duration}; +use tokio::sync::{broadcast, mpsc}; #[allow(dead_code)] #[allow(clippy::mutable_key_type)] // Note: Clippy Breaks with Interior Mutability Error @@ -54,7 +55,6 @@ where #[macro_export] macro_rules! collect_stream { ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::{Stream, StreamExt}; use tokio::time; // Evaluate $stream once, NOT in the loop 🐛🚨 @@ -62,14 +62,17 @@ macro_rules! collect_stream { let mut items = Vec::new(); loop { - if let Some(item) = time::timeout($timeout, stream.next()).await.expect( - format!( - "Timeout before stream could collect {} item(s). Got {} item(s).", - $take, - items.len() + if let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next(stream)) + .await + .expect( + format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + ) + .as_str(), ) - .as_str(), - ) { + { items.push(item); if items.len() == $take { break items; @@ -80,11 +83,86 @@ macro_rules! collect_stream { } }}; ($stream:expr, timeout=$timeout:expr $(,)?) => {{ - use futures::StreamExt; use tokio::time; + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, futures::stream::StreamExt::next($stream)) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + + let mut items = Vec::new(); + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv ended early", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + + let mut items = Vec::new(); + while let Some(item) = time::timeout($timeout, stream.recv()) + .await + .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) + { + items.push(item); + } + items + }}; +} + +#[macro_export] +macro_rules! collect_try_recv { + ($stream:expr, take=$take:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + // Evaluate $stream once, NOT in the loop 🐛🚨 + let mut stream = &mut $stream; + let mut items = Vec::new(); - while let Some(item) = time::timeout($timeout, $stream.next()) + loop { + let item = time::timeout($timeout, stream.recv()).await.expect(&format!( + "Timeout before stream could collect {} item(s). Got {} item(s).", + $take, + items.len() + )); + + items.push(item.expect(&format!("{}/{} recv returned unexpected result", items.len(), $take))); + if items.len() == $take { + break items; + } + } + }}; + ($stream:expr, timeout=$timeout:expr $(,)?) => {{ + use tokio::time; + + let mut stream = &mut $stream; + let mut items = Vec::new(); + while let Ok(item) = time::timeout($timeout, stream.recv()) .await .expect(format!("Timeout before stream was closed. Got {} items.", items.len()).as_str()) { @@ -102,9 +180,9 @@ macro_rules! collect_stream { /// # use std::time::Duration; /// # use tari_test_utils::collect_stream_count; /// -/// let mut rt = Runtime::new().unwrap(); +/// let rt = Runtime::new().unwrap(); /// let mut stream = stream::iter(vec![1,2,2,3,2]); -/// assert_eq!(rt.block_on(async { collect_stream_count!(stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); +/// assert_eq!(rt.block_on(async { collect_stream_count!(&mut stream, timeout=Duration::from_secs(1)) }).get(&2), Some(&3)); /// ``` #[macro_export] macro_rules! collect_stream_count { @@ -139,3 +217,56 @@ where } } } + +pub async fn assert_in_mpsc(rx: &mut mpsc::Receiver, mut predicate: P, timeout: Duration) -> R +where P: FnMut(T) -> Option { + loop { + if let Some(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the mpsc stream ended"); + } + } +} + +pub async fn assert_in_broadcast(rx: &mut broadcast::Receiver, mut predicate: P, timeout: Duration) -> R +where + P: FnMut(T) -> Option, + T: Clone, +{ + loop { + if let Ok(item) = tokio::time::timeout(timeout, rx.recv()) + .await + .expect("Timeout before stream emitted") + { + if let Some(r) = (predicate)(item) { + break r; + } + } else { + panic!("Predicate did not return true before the broadcast channel ended"); + } + } +} + +pub fn convert_mpsc_to_stream(rx: &mut mpsc::Receiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_unbounded_mpsc_to_stream(rx: &mut mpsc::UnboundedReceiver) -> impl Stream + '_ { + stream::unfold(rx, |rx| async move { rx.recv().await.map(|t| (t, rx)) }) +} + +pub fn convert_broadcast_to_stream<'a, T, S>(rx: S) -> impl Stream> + 'a +where + T: Clone + Send + 'static, + S: BorrowMut> + 'a, +{ + stream::unfold(rx, |mut rx| async move { + Some(rx.borrow_mut().recv().await).map(|t| (t, rx)) + }) +} diff --git a/integration_tests/features/Mempool.feature b/integration_tests/features/Mempool.feature index 36113cd9b9..a77d837a08 100644 --- a/integration_tests/features/Mempool.feature +++ b/integration_tests/features/Mempool.feature @@ -86,6 +86,7 @@ Feature: Mempool # Collects 7 coinbases into one wallet, send 7 transactions # Stronger chain # + Given I do not expect all automated transactions to succeed Given I have a seed node SEED_A And I have a base node NODE_A1 connected to seed SEED_A And I have wallet WALLET_A1 connected to seed node SEED_A @@ -198,4 +199,4 @@ Feature: Mempool When I submit transaction TX1 to BN1 Then I wait until base node BN1 has 1 unconfirmed transactions in its mempool When I mine 1 blocks on BN1 - Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool \ No newline at end of file + Then I wait until base node BN1 has 0 unconfirmed transactions in its mempool diff --git a/integration_tests/features/Reorgs.feature b/integration_tests/features/Reorgs.feature index 008df39e8e..d249f05015 100644 --- a/integration_tests/features/Reorgs.feature +++ b/integration_tests/features/Reorgs.feature @@ -96,6 +96,7 @@ Feature: Reorgs @critical @reorg Scenario: Zero-conf reorg with spending + Given I do not expect all automated transactions to succeed Given I have a base node NODE1 connected to all seed nodes Given I have a base node NODE2 connected to node NODE1 When I mine 14 blocks on NODE1 @@ -142,6 +143,7 @@ Feature: Reorgs # Chain 1a: # Mine X1 blocks (orphan_storage_capacity default set to 10) # + Given I do not expect all automated transactions to succeed Given I have a seed node SEED_A1 # Add multiple base nodes to ensure more robust comms And I have a base node NODE_A1 connected to seed SEED_A1 diff --git a/integration_tests/features/StressTest.feature b/integration_tests/features/StressTest.feature index 0425f45644..1e975b03f5 100644 --- a/integration_tests/features/StressTest.feature +++ b/integration_tests/features/StressTest.feature @@ -12,18 +12,18 @@ Feature: Stress Test And I have stress-test wallet WALLET_B connected to the seed node NODE2 with broadcast monitoring timeout # There need to be at least as many mature coinbase UTXOs in the wallet coin splits required for the number of transactions When I merge mine blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip When I wait for wallet WALLET_A to have at least 5100000000 uT Then I coin split tari in wallet WALLET_A to produce UTXOs of 5000 uT each with fee_per_gram 20 uT When I merge mine 3 blocks via PROXY When I merge mine blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip Then wallet WALLET_A detects all transactions as Mined_Confirmed When I send transactions of 1111 uT each from wallet WALLET_A to wallet WALLET_B at fee_per_gram 20 # Mine enough blocks for the first block of transactions to be confirmed. When I merge mine 4 blocks via PROXY - Then all nodes are at current tip height + Then all nodes are on the same chain tip # Now wait until all transactions are detected as confirmed in WALLET_A, continue to mine blocks if transactions # are not found to be confirmed as sometimes the previous mining occurs faster than transactions are submitted # to the mempool diff --git a/integration_tests/features/WalletFFI.feature b/integration_tests/features/WalletFFI.feature index 11b905051e..9432989270 100644 --- a/integration_tests/features/WalletFFI.feature +++ b/integration_tests/features/WalletFFI.feature @@ -1,129 +1,136 @@ @wallet-ffi Feature: Wallet FFI + # Increase heap memory available to nodejs if frequent crashing occurs with + # error being be similar to this: `0x1a32cd5 V8_Fatal(char const*, ...)` - Scenario: As a client I want to send Tari to a Public Key - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - Scenario: As a client I want to specify a custom fee when I send tari - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - Scenario: As a client I want to receive Tari via my Public Key while I am online - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + # It's just calling the encrypt function, we don't test if it's actually encrypted + Scenario: As a client I want to be able to protect my wallet with a passphrase + Given I have a base node BASE + And I have a ffi wallet FFI_WALLET connected to base node BASE + And I set passphrase PASSPHRASE of ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - @long-running @broken - Scenario: As a client I want to receive Tari via my Public Key sent while I am offline when I come back online + Scenario: As a client I want to see my whoami info Given I have a base node BASE - And I have wallet SENDER connected to base node BASE - And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks - Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE - And I stop wallet FFI_WALLET + Then I want to get public key of ffi wallet FFI_WALLET + And I want to get emoji id of ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET + + Scenario: As a client I want to be able to restore my ffi wallet from seed words + Given I have a base node BASE + And I have wallet SPECTATOR connected to base node BASE + And I have mining node MINER connected to base node BASE and wallet SPECTATOR + And mining node MINER mines 10 blocks + Then I wait for wallet SPECTATOR to have at least 1000000 uT + Then I recover wallet SPECTATOR into ffi wallet FFI_WALLET from seed words on node BASE + And I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET + + Scenario: As a client I want to set the base node + Given I have a base node BASE1 + Given I have a base node BASE2 + And I have a ffi wallet FFI_WALLET connected to base node BASE1 + And I set base node BASE2 for ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET + And I stop node BASE1 And I wait 5 seconds - And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 + And I restart ffi wallet FFI_WALLET + # Possibly check SAF messages, no way to get current connected base node peer from the library itself afaik + # Good idea just to add a fn to do this to the library. + # Then I wait for ffi wallet FFI_WALLET to receive 1 SAF message And I wait 5 seconds - And I start wallet FFI_WALLET - And wallet SENDER detects all transactions are at least Broadcast - And mining node MINER mines 10 blocks - Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET - @long-running - Scenario: As a client I want to retrieve a list of transactions I have made and received + Scenario: As a client I want to cancel a transaction Given I have a base node BASE And I have wallet SENDER connected to base node BASE And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks + And mining node MINER mines 10 blocks Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 And wallet SENDER detects all transactions are at least Broadcast And mining node MINER mines 10 blocks Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT - And Check callbacks for finished inbound tx on ffi wallet FFI_WALLET And I have wallet RECEIVER connected to base node BASE + And I stop wallet RECEIVER And I send 1000000 uT from ffi wallet FFI_WALLET to wallet RECEIVER at fee 100 - And ffi wallet FFI_WALLET has 1 broadcast transaction - And mining node MINER mines 4 blocks - Then I wait for wallet RECEIVER to have at least 1000000 uT - And Check callbacks for finished outbound tx on ffi wallet FFI_WALLET - And I have 1 received and 1 send transaction in ffi wallet FFI_WALLET - And I start STXO validation on wallet FFI_WALLET - And I start UTXO validation on wallet FFI_WALLET - - # It's just calling the encrypt function, we don't test if it's actually encrypted - Scenario: As a client I want to be able to protect my wallet with a passphrase - Given I have a base node BASE - And I have a ffi wallet FFI_WALLET connected to base node BASE - And I set passphrase PASSPHRASE of ffi wallet FFI_WALLET + Then I wait for ffi wallet FFI_WALLET to have 1 pending outbound transaction + Then I cancel all outbound transactions on ffi wallet FFI_WALLET and it will cancel 1 transaction + And I stop ffi wallet FFI_WALLET Scenario: As a client I want to manage contacts Given I have a base node BASE And I have a ffi wallet FFI_WALLET connected to base node BASE And I have wallet WALLET connected to base node BASE + And I wait 5 seconds And I add contact with alias ALIAS and pubkey WALLET to ffi wallet FFI_WALLET Then I have contact with alias ALIAS and pubkey WALLET in ffi wallet FFI_WALLET When I remove contact with alias ALIAS from ffi wallet FFI_WALLET Then I don't have contact with alias ALIAS in ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - Scenario: As a client I want to set the base node (should be persisted) - Given I have a base node BASE1 - Given I have a base node BASE2 - And I have a ffi wallet FFI_WALLET connected to base node BASE1 - And I set base node BASE2 for ffi wallet FFI_WALLET - Then BASE2 is connected to FFI_WALLET - And I stop wallet FFI_WALLET - And I wait 5 seconds - And I start wallet FFI_WALLET - Then BASE2 is connected to FFI_WALLET - - Scenario: As a client I want to see my public_key, emoji ID, address (whoami) - Given I have a base node BASE - And I have a ffi wallet FFI_WALLET connected to base node BASE - Then I want to get public key of ffi wallet FFI_WALLET - And I want to get emoji id of ffi wallet FFI_WALLET - - Scenario: As a client I want to get my balance - # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - - @long-running - Scenario: As a client I want to cancel a transaction + Scenario: As a client I want to retrieve a list of transactions I have made and received Given I have a base node BASE And I have wallet SENDER connected to base node BASE And I have mining node MINER connected to base node BASE and wallet SENDER - And mining node MINER mines 4 blocks + And mining node MINER mines 10 blocks Then I wait for wallet SENDER to have at least 1000000 uT And I have a ffi wallet FFI_WALLET connected to base node BASE And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 - And wallet SENDER detects all transactions are at least Broadcast And mining node MINER mines 10 blocks Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT And I have wallet RECEIVER connected to base node BASE - And I stop wallet RECEIVER And I send 1000000 uT from ffi wallet FFI_WALLET to wallet RECEIVER at fee 100 - Then I wait for ffi wallet FFI_WALLET to have 1 pending outbound transaction - Then I cancel all transactions on ffi wallet FFI_WALLET and it will cancel 1 transaction + And mining node MINER mines 10 blocks + Then I wait for wallet RECEIVER to have at least 1000000 uT + And I have 1 received and 1 send transaction in ffi wallet FFI_WALLET + And I start STXO validation on ffi wallet FFI_WALLET + And I start UTXO validation on ffi wallet FFI_WALLET + And I stop ffi wallet FFI_WALLET - @long-running - Scenario: As a client I want to be able to restore my wallet from seed words + Scenario: As a client I want to receive Tari via my Public Key sent while I am offline when I come back online Given I have a base node BASE - And I have wallet WALLET connected to base node BASE - And I have mining node MINER connected to base node BASE and wallet WALLET - And mining node MINER mines 4 blocks - Then I wait for wallet WALLET to have at least 1000000 uT - Then I recover wallet WALLET into ffi wallet FFI_WALLET from seed words on node BASE - And I wait for recovery of wallet FFI_WALLET to finish - And I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I have wallet SENDER connected to base node BASE + And I have mining node MINER connected to base node BASE and wallet SENDER + And mining node MINER mines 10 blocks + Then I wait for wallet SENDER to have at least 1000000 uT + And I have a ffi wallet FFI_WALLET connected to base node BASE + And I stop ffi wallet FFI_WALLET + And I wait 10 seconds + And I send 2000000 uT from wallet SENDER to wallet FFI_WALLET at fee 100 + And I wait 5 seconds + And I restart ffi wallet FFI_WALLET + Then I wait for ffi wallet FFI_WALLET to receive 1 transaction + Then I wait for ffi wallet FFI_WALLET to receive 1 finalization + # Assume tx will be mined to reduce time taken for test, balance is tested in later scenarios. + # And mining node MINER mines 10 blocks + # Then I wait for ffi wallet FFI_WALLET to have at least 1000000 uT + And I stop ffi wallet FFI_WALLET + + # Scenario: As a client I want to get my balance + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + + #Scenario: As a client I want to send Tari to a Public Key + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want to be able to initiate TXO and TX validation with the specifed base node. + #Scenario: As a client I want to specify a custom fee when I send tari # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about the progress of sending and receiving a transaction + #Scenario: As a client I want to receive Tari via my Public Key while I am online # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about my connection status to the specifed Base Node + # Scenario: As a client I want to be able to initiate TXO and TX validation with the specifed base node. + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + + # Scenario: As a client I want feedback about the progress of sending and receiving a transaction + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" - Scenario: As a client I want async feedback about the wallet restoration process + # Scenario: As a client I want feedback about my connection status to the specifed Base Node + + # Scenario: As a client I want feedback about the wallet restoration process # As a client I want to be able to restore my wallet from seed words - Scenario: As a client I want async feedback about TXO and TX validation processes -# It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" + # Scenario: As a client I want feedback about TXO and TX validation processes + # It's a subtest of "As a client I want to retrieve a list of transactions I have made and received" diff --git a/integration_tests/features/support/steps.js b/integration_tests/features/support/steps.js index fcedbaeeff..ef93992b5e 100644 --- a/integration_tests/features/support/steps.js +++ b/integration_tests/features/support/steps.js @@ -37,6 +37,22 @@ Given("I have {int} seed nodes", { timeout: 20 * 1000 }, async function (n) { await Promise.all(promises); }); +Given( + /I do not expect all automated transactions to succeed/, + { timeout: 20 * 1000 }, + async function () { + this.checkAutoTransactions = false; + } +); + +Given( + /I expect all automated transactions to succeed/, + { timeout: 20 * 1000 }, + async function () { + this.checkAutoTransactions = true; + } +); + Given( /I have a base node (.*) connected to all seed nodes/, { timeout: 20 * 1000 }, @@ -942,21 +958,6 @@ Then( } ); -Then( - "all nodes are at current tip height", - { timeout: 1200 * 1000 }, - async function () { - const height = parseInt(this.tipHeight); - console.log("Wait for all nodes to reach height of", height); - await this.forEachClientAsync(async (client, name) => { - await waitFor(async () => client.getTipHeight(), height, 1200 * 1000); - const currTip = await client.getTipHeight(); - console.log(`Node ${name} is at tip: ${currTip} (expected ${height})`); - expect(currTip).to.equal(height); - }); - } -); - Then( /all nodes are at the same height as node (.*)/, { timeout: 1200 * 1000 }, @@ -1106,20 +1107,24 @@ When(/I spend outputs (.*) via (.*)/, async function (inputs, node) { expect(this.lastResult.result).to.equal("ACCEPTED"); }); -Then(/(.*) has (.*) in (.*) state/, async function (node, txn, pool) { - const client = this.getClient(node); - const sig = this.transactions[txn].body.kernels[0].excess_sig; - await waitFor( - async () => await client.transactionStateResult(sig), - pool, - 1200 * 1000 - ); - this.lastResult = await this.getClient(node).transactionState( - this.transactions[txn].body.kernels[0].excess_sig - ); - console.log(`Node ${node} response is: ${this.lastResult.result}`); - expect(this.lastResult.result).to.equal(pool); -}); +Then( + /(.*) has (.*) in (.*) state/, + { timeout: 21 * 60 * 1000 }, + async function (node, txn, pool) { + const client = this.getClient(node); + const sig = this.transactions[txn].body.kernels[0].excess_sig; + await waitForPredicate( + async () => (await client.transactionStateResult(sig)) === pool, + 20 * 60 * 1000, + 1000 + ); + this.lastResult = await this.getClient(node).transactionState( + this.transactions[txn].body.kernels[0].excess_sig + ); + console.log(`Node ${node} response is: ${this.lastResult.result}`); + expect(this.lastResult.result).to.equal(pool); + } +); // The number is rounded down. E.g. if 1% can fail out of 17, that is 16.83 have to succeed. // It's means at least 16 have to succeed. @@ -1186,11 +1191,16 @@ When( /I mine a block on (.*) with coinbase (.*)/, { timeout: 600 * 1000 }, async function (name, coinbaseName) { + const tipHeight = await this.getClient(name).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await this.mineBlock(name, 0, (candidate) => { this.addOutput(coinbaseName, candidate.originalTemplate.coinbase); return candidate; }); - this.tipHeight += 1; } ); @@ -1198,11 +1208,25 @@ When( /I mine (\d+) custom weight blocks on (.*) with weight (\d+)/, { timeout: -1 }, async function (numBlocks, name, weight) { + const tipHeight = await this.getClient(name).getTipHeight(); for (let i = 0; i < numBlocks; i++) { + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + i + 1 + ); + expect(autoTransactionResult).to.equal(true); // If a block cannot be mined quickly enough (or the process has frozen), timeout. - await withTimeout(60 * 1000, this.mineBlock(name, parseInt(weight))); + await withTimeout( + 60 * 1000, + this.mineBlock(name, parseInt(weight), (candidate) => { + this.addTransactionOutput( + tipHeight + i + 1 + 2, + candidate.originalTemplate.coinbase + ); + return candidate; + }) + ); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1244,10 +1268,24 @@ When( /I mine (\d+) blocks on (.*)/, { timeout: -1 }, async function (numBlocks, name) { + const tipHeight = await this.getClient(name).getTipHeight(); for (let i = 0; i < numBlocks; i++) { - await withTimeout(60 * 1000, this.mineBlock(name, 0)); + let autoTransactionResult = await this.createTransactions( + name, + tipHeight + i + 1 + ); + expect(autoTransactionResult).to.equal(true); + await withTimeout( + 60 * 1000, + this.mineBlock(name, 0, (candidate) => { + this.addTransactionOutput( + tipHeight + i + 1 + 2, + candidate.originalTemplate.coinbase + ); + return candidate; + }) + ); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1257,7 +1295,13 @@ When( async function (numBlocks, walletName, nodeName) { const nodeClient = this.getClient(nodeName); const walletClient = await this.getWallet(walletName).connectClient(); + const tipHeight = await this.getClient(nodeName).getTipHeight(); for (let i = 0; i < numBlocks; i++) { + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + i + ); + expect(autoTransactionResult).to.equal(true); await nodeClient.mineBlock(walletClient); } } @@ -1270,7 +1314,6 @@ When( for (let i = 0; i < numBlocks; i++) { await this.mergeMineBlock(mmProxy); } - this.tipHeight += parseInt(numBlocks); } ); @@ -1282,7 +1325,8 @@ When( /I co-mine (.*) blocks via merge mining proxy (.*) and base node (.*) with wallet (.*)/, { timeout: 1200 * 1000 }, async function (numBlocks, mmProxy, node, wallet) { - this.lastResult = this.tipHeight; + let tipHeight = await this.getClient(node).getTipHeight(); + this.lastResult = tipHeight; const baseNodeMiningPromise = await this.baseNodeMineBlocksUntilHeightIncreasedBy( node, @@ -1295,13 +1339,13 @@ When( ); await Promise.all([baseNodeMiningPromise, mergeMiningPromise]).then( ([res1, res2]) => { - this.tipHeight = Math.max(res1, res2); - this.lastResult = this.tipHeight - this.lastResult; + tipHeight = Math.max(res1, res2); + this.lastResult = tipHeight - this.lastResult; console.log( "Co-mining", numBlocks, "blocks concluded, tip at", - this.tipHeight + tipHeight ); } ); @@ -1312,7 +1356,6 @@ When( /I co-mine (.*) blocks via merge mining proxy (.*) and mining node (.*)/, { timeout: 6000 * 1000 }, async function (numBlocks, mmProxy, miner) { - this.lastResult = this.tipHeight; const sha3MiningPromise = this.sha3MineBlocksUntilHeightIncreasedBy( miner, numBlocks, @@ -1324,13 +1367,14 @@ When( ); await Promise.all([sha3MiningPromise, mergeMiningPromise]).then( ([res1, res2]) => { - this.tipHeight = Math.max(res1, res2); - this.lastResult = this.tipHeight - this.lastResult; console.log( "Co-mining", numBlocks, - "blocks concluded, tip at", - this.tipHeight + "blocks concluded, tips at [", + res1, + ",", + res2, + "]" ); } ); @@ -1340,10 +1384,20 @@ When( When( /I mine but do not submit a block (.*) on (.*)/, async function (blockName, nodeName) { + const tipHeight = await this.getClient(nodeName).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await this.mineBlock( nodeName, null, (block) => { + this.addTransactionOutput( + tipHeight + 2, + block.originalTemplate.coinbase + ); this.saveBlock(blockName, block); return false; }, @@ -1362,7 +1416,15 @@ When( const client = this.getClient(node); const template = client.getPreviousBlockTemplate(atHeight); const candidate = await client.getMinedCandidateBlock(0, template); - + let autoTransactionResult = await this.createTransactions( + node, + parseInt(atHeight) + ); + expect(autoTransactionResult).to.equal(true); + this.addTransactionOutput( + parseInt(atHeight) + 1, + candidate.originalTemplate.coinbase + ); await client.submitBlock( candidate.template, (block) => { @@ -2577,8 +2639,13 @@ Then( if (await walletClient.isTransactionMinedConfirmed(txIds[i])) { return true; } else { + const tipHeight = await this.getClient(nodeName).getTipHeight(); + let autoTransactionResult = await this.createTransactions( + nodeName, + tipHeight + 1 + ); + expect(autoTransactionResult).to.equal(true); await nodeClient.mineBlock(walletClient); - this.tipHeight += 1; return false; } }, @@ -2632,7 +2699,6 @@ Then( return true; } else { await this.mergeMineBlock(mmProxy); - this.tipHeight += 1; return false; } }, @@ -3379,29 +3445,6 @@ When( } ); -When( - "I have a ffi wallet {word} connected to base node {word}", - { timeout: 20 * 1000 }, - async function (name, node) { - let wallet = await this.createAndAddFFIWallet(name); - let peer = this.nodes[node].peerAddress().split("::"); - await wallet.addBaseNodePeer(peer[0], peer[1]); - } -); - -Then( - "I want to get public key of ffi wallet {word}", - { timeout: 20 * 1000 }, - async function (name) { - let wallet = this.getWallet(name); - let public_key = await wallet.getPublicKey(); - expect(public_key.length).to.be.equal( - 64, - `Public key has wrong length : ${public_key}` - ); - } -); - Then( /I wait until base node (.*) has (.*) unconfirmed transactions in its mempool/, { timeout: 180 * 1000 }, @@ -3429,57 +3472,120 @@ Then( ); Then( - "I want to get emoji id of ffi wallet {word}", + /node (.*) lists heights (\d+) to (\d+)/, + async function (node, first, last) { + const client = this.getClient(node); + const start = first; + const end = last; + let heights = []; + + for (let i = start; i <= end; i++) { + heights.push(i); + } + const blocks = await client.getBlocks(heights); + const results = blocks.map((result) => + parseInt(result.block.header.height) + ); + let i = 0; // for ordering check + for (let height = start; height <= end; height++) { + expect(results[i]).equal(height); + i++; + } + } +); + +Then( + "I wait for recovery of wallet {word} to finish", + { timeout: 600 * 1000 }, + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + while (wallet.recoveryInProgress) { + await sleep(1000); + } + expect(wallet.recoveryProgress[1]).to.be.greaterThan(0); + expect(wallet.recoveryProgress[0]).to.be.equal(wallet.recoveryProgress[1]); + } +); + +When( + "I have {int} base nodes with pruning horizon {int} force syncing on node {word}", + { timeout: 190 * 1000 }, + async function (nodes_count, horizon, force_sync_to) { + const promises = []; + const force_sync_address = this.getNode(force_sync_to).peerAddress(); + for (let i = 0; i < nodes_count; i++) { + const base_node = this.createNode(`BaseNode${i}`, { + pruningHorizon: horizon, + }); + base_node.setPeerSeeds([force_sync_address]); + base_node.setForceSyncPeers([force_sync_address]); + promises.push( + base_node.startNew().then(() => this.addNode(`BaseNode${i}`, base_node)) + ); + } + await Promise.all(promises); + } +); + +//region FFI +When( + "I have ffi wallet {word} connected to base node {word}", { timeout: 20 * 1000 }, - async function (name) { + async function (name, node) { + let wallet = await this.createAndAddFFIWallet(name); + let peer = this.nodes[node].peerAddress().split("::"); + wallet.addBaseNodePeer(peer[0], peer[1]); + } +); + +Then( + "I want to get public key of ffi wallet {word}", + { timeout: 20 * 1000 }, + function (name) { let wallet = this.getWallet(name); - let emoji_id = await wallet.getEmojiId(); - expect(emoji_id.length).to.be.equal( - 22 * 3, // 22 emojis, 3 bytes per one emoji - `Emoji id has wrong length : ${emoji_id}` + let public_key = wallet.identify(); + expect(public_key.length).to.be.equal( + 64, + `Public key has wrong length : ${public_key}` ); } ); Then( - "I wait for ffi wallet {word} to have at least {int} uT", - { timeout: 60 * 1000 }, - async function (name, amount) { + "I want to get emoji id of ffi wallet {word}", + { timeout: 20 * 1000 }, + async function (name) { let wallet = this.getWallet(name); - let retries = 1; - let balance = 0; - const retries_limit = 12; - while (retries <= retries_limit) { - balance = await wallet.getBalance(); - if (balance >= amount) { - break; - } - await sleep(5000); - ++retries; - } - expect(balance, "Balance is not enough").to.be.greaterThanOrEqual(amount); + let emoji_id = wallet.identifyEmoji(); + console.log(emoji_id); + expect(emoji_id.length).to.be.equal( + 22 * 3, // 22 emojis, 3 bytes per one emoji + `Emoji id has wrong length : ${emoji_id}` + ); } ); When( "I send {int} uT from ffi wallet {word} to wallet {word} at fee {int}", { timeout: 20 * 1000 }, - async function (amount, sender, receiver, fee) { - await this.getWallet(sender).sendTransaction( - await this.getWalletPubkey(receiver), + function (amount, sender, receiver, fee) { + let ffi_wallet = this.getWallet(sender); + let result = ffi_wallet.sendTransaction( + this.getWalletPubkey(receiver), amount, fee, `Send from ffi ${sender} to ${receiver} at fee ${fee}` ); + console.log(result); } ); When( "I set passphrase {word} of ffi wallet {word}", { timeout: 20 * 1000 }, - async function (passphrase, name) { + function (passphrase, name) { let wallet = this.getWallet(name); - await wallet.applyEncryption(passphrase); + wallet.applyEncryption(passphrase); } ); @@ -3488,17 +3594,29 @@ Then( { timeout: 120 * 1000 }, async function (received, send, name) { let wallet = this.getWallet(name); - let [outbound, inbound] = await wallet.getCompletedTransactions(); - let retries = 1; - const retries_limit = 23; - while ( - (inbound != received || outbound != send) && - retries <= retries_limit - ) { - await sleep(5000); - [outbound, inbound] = await wallet.getCompletedTransactions(); - ++retries; + let completed = wallet.getCompletedTxs(); + let inbound = 0; + let outbound = 0; + let length = completed.getLength(); + let inboundTxs = wallet.getInboundTxs(); + inbound += inboundTxs.getLength(); + inboundTxs.destroy(); + let outboundTxs = wallet.getOutboundTxs(); + outbound += outboundTxs.getLength(); + outboundTxs.destroy(); + for (let i = 0; i < length; i++) { + { + let tx = completed.getAt(i); + if (tx.isOutbound()) { + outbound++; + } else { + inbound++; + } + tx.destroy(); + } } + completed.destroy(); + expect(outbound, "Outbound transaction count mismatch").to.be.equal(send); expect(inbound, "Inbound transaction count mismatch").to.be.equal(received); } @@ -3526,70 +3644,86 @@ Then( When( "I add contact with alias {word} and pubkey {word} to ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name, ffi_wallet_name) { + function (alias, wallet_name, ffi_wallet_name) { let ffi_wallet = this.getWallet(ffi_wallet_name); - await ffi_wallet.addContact(alias, await this.getWalletPubkey(wallet_name)); + ffi_wallet.addContact(alias, this.getWalletPubkey(wallet_name)); } ); Then( "I have contact with alias {word} and pubkey {word} in ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name, ffi_wallet_name) { + function (alias, wallet_name, ffi_wallet_name) { + let wallet = this.getWalletPubkey(wallet_name); let ffi_wallet = this.getWallet(ffi_wallet_name); - expect(await this.getWalletPubkey(wallet_name)).to.be.equal( - await ffi_wallet.getContact(alias) - ); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + let found = false; + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let hex = contact.getPubkeyHex(); + if (wallet === hex) { + found = true; + } + contact.destroy(); + } + } + contacts.destroy(); + expect(found).to.be.equal(true); } ); When( "I remove contact with alias {word} from ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, walllet_name) { - let wallet = this.getWallet(walllet_name); - await wallet.removeContact(alias); + function (alias, wallet_name) { + let ffi_wallet = this.getWallet(wallet_name); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let calias = contact.getAlias(); + if (alias === calias) { + ffi_wallet.removeContact(contact); + } + contact.destroy(); + } + } + contacts.destroy(); } ); Then( "I don't have contact with alias {word} in ffi wallet {word}", { timeout: 20 * 1000 }, - async function (alias, wallet_name) { - let wallet = this.getWallet(wallet_name); - expect(await wallet.getContact("alias")).to.be.undefined; - } -); - -Then( - /node (.*) lists heights (\d+) to (\d+)/, - async function (node, first, last) { - const client = this.getClient(node); - const start = first; - const end = last; - let heights = []; - - for (let i = start; i <= end; i++) { - heights.push(i); - } - const blocks = await client.getBlocks(heights); - const results = blocks.map((result) => - parseInt(result.block.header.height) - ); - let i = 0; // for ordering check - for (let height = start; height <= end; height++) { - expect(results[i]).equal(height); - i++; + function (alias, wallet_name) { + let ffi_wallet = this.getWallet(wallet_name); + let contacts = ffi_wallet.getContactList(); + let length = contacts.getLength(); + let found = false; + for (let i = 0; i < length; i++) { + { + let contact = contacts.getAt(i); + let calias = contact.getAlias(); + if (alias === calias) { + found = true; + } + contact.destroy(); + } } + contacts.destroy(); + expect(found).to.be.equal(false); } ); When( "I set base node {word} for ffi wallet {word}", - async function (node, wallet_name) { + function (node, wallet_name) { let wallet = this.getWallet(wallet_name); let peer = this.nodes[node].peerAddress().split("::"); - await wallet.addBaseNodePeer(peer[0], peer[1]); + wallet.addBaseNodePeer(peer[0], peer[1]); } ); @@ -3598,26 +3732,48 @@ Then( { timeout: 120 * 1000 }, async function (wallet_name, count) { let wallet = this.getWallet(wallet_name); - let broadcast = await wallet.getOutboundTransactionsCount(); + let broadcast = wallet.getOutboundTransactions(); + let length = broadcast.getLength(); + broadcast.destroy(); let retries = 1; const retries_limit = 24; - while (broadcast != count && retries <= retries_limit) { + while (length != count && retries <= retries_limit) { await sleep(5000); - broadcast = await wallet.getOutboundTransactionsCount(); + broadcast = wallet.getOutboundTransactions(); + length = broadcast.getLength(); + broadcast.destroy(); ++retries; } - expect(broadcast, "Number of pending messages mismatch").to.be.equal(count); + expect(length, "Number of pending messages mismatch").to.be.equal(count); } ); Then( - "I cancel all transactions on ffi wallet {word} and it will cancel {int} transaction", + "I cancel all outbound transactions on ffi wallet {word} and it will cancel {int} transaction", async function (wallet_name, count) { const wallet = this.getWallet(wallet_name); - expect( - await wallet.cancelAllOutboundTransactions(), - "Number of cancelled transactions" - ).to.be.equal(count); + let txs = wallet.getOutboundTransactions(); + let cancelled = 0; + for (let i = 0; i < txs.getLength(); i++) { + let tx = txs.getAt(i); + let cancellation = wallet.cancelPendingTransaction(tx.getTransactionID()); + tx.destroy(); + if (cancellation) { + cancelled++; + } + } + txs.destroy(); + expect(cancelled).to.be.equal(count); + } +); + +Given( + /I have a ffi wallet (.*) connected to base node (.*)/, + { timeout: 20 * 1000 }, + async function (walletName, nodeName) { + let ffi_wallet = await this.createAndAddFFIWallet(walletName, null); + let peer = this.nodes[nodeName].peerAddress().split("::"); + ffi_wallet.addBaseNodePeer(peer[0], peer[1]); } ); @@ -3634,78 +3790,207 @@ Then( seed_words_text ); let peer = this.nodes[node].peerAddress().split("::"); - await ffi_wallet.addBaseNodePeer(peer[0], peer[1]); - await ffi_wallet.startRecovery(peer[0]); + ffi_wallet.addBaseNodePeer(peer[0], peer[1]); + ffi_wallet.startRecovery(peer[0]); } ); Then( - "I wait for recovery of wallet {word} to finish", - { timeout: 600 * 1000 }, + "Check callbacks for finished inbound tx on ffi wallet {word}", async function (wallet_name) { const wallet = this.getWallet(wallet_name); - while (wallet.recoveryInProgress) { - await sleep(1000); + expect(wallet.receivedTransaction).to.be.greaterThanOrEqual(1); + expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); + wallet.clearCallbackCounters(); + } +); + +Then( + "Check callbacks for finished outbound tx on ffi wallet {word}", + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + expect(wallet.receivedTransactionReply).to.be.greaterThanOrEqual(1); + expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); + wallet.clearCallbackCounters(); + } +); + +Then( + /I wait for ffi wallet (.*) to receive (.*) transaction/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + wallet_name + " to receive " + amount + " transaction(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().received >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().received >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); } - expect(wallet.recoveryProgress[1]).to.be.greaterThan(0); - expect(wallet.recoveryProgress[0]).to.be.equal(wallet.recoveryProgress[1]); + expect(wallet.getCounters().received >= amount).to.equal(true); } ); -Then("I start STXO validation on wallet {word}", async function (wallet_name) { - const wallet = this.getWallet(wallet_name); - await wallet.startStxoValidation(); - while (!wallet.stxo_validation_complete) { - await sleep(1000); +Then( + /I wait for ffi wallet (.*) to receive (.*) finalization/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + + wallet_name + + " to receive " + + amount + + " transaction finalization(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().finalized >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().finalized >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); + } + expect(wallet.getCounters().finalized >= amount).to.equal(true); } - expect(wallet.stxo_validation_result).to.be.equal(0); -}); +); -Then("I start UTXO validation on wallet {word}", async function (wallet_name) { - const wallet = this.getWallet(wallet_name); - await wallet.startUtxoValidation(); - while (!wallet.utxo_validation_complete) { - await sleep(1000); +Then( + /I wait for ffi wallet (.*) to receive (.*) SAF message/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + + wallet_name + + " to receive " + + amount + + " SAF messages(s)" + ); + + await waitFor( + async () => { + return wallet.getCounters().saf >= amount; + }, + true, + 700 * 1000, + 5 * 1000, + 5 + ); + + if (!(wallet.getCounters().saf >= amount)) { + console.log("Counter not adequate!"); + } else { + console.log(wallet.getCounters()); + } + expect(wallet.getCounters().saf >= amount).to.equal(true); } - expect(wallet.utxo_validation_result).to.be.equal(0); -}); +); Then( - "Check callbacks for finished inbound tx on ffi wallet {word}", - async function (wallet_name) { + /I wait for ffi wallet (.*) to have at least (.*) uT/, + { timeout: 710 * 1000 }, + async function (wallet_name, amount) { + let wallet = this.getWallet(wallet_name); + + console.log("\n"); + console.log( + "Waiting for " + wallet_name + " balance to be at least " + amount + " uT" + ); + + let count = 0; + + while (!(wallet.getBalance().available >= amount)) { + await sleep(1000); + count++; + if (count > 700) { + break; + } + } + + let balance = wallet.getBalance().available; + + if (!(balance >= amount)) { + console.log("Balance not adequate!"); + } else { + console.log(wallet.getBalance()); + } + expect(balance >= amount).to.equal(true); + } +); + +Then( + "I wait for recovery of ffi wallet {word} to finish", + { timeout: 600 * 1000 }, + function (wallet_name) { const wallet = this.getWallet(wallet_name); - expect(wallet.receivedTransaction).to.be.greaterThanOrEqual(1); - expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); - wallet.clearCallbackCounters(); + while (!wallet.recoveryFinished) { + sleep(1000).then(); + } } ); +When(/I start ffi wallet (.*)/, async function (walletName) { + let wallet = this.getWallet(walletName); + await wallet.startNew(null, null); +}); + +When(/I restart ffi wallet (.*)/, async function (walletName) { + let wallet = this.getWallet(walletName); + await wallet.restart(); +}); + +When(/I stop ffi wallet (.*)/, function (walletName) { + let wallet = this.getWallet(walletName); + wallet.stop(); + wallet.resetCounters(); +}); + Then( - "Check callbacks for finished outbound tx on ffi wallet {word}", + "I start STXO validation on ffi wallet {word}", async function (wallet_name) { const wallet = this.getWallet(wallet_name); - expect(wallet.receivedTransactionReply).to.be.greaterThanOrEqual(1); - expect(wallet.transactionBroadcast).to.be.greaterThanOrEqual(1); - wallet.clearCallbackCounters(); + await wallet.startStxoValidation(); + while (!wallet.getStxoValidationStatus().stxo_validation_complete) { + await sleep(1000); + } } ); -When( - "I have {int} base nodes with pruning horizon {int} force syncing on node {word}", - { timeout: 190 * 1000 }, - async function (nodes_count, horizon, force_sync_to) { - const promises = []; - const force_sync_address = this.getNode(force_sync_to).peerAddress(); - for (let i = 0; i < nodes_count; i++) { - const base_node = this.createNode(`BaseNode${i}`, { - pruningHorizon: horizon, - }); - base_node.setPeerSeeds([force_sync_address]); - base_node.setForceSyncPeers([force_sync_address]); - promises.push( - base_node.startNew().then(() => this.addNode(`BaseNode${i}`, base_node)) - ); +Then( + "I start UTXO validation on ffi wallet {word}", + async function (wallet_name) { + const wallet = this.getWallet(wallet_name); + await wallet.startUtxoValidation(); + while (!wallet.getUtxoValidationStatus().utxo_validation_complete) { + await sleep(1000); } - await Promise.all(promises); } ); +//endregion diff --git a/integration_tests/features/support/world.js b/integration_tests/features/support/world.js index 6ca3d0f699..91e8b45f16 100644 --- a/integration_tests/features/support/world.js +++ b/integration_tests/features/support/world.js @@ -5,6 +5,7 @@ const MergeMiningProxyProcess = require("../../helpers/mergeMiningProxyProcess") const WalletProcess = require("../../helpers/walletProcess"); const WalletFFIClient = require("../../helpers/walletFFIClient"); const MiningNodeProcess = require("../../helpers/miningNodeProcess"); +const TransactionBuilder = require("../../helpers/transactionBuilder"); const glob = require("glob"); const fs = require("fs"); const archiver = require("archiver"); @@ -12,7 +13,7 @@ class CustomWorld { constructor({ attach, parameters }) { // this.variable = 0; this.attach = attach; - + this.checkAutoTransactions = true; this.seeds = {}; this.nodes = {}; this.proxies = {}; @@ -23,6 +24,7 @@ class CustomWorld { this.clients = {}; this.headers = {}; this.outputs = {}; + this.transactionOutputs = {}; this.testrun = `run${Date.now()}`; this.lastResult = null; this.blocks = {}; @@ -30,7 +32,6 @@ class CustomWorld { this.peers = {}; this.transactionsMap = new Map(); this.resultStack = []; - this.tipHeight = 0; this.logFilePathBaseNode = parameters.logFilePathBaseNode || "./log4rs/base_node.yml"; this.logFilePathProxy = parameters.logFilePathProxy || "./log4rs/proxy.yml"; @@ -106,11 +107,11 @@ class CustomWorld { this.walletPubkeys[name] = walletInfo.public_key; } - async createAndAddFFIWallet(name, seed_words) { + async createAndAddFFIWallet(name, seed_words = null, passphrase = null) { const wallet = new WalletFFIClient(name); - await wallet.startNew(seed_words); + await wallet.startNew(seed_words, passphrase); this.walletsFFI[name] = wallet; - this.walletPubkeys[name] = await wallet.getPublicKey(); + this.walletPubkeys[name] = wallet.identify(); return wallet; } @@ -126,6 +127,47 @@ class CustomWorld { this.outputs[name] = output; } + addTransactionOutput(spendHeight, output) { + if (this.transactionOutputs[spendHeight] == null) { + this.transactionOutputs[spendHeight] = [output]; + } else { + this.transactionOutputs[spendHeight].push(output); + } + } + + async createTransactions(name, height) { + let result = true; + const txInputs = this.transactionOutputs[height]; + if (txInputs == null) { + return result; + } + let i = 0; + for (const input of txInputs) { + const txn = new TransactionBuilder(); + txn.addInput(input); + const txOutput = txn.addOutput(txn.getSpendableAmount()); + this.addTransactionOutput(height + 1, txOutput); + const completedTx = txn.build(); + const submitResult = await this.getClient(name).submitTransaction( + completedTx + ); + if (this.checkAutoTransactions && submitResult.result != "ACCEPTED") { + result = false; + } + if (submitResult.result == "ACCEPTED") { + i++; + } + if (i > 9) { + //this is to make sure the blocks stay relatively empty so that the tests don't take too long + break; + } + } + console.log( + `Created ${i} transactions for node: ${name} at height: ${height}` + ); + return result; + } + async mineBlock(name, weight, beforeSubmit, onError) { await this.clients[name].mineBlockWithoutWallet( beforeSubmit, diff --git a/integration_tests/helpers/ffi/byteVector.js b/integration_tests/helpers/ffi/byteVector.js index 51f5d338bd..245cb4320e 100644 --- a/integration_tests/helpers/ffi/byteVector.js +++ b/integration_tests/helpers/ffi/byteVector.js @@ -1,28 +1,51 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class ByteVector { #byte_vector_ptr; - constructor(byte_vector_ptr) { - this.#byte_vector_ptr = byte_vector_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#byte_vector_ptr) { + this.destroy(); + this.#byte_vector_ptr = ptr; + } else { + this.#byte_vector_ptr = ptr; + } } - static async fromBuffer(buffer) { - let buf = Buffer.from(buffer, "utf-8"); // get the bytes + fromBytes(input) { + let buf = Buffer.from(input, "utf-8"); // ensure encoding is utf=8, js default is utf-16 let len = buf.length; // get the length - return new ByteVector(await WalletFFI.byteVectorCreate(buf, len)); + let result = new ByteVector(); + result.pointerAssign(InterfaceFFI.byteVectorCreate(buf, len)); + return result; + } + + getBytes() { + let result = []; + for (let i = 0; i < this.getLength(); i++) { + result.push(this.getAt(i)); + } + return result; } getLength() { - return WalletFFI.byteVectorGetLength(this.#byte_vector_ptr); + return InterfaceFFI.byteVectorGetLength(this.#byte_vector_ptr); } getAt(position) { - return WalletFFI.byteVectorGetAt(this.#byte_vector_ptr, position); + return InterfaceFFI.byteVectorGetAt(this.#byte_vector_ptr, position); + } + + getPtr() { + return this.#byte_vector_ptr; } destroy() { - return WalletFFI.byteVectorDestroy(this.#byte_vector_ptr); + if (this.#byte_vector_ptr) { + InterfaceFFI.byteVectorDestroy(this.#byte_vector_ptr); + this.#byte_vector_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/commsConfig.js b/integration_tests/helpers/ffi/commsConfig.js new file mode 100644 index 0000000000..9bb9ddcb7a --- /dev/null +++ b/integration_tests/helpers/ffi/commsConfig.js @@ -0,0 +1,43 @@ +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); + +class CommsConfig { + #comms_config_ptr; + + constructor( + public_address, + transport_ptr, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network + ) { + let sanitize_address = utf8.encode(public_address); + let sanitize_db_name = utf8.encode(database_name); + let sanitize_db_path = utf8.encode(datastore_path); + let sanitize_network = utf8.encode(network); + this.#comms_config_ptr = InterfaceFFI.commsConfigCreate( + sanitize_address, + transport_ptr, + sanitize_db_name, + sanitize_db_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + sanitize_network + ); + } + + getPtr() { + return this.#comms_config_ptr; + } + + destroy() { + if (this.#comms_config_ptr) { + InterfaceFFI.commsConfigDestroy(this.#comms_config_ptr); + this.#comms_config_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = CommsConfig; diff --git a/integration_tests/helpers/ffi/completedTransaction.js b/integration_tests/helpers/ffi/completedTransaction.js index a7a21c28cd..cc23f22ecf 100644 --- a/integration_tests/helpers/ffi/completedTransaction.js +++ b/integration_tests/helpers/ffi/completedTransaction.js @@ -1,23 +1,104 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class CompletedTransaction { #tari_completed_transaction_ptr; - constructor(tari_completed_transaction_ptr) { - this.#tari_completed_transaction_ptr = tari_completed_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_completed_transaction_ptr) { + this.destroy(); + this.#tari_completed_transaction_ptr = ptr; + } else { + this.#tari_completed_transaction_ptr = ptr; + } + } + + getPtr() { + return this.#tari_completed_transaction_ptr; } isOutbound() { - return WalletFFI.completedTransactionIsOutbound( + return InterfaceFFI.completedTransactionIsOutbound( this.#tari_completed_transaction_ptr ); } - destroy() { - return WalletFFI.completedTransactionDestroy( + getDestinationPublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.completedTransactionGetDestinationPublicKey( + this.#tari_completed_transaction_ptr + ) + ); + return result; + } + + getSourcePublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.completedTransactionGetSourcePublicKey( + this.#tari_completed_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.completedTransactionGetAmount( + this.#tari_completed_transaction_ptr + ); + } + + getFee() { + return InterfaceFFI.completedTransactionGetFee( + this.#tari_completed_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.completedTransactionGetMessage( + this.#tari_completed_transaction_ptr + ); + } + + getStatus() { + return InterfaceFFI.completedTransactionGetStatus( + this.#tari_completed_transaction_ptr + ); + } + + getTransactionID() { + return InterfaceFFI.completedTransactionGetTransactionId( + this.#tari_completed_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.completedTransactionGetTimestamp( + this.#tari_completed_transaction_ptr + ); + } + + isValid() { + return InterfaceFFI.completedTransactionIsValid( + this.#tari_completed_transaction_ptr + ); + } + + getConfirmations() { + return InterfaceFFI.completedTransactionGetConfirmations( this.#tari_completed_transaction_ptr ); } + + destroy() { + if (this.#tari_completed_transaction_ptr) { + InterfaceFFI.completedTransactionDestroy( + this.#tari_completed_transaction_ptr + ); + this.#tari_completed_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = CompletedTransaction; diff --git a/integration_tests/helpers/ffi/completedTransactions.js b/integration_tests/helpers/ffi/completedTransactions.js index d2d4c96156..2b8387bb72 100644 --- a/integration_tests/helpers/ffi/completedTransactions.js +++ b/integration_tests/helpers/ffi/completedTransactions.js @@ -1,38 +1,37 @@ const CompletedTransaction = require("./completedTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class CompletedTransactions { #tari_completed_transactions_ptr; - constructor(tari_completed_transactions_ptr) { - this.#tari_completed_transactions_ptr = tari_completed_transactions_ptr; - } - - static async fromWallet(wallet) { - return new CompletedTransactions( - await WalletFFI.walletGetCompletedTransactions(wallet) - ); + constructor(ptr) { + this.#tari_completed_transactions_ptr = ptr; } getLength() { - return WalletFFI.completedTransactionsGetLength( + return InterfaceFFI.completedTransactionsGetLength( this.#tari_completed_transactions_ptr ); } - async getAt(position) { - return new CompletedTransaction( - await WalletFFI.completedTransactionsGetAt( + getAt(position) { + let result = new CompletedTransaction(); + result.pointerAssign( + InterfaceFFI.completedTransactionsGetAt( this.#tari_completed_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.completedTransactionsDestroy( - this.#tari_completed_transactions_ptr - ); + if (this.#tari_completed_transactions_ptr) { + InterfaceFFI.completedTransactionsDestroy( + this.#tari_completed_transactions_ptr + ); + this.#tari_completed_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/contact.js b/integration_tests/helpers/ffi/contact.js index 184c684a2b..ea72376e75 100644 --- a/integration_tests/helpers/ffi/contact.js +++ b/integration_tests/helpers/ffi/contact.js @@ -1,32 +1,52 @@ const PublicKey = require("./publicKey"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class Contact { #tari_contact_ptr; - constructor(tari_contact_ptr) { - this.#tari_contact_ptr = tari_contact_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_contact_ptr) { + this.destroy(); + this.#tari_contact_ptr = ptr; + } else { + this.#tari_contact_ptr = ptr; + } } getPtr() { return this.#tari_contact_ptr; } - async getAlias() { - const alias = await WalletFFI.contactGetAlias(this.#tari_contact_ptr); + getAlias() { + const alias = InterfaceFFI.contactGetAlias(this.#tari_contact_ptr); const result = alias.readCString(); - await WalletFFI.stringDestroy(alias); + InterfaceFFI.stringDestroy(alias); return result; } - async getPubkey() { - return new PublicKey( - await WalletFFI.contactGetPublicKey(this.#tari_contact_ptr) + getPubkey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.contactGetPublicKey(this.#tari_contact_ptr) ); + return result; + } + + getPubkeyHex() { + let result = ""; + let pk = new PublicKey(); + pk.pointerAssign(InterfaceFFI.contactGetPublicKey(this.#tari_contact_ptr)); + result = pk.getHex(); + pk.destroy(); + return result; } destroy() { - return WalletFFI.contactDestroy(this.#tari_contact_ptr); + if (this.#tari_contact_ptr) { + InterfaceFFI.contactDestroy(this.#tari_contact_ptr); + this.#tari_contact_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/contacts.js b/integration_tests/helpers/ffi/contacts.js index d8803874ab..1f7db81fcc 100644 --- a/integration_tests/helpers/ffi/contacts.js +++ b/integration_tests/helpers/ffi/contacts.js @@ -1,29 +1,30 @@ const Contact = require("./contact"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class Contacts { #tari_contacts_ptr; - constructor(tari_contacts_ptr) { - this.#tari_contacts_ptr = tari_contacts_ptr; - } - - static async fromWallet(wallet) { - return new Contacts(await WalletFFI.walletGetContacts(wallet)); + constructor(ptr) { + this.#tari_contacts_ptr = ptr; } getLength() { - return WalletFFI.contactsGetLength(this.#tari_contacts_ptr); + return InterfaceFFI.contactsGetLength(this.#tari_contacts_ptr); } - async getAt(position) { - return new Contact( - await WalletFFI.contactsGetAt(this.#tari_contacts_ptr, position) + getAt(position) { + let result = new Contact(); + result.pointerAssign( + InterfaceFFI.contactsGetAt(this.#tari_contacts_ptr, position) ); + return result; } destroy() { - return WalletFFI.contactsDestroy(this.#tari_contacts_ptr); + if (this.#tari_contacts_ptr) { + InterfaceFFI.contactsDestroy(this.#tari_contacts_ptr); + this.#tari_contacts_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/emojiSet.js b/integration_tests/helpers/ffi/emojiSet.js new file mode 100644 index 0000000000..f94e2ae746 --- /dev/null +++ b/integration_tests/helpers/ffi/emojiSet.js @@ -0,0 +1,36 @@ +const InterfaceFFI = require("./ffiInterface"); + +class EmojiSet { + #emoji_set_ptr; + + constructor() { + this.#emoji_set_ptr = InterfaceFFI.getEmojiSet(); + } + + getLength() { + return InterfaceFFI.emojiSetGetLength(this.#emoji_set_ptr); + } + + getAt(position) { + return InterfaceFFI.emojiSetGetAt(this.#emoji_set_ptr, position); + } + + list() { + let set = []; + for (let i = 0; i < this.getLength(); i++) { + let item = this.getAt(i); + set.push(Buffer.from(item.getBytes(), "utf-8").toString()); + item.destroy(); + } + return set; + } + + destroy() { + if (this.#emoji_set_ptr) { + InterfaceFFI.byteVectorDestroy(this.#emoji_set_ptr); + this.#emoji_set_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = EmojiSet; diff --git a/integration_tests/helpers/ffi/ffiInterface.js b/integration_tests/helpers/ffi/ffiInterface.js new file mode 100644 index 0000000000..e39632cd05 --- /dev/null +++ b/integration_tests/helpers/ffi/ffiInterface.js @@ -0,0 +1,1473 @@ +/** + * This library was AUTO-GENERATED. Do not modify manually! + */ + +const { expect } = require("chai"); +const ffi = require("ffi-napi"); +const ref = require("ref-napi"); +const dateFormat = require("dateformat"); +const { spawn } = require("child_process"); +const fs = require("fs"); + +class InterfaceFFI { + //region Compile + static compile() { + return new Promise((resolve, _reject) => { + const cmd = "cargo"; + const args = [ + "build", + "--release", + "--package", + "tari_wallet_ffi", + "-Z", + "unstable-options", + "--out-dir", + process.cwd() + "/temp/out", + ]; + const baseDir = `./temp/base_nodes/${dateFormat( + new Date(), + "yyyymmddHHMM" + )}/WalletFFI-compile`; + if (!fs.existsSync(baseDir)) { + fs.mkdirSync(baseDir, { recursive: true }); + fs.mkdirSync(baseDir + "/log", { recursive: true }); + } + const ps = spawn(cmd, args, { + cwd: baseDir, + env: { ...process.env }, + }); + ps.on("close", (_code) => { + resolve(ps); + }); + ps.stderr.on("data", (data) => { + console.log("stderr : ", data.toString()); + }); + ps.on("error", (error) => { + console.log("error : ", error.toString()); + }); + expect(ps.error).to.be.an("undefined"); + this.#ps = ps; + }); + } + //endregion + + //region Interface + static #fn; + + static #loaded = false; + static #ps = null; + + static async Init() { + if (this.#loaded) { + return; + } + + this.#loaded = true; + await this.compile(); + const outputProcess = `${process.cwd()}/temp/out/${ + process.platform === "win32" ? "" : "lib" + }tari_wallet_ffi`; + + // Load the library + this.#fn = ffi.Library(outputProcess, { + transport_memory_create: ["pointer", ["void"]], + transport_tcp_create: ["pointer", ["string", "int*"]], + transport_tor_create: [ + "pointer", + ["string", "pointer", "ushort", "string", "string", "int*"], + ], + transport_memory_get_address: ["char*", ["pointer", "int*"]], + transport_type_destroy: ["void", ["pointer"]], + string_destroy: ["void", ["string"]], + byte_vector_create: ["pointer", ["uchar*", "uint", "int*"]], + byte_vector_get_at: ["uchar", ["pointer", "uint", "int*"]], + byte_vector_get_length: ["uint", ["pointer", "int*"]], + byte_vector_destroy: ["void", ["pointer"]], + public_key_create: ["pointer", ["pointer", "int*"]], + public_key_get_bytes: ["pointer", ["pointer", "int*"]], + public_key_from_private_key: ["pointer", ["pointer", "int*"]], + public_key_from_hex: ["pointer", ["string", "int*"]], + public_key_destroy: ["void", ["pointer"]], + public_key_to_emoji_id: ["char*", ["pointer", "int*"]], + emoji_id_to_public_key: ["pointer", ["string", "int*"]], + private_key_create: ["pointer", ["pointer", "int*"]], + private_key_generate: ["pointer", ["void"]], + private_key_get_bytes: ["pointer", ["pointer", "int*"]], + private_key_from_hex: ["pointer", ["string", "int*"]], + private_key_destroy: ["void", ["pointer"]], + seed_words_create: ["pointer", ["void"]], + seed_words_get_length: ["uint", ["pointer", "int*"]], + seed_words_get_at: ["char*", ["pointer", "uint", "int*"]], + seed_words_push_word: ["uchar", ["pointer", "string", "int*"]], + seed_words_destroy: ["void", ["pointer"]], + contact_create: ["pointer", ["string", "pointer", "int*"]], + contact_get_alias: ["char*", ["pointer", "int*"]], + contact_get_public_key: ["pointer", ["pointer", "int*"]], + contact_destroy: ["void", ["pointer"]], + contacts_get_length: ["uint", ["pointer", "int*"]], + contacts_get_at: ["pointer", ["pointer", "uint", "int*"]], + contacts_destroy: ["void", ["pointer"]], + completed_transaction_get_destination_public_key: [ + "pointer", + ["pointer", "int*"], + ], + completed_transaction_get_source_public_key: [ + "pointer", + ["pointer", "int*"], + ], + completed_transaction_get_amount: ["uint64", ["pointer", "int*"]], + completed_transaction_get_fee: ["uint64", ["pointer", "int*"]], + completed_transaction_get_message: ["char*", ["pointer", "int*"]], + completed_transaction_get_status: ["int", ["pointer", "int*"]], + completed_transaction_get_transaction_id: ["uint64", ["pointer", "int*"]], + completed_transaction_get_timestamp: ["uint64", ["pointer", "int*"]], + completed_transaction_is_valid: ["bool", ["pointer", "int*"]], + completed_transaction_is_outbound: ["bool", ["pointer", "int*"]], + completed_transaction_get_confirmations: ["uint64", ["pointer", "int*"]], + completed_transaction_destroy: ["void", ["pointer"]], + //completed_transaction_get_excess: [ + //this.tari_excess_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + //completed_transaction_get_public_nonce: [ + // this.tari_excess_public_nonce_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + //completed_transaction_get_signature: [ + // this.tari_excess_signature_ptr, + // [this.tari_completed_transaction_ptr, "int*"], + //], + // excess_destroy: ["void", [this.tari_excess_ptr]], + // nonce_destroy: ["void", [this.tari_excess_public_nonce_ptr]], + // signature_destroy: ["void", [this.tari_excess_signature_ptr]], + completed_transactions_get_length: ["uint", ["pointer", "int*"]], + completed_transactions_get_at: ["pointer", ["pointer", "uint", "int*"]], + completed_transactions_destroy: ["void", ["pointer"]], + pending_outbound_transaction_get_transaction_id: [ + "uint64", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_destination_public_key: [ + "pointer", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_amount: ["uint64", ["pointer", "int*"]], + pending_outbound_transaction_get_fee: ["uint64", ["pointer", "int*"]], + pending_outbound_transaction_get_message: ["char*", ["pointer", "int*"]], + pending_outbound_transaction_get_timestamp: [ + "uint64", + ["pointer", "int*"], + ], + pending_outbound_transaction_get_status: ["int", ["pointer", "int*"]], + pending_outbound_transaction_destroy: ["void", ["pointer"]], + pending_outbound_transactions_get_length: ["uint", ["pointer", "int*"]], + pending_outbound_transactions_get_at: [ + "pointer", + ["pointer", "uint", "int*"], + ], + pending_outbound_transactions_destroy: ["void", ["pointer"]], + pending_inbound_transaction_get_transaction_id: [ + "uint64", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_source_public_key: [ + "pointer", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_message: ["char*", ["pointer", "int*"]], + pending_inbound_transaction_get_amount: ["uint64", ["pointer", "int*"]], + pending_inbound_transaction_get_timestamp: [ + "uint64", + ["pointer", "int*"], + ], + pending_inbound_transaction_get_status: ["int", ["pointer", "int*"]], + pending_inbound_transaction_destroy: ["void", ["pointer"]], + pending_inbound_transactions_get_length: ["uint", ["pointer", "int*"]], + pending_inbound_transactions_get_at: [ + "pointer", + ["pointer", "uint", "int*"], + ], + pending_inbound_transactions_destroy: ["void", ["pointer"]], + comms_config_create: [ + "pointer", + [ + "string", + "pointer", + "string", + "string", + "uint64", + "uint64", + "string", + "int*", + ], + ], + comms_config_destroy: ["void", ["pointer"]], + wallet_create: [ + "pointer", + [ + "pointer", + "string", + "uint", + "uint", + "string", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "pointer", + "bool*", + "int*", + ], + ], + wallet_sign_message: ["char*", ["pointer", "string", "int*"]], + wallet_verify_message_signature: [ + "bool", + ["pointer", "pointer", "string", "string", "int*"], + ], + wallet_add_base_node_peer: [ + "bool", + ["pointer", "pointer", "string", "int*"], + ], + wallet_upsert_contact: ["bool", ["pointer", "pointer", "int*"]], + wallet_remove_contact: ["bool", ["pointer", "pointer", "int*"]], + wallet_get_available_balance: ["uint64", ["pointer", "int*"]], + wallet_get_pending_incoming_balance: ["uint64", ["pointer", "int*"]], + wallet_get_pending_outgoing_balance: ["uint64", ["pointer", "int*"]], + wallet_get_fee_estimate: [ + "uint64", + ["pointer", "uint64", "uint64", "uint64", "uint64", "int*"], + ], + wallet_get_num_confirmations_required: ["uint64", ["pointer", "int*"]], + wallet_set_num_confirmations_required: [ + "void", + ["pointer", "uint64", "int*"], + ], + wallet_send_transaction: [ + "uint64", + ["pointer", "pointer", "uint64", "uint64", "string", "int*"], + ], + wallet_get_contacts: ["pointer", ["pointer", "int*"]], + wallet_get_completed_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_pending_outbound_transactions: [ + "pointer", + ["pointer", "int*"], + ], + wallet_get_public_key: ["pointer", ["pointer", "int*"]], + wallet_get_pending_inbound_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_cancelled_transactions: ["pointer", ["pointer", "int*"]], + wallet_get_completed_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_pending_outbound_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_pending_inbound_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_get_cancelled_transaction_by_id: [ + "pointer", + ["pointer", "uint64", "int*"], + ], + wallet_import_utxo: [ + "uint64", + ["pointer", "uint64", "pointer", "pointer", "string", "int*"], + ], + wallet_start_utxo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_stxo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_invalid_txo_validation: ["uint64", ["pointer", "int*"]], + wallet_start_transaction_validation: ["uint64", ["pointer", "int*"]], + wallet_restart_transaction_broadcast: ["bool", ["pointer", "int*"]], + wallet_set_low_power_mode: ["void", ["pointer", "int*"]], + wallet_set_normal_power_mode: ["void", ["pointer", "int*"]], + wallet_cancel_pending_transaction: [ + "bool", + ["pointer", "uint64", "int*"], + ], + wallet_coin_split: [ + "uint64", + ["pointer", "uint64", "uint64", "uint64", "string", "uint64", "int*"], + ], + wallet_get_seed_words: ["pointer", ["pointer", "int*"]], + wallet_apply_encryption: ["void", ["pointer", "string", "int*"]], + wallet_remove_encryption: ["void", ["pointer", "int*"]], + wallet_set_key_value: ["bool", ["pointer", "string", "string", "int*"]], + wallet_get_value: ["char*", ["pointer", "string", "int*"]], + wallet_clear_value: ["bool", ["pointer", "string", "int*"]], + wallet_is_recovery_in_progress: ["bool", ["pointer", "int*"]], + wallet_start_recovery: [ + "bool", + ["pointer", "pointer", "pointer", "int*"], + ], + wallet_destroy: ["void", ["pointer"]], + file_partial_backup: ["void", ["string", "string", "int*"]], + log_debug_message: ["void", ["string"]], + get_emoji_set: ["pointer", ["void"]], + emoji_set_destroy: ["void", ["pointer"]], + emoji_set_get_at: ["pointer", ["pointer", "uint", "int*"]], + emoji_set_get_length: ["uint", ["pointer", "int*"]], + }); + } + //endregion + + static checkErrorResult(error, error_name) { + expect(error.deref()).to.equal(0, `Error in ${error_name}`); + } + + //region Helpers + static initError() { + let error = Buffer.alloc(4); + error.writeInt32LE(-1, 0); + error.type = ref.types.int; + return error; + } + + static initBool() { + let boolean = ref.alloc(ref.types.bool); + return boolean; + } + + static filePartialBackup(original_file_path, backup_file_path) { + let error = this.initError(); + let result = this.#fn.file_partial_backup( + original_file_path, + backup_file_path, + error + ); + this.checkErrorResult(error, `filePartialBackup`); + return result; + } + + static logDebugMessage(msg) { + this.#fn.log_debug_message(msg); + } + //endregion + + //region String + static stringDestroy(s) { + this.#fn.string_destroy(s); + } + //endregion + + // region ByteVector + static byteVectorCreate(byte_array, element_count) { + let error = this.initError(); + let result = this.#fn.byte_vector_create(byte_array, element_count, error); + this.checkErrorResult(error, `byteVectorCreate`); + return result; + } + + static byteVectorGetAt(ptr, i) { + let error = this.initError(); + let result = this.#fn.byte_vector_get_at(ptr, i, error); + this.checkErrorResult(error, `byteVectorGetAt`); + return result; + } + + static byteVectorGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.byte_vector_get_length(ptr, error); + this.checkErrorResult(error, `byteVectorGetLength`); + return result; + } + + static byteVectorDestroy(ptr) { + this.#fn.byte_vector_destroy(ptr); + } + //endregion + + //region PrivateKey + static privateKeyCreate(ptr) { + let error = this.initError(); + let result = this.#fn.private_key_create(ptr, error); + this.checkErrorResult(error, `privateKeyCreate`); + return result; + } + + static privateKeyGenerate() { + return this.#fn.private_key_generate(); + } + + static privateKeyGetBytes(ptr) { + let error = this.initError(); + let result = this.#fn.private_key_get_bytes(ptr, error); + this.checkErrorResult(error, "privateKeyGetBytes"); + return result; + } + + static privateKeyFromHex(hex) { + let error = this.initError(); + let result = this.#fn.private_key_from_hex(hex, error); + this.checkErrorResult(error, "privateKeyFromHex"); + return result; + } + + static privateKeyDestroy(ptr) { + this.#fn.private_key_destroy(ptr); + } + + //endregion + + //region PublicKey + static publicKeyCreate(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_create(ptr, error); + this.checkErrorResult(error, `publicKeyCreate`); + return result; + } + + static publicKeyGetBytes(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_get_bytes(ptr, error); + this.checkErrorResult(error, `publicKeyGetBytes`); + return result; + } + + static publicKeyFromPrivateKey(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_from_private_key(ptr, error); + this.checkErrorResult(error, `publicKeyFromPrivateKey`); + return result; + } + + static publicKeyFromHex(hex) { + let error = this.initError(); + let result = this.#fn.public_key_from_hex(hex, error); + this.checkErrorResult(error, `publicKeyFromHex`); + return result; + } + + static emojiIdToPublicKey(emoji) { + let error = this.initError(); + let result = this.#fn.emoji_id_to_public_key(emoji, error); + this.checkErrorResult(error, `emojiIdToPublicKey`); + return result; + } + + static publicKeyToEmojiId(ptr) { + let error = this.initError(); + let result = this.#fn.public_key_to_emoji_id(ptr, error); + this.checkErrorResult(error, `publicKeyToEmojiId`); + return result; + } + + static publicKeyDestroy(ptr) { + this.#fn.public_key_destroy(ptr); + } + //endregion + + //region TransportType + static transportMemoryCreate() { + return this.#fn.transport_memory_create(); + } + + static transportTcpCreate(listener_address) { + let error = this.initError(); + let result = this.#fn.transport_tcp_create(listener_address, error); + this.checkErrorResult(error, `transportTcpCreate`); + return result; + } + + static transportTorCreate( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password + ) { + let error = this.initError(); + let result = this.#fn.transport_tor_create( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password, + error + ); + this.checkErrorResult(error, `transportTorCreate`); + return result; + } + + static transportMemoryGetAddress(transport) { + let error = this.initError(); + let result = this.#fn.transport_memory_get_address(transport, error); + this.checkErrorResult(error, `transportMemoryGetAddress`); + return result; + } + + static transportTypeDestroy(transport) { + this.#fn.transport_type_destroy(transport); + } + //endregion + + //region EmojiSet + static getEmojiSet() { + return this.#fn.this.#fn.get_emoji_set(); + } + + static emojiSetDestroy(ptr) { + this.#fn.emoji_set_destroy(ptr); + } + + static emojiSetGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.emoji_set_get_at(ptr, position, error); + this.checkErrorResult(error, `emojiSetGetAt`); + return result; + } + + static emojiSetGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.emoji_set_get_length(ptr, error); + this.checkErrorResult(error, `emojiSetGetLength`); + return result; + } + //endregion + + //region SeedWords + static seedWordsCreate() { + return this.#fn.seed_words_create(); + } + + static seedWordsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.seed_words_get_length(ptr, error); + this.checkErrorResult(error, `emojiSetGetLength`); + return result; + } + + static seedWordsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.seed_words_get_at(ptr, position, error); + this.checkErrorResult(error, `seedWordsGetAt`); + return result; + } + + static seedWordsPushWord(ptr, word) { + let error = this.initError(); + let result = this.#fn.seed_words_push_word(ptr, word, error); + this.checkErrorResult(error, `seedWordsPushWord`); + return result; + } + + static seedWordsDestroy(ptr) { + this.#fn.seed_words_destroy(ptr); + } + //endregion + + //region CommsConfig + static commsConfigCreate( + public_address, + transport, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network + ) { + let error = this.initError(); + let result = this.#fn.comms_config_create( + public_address, + transport, + database_name, + datastore_path, + discovery_timeout_in_secs, + saf_message_duration_in_secs, + network, + error + ); + this.checkErrorResult(error, `commsConfigCreate`); + return result; + } + + static commsConfigDestroy(ptr) { + this.#fn.comms_config_destroy(ptr); + } + //endregion + + //region Contact + static contactCreate(alias, public_key) { + let error = this.initError(); + let result = this.#fn.contact_create(alias, public_key, error); + this.checkErrorResult(error, `contactCreate`); + return result; + } + + static contactGetAlias(ptr) { + let error = this.initError(); + let result = this.#fn.contact_get_alias(ptr, error); + this.checkErrorResult(error, `contactGetAlias`); + return result; + } + + static contactGetPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.contact_get_public_key(ptr, error); + this.checkErrorResult(error, `contactGetPublicKey`); + return result; + } + + static contactDestroy(ptr) { + this.#fn.contact_destroy(ptr); + } + //endregion + + //region Contacts (List) + static contactsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.contacts_get_length(ptr, error); + this.checkErrorResult(error, `contactsGetLength`); + return result; + } + + static contactsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.contacts_get_at(ptr, position, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static contactsDestroy(ptr) { + this.#fn.contacts_destroy(ptr); + } + //endregion + + //region CompletedTransaction + static completedTransactionGetDestinationPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_destination_public_key( + ptr, + error + ); + this.checkErrorResult(error, `completedTransactionGetDestinationPublicKey`); + return result; + } + + static completedTransactionGetSourcePublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_source_public_key( + ptr, + error + ); + this.checkErrorResult(error, `completedTransactionGetSourcePublicKey`); + return result; + } + + static completedTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `completedTransactionGetAmount`); + return result; + } + + static completedTransactionGetFee(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_fee(ptr, error); + this.checkErrorResult(error, `completedTransactionGetFee`); + return result; + } + + static completedTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_message(ptr, error); + this.checkErrorResult(error, `completedTransactionGetMessage`); + return result; + } + + static completedTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_status(ptr, error); + this.checkErrorResult(error, `completedTransactionGetStatus`); + return result; + } + + static completedTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_transaction_id(ptr, error); + this.checkErrorResult(error, `completedTransactionGetTransactionId`); + return result; + } + + static completedTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_timestamp(ptr, error); + this.checkErrorResult(error, `completedTransactionGetTimestamp`); + return result; + } + + static completedTransactionIsValid(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_is_valid(ptr, error); + this.checkErrorResult(error, `completedTransactionIsValid`); + return result; + } + + static completedTransactionIsOutbound(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_is_outbound(ptr, error); + this.checkErrorResult(error, `completedTransactionGetConfirmations`); + return result; + } + + static completedTransactionGetConfirmations(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transaction_get_confirmations(ptr, error); + this.checkErrorResult(error, `completedTransactionGetConfirmations`); + return result; + } + + static completedTransactionDestroy(ptr) { + this.#fn.completed_transaction_destroy(ptr); + } + + //endregion + + /* + //Flagged as design flaw in the FFI lib + + static completedTransactionGetExcess(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_excess.async( + transaction, + this.error, + this.checkAsyncRes(resolve, reject, "completedTransactionGetExcess") + ) + ); + } + + static completedTransactionGetPublicNonce(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_public_nonce.async( + transaction, + this.error, + this.checkAsyncRes( + resolve, + reject, + "completedTransactionGetPublicNonce" + ) + ) + ); + } + + static completedTransactionGetSignature(transaction) { + return new Promise((resolve, reject) => + this.#fn.completed_transaction_get_signature.async( + transaction, + this.error, + this.checkAsyncRes(resolve, reject, "completedTransactionGetSignature") + ) + ); + } + + static excessDestroy(excess) { + return new Promise((resolve, reject) => + this.#fn.excess_destroy.async( + excess, + this.checkAsyncRes(resolve, reject, "excessDestroy") + ) + ); + } + + static nonceDestroy(nonce) { + return new Promise((resolve, reject) => + this.#fn.nonce_destroy.async( + nonce, + this.checkAsyncRes(resolve, reject, "nonceDestroy") + ) + ); + } + + static signatureDestroy(signature) { + return new Promise((resolve, reject) => + this.#fn.signature_destroy.async( + signature, + this.checkAsyncRes(resolve, reject, "signatureDestroy") + ) + ); + } + */ + + //region CompletedTransactions (List) + static completedTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.completed_transactions_get_length(ptr, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static completedTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.completed_transactions_get_at(ptr, position, error); + this.checkErrorResult(error, `contactsGetAt`); + return result; + } + + static completedTransactionsDestroy(transactions) { + this.#fn.completed_transactions_destroy(transactions); + } + //endregion + + //region PendingOutboundTransaction + static pendingOutboundTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_transaction_id( + ptr, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionGetTransactionId`); + return result; + } + + static pendingOutboundTransactionGetDestinationPublicKey(ptr) { + let error = this.initError(); + let result = + this.#fn.pending_outbound_transaction_get_destination_public_key( + ptr, + error + ); + this.checkErrorResult( + error, + `pendingOutboundTransactionGetDestinationPublicKey` + ); + return result; + } + + static pendingOutboundTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetAmount`); + return result; + } + + static pendingOutboundTransactionGetFee(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_fee(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetFee`); + return result; + } + + static pendingOutboundTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_message(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetMessage`); + return result; + } + + static pendingOutboundTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_timestamp( + ptr, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionGetTimestamp`); + return result; + } + + static pendingOutboundTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transaction_get_status(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionGetStatus`); + return result; + } + + static pendingOutboundTransactionDestroy(ptr) { + this.#fn.pending_outbound_transaction_destroy(ptr); + } + //endregion + + //region PendingOutboundTransactions (List) + static pendingOutboundTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transactions_get_length(ptr, error); + this.checkErrorResult(error, `pendingOutboundTransactionsGetLength`); + return result; + } + + static pendingOutboundTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.pending_outbound_transactions_get_at( + ptr, + position, + error + ); + this.checkErrorResult(error, `pendingOutboundTransactionsGetAt`); + return result; + } + + static pendingOutboundTransactionsDestroy(ptr) { + this.#fn.pending_outbound_transactions_destroy(ptr); + } + //endregion + + //region PendingInboundTransaction + static pendingInboundTransactionGetTransactionId(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_transaction_id( + ptr, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionGetTransactionId`); + return result; + } + + static pendingInboundTransactionGetSourcePublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_source_public_key( + ptr, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionGetSourcePublicKey`); + return result; + } + + static pendingInboundTransactionGetMessage(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_message(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetMessage`); + return result; + } + + static pendingInboundTransactionGetAmount(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_amount(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetAmount`); + return result; + } + + static pendingInboundTransactionGetTimestamp(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_timestamp(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetTimestamp`); + return result; + } + + static pendingInboundTransactionGetStatus(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transaction_get_status(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionGetStatus`); + return result; + } + + static pendingInboundTransactionDestroy(ptr) { + this.#fn.pending_inbound_transaction_destroy(ptr); + } + //endregion + + //region PendingInboundTransactions (List) + static pendingInboundTransactionsGetLength(ptr) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transactions_get_length(ptr, error); + this.checkErrorResult(error, `pendingInboundTransactionsGetLength`); + return result; + } + + static pendingInboundTransactionsGetAt(ptr, position) { + let error = this.initError(); + let result = this.#fn.pending_inbound_transactions_get_at( + ptr, + position, + error + ); + this.checkErrorResult(error, `pendingInboundTransactionsGetAt`); + return result; + } + + static pendingInboundTransactionsDestroy(ptr) { + this.#fn.pending_inbound_transactions_destroy(ptr); + } + //endregion + + //region Wallet + + //region Callbacks + static createCallbackReceivedTransaction(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackReceivedTransactionReply(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackReceivedFinalizedTransaction(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionBroadcast(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionMined(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + + static createCallbackTransactionMinedUnconfirmed(fn) { + return ffi.Callback("void", ["pointer", "uint64"], fn); + } + + static createCallbackDirectSendResult(fn) { + return ffi.Callback("void", ["uint64", "bool"], fn); + } + + static createCallbackStoreAndForwardSendResult(fn) { + return ffi.Callback("void", ["uint64", "bool"], fn); + } + + static createCallbackTransactionCancellation(fn) { + return ffi.Callback("void", ["pointer"], fn); + } + static createCallbackUtxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackStxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackInvalidTxoValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackTransactionValidationComplete(fn) { + return ffi.Callback("void", ["uint64", "uchar"], fn); + } + static createCallbackSafMessageReceived(fn) { + return ffi.Callback("void", ["void"], fn); + } + static createRecoveryProgressCallback(fn) { + return ffi.Callback("void", ["uchar", "uint64", "uint64"], fn); + } + //endregion + + static walletCreate( + config, + log_path, + num_rolling_log_files, + size_per_log_file_bytes, + passphrase, + seed_words, + callback_received_transaction, + callback_received_transaction_reply, + callback_received_finalized_transaction, + callback_transaction_broadcast, + callback_transaction_mined, + callback_transaction_mined_unconfirmed, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, + callback_utxo_validation_complete, + callback_stxo_validation_complete, + callback_invalid_txo_validation_complete, + callback_transaction_validation_complete, + callback_saf_message_received + ) { + let error = this.initError(); + let recovery_in_progress = this.initBool(); + + let result = this.#fn.wallet_create( + config, + log_path, + num_rolling_log_files, + size_per_log_file_bytes, + passphrase, + seed_words, + callback_received_transaction, + callback_received_transaction_reply, + callback_received_finalized_transaction, + callback_transaction_broadcast, + callback_transaction_mined, + callback_transaction_mined_unconfirmed, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, + callback_utxo_validation_complete, + callback_stxo_validation_complete, + callback_invalid_txo_validation_complete, + callback_transaction_validation_complete, + callback_saf_message_received, + recovery_in_progress, + error + ); + this.checkErrorResult(error, `walletCreate`); + if (recovery_in_progress) { + console.log("Wallet recovery is in progress"); + } + return result; + } + + static walletGetPublicKey(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_public_key(ptr, error); + this.checkErrorResult(error, `walletGetPublicKey`); + return result; + } + + static walletSignMessage(ptr, msg) { + let error = this.initError(); + let result = this.#fn.wallet_sign_message(ptr, msg, error); + this.checkErrorResult(error, `walletSignMessage`); + return result; + } + + static walletVerifyMessageSignature(ptr, public_key_ptr, hex_sig_nonce, msg) { + let error = this.initError(); + let result = this.#fn.wallet_verify_message_signature( + ptr, + public_key_ptr, + hex_sig_nonce, + msg, + error + ); + this.checkErrorResult(error, `walletVerifyMessageSignature`); + return result; + } + + static walletAddBaseNodePeer(ptr, public_key_ptr, address) { + let error = this.initError(); + let result = this.#fn.wallet_add_base_node_peer( + ptr, + public_key_ptr, + address, + error + ); + this.checkErrorResult(error, `walletAddBaseNodePeer`); + return result; + } + + static walletUpsertContact(ptr, contact_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_upsert_contact(ptr, contact_ptr, error); + this.checkErrorResult(error, `walletUpsertContact`); + return result; + } + + static walletRemoveContact(ptr, contact_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_remove_contact(ptr, contact_ptr, error); + this.checkErrorResult(error, `walletRemoveContact`); + return result; + } + + static walletGetAvailableBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_available_balance(ptr, error); + this.checkErrorResult(error, `walletGetAvailableBalance`); + return result; + } + + static walletGetPendingIncomingBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_incoming_balance(ptr, error); + this.checkErrorResult(error, `walletGetPendingIncomingBalance`); + return result; + } + + static walletGetPendingOutgoingBalance(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outgoing_balance(ptr, error); + this.checkErrorResult(error, `walletGetPendingOutgoingBalance`); + return result; + } + + static walletGetFeeEstimate( + ptr, + amount, + fee_per_gram, + num_kernels, + num_outputs + ) { + let error = this.initError(); + let result = this.#fn.wallet_get_fee_estimate( + ptr, + amount, + fee_per_gram, + num_kernels, + num_outputs, + error + ); + this.checkErrorResult(error, `walletGetFeeEstimate`); + return result; + } + + static walletGetNumConfirmationsRequired(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_num_confirmations_required(ptr, error); + this.checkErrorResult(error, `walletGetNumConfirmationsRequired`); + return result; + } + + static walletSetNumConfirmationsRequired(ptr, num) { + let error = this.initError(); + this.#fn.wallet_set_num_confirmations_required(ptr, num, error); + this.checkErrorResult(error, `walletSetNumConfirmationsRequired`); + } + + static walletSendTransaction( + ptr, + destination, + amount, + fee_per_gram, + message + ) { + let error = this.initError(); + let result = this.#fn.wallet_send_transaction( + ptr, + destination, + amount, + fee_per_gram, + message, + error + ); + this.checkErrorResult(error, `walletSendTransaction`); + return result; + } + + static walletGetContacts(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_contacts(ptr, error); + this.checkErrorResult(error, `walletGetContacts`); + return result; + } + + static walletGetCompletedTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_completed_transactions(ptr, error); + this.checkErrorResult(error, `walletGetCompletedTransactions`); + return result; + } + + static walletGetPendingOutboundTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outbound_transactions(ptr, error); + this.checkErrorResult(error, `walletGetPendingOutboundTransactions`); + return result; + } + + static walletGetPendingInboundTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_inbound_transactions(ptr, error); + this.checkErrorResult(error, `walletGetPendingInboundTransactions`); + return result; + } + + static walletGetCancelledTransactions(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_cancelled_transactions(ptr, error); + this.checkErrorResult(error, `walletGetCancelledTransactions`); + return result; + } + + static walletGetCompletedTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_completed_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetCompletedTransactionById`); + return result; + } + + static walletGetPendingOutboundTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_outbound_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetPendingOutboundTransactionById`); + return result; + } + + static walletGetPendingInboundTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_pending_inbound_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetPendingInboundTransactionById`); + return result; + } + + static walletGetCancelledTransactionById(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_get_cancelled_transaction_by_id( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletGetCancelledTransactionById`); + return result; + } + + static walletImportUtxo( + ptr, + amount, + spending_key_ptr, + source_public_key_ptr, + message + ) { + let error = this.initError(); + let result = this.#fn.wallet_import_utxo( + ptr, + amount, + spending_key_ptr, + source_public_key_ptr, + message, + error + ); + this.checkErrorResult(error, `walletImportUtxo`); + return result; + } + + static walletStartUtxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_utxo_validation(ptr, error); + this.checkErrorResult(error, `walletStartUtxoValidation`); + return result; + } + + static walletStartStxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_stxo_validation(ptr, error); + this.checkErrorResult(error, `walletStartStxoValidation`); + return result; + } + + static walletStartInvalidTxoValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_invalid_txo_validation(ptr, error); + this.checkErrorResult(error, `walletStartInvalidUtxoValidation`); + return result; + } + + static walletStartTransactionValidation(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_start_transaction_validation(ptr, error); + this.checkErrorResult(error, `walletStartTransactionValidation`); + return result; + } + + static walletRestartTransactionBroadcast(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_restart_transaction_broadcast(ptr, error); + this.checkErrorResult(error, `walletRestartTransactionBroadcast`); + return result; + } + + static walletSetLowPowerMode(ptr) { + let error = this.initError(); + this.#fn.wallet_set_low_power_mode(ptr, error); + this.checkErrorResult(error, `walletSetLowPowerMode`); + } + + static walletSetNormalPowerMode(ptr) { + let error = this.initError(); + this.#fn.wallet_set_normal_power_mode(ptr, error); + this.checkErrorResult(error, `walletSetNormalPowerMode`); + } + + static walletCancelPendingTransaction(ptr, transaction_id) { + let error = this.initError(); + let result = this.#fn.wallet_cancel_pending_transaction( + ptr, + transaction_id, + error + ); + this.checkErrorResult(error, `walletCancelPendingTransaction`); + return result; + } + + static walletCoinSplit(ptr, amount, count, fee, msg, lock_height) { + let error = this.initError(); + let result = this.#fn.wallet_coin_split( + ptr, + amount, + count, + fee, + msg, + lock_height, + error + ); + this.checkErrorResult(error, `walletCoinSplit`); + return result; + } + + static walletGetSeedWords(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_seed_words(ptr, error); + this.checkErrorResult(error, `walletGetSeedWords`); + return result; + } + + static walletApplyEncryption(ptr, passphrase) { + let error = this.initError(); + this.#fn.wallet_apply_encryption(ptr, passphrase, error); + this.checkErrorResult(error, `walletApplyEncryption`); + } + + static walletRemoveEncryption(ptr) { + let error = this.initError(); + this.#fn.wallet_remove_encryption(ptr, error); + this.checkErrorResult(error, `walletRemoveEncryption`); + } + + static walletSetKeyValue(ptr, key_ptr, value) { + let error = this.initError(); + let result = this.#fn.wallet_set_key_value(ptr, key_ptr, value, error); + this.checkErrorResult(error, `walletSetKeyValue`); + return result; + } + + static walletGetValue(ptr, key_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_get_value(ptr, key_ptr, error); + this.checkErrorResult(error, `walletGetValue`); + return result; + } + + static walletClearValue(ptr, key_ptr) { + let error = this.initError(); + let result = this.#fn.wallet_clear_value(ptr, key_ptr, error); + this.checkErrorResult(error, `walletClearValue`); + return result; + } + + static walletIsRecoveryInProgress(ptr) { + let error = this.initError(); + let result = this.#fn.wallet_is_recovery_in_progress(ptr, error); + this.checkErrorResult(error, `walletIsRecoveryInProgress`); + return result; + } + + static walletStartRecovery( + ptr, + base_node_public_key_ptr, + recovery_progress_callback + ) { + let error = this.initError(); + let result = this.#fn.wallet_start_recovery( + ptr, + base_node_public_key_ptr, + recovery_progress_callback, + error + ); + this.checkErrorResult(error, `walletStartRecovery`); + return result; + } + + static walletDestroy(ptr) { + this.#fn.wallet_destroy(ptr); + } + //endregion +} +module.exports = InterfaceFFI; diff --git a/integration_tests/helpers/ffi/pendingInboundTransaction.js b/integration_tests/helpers/ffi/pendingInboundTransaction.js index 32cae202fb..dc2071e24b 100644 --- a/integration_tests/helpers/ffi/pendingInboundTransaction.js +++ b/integration_tests/helpers/ffi/pendingInboundTransaction.js @@ -1,24 +1,70 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class PendingInboundTransaction { #tari_pending_inbound_transaction_ptr; - constructor(tari_pending_inbound_transaction_ptr) { - this.#tari_pending_inbound_transaction_ptr = - tari_pending_inbound_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_pending_inbound_transaction_ptr) { + this.destroy(); + this.#tari_pending_inbound_transaction_ptr = ptr; + } else { + this.#tari_pending_inbound_transaction_ptr = ptr; + } + } + + getPtr() { + return this.#tari_pending_inbound_transaction_ptr; + } + + getSourcePublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.pendingInboundTransactionGetSourcePublicKey( + this.#tari_pending_inbound_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.pendingInboundTransactionGetAmount( + this.#tari_pending_inbound_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.pendingInboundTransactionGetMessage( + this.#tari_pending_inbound_transaction_ptr + ); } getStatus() { - return WalletFFI.pendingInboundTransactionGetStatus( + return InterfaceFFI.pendingInboundTransactionGetStatus( this.#tari_pending_inbound_transaction_ptr ); } - destroy() { - return WalletFFI.pendingInboundTransactionDestroy( + getTransactionID() { + return InterfaceFFI.pendingInboundTransactionGetTransactionId( + this.#tari_pending_inbound_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.pendingInboundTransactionGetTimestamp( this.#tari_pending_inbound_transaction_ptr ); } + + destroy() { + if (this.#tari_pending_inbound_transaction_ptr) { + InterfaceFFI.pendingInboundTransactionDestroy( + this.#tari_pending_inbound_transaction_ptr + ); + this.#tari_pending_inbound_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = PendingInboundTransaction; diff --git a/integration_tests/helpers/ffi/pendingInboundTransactions.js b/integration_tests/helpers/ffi/pendingInboundTransactions.js index 6246b03429..2500b41a04 100644 --- a/integration_tests/helpers/ffi/pendingInboundTransactions.js +++ b/integration_tests/helpers/ffi/pendingInboundTransactions.js @@ -1,39 +1,37 @@ const PendingInboundTransaction = require("./pendingInboundTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class PendingInboundTransactions { #tari_pending_inbound_transactions_ptr; - constructor(tari_pending_inbound_transactions_ptr) { - this.#tari_pending_inbound_transactions_ptr = - tari_pending_inbound_transactions_ptr; - } - - static async fromWallet(wallet) { - return new PendingInboundTransactions( - await WalletFFI.walletGetPendingInboundTransactions(wallet) - ); + constructor(ptr) { + this.#tari_pending_inbound_transactions_ptr = ptr; } getLength() { - return WalletFFI.pendingInboundTransactionsGetLength( + return InterfaceFFI.pendingInboundTransactionsGetLength( this.#tari_pending_inbound_transactions_ptr ); } - async getAt(position) { - return new PendingInboundTransaction( - await WalletFFI.pendingInboundTransactionsGetAt( + getAt(position) { + let result = new PendingInboundTransaction(); + result.pointerAssign( + InterfaceFFI.pendingInboundTransactionsGetAt( this.#tari_pending_inbound_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.pendingInboundTransactionsDestroy( - this.#tari_pending_inbound_transactions_ptr - ); + if (this.#tari_pending_inbound_transactions_ptr) { + InterfaceFFI.pendingInboundTransactionsDestroy( + this.#tari_pending_inbound_transactions_ptr + ); + this.#tari_pending_inbound_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/pendingOutboundTransaction.js b/integration_tests/helpers/ffi/pendingOutboundTransaction.js index eed2d722bb..0fc2ca47b9 100644 --- a/integration_tests/helpers/ffi/pendingOutboundTransaction.js +++ b/integration_tests/helpers/ffi/pendingOutboundTransaction.js @@ -1,30 +1,76 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); class PendingOutboundTransaction { #tari_pending_outbound_transaction_ptr; - constructor(tari_pending_outbound_transaction_ptr) { - this.#tari_pending_outbound_transaction_ptr = - tari_pending_outbound_transaction_ptr; + pointerAssign(ptr) { + if (this.#tari_pending_outbound_transaction_ptr) { + this.#tari_pending_outbound_transaction_ptr = ptr; + this.destroy(); + } else { + this.#tari_pending_outbound_transaction_ptr = ptr; + } } - getTransactionId() { - return WalletFFI.pendingOutboundTransactionGetTransactionId( + getPtr() { + return this.#tari_pending_outbound_transaction_ptr; + } + + getDestinationPublicKey() { + let result = new PublicKey(); + result.pointerAssign( + InterfaceFFI.pendingOutboundTransactionGetDestinationPublicKey( + this.#tari_pending_outbound_transaction_ptr + ) + ); + return result; + } + + getAmount() { + return InterfaceFFI.pendingOutboundTransactionGetAmount( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getFee() { + return InterfaceFFI.pendingOutboundTransactionGetFee( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getMessage() { + return InterfaceFFI.pendingOutboundTransactionGetMessage( this.#tari_pending_outbound_transaction_ptr ); } getStatus() { - return WalletFFI.pendingOutboundTransactionGetStatus( + return InterfaceFFI.pendingOutboundTransactionGetStatus( this.#tari_pending_outbound_transaction_ptr ); } - destroy() { - return WalletFFI.pendingOutboundTransactionDestroy( + getTransactionID() { + return InterfaceFFI.pendingOutboundTransactionGetTransactionId( + this.#tari_pending_outbound_transaction_ptr + ); + } + + getTimestamp() { + return InterfaceFFI.pendingOutboundTransactionGetTimestamp( this.#tari_pending_outbound_transaction_ptr ); } + + destroy() { + if (this.#tari_pending_outbound_transaction_ptr) { + InterfaceFFI.pendingOutboundTransactionDestroy( + this.#tari_pending_outbound_transaction_ptr + ); + this.#tari_pending_outbound_transaction_ptr = undefined; //prevent double free segfault + } + } } module.exports = PendingOutboundTransaction; diff --git a/integration_tests/helpers/ffi/pendingOutboundTransactions.js b/integration_tests/helpers/ffi/pendingOutboundTransactions.js index 28e408563d..45de06033b 100644 --- a/integration_tests/helpers/ffi/pendingOutboundTransactions.js +++ b/integration_tests/helpers/ffi/pendingOutboundTransactions.js @@ -1,39 +1,37 @@ const PendingOutboundTransaction = require("./pendingOutboundTransaction"); -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); class PendingOutboundTransactions { #tari_pending_outbound_transactions_ptr; - constructor(tari_pending_outbound_transactions_ptr) { - this.#tari_pending_outbound_transactions_ptr = - tari_pending_outbound_transactions_ptr; - } - - static async fromWallet(wallet) { - return new PendingOutboundTransactions( - await WalletFFI.walletGetPendingOutboundTransactions(wallet) - ); + constructor(ptr) { + this.#tari_pending_outbound_transactions_ptr = ptr; } getLength() { - return WalletFFI.pendingOutboundTransactionsGetLength( + return InterfaceFFI.pendingOutboundTransactionsGetLength( this.#tari_pending_outbound_transactions_ptr ); } - async getAt(position) { - return new PendingOutboundTransaction( - await WalletFFI.pendingOutboundTransactionsGetAt( + getAt(position) { + let result = new PendingOutboundTransaction(); + result.pointerAssign( + InterfaceFFI.pendingOutboundTransactionsGetAt( this.#tari_pending_outbound_transactions_ptr, position ) ); + return result; } destroy() { - return WalletFFI.pendingOutboundTransactionsDestroy( - this.#tari_pending_outbound_transactions_ptr - ); + if (this.#tari_pending_outbound_transactions_ptr) { + InterfaceFFI.pendingOutboundTransactionsDestroy( + this.#tari_pending_outbound_transactions_ptr + ); + this.#tari_pending_outbound_transactions_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/privateKey.js b/integration_tests/helpers/ffi/privateKey.js new file mode 100644 index 0000000000..7115ab8a1d --- /dev/null +++ b/integration_tests/helpers/ffi/privateKey.js @@ -0,0 +1,67 @@ +const InterfaceFFI = require("./ffiInterface"); +const ByteVector = require("./byteVector"); +const utf8 = require("utf8"); + +class PrivateKey { + #tari_private_key_ptr; + + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_private_key_ptr) { + this.#tari_private_key_ptr = ptr; + this.destroy(); + } else { + this.#tari_private_key_ptr = ptr; + } + } + + generate() { + this.#tari_private_key_ptr = InterfaceFFI.privateKeyGenerate(); + } + + fromHexString(hex) { + let sanitize = utf8.encode(hex); // Make sure it's not UTF-16 encoded (JS default) + let result = new PrivateKey(); + result.pointerAssign(InterfaceFFI.privateKeyFromHex(sanitize)); + return result; + } + + fromByteVector(byte_vector) { + let result = new PrivateKey(); + result.pointerAssign(InterfaceFFI.privateKeyCreate(byte_vector.getPtr())); + return result; + } + + getPtr() { + return this.#tari_private_key_ptr; + } + + getBytes() { + let result = new ByteVector(); + result.pointerAssign( + InterfaceFFI.privateKeyGetBytes(this.#tari_private_key_ptr) + ); + return result; + } + + getHex() { + const bytes = this.getBytes(); + const length = bytes.getLength(); + let byte_array = new Uint8Array(length); + for (let i = 0; i < length; i++) { + byte_array[i] = bytes.getAt(i); + } + bytes.destroy(); + let buffer = Buffer.from(byte_array, 0); + return buffer.toString("hex"); + } + + destroy() { + if (this.#tari_private_key_ptr) { + InterfaceFFI.privateKeyDestroy(this.#tari_private_key_ptr); + this.#tari_private_key_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = PrivateKey; diff --git a/integration_tests/helpers/ffi/publicKey.js b/integration_tests/helpers/ffi/publicKey.js index 1165aa193d..7e1476c3d5 100644 --- a/integration_tests/helpers/ffi/publicKey.js +++ b/integration_tests/helpers/ffi/publicKey.js @@ -1,64 +1,82 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); const ByteVector = require("./byteVector"); const utf8 = require("utf8"); class PublicKey { #tari_public_key_ptr; - constructor(public_key) { - this.#tari_public_key_ptr = public_key; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_public_key_ptr) { + this.destroy(); + this.#tari_public_key_ptr = ptr; + } else { + this.#tari_public_key_ptr = ptr; + } } - static fromPubkey(public_key) { - return new PublicKey(public_key); + fromPrivateKey(key) { + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyFromPrivateKey(key.getPtr())); + return result; } - static async fromWallet(wallet) { - return new PublicKey(await WalletFFI.walletGetPublicKey(wallet)); + static fromHexString(hex) { + let sanitize = utf8.encode(hex); // Make sure it's not UTF-16 encoded (JS default) + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyFromHex(sanitize)); + return result; } - static async fromString(public_key_hex) { - let sanitize = utf8.encode(public_key_hex); // Make sure it's not UTF-16 encoded (JS default) - return new PublicKey(await WalletFFI.publicKeyFromHex(sanitize)); + fromEmojiID(emoji) { + let sanitize = utf8.encode(emoji); // Make sure it's not UTF-16 encoded (JS default) + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.emojiIdToPublicKey(sanitize)); + return result; } - static async fromBytes(bytes) { - return new PublicKey(await WalletFFI.publicKeyCreate(bytes)); + fromByteVector(byte_vector) { + let result = new PublicKey(); + result.pointerAssign(InterfaceFFI.publicKeyCreate(byte_vector.getPtr())); + return result; } getPtr() { return this.#tari_public_key_ptr; } - async getBytes() { - return new ByteVector( - await WalletFFI.publicKeyGetBytes(this.#tari_public_key_ptr) + getBytes() { + let result = new ByteVector(); + result.pointerAssign( + InterfaceFFI.publicKeyGetBytes(this.#tari_public_key_ptr) ); + return result; } - async getHex() { - const bytes = await this.getBytes(); - const length = await bytes.getLength(); + getHex() { + const bytes = this.getBytes(); + const length = bytes.getLength(); let byte_array = new Uint8Array(length); - for (let i = 0; i < length; ++i) { - byte_array[i] = await bytes.getAt(i); + for (let i = 0; i < length; i++) { + byte_array[i] = bytes.getAt(i); } - await bytes.destroy(); + bytes.destroy(); let buffer = Buffer.from(byte_array, 0); return buffer.toString("hex"); } - async getEmojiId() { - const emoji_id = await WalletFFI.publicKeyToEmojiId( - this.#tari_public_key_ptr - ); + getEmojiId() { + const emoji_id = InterfaceFFI.publicKeyToEmojiId(this.#tari_public_key_ptr); const result = emoji_id.readCString(); - await WalletFFI.stringDestroy(emoji_id); + InterfaceFFI.stringDestroy(emoji_id); return result; } destroy() { - return WalletFFI.publicKeyDestroy(this.#tari_public_key_ptr); + if (this.#tari_public_key_ptr) { + InterfaceFFI.publicKeyDestroy(this.#tari_public_key_ptr); + this.#tari_public_key_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/seedWords.js b/integration_tests/helpers/ffi/seedWords.js index 86c05cab48..e191bc38a9 100644 --- a/integration_tests/helpers/ffi/seedWords.js +++ b/integration_tests/helpers/ffi/seedWords.js @@ -1,45 +1,55 @@ -const WalletFFI = require("./walletFFI"); +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); class SeedWords { #tari_seed_words_ptr; - constructor(tari_seed_words_ptr) { - this.#tari_seed_words_ptr = tari_seed_words_ptr; + pointerAssign(ptr) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_seed_words_ptr) { + this.destroy(); + this.#tari_seed_words_ptr = ptr; + } else { + this.#tari_seed_words_ptr = ptr; + } } - static async fromString(seed_words_text) { - const seed_words = await WalletFFI.seedWordsCreate(); + static fromText(seed_words_text) { + const seed_words = new SeedWords(); + seed_words.pointerAssign(InterfaceFFI.seedWordsCreate()); const seed_words_list = seed_words_text.split(" "); for (const seed_word of seed_words_list) { - await WalletFFI.seedWordsPushWord(seed_words, seed_word); + InterfaceFFI.seedWordsPushWord( + seed_words.getPtr(), + utf8.encode(seed_word) + ); } - return new SeedWords(seed_words); - } - - static async fromWallet(wallet) { - return new SeedWords(await WalletFFI.walletGetSeedWords(wallet)); + return seed_words; } getLength() { - return WalletFFI.seedWordsGetLength(this.#tari_seed_words_ptr); + return InterfaceFFI.seedWordsGetLength(this.#tari_seed_words_ptr); } getPtr() { return this.#tari_seed_words_ptr; } - async getAt(position) { - const seed_word = await WalletFFI.seedWordsGetAt( + getAt(position) { + const seed_word = InterfaceFFI.seedWordsGetAt( this.#tari_seed_words_ptr, position ); const result = seed_word.readCString(); - await WalletFFI.stringDestroy(seed_word); + InterfaceFFI.stringDestroy(seed_word); return result; } destroy() { - return WalletFFI.seedWordsDestroy(this.#tari_seed_words_ptr); + if (this.#tari_seed_words_ptr) { + InterfaceFFI.seedWordsDestroy(this.#tari_seed_words_ptr); + this.#tari_seed_words_ptr = undefined; //prevent double free segfault + } } } diff --git a/integration_tests/helpers/ffi/transportType.js b/integration_tests/helpers/ffi/transportType.js new file mode 100644 index 0000000000..0826c423b5 --- /dev/null +++ b/integration_tests/helpers/ffi/transportType.js @@ -0,0 +1,85 @@ +const InterfaceFFI = require("./ffiInterface"); +const utf8 = require("utf8"); + +class TransportType { + #tari_transport_type_ptr; + #type = "None"; + + pointerAssign(ptr, type) { + // Prevent pointer from being leaked in case of re-assignment + if (this.#tari_transport_type_ptr) { + this.destroy(); + this.#tari_transport_type_ptr = ptr; + this.#type = type; + } else { + this.#tari_transport_type_ptr = ptr; + this.#type = type; + } + } + + getPtr() { + return this.#tari_transport_type_ptr; + } + + getType() { + return this.#type; + } + + static createMemory() { + let result = new TransportType(); + result.pointerAssign(InterfaceFFI.transportMemoryCreate(), "Memory"); + return result; + } + + static createTCP(listener_address) { + let sanitize = utf8.encode(listener_address); // Make sure it's not UTF-16 encoded (JS default) + let result = new TransportType(); + result.pointerAssign(InterfaceFFI.transportTcpCreate(sanitize), "TCP"); + return result; + } + + static createTor( + control_server_address, + tor_cookie, + tor_port, + socks_username, + socks_password + ) { + let sanitize_address = utf8.encode(control_server_address); + let sanitize_username = utf8.encode(socks_username); + let sanitize_password = utf8.encode(socks_password); + let result = new TransportType(); + result.pointerAssign( + InterfaceFFI.transportTorCreate( + sanitize_address, + tor_cookie.getPtr(), + tor_port, + sanitize_username, + sanitize_password + ), + "Tor" + ); + return result; + } + + getAddress() { + if (this.#type === "Memory") { + let c_address = InterfaceFFI.transportMemoryGetAddress(this.getPtr()); + let result = c_address.readCString(); + InterfaceFFI.stringDestroy(c_address); + return result; + } else { + return "N/A"; + } + } + + destroy() { + this.#type = "None"; + if (this.#tari_transport_type_ptr) { + InterfaceFFI.transportTypeDestroy(this.#tari_transport_type_ptr); + this.#tari_transport_type_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = TransportType; diff --git a/integration_tests/helpers/ffi/wallet.js b/integration_tests/helpers/ffi/wallet.js new file mode 100644 index 0000000000..fea21fe682 --- /dev/null +++ b/integration_tests/helpers/ffi/wallet.js @@ -0,0 +1,449 @@ +const InterfaceFFI = require("./ffiInterface"); +const PublicKey = require("./publicKey"); +const CompletedTransaction = require("./completedTransaction"); +const CompletedTransactions = require("./completedTransactions"); +const PendingInboundTransaction = require("./pendingInboundTransaction"); +const PendingInboundTransactions = require("./pendingInboundTransactions"); +const PendingOutboundTransactions = require("./pendingOutboundTransactions"); +const Contact = require("./contact"); +const Contacts = require("./contacts"); + +const utf8 = require("utf8"); + +class Wallet { + #wallet_ptr; + #log_path = ""; + receivedTransaction = 0; + receivedTransactionReply = 0; + transactionBroadcast = 0; + transactionMined = 0; + saf_messages = 0; + + utxo_validation_complete = false; + utxo_validation_result = 0; + stxo_validation_complete = false; + stxo_validation_result = 0; + + getUtxoValidationStatus() { + return { + utxo_validation_complete: this.utxo_validation_complete, + utxo_validation_result: this.utxo_validation_result, + }; + } + + getStxoValidationStatus() { + return { + stxo_validation_complete: this.stxo_validation_complete, + stxo_validation_result: this.stxo_validation_result, + }; + } + + clearCallbackCounters() { + this.receivedTransaction = + this.receivedTransactionReply = + this.transactionBroadcast = + this.transactionMined = + this.saf_messages = + this.cancelled = + this.minedunconfirmed = + this.finalized = + 0; + } + + getCounters() { + return { + received: this.receivedTransaction, + replyreceived: this.receivedTransactionReply, + broadcast: this.transactionBroadcast, + finalized: this.finalized, + minedunconfirmed: this.minedunconfirmed, + cancelled: this.cancelled, + mined: this.transactionMined, + saf: this.saf_messages, + }; + } + + constructor( + comms_config_ptr, + log_path, + passphrase, + seed_words_ptr, + num_rolling_log_file = 50, + log_size_bytes = 102400 + ) { + this.receivedTransaction = 0; + this.receivedTransactionReply = 0; + this.transactionBroadcast = 0; + this.transactionMined = 0; + this.saf_messages = 0; + this.cancelled = 0; + this.minedunconfirmed = 0; + this.finalized = 0; + this.recoveryFinished = true; + let sanitize = null; + let words = null; + if (passphrase) { + sanitize = utf8.encode(passphrase); + } + if (seed_words_ptr) { + words = seed_words_ptr; + } + this.#log_path = log_path; + this.#wallet_ptr = InterfaceFFI.walletCreate( + comms_config_ptr, + utf8.encode(this.#log_path), //`${this.baseDir}/log/wallet.log`, + num_rolling_log_file, + log_size_bytes, + sanitize, + words, + this.#callback_received_transaction, + this.#callback_received_transaction_reply, + this.#callback_received_finalized_transaction, + this.#callback_transaction_broadcast, + this.#callback_transaction_mined, + this.#callback_transaction_mined_unconfirmed, + this.#callback_direct_send_result, + this.#callback_store_and_forward_send_result, + this.#callback_transaction_cancellation, + this.#callback_utxo_validation_complete, + this.#callback_stxo_validation_complete, + this.#callback_invalid_txo_validation_complete, + this.#callback_transaction_validation_complete, + this.#callback_saf_message_received + ); + } + + //region Callbacks + #onReceivedTransaction = (ptr) => { + // refer to outer scope in callback function otherwise this is null + let tx = new PendingInboundTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received Transaction with txID ${tx.getTransactionID()}` + ); + tx.destroy(); + this.receivedTransaction += 1; + }; + + #onReceivedTransactionReply = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received reply for Transaction with txID ${tx.getTransactionID()}.` + ); + tx.destroy(); + this.receivedTransactionReply += 1; + }; + + #onReceivedFinalizedTransaction = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} received finalization for Transaction with txID ${tx.getTransactionID()}.` + ); + tx.destroy(); + this.finalized += 1; + }; + + #onTransactionBroadcast = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was broadcast.` + ); + tx.destroy(); + this.transactionBroadcast += 1; + }; + + #onTransactionMined = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was mined.` + ); + tx.destroy(); + this.transactionMined += 1; + }; + + #onTransactionMinedUnconfirmed = (ptr, confirmations) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} is mined unconfirmed with ${confirmations} confirmations.` + ); + tx.destroy(); + this.minedunconfirmed += 1; + }; + + #onTransactionCancellation = (ptr) => { + let tx = new CompletedTransaction(); + tx.pointerAssign(ptr); + console.log( + `${new Date().toISOString()} Transaction with txID ${tx.getTransactionID()} was cancelled` + ); + tx.destroy(); + this.cancelled += 1; + }; + + #onDirectSendResult = (id, success) => { + console.log( + `${new Date().toISOString()} callbackDirectSendResult(${id},${success})` + ); + }; + + #onStoreAndForwardSendResult = (id, success) => { + console.log( + `${new Date().toISOString()} callbackStoreAndForwardSendResult(${id},${success})` + ); + }; + + #onUtxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackUtxoValidationComplete(${request_key},${validation_results})` + ); + this.utxo_validation_complete = true; + this.utxo_validation_result = validation_results; + }; + + #onStxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackStxoValidationComplete(${request_key},${validation_results})` + ); + this.stxo_validation_complete = true; + this.stxo_validation_result = validation_results; + }; + + #onInvalidTxoValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackInvalidTxoValidationComplete(${request_key},${validation_results})` + ); + //this.invalidtxo_validation_complete = true; + //this.invalidtxo_validation_result = validation_results; + }; + + #onTransactionValidationComplete = (request_key, validation_results) => { + console.log( + `${new Date().toISOString()} callbackTransactionValidationComplete(${request_key},${validation_results})` + ); + //this.transaction_validation_complete = true; + //this.transaction_validation_result = validation_results; + }; + + #onSafMessageReceived = () => { + console.log(`${new Date().toISOString()} callbackSafMessageReceived()`); + this.saf_messages += 1; + }; + + #onRecoveryProgress = (a, b, c) => { + console.log( + `${new Date().toISOString()} recoveryProgressCallback(${a},${b},${c})` + ); + if (a === 4) { + console.log(`Recovery completed, funds recovered: ${c} uT`); + } + }; + + #callback_received_transaction = + InterfaceFFI.createCallbackReceivedTransaction(this.#onReceivedTransaction); + #callback_received_transaction_reply = + InterfaceFFI.createCallbackReceivedTransactionReply( + this.#onReceivedTransactionReply + ); + #callback_received_finalized_transaction = + InterfaceFFI.createCallbackReceivedFinalizedTransaction( + this.#onReceivedFinalizedTransaction + ); + #callback_transaction_broadcast = + InterfaceFFI.createCallbackTransactionBroadcast( + this.#onTransactionBroadcast + ); + #callback_transaction_mined = InterfaceFFI.createCallbackTransactionMined( + this.#onTransactionMined + ); + #callback_transaction_mined_unconfirmed = + InterfaceFFI.createCallbackTransactionMinedUnconfirmed( + this.#onTransactionMinedUnconfirmed + ); + #callback_direct_send_result = InterfaceFFI.createCallbackDirectSendResult( + this.#onDirectSendResult + ); + #callback_store_and_forward_send_result = + InterfaceFFI.createCallbackStoreAndForwardSendResult( + this.#onStoreAndForwardSendResult + ); + #callback_transaction_cancellation = + InterfaceFFI.createCallbackTransactionCancellation( + this.#onTransactionCancellation + ); + #callback_utxo_validation_complete = + InterfaceFFI.createCallbackUtxoValidationComplete( + this.#onUtxoValidationComplete + ); + #callback_stxo_validation_complete = + InterfaceFFI.createCallbackStxoValidationComplete( + this.#onStxoValidationComplete + ); + #callback_invalid_txo_validation_complete = + InterfaceFFI.createCallbackInvalidTxoValidationComplete( + this.#onInvalidTxoValidationComplete + ); + #callback_transaction_validation_complete = + InterfaceFFI.createCallbackTransactionValidationComplete( + this.#onTransactionValidationComplete + ); + #callback_saf_message_received = + InterfaceFFI.createCallbackSafMessageReceived(this.#onSafMessageReceived); + #recoveryProgressCallback = InterfaceFFI.createRecoveryProgressCallback( + this.#onRecoveryProgress + ); + //endregion + + startRecovery(base_node_pubkey) { + let node_pubkey = PublicKey.fromHexString(utf8.encode(base_node_pubkey)); + InterfaceFFI.walletStartRecovery( + this.#wallet_ptr, + node_pubkey.getPtr(), + this.#recoveryProgressCallback + ); + node_pubkey.destroy(); + } + + recoveryInProgress() { + return InterfaceFFI.walletIsRecoveryInProgress(this.#wallet_ptr); + } + + getPublicKey() { + let ptr = InterfaceFFI.walletGetPublicKey(this.#wallet_ptr); + let pk = new PublicKey(); + pk.pointerAssign(ptr); + let result = pk.getHex(); + pk.destroy(); + return result; + } + + getEmojiId() { + let ptr = InterfaceFFI.walletGetPublicKey(this.#wallet_ptr); + let pk = new PublicKey(); + pk.pointerAssign(ptr); + let result = pk.getEmojiId(); + pk.destroy(); + return result; + } + + getBalance() { + let available = InterfaceFFI.walletGetAvailableBalance(this.#wallet_ptr); + let pendingIncoming = InterfaceFFI.walletGetPendingIncomingBalance( + this.#wallet_ptr + ); + let pendingOutgoing = InterfaceFFI.walletGetPendingOutgoingBalance( + this.#wallet_ptr + ); + return { + pendingIn: pendingIncoming, + pendingOut: pendingOutgoing, + available: available, + }; + } + + addBaseNodePeer(public_key_hex, address) { + let public_key = PublicKey.fromHexString(utf8.encode(public_key_hex)); + let result = InterfaceFFI.walletAddBaseNodePeer( + this.#wallet_ptr, + public_key.getPtr(), + utf8.encode(address) + ); + public_key.destroy(); + return result; + } + + sendTransaction(destination, amount, fee_per_gram, message) { + let dest_public_key = PublicKey.fromHexString(utf8.encode(destination)); + let result = InterfaceFFI.walletSendTransaction( + this.#wallet_ptr, + dest_public_key.getPtr(), + amount, + fee_per_gram, + utf8.encode(message) + ); + dest_public_key.destroy(); + return result; + } + + applyEncryption(passphrase) { + InterfaceFFI.walletApplyEncryption( + this.#wallet_ptr, + utf8.encode(passphrase) + ); + } + + getCompletedTransactions() { + let list_ptr = InterfaceFFI.walletGetCompletedTransactions( + this.#wallet_ptr + ); + return new CompletedTransactions(list_ptr); + } + + getInboundTransactions() { + let list_ptr = InterfaceFFI.walletGetPendingInboundTransactions( + this.#wallet_ptr + ); + return new PendingInboundTransactions(list_ptr); + } + + getOutboundTransactions() { + let list_ptr = InterfaceFFI.walletGetPendingOutboundTransactions( + this.#wallet_ptr + ); + return new PendingOutboundTransactions(list_ptr); + } + + getContacts() { + let list_ptr = InterfaceFFI.walletGetContacts(this.#wallet_ptr); + return new Contacts(list_ptr); + } + + addContact(alias, pubkey_hex) { + let public_key = PublicKey.fromHexString(utf8.encode(pubkey_hex)); + let contact = new Contact(); + contact.pointerAssign( + InterfaceFFI.contactCreate(utf8.encode(alias), public_key.getPtr()) + ); + let result = InterfaceFFI.walletUpsertContact( + this.#wallet_ptr, + contact.getPtr() + ); + contact.destroy(); + public_key.destroy(); + return result; + } + + removeContact(contact) { + let result = InterfaceFFI.walletRemoveContact( + this.#wallet_ptr, + contact.getPtr() + ); + contact.destroy(); + return result; + } + + cancelPendingTransaction(tx_id) { + return InterfaceFFI.walletCancelPendingTransaction(this.#wallet_ptr, tx_id); + } + + startUtxoValidation() { + return InterfaceFFI.walletStartUtxoValidation(this.#wallet_ptr); + } + + startStxoValidation() { + return InterfaceFFI.walletStartStxoValidation(this.#wallet_ptr); + } + + destroy() { + if (this.#wallet_ptr) { + InterfaceFFI.walletDestroy(this.#wallet_ptr); + this.#wallet_ptr = undefined; //prevent double free segfault + } + } +} + +module.exports = Wallet; diff --git a/integration_tests/helpers/walletFFIClient.js b/integration_tests/helpers/walletFFIClient.js index 60596c8e34..8835f03dee 100644 --- a/integration_tests/helpers/walletFFIClient.js +++ b/integration_tests/helpers/walletFFIClient.js @@ -1,440 +1,187 @@ -const WalletFFI = require("./ffi/walletFFI"); +const SeedWords = require("./ffi/seedWords"); +const TransportType = require("./ffi/transportType"); +const CommsConfig = require("./ffi/commsConfig"); +const Wallet = require("./ffi/wallet"); const { getFreePort } = require("./util"); const dateFormat = require("dateformat"); -const { expect } = require("chai"); -const PublicKey = require("./ffi/publicKey"); -const CompletedTransactions = require("./ffi/completedTransactions"); -const PendingOutboundTransactions = require("./ffi/pendingOutboundTransactions"); -const Contact = require("./ffi/contact"); -const Contacts = require("./ffi/contacts"); -const SeedWords = require("./ffi/seedWords"); +const InterfaceFFI = require("./ffi/ffiInterface"); class WalletFFIClient { #name; #wallet; #comms_config; + #transport; + #seed_words; + #pass_phrase; #port; - #callback_received_transaction; - #callback_received_transaction_reply; - #callback_received_finalized_transaction; - #callback_transaction_broadcast; - #callback_transaction_mined; - #callback_transaction_mined_unconfirmed; - #callback_direct_send_result; - #callback_store_and_forward_send_result; - #callback_transaction_cancellation; - #callback_utxo_validation_complete; - #callback_stxo_validation_complete; - #callback_invalid_txo_validation_complete; - #callback_transaction_validation_complete; - #callback_saf_message_received; - #recovery_progress_callback; - - #callbackReceivedTransaction = (..._args) => { - console.log(`${new Date().toISOString()} callbackReceivedTransaction`); - this.receivedTransaction += 1; - }; - #callbackReceivedTransactionReply = (..._args) => { - console.log(`${new Date().toISOString()} callbackReceivedTransactionReply`); - this.receivedTransactionReply += 1; - }; - #callbackReceivedFinalizedTransaction = (..._args) => { - console.log( - `${new Date().toISOString()} callbackReceivedFinalizedTransaction` - ); - }; - #callbackTransactionBroadcast = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionBroadcast`); - this.transactionBroadcast += 1; - }; - #callbackTransactionMined = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionMined`); - this.transactionMined += 1; - }; - #callbackTransactionMinedUnconfirmed = (..._args) => { - console.log( - `${new Date().toISOString()} callbackTransactionMinedUnconfirmed` - ); - }; - #callbackDirectSendResult = (..._args) => { - console.log(`${new Date().toISOString()} callbackDirectSendResult`); - }; - #callbackStoreAndForwardSendResult = (..._args) => { - console.log( - `${new Date().toISOString()} callbackStoreAndForwardSendResult` - ); - }; - #callbackTransactionCancellation = (..._args) => { - console.log(`${new Date().toISOString()} callbackTransactionCancellation`); - }; - #callbackUtxoValidationComplete = (_request_key, validation_results) => { - console.log(`${new Date().toISOString()} callbackUtxoValidationComplete`); - this.utxo_validation_complete = true; - this.utxo_validation_result = validation_results; - }; - #callbackStxoValidationComplete = (_request_key, validation_results) => { - console.log(`${new Date().toISOString()} callbackStxoValidationComplete`); - this.stxo_validation_complete = true; - this.stxo_validation_result = validation_results; - }; - #callbackInvalidTxoValidationComplete = (..._args) => { - console.log( - `${new Date().toISOString()} callbackInvalidTxoValidationComplete` - ); - }; - #callbackTransactionValidationComplete = (..._args) => { - console.log( - `${new Date().toISOString()} callbackTransactionValidationComplete` - ); - }; - #callbackSafMessageReceived = (..._args) => { - console.log(`${new Date().toISOString()} callbackSafMessageReceived`); - }; - #recoveryProgressCallback = (a, b, c) => { - console.log(`${new Date().toISOString()} recoveryProgressCallback`); - if (a == 3) - // Progress - this.recoveryProgress = [b, c]; - if (a == 4) - // Completed - this.recoveryInProgress = false; - }; - - clearCallbackCounters() { - this.receivedTransaction = - this.receivedTransactionReply = - this.transactionBroadcast = - this.transactionMined = - 0; - } + baseDir = ""; constructor(name) { - this.#wallet = null; this.#name = name; - this.baseDir = ""; - this.clearCallbackCounters(); - - // Create the ffi callbacks - this.#callback_received_transaction = - WalletFFI.createCallbackReceivedTransaction( - this.#callbackReceivedTransaction - ); - this.#callback_received_transaction_reply = - WalletFFI.createCallbackReceivedTransactionReply( - this.#callbackReceivedTransactionReply - ); - this.#callback_received_finalized_transaction = - WalletFFI.createCallbackReceivedFinalizedTransaction( - this.#callbackReceivedFinalizedTransaction - ); - this.#callback_transaction_broadcast = - WalletFFI.createCallbackTransactionBroadcast( - this.#callbackTransactionBroadcast - ); - this.#callback_transaction_mined = WalletFFI.createCallbackTransactionMined( - this.#callbackTransactionMined - ); - this.#callback_transaction_mined_unconfirmed = - WalletFFI.createCallbackTransactionMinedUnconfirmed( - this.#callbackTransactionMinedUnconfirmed - ); - this.#callback_direct_send_result = - WalletFFI.createCallbackDirectSendResult(this.#callbackDirectSendResult); - this.#callback_store_and_forward_send_result = - WalletFFI.createCallbackStoreAndForwardSendResult( - this.#callbackStoreAndForwardSendResult - ); - this.#callback_transaction_cancellation = - WalletFFI.createCallbackTransactionCancellation( - this.#callbackTransactionCancellation - ); - this.#callback_utxo_validation_complete = - WalletFFI.createCallbackUtxoValidationComplete( - this.#callbackUtxoValidationComplete - ); - this.#callback_stxo_validation_complete = - WalletFFI.createCallbackStxoValidationComplete( - this.#callbackStxoValidationComplete - ); - this.#callback_invalid_txo_validation_complete = - WalletFFI.createCallbackInvalidTxoValidationComplete( - this.#callbackInvalidTxoValidationComplete - ); - this.#callback_transaction_validation_complete = - WalletFFI.createCallbackTransactionValidationComplete( - this.#callbackTransactionValidationComplete - ); - this.#callback_saf_message_received = - WalletFFI.createCallbackSafMessageReceived( - this.#callbackSafMessageReceived - ); - this.#recovery_progress_callback = WalletFFI.createRecoveryProgressCallback( - this.#recoveryProgressCallback - ); } static async Init() { - await WalletFFI.Init(); + await InterfaceFFI.Init(); } - async startNew(seed_words_text) { + async startNew(seed_words_text, pass_phrase) { this.#port = await getFreePort(19000, 25000); const name = `WalletFFI${this.#port}-${this.#name}`; this.baseDir = `./temp/base_nodes/${dateFormat( new Date(), "yyyymmddHHMM" )}/${name}`; - const tcp = await WalletFFI.transportTcpCreate( - `/ip4/0.0.0.0/tcp/${this.#port}` - ); - this.#comms_config = await WalletFFI.commsConfigCreate( + this.#transport = TransportType.createTCP(`/ip4/0.0.0.0/tcp/${this.#port}`); + this.#comms_config = new CommsConfig( `/ip4/0.0.0.0/tcp/${this.#port}`, - tcp, + this.#transport.getPtr(), "wallet.dat", this.baseDir, 30, 600, "localnet" ); - await this.start(seed_words_text); + this.#start(seed_words_text, pass_phrase); } - async start(seed_words_text) { - let seed_words; - let seed_words_ptr = WalletFFI.NULL; - if (seed_words_text) { - seed_words = await SeedWords.fromString(seed_words_text); - seed_words_ptr = seed_words.getPtr(); - } - this.#wallet = await WalletFFI.walletCreate( - this.#comms_config, - `${this.baseDir}/log/wallet.log`, - 50, - 102400, - WalletFFI.NULL, - seed_words_ptr, - this.#callback_received_transaction, - this.#callback_received_transaction_reply, - this.#callback_received_finalized_transaction, - this.#callback_transaction_broadcast, - this.#callback_transaction_mined, - this.#callback_transaction_mined_unconfirmed, - this.#callback_direct_send_result, - this.#callback_store_and_forward_send_result, - this.#callback_transaction_cancellation, - this.#callback_utxo_validation_complete, - this.#callback_stxo_validation_complete, - this.#callback_invalid_txo_validation_complete, - this.#callback_transaction_validation_complete, - this.#callback_saf_message_received + async restart(seed_words_text, pass_phrase) { + this.#transport = TransportType.createTCP(`/ip4/0.0.0.0/tcp/${this.#port}`); + this.#comms_config = new CommsConfig( + `/ip4/0.0.0.0/tcp/${this.#port}`, + this.#transport.getPtr(), + "wallet.dat", + this.baseDir, + 30, + 600, + "localnet" ); - if (seed_words) await seed_words.destroy(); + this.#start(seed_words_text, pass_phrase); } - async startRecovery(base_node_pubkey) { - const node_pubkey = await PublicKey.fromString(base_node_pubkey); - expect( - await WalletFFI.walletStartRecovery( - this.#wallet, - node_pubkey.getPtr(), - this.#recovery_progress_callback - ) - ).to.be.true; - node_pubkey.destroy(); - this.recoveryInProgress = true; + getStxoValidationStatus() { + return this.#wallet.getStxoValidationStatus(); } - recoveryInProgress() { - return this.recoveryInProgress; + getUtxoValidationStatus() { + return this.#wallet.getUtxoValidationStatus(); + } + identify() { + return this.#wallet.getPublicKey(); } - async stop() { - await WalletFFI.walletDestroy(this.#wallet); + identifyEmoji() { + return this.#wallet.getEmojiId(); } - async getPublicKey() { - const public_key = await PublicKey.fromWallet(this.#wallet); - const public_key_hex = public_key.getHex(); - public_key.destroy(); - return public_key_hex; + getBalance() { + return this.#wallet.getBalance(); } - async getEmojiId() { - const public_key = await PublicKey.fromWallet(this.#wallet); - const emoji_id = await public_key.getEmojiId(); - public_key.destroy(); - return emoji_id; + addBaseNodePeer(public_key_hex, address) { + return this.#wallet.addBaseNodePeer(public_key_hex, address); } - async getBalance() { - return await WalletFFI.walletGetAvailableBalance(this.#wallet); + addContact(alias, pubkey_hex) { + return this.#wallet.addContact(alias, pubkey_hex); } - async addBaseNodePeer(public_key_hex, address) { - const public_key = await PublicKey.fromString(public_key_hex); - expect( - await WalletFFI.walletAddBaseNodePeer( - this.#wallet, - public_key.getPtr(), - address - ) - ).to.be.true; - await public_key.destroy(); + getContactList() { + return this.#wallet.getContacts(); } - async sendTransaction(destination, amount, fee_per_gram, message) { - const dest_public_key = await PublicKey.fromString(destination); - const result = await WalletFFI.walletSendTransaction( - this.#wallet, - dest_public_key.getPtr(), - amount, - fee_per_gram, - message - ); - await dest_public_key.destroy(); - return result; + getCompletedTxs() { + return this.#wallet.getCompletedTransactions(); } - async applyEncryption(passphrase) { - await WalletFFI.walletApplyEncryption(this.#wallet, passphrase); + getInboundTxs() { + return this.#wallet.getInboundTransactions(); } - async getCompletedTransactions() { - const txs = await CompletedTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - let outbound = 0; - let inbound = 0; - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - if (await tx.isOutbound()) { - ++outbound; - } else { - ++inbound; - } - tx.destroy(); - } - txs.destroy(); - return [outbound, inbound]; + getOutboundTxs() { + return this.#wallet.getOutboundTransactions(); } - async getBroadcastTransactionsCount() { - let broadcast_tx_cnt = 0; - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - const status = await tx.getStatus(); - tx.destroy(); - if (status === 1) { - // Broadcast - broadcast_tx_cnt += 1; - } - } - await txs.destroy(); - return broadcast_tx_cnt; + removeContact(contact) { + return this.#wallet.removeContact(contact); } - async getOutboundTransactionsCount() { - let outbound_tx_cnt = 0; - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - const status = await tx.getStatus(); - if (status === 4) { - // Pending - outbound_tx_cnt += 1; - } - tx.destroy(); - } - await txs.destroy(); - return outbound_tx_cnt; + startRecovery(base_node_pubkey) { + this.#wallet.startRecovery(base_node_pubkey); } - async addContact(alias, pubkey_hex) { - const public_key = await PublicKey.fromString(pubkey_hex); - const contact = new Contact( - await WalletFFI.contactCreate(alias, public_key.getPtr()) - ); - public_key.destroy(); - expect(await WalletFFI.walletUpsertContact(this.#wallet, contact.getPtr())) - .to.be.true; - contact.destroy(); + checkRecoveryInProgress() { + return this.#wallet.recoveryInProgress(); } - async #findContact(lookup_alias) { - const contacts = await Contacts.fromWallet(this.#wallet); - const length = await contacts.getLength(); - let contact; - for (let i = 0; i < length; ++i) { - contact = await contacts.getAt(i); - const alias = await contact.getAlias(); - const found = alias === lookup_alias; - if (found) { - break; - } - contact.destroy(); - contact = undefined; - } - contacts.destroy(); - return contact; + applyEncryption(passphrase) { + this.#wallet.applyEncryption(passphrase); } - async getContact(alias) { - const contact = await this.#findContact(alias); - if (contact) { - const pubkey = await contact.getPubkey(); - const pubkey_hex = pubkey.getHex(); - pubkey.destroy(); - contact.destroy(); - return pubkey_hex; - } + startStxoValidation() { + this.#wallet.startStxoValidation(); } - async removeContact(alias) { - const contact = await this.#findContact(alias); - if (contact) { - expect( - await WalletFFI.walletRemoveContact(this.#wallet, contact.getPtr()) - ).to.be.true; - contact.destroy(); - } + startUtxoValidation() { + this.#wallet.startUtxoValidation(); + } + + getCounters() { + return this.#wallet.getCounters(); + } + resetCounters() { + this.#wallet.clearCallbackCounters(); } - async identify() { - return { - public_key: await this.getPublicKey(), - }; + sendTransaction(destination, amount, fee_per_gram, message) { + return this.#wallet.sendTransaction( + destination, + amount, + fee_per_gram, + message + ); } - async cancelAllOutboundTransactions() { - const txs = await PendingOutboundTransactions.fromWallet(this.#wallet); - const length = await txs.getLength(); - let cancelled = 0; - for (let i = 0; i < length; ++i) { - const tx = await txs.getAt(i); - if ( - await WalletFFI.walletCancelPendingTransaction( - this.#wallet, - await tx.getTransactionId() - ) - ) { - ++cancelled; - } - tx.destroy(); + #start( + seed_words_text, + pass_phrase, + rolling_log_files = 50, + byte_size_per_log = 102400 + ) { + this.#pass_phrase = pass_phrase; + if (seed_words_text) { + let seed_words = SeedWords.fromText(seed_words_text); + this.#seed_words = seed_words; } - txs.destroy(); - return cancelled; + + let log_path = `${this.baseDir}/log/wallet.log`; + this.#wallet = new Wallet( + this.#comms_config.getPtr(), + log_path, + this.#pass_phrase, + this.#seed_words ? this.#seed_words.getPtr() : null, + rolling_log_files, + byte_size_per_log + ); } - startUtxoValidation() { - this.utxo_validation_complete = false; - return WalletFFI.walletStartUtxoValidation(this.#wallet); + getOutboundTransactions() { + return this.#wallet.getOutboundTransactions(); } - startStxoValidation() { - this.stxo_validation_complete = false; - return WalletFFI.walletStartStxoValidation(this.#wallet); + cancelPendingTransaction(tx_id) { + return this.#wallet.cancelPendingTransaction(tx_id); + } + + stop() { + if (this.#wallet) { + this.#wallet.destroy(); + } + if (this.#comms_config) { + this.#comms_config.destroy(); + } + if (this.#seed_words) { + this.#seed_words.destroy(); + } } }