From e1c7cf8bfed946f69839711f6c659ba49857aad3 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 20 Jul 2022 19:33:47 -0700 Subject: [PATCH] feat: implement testing utilities --- .github/workflows/sqlx.yml | 17 +- Cargo.lock | 444 +++++++++++++++++- Cargo.toml | 16 + .../axum-social-with-tests/Cargo.toml | 33 ++ .../postgres/axum-social-with-tests/README.md | 10 + .../migrations/1_user.sql | 6 + .../migrations/2_post.sql | 8 + .../migrations/3_comment.sql | 9 + .../axum-social-with-tests/src/http/error.rs | 75 +++ .../axum-social-with-tests/src/http/mod.rs | 26 + .../src/http/post/comment.rs | 100 ++++ .../src/http/post/mod.rs | 93 ++++ .../axum-social-with-tests/src/http/user.rs | 95 ++++ .../axum-social-with-tests/src/lib.rs | 3 + .../axum-social-with-tests/src/main.rs | 19 + .../axum-social-with-tests/src/password.rs | 34 ++ .../axum-social-with-tests/tests/comment.rs | 152 ++++++ .../axum-social-with-tests/tests/common.rs | 72 +++ .../tests/fixtures/comments.sql | 12 + .../tests/fixtures/posts.sql | 8 + .../tests/fixtures/users.sql | 10 + .../axum-social-with-tests/tests/post.rs | 120 +++++ .../axum-social-with-tests/tests/user.rs | 89 ++++ sqlx-bench/test.db | Bin 794624 -> 794624 bytes sqlx-core/Cargo.toml | 2 + sqlx-core/src/lib.rs | 6 + sqlx-core/src/migrate/migrator.rs | 8 + sqlx-core/src/mysql/connection/establish.rs | 28 +- sqlx-core/src/mysql/mod.rs | 3 + sqlx-core/src/mysql/testing/mod.rs | 223 +++++++++ sqlx-core/src/net/socket.rs | 9 + sqlx-core/src/pool/inner.rs | 143 ++++-- sqlx-core/src/pool/options.rs | 16 + sqlx-core/src/postgres/mod.rs | 3 + sqlx-core/src/postgres/testing/mod.rs | 204 ++++++++ sqlx-core/src/sqlite/mod.rs | 3 + sqlx-core/src/sqlite/testing/mod.rs | 81 ++++ sqlx-core/src/testing/fixtures.rs | 280 +++++++++++ sqlx-core/src/testing/mod.rs | 262 +++++++++++ sqlx-macros/src/common.rs | 4 +- sqlx-macros/src/lib.rs | 47 +- sqlx-macros/src/migrate.rs | 17 +- sqlx-macros/src/test_attr.rs | 217 +++++++++ sqlx-rt/src/lib.rs | 120 +---- sqlx-rt/src/rt_async_std.rs | 24 + sqlx-rt/src/rt_tokio.rs | 47 ++ src/lib.rs | 12 + src/{macros.rs => macros/mod.rs} | 0 src/macros/test.md | 218 +++++++++ tests/docker-compose.yml | 47 +- tests/mysql/fixtures/comments.sql | 16 + tests/mysql/fixtures/posts.sql | 9 + tests/mysql/fixtures/users.sql | 2 + tests/mysql/migrations/1_user.sql | 7 + tests/mysql/migrations/2_post.sql | 10 + tests/mysql/migrations/3_comment.sql | 10 + tests/mysql/test-attr.rs | 96 ++++ tests/postgres/fixtures/comments.sql | 16 + tests/postgres/fixtures/posts.sql | 14 + tests/postgres/fixtures/users.sql | 2 + tests/postgres/migrations/0_setup.sql | 2 + tests/postgres/migrations/1_user.sql | 5 + tests/postgres/migrations/2_post.sql | 8 + tests/postgres/migrations/3_comment.sql | 9 + tests/postgres/test-attr.rs | 94 ++++ tests/postgres/types.rs | 4 +- tests/sqlite/fixtures/comments.sql | 16 + tests/sqlite/fixtures/posts.sql | 9 + tests/sqlite/fixtures/users.sql | 2 + tests/sqlite/migrations/1_user.sql | 6 + tests/sqlite/migrations/2_post.sql | 10 + tests/sqlite/migrations/3_comment.sql | 10 + tests/sqlite/test-attr.rs | 99 ++++ tests/x.py | 6 +- 74 files changed, 3701 insertions(+), 236 deletions(-) create mode 100644 examples/postgres/axum-social-with-tests/Cargo.toml create mode 100644 examples/postgres/axum-social-with-tests/README.md create mode 100644 examples/postgres/axum-social-with-tests/migrations/1_user.sql create mode 100644 examples/postgres/axum-social-with-tests/migrations/2_post.sql create mode 100644 examples/postgres/axum-social-with-tests/migrations/3_comment.sql create mode 100644 examples/postgres/axum-social-with-tests/src/http/error.rs create mode 100644 examples/postgres/axum-social-with-tests/src/http/mod.rs create mode 100644 examples/postgres/axum-social-with-tests/src/http/post/comment.rs create mode 100644 examples/postgres/axum-social-with-tests/src/http/post/mod.rs create mode 100644 examples/postgres/axum-social-with-tests/src/http/user.rs create mode 100644 examples/postgres/axum-social-with-tests/src/lib.rs create mode 100644 examples/postgres/axum-social-with-tests/src/main.rs create mode 100644 examples/postgres/axum-social-with-tests/src/password.rs create mode 100644 examples/postgres/axum-social-with-tests/tests/comment.rs create mode 100644 examples/postgres/axum-social-with-tests/tests/common.rs create mode 100644 examples/postgres/axum-social-with-tests/tests/fixtures/comments.sql create mode 100644 examples/postgres/axum-social-with-tests/tests/fixtures/posts.sql create mode 100644 examples/postgres/axum-social-with-tests/tests/fixtures/users.sql create mode 100644 examples/postgres/axum-social-with-tests/tests/post.rs create mode 100644 examples/postgres/axum-social-with-tests/tests/user.rs create mode 100644 sqlx-core/src/mysql/testing/mod.rs create mode 100644 sqlx-core/src/postgres/testing/mod.rs create mode 100644 sqlx-core/src/sqlite/testing/mod.rs create mode 100644 sqlx-core/src/testing/fixtures.rs create mode 100644 sqlx-core/src/testing/mod.rs create mode 100644 sqlx-macros/src/test_attr.rs create mode 100644 sqlx-rt/src/rt_async_std.rs create mode 100644 sqlx-rt/src/rt_tokio.rs rename src/{macros.rs => macros/mod.rs} (100%) create mode 100644 src/macros/test.md create mode 100644 tests/mysql/fixtures/comments.sql create mode 100644 tests/mysql/fixtures/posts.sql create mode 100644 tests/mysql/fixtures/users.sql create mode 100644 tests/mysql/migrations/1_user.sql create mode 100644 tests/mysql/migrations/2_post.sql create mode 100644 tests/mysql/migrations/3_comment.sql create mode 100644 tests/mysql/test-attr.rs create mode 100644 tests/postgres/fixtures/comments.sql create mode 100644 tests/postgres/fixtures/posts.sql create mode 100644 tests/postgres/fixtures/users.sql create mode 100644 tests/postgres/migrations/0_setup.sql create mode 100644 tests/postgres/migrations/1_user.sql create mode 100644 tests/postgres/migrations/2_post.sql create mode 100644 tests/postgres/migrations/3_comment.sql create mode 100644 tests/postgres/test-attr.rs create mode 100644 tests/sqlite/fixtures/comments.sql create mode 100644 tests/sqlite/fixtures/posts.sql create mode 100644 tests/sqlite/fixtures/users.sql create mode 100644 tests/sqlite/migrations/1_user.sql create mode 100644 tests/sqlite/migrations/2_post.sql create mode 100644 tests/sqlite/migrations/3_comment.sql create mode 100644 tests/sqlite/test-attr.rs diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 6f8f49ee2a..a9a805002f 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -170,7 +170,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - postgres: [14, 9_6] + postgres: [14, 10] runtime: [async-std, tokio, actix] tls: [native-tls, rustls] needs: check @@ -230,7 +230,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mysql: [8, 5_6] + mysql: [8, 5_7] runtime: [async-std, tokio, actix] tls: [native-tls, rustls] needs: check @@ -257,6 +257,17 @@ jobs: - run: sleep 60 - uses: actions-rs/cargo@v1 + with: + command: test + args: > + --no-default-features + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled + + # MySQL 5.7 supports TLS but not TLSv1.3 as required by RusTLS. + - uses: actions-rs/cargo@v1 + if: ${{ !(matrix.mysql == '5_7' && matrix.tls == 'rustls') }} with: command: test args: > @@ -270,7 +281,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mariadb: [10_6, 10_2] + mariadb: [10_6, 10_3] runtime: [async-std, tokio, actix] tls: [native-tls, rustls] needs: check diff --git a/Cargo.lock b/Cargo.lock index d62660ecfa..03ef45abee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,6 +37,17 @@ version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb07d2053ccdbe10e2af2995a2f116c1330396493dc1269f6a91d0ae82e19704" +[[package]] +name = "argon2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73" +dependencies = [ + "base64ct", + "blake2", + "password-hash", +] + [[package]] name = "arrayvec" version = "0.7.2" @@ -228,6 +239,64 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b9496f0c1d1afb7a2af4338bbe1d969cddfead41d87a9fb3aaa6d0bbc7af648" +dependencies = [ + "async-trait", + "axum-core", + "axum-macros", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa 1.0.2", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4f44a0e6200e9d11a1cdc989e4b358f6e3d354fbf48478f345a17f4e43f8635" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", +] + +[[package]] +name = "axum-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6293dae2ec708e679da6736e857cf8532886ef258e92930f38279c12641628b8" +dependencies = [ + "heck 0.4.0", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "backoff" version = "0.4.0" @@ -277,6 +346,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "blake2" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9cf849ee05b2ee5fba5e36f97ff8ec2533916700fc0758d40d92136a42f3388" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.2" @@ -406,6 +484,7 @@ dependencies = [ "libc", "num-integer", "num-traits", + "serde", "time 0.1.44", "winapi", ] @@ -688,6 +767,41 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4529658bdda7fd6769b8614be250cdcfc3aeb0ee72fe66f9e41e5e5eb73eac02" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "649c91bc01e8b1eac09fb91e8dbc7d517684ca6be8ebc75bb9cafc894f9fdb6f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddfc69c5bfcbd2fc09a0f38451d2daf0e372e367986a83906d1b0dbc88134fb5" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "der" version = "0.5.1" @@ -915,6 +1029,12 @@ dependencies = [ "spin 0.9.3", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1199,12 +1319,81 @@ dependencies = [ "digest", ] +[[package]] +name = "http" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" +dependencies = [ + "bytes", + "fnv", + "itoa 1.0.2", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + +[[package]] +name = "httparse" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "496ce29bb5a52785b44e0f7ca2847ae0bb839c9bd28f69acac9b99d461c0c04c" + +[[package]] +name = "httpdate" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" + [[package]] name = "humantime" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "hyper" +version = "0.14.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "httparse", + "httpdate", + "itoa 1.0.2", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.2.3" @@ -1216,6 +1405,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "if_chain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" + [[package]] name = "indexmap" version = "1.9.1" @@ -1224,6 +1419,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" dependencies = [ "autocfg", "hashbrown", + "serde", ] [[package]] @@ -1407,6 +1603,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +[[package]] +name = "matchit" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" + [[package]] name = "md-5" version = "0.10.1" @@ -1431,6 +1633,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mime" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1737,6 +1945,17 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.7" @@ -2273,9 +2492,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0171ebb889e45aa68b44aee0859b3eede84c6f5f5c228e6f140c0b2a0a46cad6" +checksum = "fc855a42c7967b7c369eb5860f7164ef1f6f81c20c7cc1141f2a604e18723b03" dependencies = [ "serde_derive", ] @@ -2292,9 +2511,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1d3230c1de7932af58ad8ffbe1d784bd55efd5a9d84ac24f69c72d83543dfb" +checksum = "6f2122636b9fe3b81f1cb25099fcf2d3f542cdb1d45940d56c713158884a05da" dependencies = [ "proc-macro2", "quote", @@ -2312,6 +2531,46 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa 1.0.2", + "ryu", + "serde", +] + +[[package]] +name = "serde_with" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89df7a26519371a3cce44fbb914c2819c84d9b897890987fa3ab096491cc0ea8" +dependencies = [ + "base64", + "chrono", + "hex", + "indexmap", + "serde", + "serde_json", + "serde_with_macros", + "time 0.3.11", +] + +[[package]] +name = "serde_with_macros" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de337f322382fcdfbb21a014f7c224ee041a23785651db67b9827403178f698f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sha-1" version = "0.10.0" @@ -2491,6 +2750,7 @@ dependencies = [ "crossbeam-queue", "digest", "dirs", + "dotenvy", "either", "encoding_rs", "event-listener", @@ -2555,6 +2815,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx-example-postgres-axum-social" +version = "0.1.0" +dependencies = [ + "anyhow", + "argon2", + "axum", + "dotenvy", + "once_cell", + "rand", + "regex", + "serde", + "serde_json", + "serde_with", + "sqlx", + "thiserror", + "time 0.3.11", + "tokio", + "tower", + "tracing", + "uuid", + "validator", +] + [[package]] name = "sqlx-example-postgres-listen" version = "0.1.0" @@ -2725,6 +3009,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "tempfile" version = "3.3.0" @@ -2819,6 +3109,7 @@ dependencies = [ "itoa 1.0.2", "libc", "num_threads", + "serde", "time-macros", ] @@ -2855,9 +3146,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57aec3cfa4c296db7255446efb4928a6be304b431a806216105542a67b6ca82e" +checksum = "7a8325f63a7d4774dd041e363b2409ed1c5cbbd0f867795e661df066b2b0a581" dependencies = [ "autocfg", "bytes", @@ -2926,6 +3217,92 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a400e31aa60b9d44a52a8ee0343b5b18566b03a8321e0d321f695cf56e940160" +dependencies = [ + "cfg-if", + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b7358be39f2f274f322d2aaed611acc57f382e8eb1e5b48cb9ae30933495ce7" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" + [[package]] name = "trybuild" version = "1.0.63" @@ -3015,6 +3392,51 @@ name = "uuid" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f" +dependencies = [ + "serde", +] + +[[package]] +name = "validator" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32ad5bf234c7d3ad1042e5252b7eddb2c4669ee23f32c7dd0e9b7705f07ef591" +dependencies = [ + "idna", + "lazy_static", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af" +dependencies = [ + "if_chain", + "lazy_static", + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn", + "validator_types", +] + +[[package]] +name = "validator_types" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3" +dependencies = [ + "proc-macro2", + "syn", +] [[package]] name = "value-bag" @@ -3061,6 +3483,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +dependencies = [ + "log", + "try-lock", +] + [[package]] name = "wasi" version = "0.10.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index f79e04bcd4..0ad5d0df3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "sqlx-cli", "sqlx-bench", "examples/mysql/todos", + "examples/postgres/axum-social-with-tests", "examples/postgres/files", "examples/postgres/json", "examples/postgres/listen", @@ -197,6 +198,11 @@ name = "sqlite-derives" path = "tests/sqlite/derives.rs" required-features = ["sqlite", "macros"] +[[test]] +name = "sqlite-test-attr" +path = "tests/sqlite/test-attr.rs" +required-features = ["sqlite", "macros", "migrate"] + # # MySQL # @@ -221,6 +227,11 @@ name = "mysql-macros" path = "tests/mysql/macros.rs" required-features = ["mysql", "macros"] +[[test]] +name = "mysql-test-attr" +path = "tests/mysql/test-attr.rs" +required-features = ["mysql", "macros", "migrate"] + # # PostgreSQL # @@ -250,6 +261,11 @@ name = "postgres-derives" path = "tests/postgres/derives.rs" required-features = ["postgres", "macros"] +[[test]] +name = "postgres-test-attr" +path = "tests/postgres/test-attr.rs" +required-features = ["postgres", "macros", "migrate"] + # # Microsoft SQL Server (MSSQL) # diff --git a/examples/postgres/axum-social-with-tests/Cargo.toml b/examples/postgres/axum-social-with-tests/Cargo.toml new file mode 100644 index 0000000000..95eb242b85 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "sqlx-example-postgres-axum-social" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Primary crates +axum = { version = "0.5.13", features = ["macros"] } +sqlx = { version = "0.6.0", path = "../../../", features = ["runtime-tokio-rustls", "postgres", "time", "uuid"] } +tokio = { version = "1.20.1", features = ["rt-multi-thread", "macros"] } + +# Important secondary crates +argon2 = "0.4.1" +rand = "0.8.5" +regex = "1.6.0" +serde = "1.0.140" +serde_with = { version = "2.0.0", features = ["time_0_3"] } +time = "0.3.11" +uuid = { version = "1.1.2", features = ["serde"] } +validator = { version = "0.16.0", features = ["derive"] } + +# Auxilliary crates +anyhow = "1.0.58" +dotenvy = "0.15.1" +once_cell = "1.13.0" +thiserror = "1.0.31" +tracing = "0.1.35" + +[dev-dependencies] +serde_json = "1.0.82" +tower = "0.4.13" diff --git a/examples/postgres/axum-social-with-tests/README.md b/examples/postgres/axum-social-with-tests/README.md new file mode 100644 index 0000000000..dfccd84764 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/README.md @@ -0,0 +1,10 @@ +This example demonstrates how to write integration tests for an API build with [Axum] and SQLx using `#[sqlx::test]`. + +See also: https://github.com/tokio-rs/axum/blob/main/examples/testing + +# Warning + +For the sake of brevity, this project omits numerous critical security precautions. You can use it as a starting point, +but deploy to production at your own risk! + +[Axum]: https://github.com/tokio-rs/axum diff --git a/examples/postgres/axum-social-with-tests/migrations/1_user.sql b/examples/postgres/axum-social-with-tests/migrations/1_user.sql new file mode 100644 index 0000000000..62d42bceb2 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/migrations/1_user.sql @@ -0,0 +1,6 @@ +create table "user" +( + user_id uuid primary key default gen_random_uuid(), + username text unique not null, + password_hash text not null +); diff --git a/examples/postgres/axum-social-with-tests/migrations/2_post.sql b/examples/postgres/axum-social-with-tests/migrations/2_post.sql new file mode 100644 index 0000000000..b91705fa85 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/migrations/2_post.sql @@ -0,0 +1,8 @@ +create table post ( + post_id uuid primary key default gen_random_uuid(), + user_id uuid not null references "user"(user_id), + content text not null, + created_at timestamptz not null default now() +); + +create index on post(created_at desc); diff --git a/examples/postgres/axum-social-with-tests/migrations/3_comment.sql b/examples/postgres/axum-social-with-tests/migrations/3_comment.sql new file mode 100644 index 0000000000..d76cce1b2f --- /dev/null +++ b/examples/postgres/axum-social-with-tests/migrations/3_comment.sql @@ -0,0 +1,9 @@ +create table comment ( + comment_id uuid primary key default gen_random_uuid(), + post_id uuid not null references post(post_id), + user_id uuid not null references "user"(user_id), + content text not null, + created_at timestamptz not null default now() +); + +create index on comment(post_id, created_at); diff --git a/examples/postgres/axum-social-with-tests/src/http/error.rs b/examples/postgres/axum-social-with-tests/src/http/error.rs new file mode 100644 index 0000000000..effb518ba4 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/http/error.rs @@ -0,0 +1,75 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; + +use serde_with::DisplayFromStr; +use validator::ValidationErrors; + +/// An API-friendly error type. +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// A SQLx call returned an error. + /// + /// The exact error contents are not reported to the user in order to avoid leaking + /// information about databse internals. + #[error("an internal database error occurred")] + Sqlx(#[from] sqlx::Error), + + /// Similarly, we don't want to report random `anyhow` errors to the user. + #[error("an internal server error occurred")] + Anyhow(#[from] anyhow::Error), + + #[error("validation error in request body")] + InvalidEntity(#[from] ValidationErrors), + + #[error("{0}")] + UnprocessableEntity(String), + + #[error("{0}")] + Conflict(String), +} + +impl IntoResponse for Error { + fn into_response(self) -> Response { + #[serde_with::serde_as] + #[serde_with::skip_serializing_none] + #[derive(serde::Serialize)] + struct ErrorResponse<'a> { + // Serialize the `Display` output as the error message + #[serde_as(as = "DisplayFromStr")] + message: &'a Error, + + errors: Option<&'a ValidationErrors>, + } + + let errors = match &self { + Error::InvalidEntity(errors) => Some(errors), + _ => None, + }; + + // Normally you wouldn't just print this, but it's useful for debugging without + // using a logging framework. + println!("API error: {:?}", self); + + ( + self.status_code(), + Json(ErrorResponse { + message: &self, + errors, + }), + ) + .into_response() + } +} + +impl Error { + fn status_code(&self) -> StatusCode { + use Error::*; + + match self { + Sqlx(_) | Anyhow(_) => StatusCode::INTERNAL_SERVER_ERROR, + InvalidEntity(_) | UnprocessableEntity(_) => StatusCode::UNPROCESSABLE_ENTITY, + Conflict(_) => StatusCode::CONFLICT, + } + } +} diff --git a/examples/postgres/axum-social-with-tests/src/http/mod.rs b/examples/postgres/axum-social-with-tests/src/http/mod.rs new file mode 100644 index 0000000000..a871a93d7e --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/http/mod.rs @@ -0,0 +1,26 @@ +use anyhow::Context; +use axum::{Extension, Router}; +use sqlx::PgPool; + +mod error; + +mod post; +mod user; + +pub use self::error::Error; + +pub type Result = ::std::result::Result; + +pub fn app(db: PgPool) -> Router { + Router::new() + .merge(user::router()) + .merge(post::router()) + .layer(Extension(db)) +} + +pub async fn serve(db: PgPool) -> anyhow::Result<()> { + axum::Server::bind(&"0.0.0.0:8080".parse().unwrap()) + .serve(app(db).into_make_service()) + .await + .context("failed to serve API") +} diff --git a/examples/postgres/axum-social-with-tests/src/http/post/comment.rs b/examples/postgres/axum-social-with-tests/src/http/post/comment.rs new file mode 100644 index 0000000000..630dedaa21 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/http/post/comment.rs @@ -0,0 +1,100 @@ +use axum::extract::Path; +use axum::{Extension, Json, Router}; + +use axum::routing::get; + +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +use crate::http::user::UserAuth; +use sqlx::PgPool; +use validator::Validate; + +use crate::http::Result; + +use time::format_description::well_known::Rfc3339; +use uuid::Uuid; + +pub fn router() -> Router { + Router::new().route( + "/v1/post/:postId/comment", + get(get_post_comments).post(create_post_comment), + ) +} + +#[derive(Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +struct CreateCommentRequest { + auth: UserAuth, + #[validate(length(min = 1, max = 1000))] + content: String, +} + +#[serde_with::serde_as] +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct Comment { + comment_id: Uuid, + username: String, + content: String, + // `OffsetDateTime`'s default serialization format is not standard. + #[serde_as(as = "Rfc3339")] + created_at: OffsetDateTime, +} + +// #[axum::debug_handler] // very useful! +async fn create_post_comment( + db: Extension, + Path(post_id): Path, + Json(req): Json, +) -> Result> { + req.validate()?; + let user_id = req.auth.verify(&*db).await?; + + let comment = sqlx::query_as!( + Comment, + // language=PostgreSQL + r#" + with inserted_comment as ( + insert into comment(user_id, post_id, content) + values ($1, $2, $3) + returning comment_id, user_id, content, created_at + ) + select comment_id, username, content, created_at + from inserted_comment + inner join "user" using (user_id) + "#, + user_id, + post_id, + req.content + ) + .fetch_one(&*db) + .await?; + + Ok(Json(comment)) +} + +/// Returns comments in ascending chronological order. +async fn get_post_comments( + db: Extension, + Path(post_id): Path, +) -> Result>> { + // Note: normally you'd want to put a `LIMIT` on this as well, + // though that would also necessitate implementing pagination. + let comments = sqlx::query_as!( + Comment, + // language=PostgreSQL + r#" + select comment_id, username, content, created_at + from comment + inner join "user" using (user_id) + where post_id = $1 + order by created_at + "#, + post_id + ) + .fetch_all(&*db) + .await?; + + Ok(Json(comments)) +} diff --git a/examples/postgres/axum-social-with-tests/src/http/post/mod.rs b/examples/postgres/axum-social-with-tests/src/http/post/mod.rs new file mode 100644 index 0000000000..09c2fa44bb --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/http/post/mod.rs @@ -0,0 +1,93 @@ +use axum::{Extension, Json, Router}; + +use axum::routing::get; + +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +use crate::http::user::UserAuth; +use sqlx::PgPool; +use validator::Validate; + +use crate::http::Result; + +use time::format_description::well_known::Rfc3339; +use uuid::Uuid; + +mod comment; + +pub fn router() -> Router { + Router::new() + .route("/v1/post", get(get_posts).post(create_post)) + .merge(comment::router()) +} + +#[derive(Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +struct CreatePostRequest { + auth: UserAuth, + #[validate(length(min = 1, max = 1000))] + content: String, +} + +#[serde_with::serde_as] +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct Post { + post_id: Uuid, + username: String, + content: String, + // `OffsetDateTime`'s default serialization format is not standard. + #[serde_as(as = "Rfc3339")] + created_at: OffsetDateTime, +} + +// #[axum::debug_handler] // very useful! +async fn create_post( + db: Extension, + Json(req): Json, +) -> Result> { + req.validate()?; + let user_id = req.auth.verify(&*db).await?; + + let post = sqlx::query_as!( + Post, + // language=PostgreSQL + r#" + with inserted_post as ( + insert into post(user_id, content) + values ($1, $2) + returning post_id, user_id, content, created_at + ) + select post_id, username, content, created_at + from inserted_post + inner join "user" using (user_id) + "#, + user_id, + req.content + ) + .fetch_one(&*db) + .await?; + + Ok(Json(post)) +} + +/// Returns posts in descending chronological order. +async fn get_posts(db: Extension) -> Result>> { + // Note: normally you'd want to put a `LIMIT` on this as well, + // though that would also necessitate implementing pagination. + let posts = sqlx::query_as!( + Post, + // language=PostgreSQL + r#" + select post_id, username, content, created_at + from post + inner join "user" using (user_id) + order by created_at desc + "# + ) + .fetch_all(&*db) + .await?; + + Ok(Json(posts)) +} diff --git a/examples/postgres/axum-social-with-tests/src/http/user.rs b/examples/postgres/axum-social-with-tests/src/http/user.rs new file mode 100644 index 0000000000..55f7f05bab --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/http/user.rs @@ -0,0 +1,95 @@ +use axum::http::StatusCode; +use axum::{routing::post, Extension, Json, Router}; +use once_cell::sync::Lazy; +use rand::Rng; +use regex::Regex; +use std::time::Duration; + +use serde::Deserialize; +use sqlx::{PgExecutor, PgPool}; +use uuid::Uuid; +use validator::Validate; + +use crate::http::{Error, Result}; + +pub type UserId = Uuid; + +pub fn router() -> Router { + Router::new().route("/v1/user", post(create_user)) +} + +static USERNAME_REGEX: Lazy = Lazy::new(|| Regex::new(r"^[0-9A-Za-z_]+$").unwrap()); + +// CREATE USER + +#[derive(Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +pub struct UserAuth { + #[validate(length(min = 3, max = 16), regex = "USERNAME_REGEX")] + username: String, + #[validate(length(min = 8, max = 32))] + password: String, +} + +// WARNING: this API has none of the checks that a normal user signup flow implements, +// such as email or phone verification. +async fn create_user(db: Extension, Json(req): Json) -> Result { + req.validate()?; + + let UserAuth { username, password } = req; + + // It would be irresponsible to store passwords in plaintext, however. + let password_hash = crate::password::hash(password).await?; + + sqlx::query!( + // language=PostgreSQL + r#" + insert into "user"(username, password_hash) + values ($1, $2) + "#, + username, + password_hash + ) + .execute(&*db) + .await + .map_err(|e| match e { + sqlx::Error::Database(dbe) if dbe.constraint() == Some("user_username_key") => { + Error::Conflict("username taken".into()) + } + _ => e.into(), + })?; + + Ok(StatusCode::NO_CONTENT) +} + +impl UserAuth { + // NOTE: normally we wouldn't want to verify the username and password every time, + // but persistent sessions would have complicated the example. + pub async fn verify(self, db: impl PgExecutor<'_> + Send) -> Result { + self.validate()?; + + let maybe_user = sqlx::query!( + r#"select user_id, password_hash from "user" where username = $1"#, + self.username + ) + .fetch_optional(db) + .await?; + + if let Some(user) = maybe_user { + let verified = crate::password::verify(self.password, user.password_hash).await?; + + if verified { + return Ok(user.user_id); + } + } + + // Sleep a random amount of time to avoid leaking existence of a user in timing. + let sleep_duration = + rand::thread_rng().gen_range(Duration::from_millis(100)..=Duration::from_millis(500)); + tokio::time::sleep(sleep_duration).await; + + Err(Error::UnprocessableEntity( + "invalid username/password".into(), + )) + } +} diff --git a/examples/postgres/axum-social-with-tests/src/lib.rs b/examples/postgres/axum-social-with-tests/src/lib.rs new file mode 100644 index 0000000000..76f3b5b7ec --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/lib.rs @@ -0,0 +1,3 @@ +pub mod http; + +mod password; diff --git a/examples/postgres/axum-social-with-tests/src/main.rs b/examples/postgres/axum-social-with-tests/src/main.rs new file mode 100644 index 0000000000..6b8f223dec --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/main.rs @@ -0,0 +1,19 @@ +use anyhow::Context; +use sqlx::postgres::PgPoolOptions; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let database_url = dotenvy::var("DATABASE_URL") + // The error from `var()` doesn't mention the environment variable. + .context("DATABASE_URL must be set")?; + + let db = PgPoolOptions::new() + .max_connections(20) + .connect(&database_url) + .await + .context("failed to connect to DATABASE_URL")?; + + sqlx::migrate!().run(&db).await?; + + sqlx_example_postgres_axum_social::http::serve(db).await +} diff --git a/examples/postgres/axum-social-with-tests/src/password.rs b/examples/postgres/axum-social-with-tests/src/password.rs new file mode 100644 index 0000000000..44f1551f85 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/src/password.rs @@ -0,0 +1,34 @@ +use anyhow::{anyhow, Context}; +use tokio::task; + +use argon2::password_hash::SaltString; +use argon2::{password_hash, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; + +pub async fn hash(password: String) -> anyhow::Result { + task::spawn_blocking(move || { + let salt = SaltString::generate(rand::thread_rng()); + Ok(Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map_err(|e| anyhow!(e).context("failed to hash password"))? + .to_string()) + }) + .await + .context("panic in hash()")? +} + +pub async fn verify(password: String, hash: String) -> anyhow::Result { + task::spawn_blocking(move || { + let hash = PasswordHash::new(&hash) + .map_err(|e| anyhow!(e).context("BUG: password hash invalid"))?; + + let res = Argon2::default().verify_password(password.as_bytes(), &hash); + + match res { + Ok(()) => Ok(true), + Err(password_hash::Error::Password) => Ok(false), + Err(e) => Err(anyhow!(e).context("failed to verify password")), + } + }) + .await + .context("panic in verify()")? +} diff --git a/examples/postgres/axum-social-with-tests/tests/comment.rs b/examples/postgres/axum-social-with-tests/tests/comment.rs new file mode 100644 index 0000000000..84eb96d0e6 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/comment.rs @@ -0,0 +1,152 @@ +use sqlx::PgPool; + +use sqlx_example_postgres_axum_social::http; + +use axum::http::{Request, StatusCode}; +use tower::ServiceExt; + +use std::borrow::BorrowMut; + +use common::{expect_rfc3339_timestamp, expect_uuid, response_json, RequestBuilderExt}; + +use serde_json::json; + +mod common; + +#[sqlx::test(fixtures("users", "posts"))] +async fn test_create_comment(db: PgPool) { + let mut app = http::app(db); + + // Happy path! + let mut resp1 = app + .borrow_mut() + .oneshot( + Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! { + { + "auth": { + "username": "bob", + "password": "pro gamer 1990" + }, + "content": "lol bet ur still bad, 1v1 me" + } + }), + ) + .await + .unwrap(); + + assert_eq!(resp1.status(), StatusCode::OK); + + let resp1_json = response_json(&mut resp1).await; + + assert_eq!(resp1_json["username"], "bob"); + assert_eq!(resp1_json["content"], "lol bet ur still bad, 1v1 me"); + + let _comment_id = expect_uuid(&resp1_json["commentId"]); + + let _created_at = expect_rfc3339_timestamp(&resp1_json["createdAt"]); + + // Incorrect username + let mut resp2 = app + .borrow_mut() + .oneshot( + Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! { + { + "auth": { + "username": "bobbbbbb", + "password": "pro gamer 1990" + }, + "content": "lol bet ur still bad, 1v1 me" + } + }), + ) + .await + .unwrap(); + + assert_eq!(resp2.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp2_json = response_json(&mut resp2).await; + assert_eq!(resp2_json["message"], "invalid username/password"); + + // Incorrect password + let mut resp3 = app + .borrow_mut() + .oneshot( + Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! { + { + "auth": { + "username": "bob", + "password": "pro gamer 1990" + }, + "content": "lol bet ur still bad, 1v1 me" + } + }), + ) + .await + .unwrap(); + + assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp3_json = response_json(&mut resp3).await; + assert_eq!(resp3_json["message"], "invalid username/password"); +} + +#[sqlx::test(fixtures("users", "posts", "comments"))] +async fn test_list_comments(db: PgPool) { + let mut app = http::app(db); + + // This only has the happy path. + let mut resp = app + .borrow_mut() + .oneshot(Request::get("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").empty_body()) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let resp_json = response_json(&mut resp).await; + + let comments = resp_json + .as_array() + .expect("expected request to return an array"); + + assert_eq!(comments.len(), 2); + + assert_eq!(comments[0]["username"], "bob"); + assert_eq!(comments[0]["content"], "lol bet ur still bad, 1v1 me"); + + let _comment_id = expect_uuid(&comments[0]["commentId"]); + let created_at_0 = expect_rfc3339_timestamp(&comments[0]["createdAt"]); + + assert_eq!(comments[1]["username"], "alice"); + assert_eq!(comments[1]["content"], "you're on!"); + + let _comment_id = expect_uuid(&comments[1]["commentId"]); + let created_at_1 = expect_rfc3339_timestamp(&comments[1]["createdAt"]); + + assert!( + created_at_0 < created_at_1, + "comments must be assorted in ascending order" + ); + + let mut resp = app + .borrow_mut() + .oneshot(Request::get("/v1/post/7e3d4d16-a35e-46ba-8223-b4f1debbfbfe/comment").empty_body()) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let resp_json = response_json(&mut resp).await; + + let comments = resp_json + .as_array() + .expect("expected request to return an array"); + + assert_eq!(comments.len(), 1); + + assert_eq!(comments[0]["username"], "alice"); + assert_eq!(comments[0]["content"], "lol you're just mad you lost :P"); + + let _comment_id = expect_uuid(&comments[0]["commentId"]); + let _created_at = expect_rfc3339_timestamp(&comments[0]["createdAt"]); +} diff --git a/examples/postgres/axum-social-with-tests/tests/common.rs b/examples/postgres/axum-social-with-tests/tests/common.rs new file mode 100644 index 0000000000..41abd3af10 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/common.rs @@ -0,0 +1,72 @@ +// This is imported by different tests that use different functions. +#![allow(dead_code)] + +use axum::body::{Body, BoxBody, HttpBody}; +use axum::http::header::CONTENT_TYPE; +use axum::http::{request, Request}; +use axum::response::Response; +use time::format_description::well_known::Rfc3339; +use time::OffsetDateTime; +use uuid::Uuid; + +pub trait RequestBuilderExt { + fn json(self, json: serde_json::Value) -> Request; + + fn empty_body(self) -> Request; +} + +impl RequestBuilderExt for request::Builder { + fn json(self, json: serde_json::Value) -> Request { + self.header("Content-Type", "application/json") + .body(Body::from(json.to_string())) + .expect("failed to build request") + } + + fn empty_body(self) -> Request { + self.body(Body::empty()).expect("failed to build request") + } +} + +#[track_caller] +pub async fn response_json(resp: &mut Response) -> serde_json::Value { + assert_eq!( + resp.headers() + .get(CONTENT_TYPE) + .expect("expected Content-Type"), + "application/json" + ); + + let body = resp.body_mut(); + + let mut bytes = Vec::new(); + + while let Some(res) = body.data().await { + let chunk = res.expect("error reading response body"); + + bytes.extend_from_slice(&chunk[..]); + } + + serde_json::from_slice(&bytes).expect("failed to read response body as json") +} + +#[track_caller] +pub fn expect_string(value: &serde_json::Value) -> &str { + value + .as_str() + .unwrap_or_else(|| panic!("expected string, got {:?}", value)) +} + +#[track_caller] +pub fn expect_uuid(value: &serde_json::Value) -> Uuid { + expect_string(value) + .parse::() + .unwrap_or_else(|e| panic!("failed to parse UUID from {:?}: {}", value, e)) +} + +#[track_caller] +pub fn expect_rfc3339_timestamp(value: &serde_json::Value) -> OffsetDateTime { + let s = expect_string(value); + + OffsetDateTime::parse(s, &Rfc3339) + .unwrap_or_else(|e| panic!("failed to parse RFC-3339 timestamp from {:?}: {}", value, e)) +} diff --git a/examples/postgres/axum-social-with-tests/tests/fixtures/comments.sql b/examples/postgres/axum-social-with-tests/tests/fixtures/comments.sql new file mode 100644 index 0000000000..b53a09035c --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/fixtures/comments.sql @@ -0,0 +1,12 @@ +INSERT INTO public.comment (comment_id, post_id, user_id, content, created_at) +VALUES + -- from: bob + ('3a86b8f8-827b-4f14-94a2-34517b4b5bde', 'd9ca2672-24c5-4442-b32f-cd717adffbaa', + 'c994b839-84f4-4509-ad49-59119133d6f5', 'lol bet ur still bad, 1v1 me', '2022-07-29 01:52:31.167673'), + -- from: alice + ('d6f862b5-2b87-4af4-b15e-6b3398729e6d', 'd9ca2672-24c5-4442-b32f-cd717adffbaa', + '51b374f1-93ae-4c5c-89dd-611bda8412ce', 'you''re on!', '2022-07-29 01:53:53.115782'), + -- from: alice + ('1eed85ae-adae-473c-8d05-b1dae0a1df63', '7e3d4d16-a35e-46ba-8223-b4f1debbfbfe', + '51b374f1-93ae-4c5c-89dd-611bda8412ce', 'lol you''re just mad you lost :P', '2022-07-29 01:55:50.116119'); + diff --git a/examples/postgres/axum-social-with-tests/tests/fixtures/posts.sql b/examples/postgres/axum-social-with-tests/tests/fixtures/posts.sql new file mode 100644 index 0000000000..8d875e553a --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/fixtures/posts.sql @@ -0,0 +1,8 @@ +INSERT INTO public.post (post_id, user_id, content, created_at) +VALUES + -- from: alice + ('d9ca2672-24c5-4442-b32f-cd717adffbaa', '51b374f1-93ae-4c5c-89dd-611bda8412ce', + 'This new computer is blazing fast!', '2022-07-29 01:36:24.679082'), + -- from: bob + ('7e3d4d16-a35e-46ba-8223-b4f1debbfbfe', 'c994b839-84f4-4509-ad49-59119133d6f5', '@alice is a haxxor', + '2022-07-29 01:54:45.823523'); diff --git a/examples/postgres/axum-social-with-tests/tests/fixtures/users.sql b/examples/postgres/axum-social-with-tests/tests/fixtures/users.sql new file mode 100644 index 0000000000..5c29415e96 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/fixtures/users.sql @@ -0,0 +1,10 @@ +INSERT INTO public."user" (user_id, username, password_hash) +VALUES + -- username: "alice"; password: "rustacean since 2015" + ('51b374f1-93ae-4c5c-89dd-611bda8412ce', 'alice', + '$argon2id$v=19$m=4096,t=3,p=1$3v3ats/tYTXAYs3q9RycDw$ZltwjS3oQwPuNmL9f6DNb+sH5N81dTVZhVNbUQzmmVU'), + -- username: "bob"; password: "pro gamer 1990" + ('c994b839-84f4-4509-ad49-59119133d6f5', 'bob', + '$argon2id$v=19$m=4096,t=3,p=1$1zbkRinUH9WHzkyu8C1Vlg$70pu5Cca/s3d0nh5BYQGkN7+s9cqlNxTE7rFZaUaP4c'); + + diff --git a/examples/postgres/axum-social-with-tests/tests/post.rs b/examples/postgres/axum-social-with-tests/tests/post.rs new file mode 100644 index 0000000000..c23220acdc --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/post.rs @@ -0,0 +1,120 @@ +use sqlx::PgPool; + +use sqlx_example_postgres_axum_social::http; + +use axum::http::{Request, StatusCode}; +use tower::ServiceExt; + +use std::borrow::BorrowMut; + +use common::{expect_rfc3339_timestamp, expect_uuid, response_json, RequestBuilderExt}; + +use serde_json::json; + +mod common; + +#[sqlx::test(fixtures("users"))] +async fn test_create_post(db: PgPool) { + let mut app = http::app(db); + + // Happy path! + let mut resp1 = app + .borrow_mut() + .oneshot(Request::post("/v1/post").json(json! { + { + "auth": { + "username": "alice", + "password": "rustacean since 2015" + }, + "content": "This new computer is blazing fast!" + } + })) + .await + .unwrap(); + + assert_eq!(resp1.status(), StatusCode::OK); + + let resp1_json = response_json(&mut resp1).await; + + assert_eq!(resp1_json["username"], "alice"); + assert_eq!(resp1_json["content"], "This new computer is blazing fast!"); + + let _post_id = expect_uuid(&resp1_json["postId"]); + let _created_at = expect_rfc3339_timestamp(&resp1_json["createdAt"]); + + // Incorrect username + let mut resp2 = app + .borrow_mut() + .oneshot(Request::post("/v1/post").json(json! { + { + "auth": { + "username": "aliceee", + "password": "rustacean since 2015" + }, + "content": "This new computer is blazing fast!" + } + })) + .await + .unwrap(); + + assert_eq!(resp2.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp2_json = response_json(&mut resp2).await; + assert_eq!(resp2_json["message"], "invalid username/password"); + + // Incorrect password + let mut resp3 = app + .borrow_mut() + .oneshot(Request::post("/v1/post").json(json! { + { + "auth": { + "username": "alice", + "password": "rustaceansince2015" + }, + "content": "This new computer is blazing fast!" + } + })) + .await + .unwrap(); + + assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp3_json = response_json(&mut resp3).await; + assert_eq!(resp3_json["message"], "invalid username/password"); +} + +#[sqlx::test(fixtures("users", "posts"))] +async fn test_list_posts(db: PgPool) { + // This only has the happy path. + let mut resp = http::app(db) + .oneshot(Request::get("/v1/post").empty_body()) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let resp_json = response_json(&mut resp).await; + + let posts = resp_json + .as_array() + .expect("expected GET /v1/post to return an array"); + + assert_eq!(posts.len(), 2); + + assert_eq!(posts[0]["username"], "bob"); + assert_eq!(posts[0]["content"], "@alice is a haxxor"); + + let _post_id = expect_uuid(&posts[0]["postId"]); + let created_at_0 = expect_rfc3339_timestamp(&posts[0]["createdAt"]); + + assert_eq!(posts[1]["username"], "alice"); + assert_eq!(posts[1]["content"], "This new computer is blazing fast!"); + + let _post_id = expect_uuid(&posts[1]["postId"]); + let created_at_1 = expect_rfc3339_timestamp(&posts[1]["createdAt"]); + + assert!( + created_at_0 > created_at_1, + "posts must be sorted in descending order" + ); +} diff --git a/examples/postgres/axum-social-with-tests/tests/user.rs b/examples/postgres/axum-social-with-tests/tests/user.rs new file mode 100644 index 0000000000..cfc642a050 --- /dev/null +++ b/examples/postgres/axum-social-with-tests/tests/user.rs @@ -0,0 +1,89 @@ +use sqlx::PgPool; + +use sqlx_example_postgres_axum_social::http; + +use axum::http::{Request, StatusCode}; +use tower::ServiceExt; + +use std::borrow::BorrowMut; + +use common::{response_json, RequestBuilderExt}; + +use serde_json::json; + +mod common; + +#[sqlx::test] +async fn test_create_user(db: PgPool) { + let mut app = http::app(db); + + // Happy path! + let resp1 = app + .borrow_mut() + // We handle JSON objects directly to sanity check the serialization and deserialization + .oneshot(Request::post("/v1/user").json(json! {{ + "username": "alice", + "password": "rustacean since 2015" + }})) + .await + .unwrap(); + + assert_eq!(resp1.status(), StatusCode::NO_CONTENT); + + // Username taken + let mut resp2 = app + .borrow_mut() + .oneshot(Request::post("/v1/user").json(json! {{ + "username": "alice", + "password": "uhhh i forgot" + }})) + .await + .unwrap(); + + assert_eq!(resp2.status(), StatusCode::CONFLICT); + + let resp2_json = response_json(&mut resp2).await; + assert_eq!(resp2_json["message"], "username taken"); + + // Invalid username + let mut resp3 = app + .borrow_mut() + .oneshot(Request::post("/v1/user").json(json! {{ + "username": "definitely an invalid username", + "password": "password" + }})) + .await + .unwrap(); + + assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp3_json = response_json(&mut resp3).await; + + assert_eq!(resp3_json["message"], "validation error in request body"); + assert!( + resp3_json["errors"]["username"].is_array(), + "errors.username is not an array: {:?}", + resp3_json + ); + + // Invalid password + let mut resp4 = app + .borrow_mut() + .oneshot(Request::post("/v1/user").json(json! {{ + "username": "bobby123", + "password": "" + }})) + .await + .unwrap(); + + assert_eq!(resp4.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let resp4_json = response_json(&mut resp4).await; + + assert_eq!(resp4_json["message"], "validation error in request body"); + assert!( + resp4_json["errors"]["password"].is_array(), + "errors.password is not an array: {:?}", + resp4_json + ); +} diff --git a/sqlx-bench/test.db b/sqlx-bench/test.db index f98f0ea542ede772b873f759fa3b28c6a53c22f2..6d62cec842bda223f62bd0d39180d5c843087d31 100644 GIT binary patch delta 2421 zcmXxj32c*f9LDi=eMdJsTqi{lP!6jcf(};Y@BpPCr-IaDfHJT!@Zu^cUgdvn!h{K{ z%sF++suczdwn~91o2+qBq8L~pVTm)cNR$W$8c0~ee4ZvQ`F*}`^Cs{6@6zAK=Ddx~ zc^~x8?UR<4b`10XU(U9X`L;d>3Ttz0_J(ZN*`A?{;>}OuZja*PIf$cJ&P6iKji-<- zbA2At(_Hlnk2I*z~eipgO{P`S`WBw>WT;{j=NUk}%0J+(mUWoKIC!a^$ z=9d?cKIW%I$SvlFLgdzvIkp(>YmP2KZZk(-MEaS70i?gFUy2Mc`<5ZMo83jo9pI{$m4@r+<4B8LiL0g^bas-$wHF$#)R1 z{^ecde*IGjc|fBxB9Xh$aKxP?5Z8$ucAL%Oa9^02OTAdl#1CNfdi zWFeDubx-6`UFAd`(-qmsWL@41@#(UgkSV$}2bro%T*%{E|SK4D!mU`rb_aVBBe$n%VqZ%WQFX?M_!WGyvWP) z%KgYI^3nsytMbBF;d(y*$X-6baTlNe zZG_MN?(;1ENRL+W1x-8ng68df!S`GEf@7Qb0!tUv=hdbc?7{V3%sxyVWn>z#7jKHO)Qek1f8UfBuI=z zNSJ66Bmv?lKH?=F;wBbJI(R}y>V&v-Bpj&|I&dpO!bFoG2@pT=5iju&H?c^v2k0bi zBtc>%Lc&CoAPEpZ@ewca5I3<%GJ_|~NSzRujzmW4gc-OMAz`9PkOYXI_=uNyh?`g> zX@O4CMiL}OA|y;S36cQu6Cd#s4{;NVB<(z*J#|7{IuiEO3GKKQAz`9PkOYXI_=uNy zh?`g>nGQNh8%dBDiI6bSBuE0pPkh8nJj6|`bQ~^e>vBfzzuR1K`>!@=HQPRPM=N7E zXJv$hi6%i3Ab#Q_Ug9BcV%c0Zl}St|rVW$8#4r&|7^5*kOaSA@_+0TycXpBUKh?I0 AB>(^b delta 2422 zcmXxj4Umj=0LSrrdG2;^@3_TA@)B8jNgNi*OK3}a@w$|kI4mlJT18&g|Lv|^T-@e5 zSDoBux8MbqcOFm6S|H8cXCf#1b96ECwK?KP>dfIK$N_U`DYD-jScZIM_PvHg z%-+|LedhBwkT1>dohy;OX8W7S7iQ}!WREFXjno*q2Kijye+&6c2k*U& zeyZ;TkZOHvEwWqx^A1v_ufL0YqW@fnRO&z0BfIqF_mGeE#SKVEpWld7=(C%Uo%-}< zWQYDOhgMVn<4KZ5~ph zYaEEt)%l3ll?8~>6>Siy%lRKHbXg&?MVC5}pjPdW&8nq6vPm^}KsKtoE@Xqc{Q&Zw zy4ev~uWoeWo32yW9z@<%S34u`s4EX4Yt^MLNI+faioC7Pbwl1#XSyS6)TxJ&)#_vq zWR*J66M0iLK7y=N4ZV;RD(*&>t7vcJ4ORar^17<)gS@69eUW9VwqFLGzf{%q$Cpb~ z^<#)%RSrNFtBS{wMXG!tvQU)`LW)&s5wbw3!N`2sG6Z>5HV;MS$-5q8uDm@AnImrw zM`p_#Bam0*wI`65<<*hMEP3TgWTw0{3jbOeK6zm@zMLV?jX|c%Gh>ly^3+qvRC#h7 zGDV&kk4%=06Ofl=!_&x%GVVoQkkMz5=Vkq~$Rt@e5qVBVovgUc@8Cm@T z;+2&zB2UYTmyqBDSw0yZFUvT?ak7*%d`c?LaI9$I49AFO&TzE2%NdRmw>iTn#ZAs| zq`1KuJ|V7gh9ksP&TzQ6!Wj+|mpDU@xWE|>73Vm^A>s^YI9QxofD{GA$zpVnIKdeX z6pftWlOObw}ZW+>7L^#7fqLwp!RMc>Wy+t)==oXcnVJ}g^ z89pM)Im4c!j5F*ZN;$)ah2jjmw*;AP&HZ54JA8#MxA;T&(BJ(2JO9N6JecA?ey1zz zVaH2@;REOSP}kX>u*2{Cj@$p*4z^27g3hK)SlBQUwmri8w>ixF7aZjM^Y`=qjxg_^ zwJiDB&1LZKb1$86I3q5K%OhQCEobjrFV|jkuDt5FK#CLy9pfcS}zc!`I&iHleyZ3iimBnc8DVG<(R?u^?5 z_Ljnk&5w^ZpFI=5>*Xa6aT6D@NZJNcBuNq^M#3aSGzkzt@ewca5I1oVi==ZwiX=&b z#7LNgh|c93=C-~eKH7Y_t#8On9^xi0Vv%$XNRcE-kQfP*5YZ$+{KQAR#6#S~MJ$rG zK#C+ug2YIegow8IhF0qv;-k%HwZ0)Qd5D|1h(*%bAVrcSL1H9KLPV1Q@e?2M5)W|` z7qLh>3#3SrBuI>eNr>nyzF}7D8{(tQm(}`)yyPKn;vyDF^Jkn1qNX0pce< z;w2v9CN5%SI&1A|oD@zHCxH{g3FCxtv@>cCv=r9b{LZKiw~5-l9Er2w-@^X@u?l=o diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 5b6dfa0582..ce0d4bcb58 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -175,6 +175,8 @@ indexmap = "1.6.0" hkdf = { version = "0.12.0", optional = true } event-listener = "2.5.2" +dotenvy = "0.15" + [dev-dependencies] sqlx = { version = "0.6.0", path = "..", features = ["postgres", "sqlite", "mysql"] } tokio = { version = "1", features = ["rt"] } diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index b1143bf97d..83e0203b93 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -103,6 +103,12 @@ pub mod mysql; #[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] pub mod mssql; +// Implements test support with automatic DB management. +#[cfg(feature = "migrate")] +pub mod testing; + +pub use sqlx_rt::test_block_on; + /// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. use ahash::AHashMap as HashMap; //type HashMap = std::collections::HashMap; diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 405d9232ff..f392e57a53 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -93,7 +93,15 @@ impl Migrator { ::Target: Migrate, { let mut conn = migrator.acquire().await?; + self.run_direct(&mut *conn).await + } + // Getting around the annoying "implementation of `Acquire` is not general enough" error + #[doc(hidden)] + pub async fn run_direct(&self, conn: &mut C) -> Result<(), MigrateError> + where + C: Migrate, + { // lock the database for exclusive access by the migrator conn.lock().await?; diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs index 9e84cc4fc1..5352b1a10c 100644 --- a/sqlx-core/src/mysql/connection/establish.rs +++ b/sqlx-core/src/mysql/connection/establish.rs @@ -60,7 +60,33 @@ impl MySqlConnection { } // Upgrade to TLS if we were asked to and the server supports it - tls::maybe_upgrade(&mut stream, options).await?; + + #[cfg(feature = "_tls-rustls")] + { + // To aid in debugging: https://github.com/rustls/rustls/issues/893 + + let local_addr = stream.local_addr(); + + match tls::maybe_upgrade(&mut stream, options).await { + Ok(()) => (), + #[cfg(feature = "_tls-rustls")] + Err(Error::Io(ioe)) => { + if let Some(&rustls::Error::CorruptMessage) = + ioe.get_ref().and_then(|e| e.downcast_ref()) + { + log::trace!("got corrupt message on socket {:?}", local_addr); + } + + return Err(Error::Io(ioe)); + } + Err(e) => return Err(e), + } + } + + #[cfg(not(feature = "_tls-rustls"))] + { + tls::maybe_upgrade(&mut stream, options).await? + } let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) { Some(plugin.scramble(&mut stream, password, &nonce).await?) diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index e108e8591f..874b893bda 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -22,6 +22,9 @@ mod value; #[cfg(feature = "migrate")] mod migrate; +#[cfg(feature = "migrate")] +mod testing; + pub use arguments::MySqlArguments; pub use column::MySqlColumn; pub use connection::MySqlConnection; diff --git a/sqlx-core/src/mysql/testing/mod.rs b/sqlx-core/src/mysql/testing/mod.rs new file mode 100644 index 0000000000..1ea7e09d01 --- /dev/null +++ b/sqlx-core/src/mysql/testing/mod.rs @@ -0,0 +1,223 @@ +use std::fmt::Write; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration}; + +use futures_core::future::BoxFuture; + +use once_cell::sync::{OnceCell}; + +use crate::connection::Connection; + +use crate::error::Error; +use crate::executor::Executor; +use crate::mysql::{MySql, MySqlConnectOptions, MySqlConnection}; +use crate::pool::{Pool, PoolOptions}; +use crate::query::query; +use crate::query_builder::QueryBuilder; +use crate::query_scalar::query_scalar; +use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport}; + +// Using a blocking `OnceCell` here because the critical sections are short. +static MASTER_POOL: OnceCell> = OnceCell::new(); +// Automatically delete any databases created before the start of the test binary. +static DO_CLEANUP: AtomicBool = AtomicBool::new(true); + +impl TestSupport for MySql { + fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { + Box::pin(async move { + let res = test_context(args).await; + res + }) + } + + fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx::test]") + .acquire() + .await?; + + let db_id = db_id(db_name); + + conn.execute(&format!("drop database if exists {};", db_name)[..]) + .await?; + + query("delete from _sqlx_test_databases where db_id = ?") + .bind(&db_id) + .execute(&mut conn) + .await?; + + Ok(()) + }) + } + + fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { + Box::pin(async move { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let mut conn = MySqlConnection::connect(&url).await?; + let num_deleted = do_cleanup(&mut conn).await?; + let _ = conn.close().await; + Ok(Some(num_deleted)) + }) + } + + fn snapshot( + _conn: &mut Self::Connection, + ) -> BoxFuture<'_, Result, Error>> { + // TODO: I want to get the testing feature out the door so this will have to wait, + // but I'm keeping the code around for now because I plan to come back to it. + todo!() + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let master_opts = MySqlConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); + + let pool = PoolOptions::new() + // MySql's normal connection limit is 150 plus 1 superuser connection + // We don't want to use the whole cap and there may be fuzziness here due to + // concurrently running tests anyway. + .max_connections(20) + // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(master_opts); + + let master_pool = match MASTER_POOL.try_insert(pool) { + Ok(inserted) => inserted, + Err((existing, pool)) => { + // Sanity checks. + assert_eq!( + existing.connect_options().host, + pool.connect_options().host, + "DATABASE_URL changed at runtime, host differs" + ); + + assert_eq!( + existing.connect_options().database, + pool.connect_options().database, + "DATABASE_URL changed at runtime, database differs" + ); + + existing + } + }; + + let mut conn = master_pool.acquire().await?; + + // language=MySQL + conn.execute( + r#" + create table if not exists _sqlx_test_databases ( + db_id bigint unsigned primary key auto_increment, + test_path text not null, + created_at timestamp not null default current_timestamp + ); + "#, + ) + .await?; + + // Only run cleanup if the test binary just started. + if DO_CLEANUP.swap(false, Ordering::SeqCst) { + do_cleanup(&mut conn).await?; + } + + query("insert into _sqlx_test_databases(test_path) values (?)") + .bind(&args.test_path) + .execute(&mut conn) + .await?; + + // MySQL doesn't have `INSERT ... RETURNING` + let new_db_id: u64 = query_scalar("select last_insert_id()") + .fetch_one(&mut conn) + .await?; + + let new_db_name = db_name(new_db_id); + + conn.execute(&format!("create database {}", new_db_name)[..]) + .await?; + + eprintln!("created database {}", new_db_name); + + Ok(TestContext { + pool_opts: PoolOptions::new() + // Don't allow a single test to take all the connections. + // Most tests shouldn't require more than 5 connections concurrently, + // or else they're likely doing too much in one test. + .max_connections(5) + // Close connections ASAP if left in the idle queue. + .idle_timeout(Some(Duration::from_secs(1))) + .parent(master_pool.clone()), + connect_opts: master_pool.connect_options().clone().database(&new_db_name), + db_name: new_db_name, + }) +} + +async fn do_cleanup(conn: &mut MySqlConnection) -> Result { + let delete_db_ids: Vec = query_scalar( + "select db_id from _sqlx_test_databases where created_at < current_timestamp()", + ) + .fetch_all(&mut *conn) + .await?; + + if delete_db_ids.is_empty() { + return Ok(0); + } + + let mut deleted_db_ids = Vec::with_capacity(delete_db_ids.len()); + + let mut command = String::new(); + + for db_id in delete_db_ids { + command.clear(); + + let db_name = db_name(db_id); + + writeln!(command, "drop database if exists {}", db_name).ok(); + match conn.execute(&*command).await { + Ok(_deleted) => { + deleted_db_ids.push(db_id); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {:?}: {}", db_id, dbe) + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + let mut query = QueryBuilder::new("delete from _sqlx_test_databases where db_id in ("); + + let mut separated = query.separated(","); + + for db_id in &deleted_db_ids { + separated.push_bind(db_id); + } + + drop(separated); + + query.push(")").build().execute(&mut *conn).await?; + + Ok(deleted_db_ids.len()) +} + +fn db_name(id: u64) -> String { + format!("_sqlx_test_database_{}", id) +} + +fn db_id(name: &str) -> u64 { + name.trim_start_matches("_sqlx_test_database_") + .parse() + .unwrap_or_else(|_1| panic!("failed to parse ID from database name {:?}", name)) +} + +#[test] +fn test_db_name_id() { + assert_eq!(db_name(12345), "_sqlx_test_database_12345"); + assert_eq!(db_id("_sqlx_test_database_12345"), 12345); +} diff --git a/sqlx-core/src/net/socket.rs b/sqlx-core/src/net/socket.rs index df345147ff..622a1a22ce 100644 --- a/sqlx-core/src/net/socket.rs +++ b/sqlx-core/src/net/socket.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::io; +use std::net::SocketAddr; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; @@ -30,6 +31,14 @@ impl Socket { .map(Socket::Unix) } + pub fn local_addr(&self) -> Option { + match self { + Self::Tcp(tcp) => tcp.local_addr().ok(), + #[cfg(unix)] + Self::Unix(_) => None, + } + } + #[cfg(not(unix))] pub async fn connect_uds(_: impl AsRef) -> io::Result { Err(io::Error::new( diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 86c0ec5d8c..7bfae7fc78 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -3,7 +3,7 @@ use crate::connection::ConnectOptions; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use crate::pool::{deadline_as_timeout, CloseEvent, PoolOptions}; +use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; use crossbeam_queue::ArrayQueue; use futures_intrusive::sync::{Semaphore, SemaphoreReleaser}; @@ -12,16 +12,13 @@ use std::cmp; use std::future::Future; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::Poll; use crate::pool::options::PoolConnectionMetadata; +use futures_util::future::{self}; +use futures_util::FutureExt; use std::time::{Duration, Instant}; -/// Ihe number of permits to release to wake all waiters, such as on `PoolInner::close()`. -/// -/// This should be large enough to realistically wake all tasks waiting on the pool without -/// potentially overflowing the permits count in the semaphore itself. -const WAKE_ALL_PERMITS: usize = usize::MAX / 2; - pub(crate) struct PoolInner { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, @@ -40,16 +37,19 @@ impl PoolInner { ) -> Arc { let capacity = options.max_connections as usize; - // ensure the permit count won't overflow if we release `WAKE_ALL_PERMITS` - // this assert should never fire on 64-bit targets as `max_connections` is a u32 - let _ = capacity - .checked_add(WAKE_ALL_PERMITS) - .expect("max_connections exceeds max capacity of the pool"); + let semaphore_capacity = if let Some(parent) = &options.parent_pool { + assert!(options.max_connections <= parent.options().max_connections); + assert_eq!(options.fair, parent.options().fair); + // The child pool must steal permits from the parent + 0 + } else { + capacity + }; let pool = Self { connect_options, idle_conns: ArrayQueue::new(capacity), - semaphore: Semaphore::new(options.fair, capacity), + semaphore: Semaphore::new(options.fair, semaphore_capacity), size: AtomicU32::new(0), num_idle: AtomicUsize::new(0), is_closed: AtomicBool::new(false), @@ -82,31 +82,22 @@ impl PoolInner { } pub(super) fn close<'a>(self: &'a Arc) -> impl Future + 'a { - let already_closed = self.is_closed.swap(true, Ordering::AcqRel); - - if !already_closed { - // if we were the one to mark this closed, release enough permits to wake all waiters - // we can't just do `usize::MAX` because that would overflow - // and we can't do this more than once cause that would _also_ overflow - self.semaphore.release(WAKE_ALL_PERMITS); - self.on_closed.notify(usize::MAX); - } + self.is_closed.store(true, Ordering::Release); + self.on_closed.notify(usize::MAX); async move { - // Close any currently idle connections in the pool. - while let Some(idle) = self.idle_conns.pop() { - let _ = idle.live.float((*self).clone()).close().await; - } + for permits in 1..=self.options.max_connections as usize { + // Close any currently idle connections in the pool. + while let Some(idle) = self.idle_conns.pop() { + let _ = idle.live.float((*self).clone()).close().await; + } - // Wait for all permits to be released. - let _permits = self - .semaphore - .acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize)) - .await; + if self.size() == 0 { + break; + } - // Clean up any remaining connections. - while let Some(idle) = self.idle_conns.pop() { - let _ = idle.live.float((*self).clone()).close().await; + // Wait for all permits to be released. + let _permits = self.semaphore.acquire(permits).await; } } } @@ -117,6 +108,67 @@ impl PoolInner { } } + /// Attempt to pull a permit from `self.semaphore` or steal one from the parent. + /// + /// If we steal a permit from the parent but *don't* open a connection, + /// it should be returned to the parent. + async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> { + let parent = self + .parent() + // If we're already at the max size, we shouldn't try to steal from the parent. + // This is just going to cause unnecessary churn in `acquire()`. + .filter(|_| self.size() < self.options.max_connections); + + let acquire_self = self.semaphore.acquire(1).fuse(); + let mut close_event = self.close_event(); + + if let Some(parent) = parent { + let acquire_parent = parent.0.semaphore.acquire(1); + let parent_close_event = parent.0.close_event(); + + futures_util::pin_mut!( + acquire_parent, + acquire_self, + close_event, + parent_close_event + ); + + let mut poll_parent = false; + + future::poll_fn(|cx| { + if close_event.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(Error::PoolClosed)); + } + + if parent_close_event.as_mut().poll(cx).is_ready() { + // Propagate the parent's close event to the child. + let _ = self.close(); + return Poll::Ready(Err(Error::PoolClosed)); + } + + if let Poll::Ready(permit) = acquire_self.as_mut().poll(cx) { + return Poll::Ready(Ok(permit)); + } + + // Don't try the parent right away. + if poll_parent { + acquire_parent.as_mut().poll(cx).map(Ok) + } else { + poll_parent = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + .await + } else { + close_event.do_until(acquire_self).await + } + } + + fn parent(&self) -> Option<&Pool> { + self.options.parent_pool.as_ref() + } + #[inline] pub(super) fn try_acquire(self: &Arc) -> Option>> { if self.is_closed() { @@ -124,6 +176,7 @@ impl PoolInner { } let permit = self.semaphore.try_acquire(1)?; + self.pop_idle(permit).ok() } @@ -184,11 +237,9 @@ impl PoolInner { self.options.acquire_timeout, async { loop { - let permit = self.semaphore.acquire(1).await; + // Handles the close-event internally + let permit = self.acquire_permit().await?; - if self.is_closed() { - return Err(Error::PoolClosed); - } // First attempt to pop a connection from the idle queue. let guard = match self.pop_idle(permit) { @@ -207,7 +258,12 @@ impl PoolInner { // we can open a new connection guard } else { + // This can happen for a child pool that's at its connection limit. log::debug!("woke but was unable to acquire idle connection or open new one; retrying"); + // If so, we're likely in the current-thread runtime if it's Tokio + // and so we should yield to let any spawned release_to_pool() tasks + // execute. + sqlx_rt::yield_now().await; continue; } }; @@ -334,6 +390,15 @@ impl PoolInner { } } +impl Drop for PoolInner { + fn drop(&mut self) { + if let Some(parent) = &self.options.parent_pool { + // Release the stolen permits. + parent.0.semaphore.release(self.semaphore.permits()); + } + } +} + /// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. fn is_beyond_max_lifetime(live: &Live, options: &PoolOptions) -> bool { options @@ -486,6 +551,8 @@ impl DecrementSizeGuard { } /// Release the semaphore permit without decreasing the pool size. + /// + /// If the permit was stolen from the pool's parent, it will be returned to the child's semaphore. fn release_permit(self) { self.pool.semaphore.release(1); self.cancel(); diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 6cd6282bf0..18f8106886 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -80,6 +80,8 @@ pub struct PoolOptions { pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, pub(crate) fair: bool, + + pub(crate) parent_pool: Option>, } /// Metadata for the connection being processed by a [`PoolOptions`] callback. @@ -125,6 +127,7 @@ impl PoolOptions { idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, + parent_pool: None, } } @@ -400,6 +403,19 @@ impl PoolOptions { self } + /// Set the parent `Pool` from which the new pool will inherit its semaphore. + /// + /// This is currently an internal-only API. + /// + /// ### Panics + /// If `self.max_connections` is greater than the setting the given pool was created with, + /// or `self.fair` differs from the setting the given pool was created with. + #[doc(hidden)] + pub fn parent(mut self, pool: Pool) -> Self { + self.parent_pool = Some(pool); + self + } + /// Create a new pool from this `PoolOptions` and immediately open at least one connection. /// /// This ensures the configuration is correct. diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 0f6bd7c0be..00abc9c967 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -24,6 +24,9 @@ mod value; #[cfg(feature = "migrate")] mod migrate; +#[cfg(feature = "migrate")] +mod testing; + pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey}; pub use arguments::{PgArgumentBuffer, PgArguments}; pub use column::PgColumn; diff --git a/sqlx-core/src/postgres/testing/mod.rs b/sqlx-core/src/postgres/testing/mod.rs new file mode 100644 index 0000000000..254312e669 --- /dev/null +++ b/sqlx-core/src/postgres/testing/mod.rs @@ -0,0 +1,204 @@ +use std::fmt::Write; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration}; + +use futures_core::future::BoxFuture; + +use once_cell::sync::{OnceCell}; + +use crate::connection::Connection; + +use crate::error::Error; +use crate::executor::Executor; +use crate::pool::{Pool, PoolOptions}; +use crate::postgres::{PgConnectOptions, PgConnection, Postgres}; +use crate::query::query; +use crate::query_scalar::query_scalar; +use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport}; + +// Using a blocking `OnceCell` here because the critical sections are short. +static MASTER_POOL: OnceCell> = OnceCell::new(); +// Automatically delete any databases created before the start of the test binary. +static DO_CLEANUP: AtomicBool = AtomicBool::new(true); + +impl TestSupport for Postgres { + fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { + Box::pin(async move { + let res = test_context(args).await; + res + }) + } + + fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx::test]") + .acquire() + .await?; + + conn.execute(&format!("drop database if exists {0:?};", db_name)[..]) + .await?; + + query("delete from _sqlx_test.databases where db_name = $1") + .bind(&db_name) + .execute(&mut conn) + .await?; + + Ok(()) + }) + } + + fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { + Box::pin(async move { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let mut conn = PgConnection::connect(&url).await?; + let num_deleted = do_cleanup(&mut conn).await?; + let _ = conn.close().await; + Ok(Some(num_deleted)) + }) + } + + fn snapshot( + _conn: &mut Self::Connection, + ) -> BoxFuture<'_, Result, Error>> { + // TODO: I want to get the testing feature out the door so this will have to wait, + // but I'm keeping the code around for now because I plan to come back to it. + todo!() + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let master_opts = PgConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); + + let pool = PoolOptions::new() + // Postgres' normal connection limit is 100 plus 3 superuser connections + // We don't want to use the whole cap and there may be fuzziness here due to + // concurrently running tests anyway. + .max_connections(20) + // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(master_opts); + + let master_pool = match MASTER_POOL.try_insert(pool) { + Ok(inserted) => inserted, + Err((existing, pool)) => { + // Sanity checks. + assert_eq!( + existing.connect_options().host, + pool.connect_options().host, + "DATABASE_URL changed at runtime, host differs" + ); + + assert_eq!( + existing.connect_options().database, + pool.connect_options().database, + "DATABASE_URL changed at runtime, database differs" + ); + + existing + } + }; + + let mut conn = master_pool.acquire().await?; + + // language=PostgreSQL + conn.execute( + // Explicit lock avoids this latent bug: https://stackoverflow.com/a/29908840 + // I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically, + // but a clearly related bug with `CREATE TABLE` has been known since 2007: + // https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org + r#" + lock table pg_catalog.pg_namespace in share row exclusive mode; + + create schema if not exists _sqlx_test; + + create table if not exists _sqlx_test.databases ( + db_name text primary key, + test_path text not null, + created_at timestamptz not null default now() + ); + + create index if not exists databases_created_at + on _sqlx_test.databases(created_at); + + create sequence if not exists _sqlx_test.database_ids; + "#, + ) + .await?; + + // Only run cleanup if the test binary just started. + if DO_CLEANUP.swap(false, Ordering::SeqCst) { + do_cleanup(&mut conn).await?; + } + + let new_db_name: String = query_scalar( + r#" + insert into _sqlx_test.databases(db_name, test_path) + select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1 + returning db_name + "#, + ) + .bind(&args.test_path) + .fetch_one(&mut conn) + .await?; + + conn.execute(&format!("create database {:?}", new_db_name)[..]) + .await?; + + Ok(TestContext { + pool_opts: PoolOptions::new() + // Don't allow a single test to take all the connections. + // Most tests shouldn't require more than 5 connections concurrently, + // or else they're likely doing too much in one test. + .max_connections(5) + // Close connections ASAP if left in the idle queue. + .idle_timeout(Some(Duration::from_secs(1))) + .parent(master_pool.clone()), + connect_opts: master_pool.connect_options().clone().database(&new_db_name), + db_name: new_db_name, + }) +} + +async fn do_cleanup(conn: &mut PgConnection) -> Result { + let delete_db_names: Vec = + query_scalar("select db_name from _sqlx_test.databases where created_at < now()") + .fetch_all(&mut *conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(0); + } + + let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); + let delete_db_names = delete_db_names.into_iter(); + + let mut command = String::new(); + + for db_name in delete_db_names { + command.clear(); + writeln!(command, "drop database if exists {:?};", db_name).ok(); + match conn.execute(&*command).await { + Ok(_deleted) => { + deleted_db_names.push(db_name); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {:?}: {}", db_name, dbe) + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + query("delete from _sqlx_test.databases where db_name = any($1::text[])") + .bind(&deleted_db_names) + .execute(&mut *conn) + .await?; + + Ok(deleted_db_names.len()) +} diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index 291394204f..d12c9e5f02 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -43,6 +43,9 @@ mod value; #[cfg(feature = "migrate")] mod migrate; +#[cfg(feature = "migrate")] +mod testing; + /// An alias for [`Pool`][crate::pool::Pool], specialized for SQLite. pub type SqlitePool = crate::pool::Pool; diff --git a/sqlx-core/src/sqlite/testing/mod.rs b/sqlx-core/src/sqlite/testing/mod.rs new file mode 100644 index 0000000000..f3e48e6b7c --- /dev/null +++ b/sqlx-core/src/sqlite/testing/mod.rs @@ -0,0 +1,81 @@ +use crate::error::Error; +use crate::pool::PoolOptions; +use crate::sqlite::{Sqlite, SqliteConnectOptions}; +use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport}; +use futures_core::future::BoxFuture; +use std::path::{Path, PathBuf}; + +const BASE_PATH: &str = "target/sqlx/test-dbs"; + +impl TestSupport for Sqlite { + fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { + Box::pin(async move { + let res = test_context(args).await; + res + }) + } + + fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { Ok(sqlx_rt::fs::remove_file(db_name).await?) }) + } + + fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { + Box::pin(async move { + sqlx_rt::fs::remove_dir_all(BASE_PATH).await?; + Ok(None) + }) + } + + fn snapshot( + _conn: &mut Self::Connection, + ) -> BoxFuture<'_, Result, Error>> { + todo!() + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let db_path = convert_path(args.test_path); + + if let Some(parent_path) = Path::parent(db_path.as_ref()) { + sqlx_rt::fs::create_dir_all(parent_path) + .await + .expect("failed to create folders"); + } + + if Path::exists(db_path.as_ref()) { + sqlx_rt::fs::remove_file(&db_path) + .await + .expect("failed to remove database from previous test run"); + } + + Ok(TestContext { + connect_opts: SqliteConnectOptions::new() + .filename(&db_path) + .create_if_missing(true), + // This doesn't really matter for SQLite as the databases are independent of each other. + // The main limitation is going to be the number of concurrent running tests. + pool_opts: PoolOptions::new().max_connections(1000), + db_name: db_path, + }) +} + +fn convert_path(test_path: &str) -> String { + let mut path = PathBuf::from(BASE_PATH); + + for segment in test_path.split("::") { + path.push(segment); + } + + path.set_extension("sqlite"); + + path.into_os_string() + .into_string() + .expect("path should be UTF-8") +} + +#[test] +fn test_convert_path() { + let path = convert_path("foo::bar::baz::quux"); + + assert_eq!(path, "target/sqlx/test-dbs/foo/bar/baz/quux.sqlite"); +} diff --git a/sqlx-core/src/testing/fixtures.rs b/sqlx-core/src/testing/fixtures.rs new file mode 100644 index 0000000000..2cb7043eb0 --- /dev/null +++ b/sqlx-core/src/testing/fixtures.rs @@ -0,0 +1,280 @@ +//! TODO: automatic test fixture capture + +use crate::database::{Database, HasArguments}; + +use crate::query_builder::QueryBuilder; + +use indexmap::set::IndexSet; +use std::cmp; +use std::collections::{BTreeMap, HashMap}; +use std::marker::PhantomData; +use std::sync::Arc; + +pub type Result = std::result::Result; + +/// A snapshot of the current state of the database. +/// +/// Can be used to generate an `INSERT` fixture for populating an empty database, +/// or in the future it may be possible to generate a fixture from the difference between +/// two snapshots. +pub struct FixtureSnapshot { + tables: BTreeMap, + db: PhantomData, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not create fixture: {0}")] +pub struct FixtureError(String); + +pub struct Fixture { + ops: Vec, + db: PhantomData, +} + +enum FixtureOp { + Insert { + table: TableName, + columns: Vec, + rows: Vec>, + }, + // TODO: handle updates and deletes by diffing two snapshots +} + +type TableName = Arc; +type ColumnName = Arc; +type Value = String; + +struct Table { + name: TableName, + columns: IndexSet, + rows: Vec>, + foreign_keys: HashMap, +} + +macro_rules! fixture_assert ( + ($cond:expr, $msg:literal $($arg:tt)*) => { + if !($cond) { + return Err(FixtureError(format!($msg $($arg)*))) + } + } +); + +impl FixtureSnapshot { + /// Generate a fixture to reproduce this snapshot from an empty database using `INSERT`s. + /// + /// Note that this doesn't take into account any triggers that might modify the data before + /// it's stored. + /// + /// The `INSERT` statements are ordered on a best-effort basis to satisfy any foreign key + /// constraints (data from tables with no foreign keys are inserted first, then the tables + /// that reference those tables, and so on). + /// + /// If a cycle in foreign-key constraints is detected, this returns with an error. + pub fn additive_fixture(&self) -> Result> { + let visit_order = self.calculate_visit_order()?; + + let mut ops = Vec::new(); + + for table_name in visit_order { + let table = self.tables.get(&table_name).unwrap(); + + ops.push(FixtureOp::Insert { + table: table_name, + columns: table.columns.iter().cloned().collect(), + rows: table.rows.clone(), + }); + } + + Ok(Fixture { ops, db: self.db }) + } + + /// Determine an order for outputting `INSERTS` for each table by calculating the max + /// length of all its foreign key chains. + /// + /// This should hopefully ensure that there are no foreign-key errors. + fn calculate_visit_order(&self) -> Result> { + let mut table_depths = HashMap::with_capacity(self.tables.len()); + let mut visited_set = IndexSet::with_capacity(self.tables.len()); + + for table in self.tables.values() { + foreign_key_depth(&self.tables, table, &mut table_depths, &mut visited_set)?; + visited_set.clear(); + } + + let mut table_names: Vec = table_depths.keys().cloned().collect(); + table_names.sort_by_key(|name| table_depths.get(name).unwrap()); + Ok(table_names) + } +} + +/// Implements `ToString` but not `Display` because it uses [`QueryBuilder`] internally, +/// which appends to an internal string. +impl ToString for Fixture +where + for<'a> >::Arguments: Default, +{ + fn to_string(&self) -> String { + let mut query = QueryBuilder::::new(""); + + for op in &self.ops { + match op { + FixtureOp::Insert { + table, + columns, + rows, + } => { + // Sanity check, empty tables shouldn't appear in snapshots anyway. + if columns.is_empty() || rows.is_empty() { + continue; + } + + query.push(format_args!("INSERT INTO {} (", table)); + + let mut separated = query.separated(", "); + + for column in columns { + separated.push(column); + } + + query.push(")\n"); + + query.push_values(rows, |mut separated, row| { + for value in row { + separated.push(value); + } + }); + + query.push(";\n"); + } + } + } + + query.into_sql() + } +} + +fn foreign_key_depth( + tables: &BTreeMap, + table: &Table, + depths: &mut HashMap, + visited_set: &mut IndexSet, +) -> Result { + if let Some(&depth) = depths.get(&table.name) { + return Ok(depth); + } + + // This keeps us from looping forever. + fixture_assert!( + visited_set.insert(table.name.clone()), + "foreign key cycle detected: {:?} -> {:?}", + visited_set, + table.name + ); + + let mut refdepth = 0; + + for (colname, (refname, refcol)) in &table.foreign_keys { + let referenced = tables.get(refname).ok_or_else(|| { + FixtureError(format!( + "table {:?} in foreign key `{}.{} references {}.{}` does not exist", + refname, table.name, colname, refname, refcol + )) + })?; + + refdepth = cmp::max( + refdepth, + foreign_key_depth(tables, referenced, depths, visited_set)?, + ); + } + + let depth = refdepth + 1; + + depths.insert(table.name.clone(), depth); + + Ok(depth) +} + +#[test] +#[cfg(feature = "postgres")] +fn test_additive_fixture() -> Result<()> { + use crate::postgres::Postgres; + + let mut snapshot = FixtureSnapshot { + tables: BTreeMap::new(), + db: PhantomData::, + }; + + snapshot.tables.insert( + "foo".into(), + Table { + name: "foo".into(), + columns: ["foo_id", "foo_a", "foo_b"] + .into_iter() + .map(Arc::::from) + .collect(), + rows: vec![vec!["1".into(), "'asdf'".into(), "true".into()]], + foreign_keys: HashMap::new(), + }, + ); + + // foreign-keyed to `foo` + // since `tables` is a `BTreeMap` we would expect a naive algorithm to visit this first. + snapshot.tables.insert( + "bar".into(), + Table { + name: "bar".into(), + columns: ["bar_id", "foo_id", "bar_a", "bar_b"] + .into_iter() + .map(Arc::::from) + .collect(), + rows: vec![vec![ + "1234".into(), + "1".into(), + "'2022-07-22 23:27:48.775113301+00:00'".into(), + "3.14".into(), + ]], + foreign_keys: [("foo_id".into(), ("foo".into(), "foo_id".into()))] + .into_iter() + .collect(), + }, + ); + + // foreign-keyed to both `foo` and `bar` + snapshot.tables.insert( + "baz".into(), + Table { + name: "baz".into(), + columns: ["baz_id", "bar_id", "foo_id", "baz_a", "baz_b"] + .into_iter() + .map(Arc::::from) + .collect(), + rows: vec![vec![ + "5678".into(), + "1234".into(), + "1".into(), + "'2022-07-22 23:27:48.775113301+00:00'".into(), + "3.14".into(), + ]], + foreign_keys: [ + ("foo_id".into(), ("foo".into(), "foo_id".into())), + ("bar_id".into(), ("bar".into(), "bar_id".into())), + ] + .into_iter() + .collect(), + }, + ); + + let fixture = snapshot.additive_fixture()?; + + assert_eq!( + fixture.to_string(), + "INSERT INTO foo (foo_id, foo_a, foo_b)\n\ + VALUES (1, 'asdf', true);\n\ + INSERT INTO bar (bar_id, foo_id, bar_a, bar_b)\n\ + VALUES (1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n\ + INSERT INTO baz (baz_id, bar_id, foo_id, baz_a, baz_b)\n\ + VALUES (5678, 1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n" + ); + + Ok(()) +} diff --git a/sqlx-core/src/testing/mod.rs b/sqlx-core/src/testing/mod.rs new file mode 100644 index 0000000000..5183c914a9 --- /dev/null +++ b/sqlx-core/src/testing/mod.rs @@ -0,0 +1,262 @@ +use std::future::Future; +use std::time::Duration; + +use futures_core::future::BoxFuture; + +pub use fixtures::FixtureSnapshot; +use sqlx_rt::test_block_on; + +use crate::connection::{ConnectOptions, Connection}; +use crate::database::Database; +use crate::error::Error; +use crate::executor::Executor; +use crate::migrate::{Migrate, Migrator}; +use crate::pool::{Pool, PoolConnection, PoolOptions}; + +mod fixtures; + +pub trait TestSupport: Database { + /// Get parameters to construct a `Pool` suitable for testing. + /// + /// This `Pool` instance will behave somewhat specially: + /// * all handles share a single global semaphore to avoid exceeding the connection limit + /// on the database server. + /// * each invocation results in a different temporary database. + /// + /// The implementation may require `DATABASE_URL` to be set in order to manage databases. + /// The user credentials it contains must have the privilege to create and drop databases. + fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>>; + + fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>>; + + /// Cleanup any test databases that are no longer in-use. + /// + /// Returns a count of the databases deleted, if possible. + /// + /// The implementation may require `DATABASE_URL` to be set in order to manage databases. + /// The user credentials it contains must have the privilege to create and drop databases. + fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>>; + + /// Take a snapshot of the current state of the database (data only). + /// + /// This snapshot can then be used to generate test fixtures. + fn snapshot(conn: &mut Self::Connection) + -> BoxFuture<'_, Result, Error>>; +} + +pub struct TestFixture { + pub path: &'static str, + pub contents: &'static str, +} + +pub struct TestArgs { + pub test_path: &'static str, + pub migrator: Option<&'static Migrator>, + pub fixtures: &'static [TestFixture], +} + +pub trait TestFn { + type Output; + + fn run_test(self, args: TestArgs) -> Self::Output; +} + +pub trait TestTermination { + fn is_success(&self) -> bool; +} + +pub struct TestContext { + pub pool_opts: PoolOptions, + pub connect_opts: ::Options, + pub db_name: String, +} + +impl TestFn for fn(Pool) -> Fut +where + DB: TestSupport + Database, + DB::Connection: Migrate, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + Fut: Future, + Fut::Output: TestTermination, +{ + type Output = Fut::Output; + + fn run_test(self, args: TestArgs) -> Self::Output { + run_test_with_pool(args, self) + } +} + +impl TestFn for fn(PoolConnection) -> Fut +where + DB: TestSupport + Database, + DB::Connection: Migrate, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + Fut: Future, + Fut::Output: TestTermination, +{ + type Output = Fut::Output; + + fn run_test(self, args: TestArgs) -> Self::Output { + run_test_with_pool(args, |pool| async move { + let conn = pool + .acquire() + .await + .expect("failed to acquire test pool connection"); + let res = (self)(conn).await; + pool.close().await; + res + }) + } +} + +impl TestFn for fn(PoolOptions, ::Options) -> Fut +where + DB: Database + TestSupport, + DB::Connection: Migrate, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + Fut: Future, + Fut::Output: TestTermination, +{ + type Output = Fut::Output; + + fn run_test(self, args: TestArgs) -> Self::Output { + run_test(args, self) + } +} + +impl TestFn for fn() -> Fut +where + Fut: Future, +{ + type Output = Fut::Output; + + fn run_test(self, args: TestArgs) -> Self::Output { + assert!( + args.fixtures.is_empty(), + "fixtures cannot be applied for a bare function" + ); + test_block_on(self()) + } +} + +impl TestArgs { + pub fn new(test_path: &'static str) -> Self { + TestArgs { + test_path, + migrator: None, + fixtures: &[], + } + } + + pub fn migrator(&mut self, migrator: &'static Migrator) { + self.migrator = Some(migrator); + } + + pub fn fixtures(&mut self, fixtures: &'static [TestFixture]) { + self.fixtures = fixtures; + } +} + +impl TestTermination for () { + fn is_success(&self) -> bool { + true + } +} + +impl TestTermination for Result { + fn is_success(&self) -> bool { + self.is_ok() + } +} + +fn run_test_with_pool(args: TestArgs, test_fn: F) -> Fut::Output +where + DB: TestSupport, + DB::Connection: Migrate, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + F: FnOnce(Pool) -> Fut, + Fut: Future, + Fut::Output: TestTermination, +{ + let test_path = args.test_path; + run_test::(args, |pool_opts, connect_opts| async move { + let pool = pool_opts + .connect_with(connect_opts) + .await + .expect("failed to connect test pool"); + + let res = test_fn(pool.clone()).await; + + let close_timed_out = sqlx_rt::timeout(Duration::from_secs(10), pool.close()) + .await + .is_err(); + + if close_timed_out { + eprintln!("test {} held onto Pool after exiting", test_path); + } + + res + }) +} + +fn run_test(args: TestArgs, test_fn: F) -> Fut::Output +where + DB: TestSupport, + DB::Connection: Migrate, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + F: FnOnce(PoolOptions, ::Options) -> Fut, + Fut: Future, + Fut::Output: TestTermination, +{ + test_block_on(async move { + let test_context = DB::test_context(&args) + .await + .expect("failed to connect to setup test database"); + + setup_test_db::(&test_context.connect_opts, &args).await; + + let res = test_fn(test_context.pool_opts, test_context.connect_opts).await; + + if res.is_success() { + if let Err(e) = DB::cleanup_test(&test_context.db_name).await { + eprintln!( + "failed to delete database {:?}: {}", + test_context.db_name, e + ); + } + } + + res + }) +} + +async fn setup_test_db( + copts: &::Options, + args: &TestArgs, +) where + DB::Connection: Migrate + Sized, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + let mut conn = copts + .connect() + .await + .expect("failed to connect to test database"); + + if let Some(migrator) = args.migrator { + migrator + .run_direct(&mut conn) + .await + .expect("failed to apply migrations"); + } + + for fixture in args.fixtures { + (&mut conn) + .execute(fixture.contents) + .await + .unwrap_or_else(|e| panic!("failed to apply test fixture {:?}: {:?}", fixture.path, e)); + } + + conn.close() + .await + .expect("failed to close setup connection"); +} diff --git a/sqlx-macros/src/common.rs b/sqlx-macros/src/common.rs index 1e4dc37435..fab09b7cae 100644 --- a/sqlx-macros/src/common.rs +++ b/sqlx-macros/src/common.rs @@ -2,8 +2,8 @@ use proc_macro2::Span; use std::env; use std::path::{Path, PathBuf}; -pub(crate) fn resolve_path(path: &str, err_span: Span) -> syn::Result { - let path = Path::new(path); +pub(crate) fn resolve_path(path: impl AsRef, err_span: Span) -> syn::Result { + let path = path.as_ref(); if path.is_absolute() { return Err(syn::Error::new( diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 459879a431..c858b204dd 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -21,6 +21,9 @@ mod database; mod derives; mod query; +// The compiler gives misleading help messages about `#[cfg(test)]` when this is just named `test`. +mod test_attr; + #[cfg(feature = "migrate")] mod migrate; @@ -84,7 +87,7 @@ pub fn migrate(input: TokenStream) -> TokenStream { use syn::LitStr; let input = syn::parse_macro_input!(input as LitStr); - match migrate::expand_migrator_from_dir(input) { + match migrate::expand_migrator_from_lit_dir(input) { Ok(ts) => ts.into(), Err(e) => { if let Some(parse_err) = e.downcast_ref::() { @@ -97,40 +100,20 @@ pub fn migrate(input: TokenStream) -> TokenStream { } } -#[doc(hidden)] #[proc_macro_attribute] -pub fn test(_attr: TokenStream, input: TokenStream) -> TokenStream { +pub fn test(args: TokenStream, input: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(args as syn::AttributeArgs); let input = syn::parse_macro_input!(input as syn::ItemFn); - let ret = &input.sig.output; - let name = &input.sig.ident; - let body = &input.block; - let attrs = &input.attrs; - - let result = if cfg!(feature = "_rt-tokio") { - quote! { - #[test] - #(#attrs)* - fn #name() #ret { - ::sqlx_rt::tokio::runtime::Builder::new_multi_thread() - .enable_io() - .enable_time() - .build() - .unwrap() - .block_on(async { #body }) - } - } - } else if cfg!(feature = "_rt-async-std") { - quote! { - #[test] - #(#attrs)* - fn #name() #ret { - ::sqlx_rt::async_std::task::block_on(async { #body }) + match test_attr::expand(args, input) { + Ok(ts) => ts.into(), + Err(e) => { + if let Some(parse_err) = e.downcast_ref::() { + parse_err.to_compile_error().into() + } else { + let msg = e.to_string(); + quote!(::std::compile_error!(#msg)).into() } } - } else { - panic!("one of 'runtime-actix', 'runtime-async-std' or 'runtime-tokio' features must be enabled"); - }; - - result.into() + } } diff --git a/sqlx-macros/src/migrate.rs b/sqlx-macros/src/migrate.rs index 018ba1b41e..a463b9dc4b 100644 --- a/sqlx-macros/src/migrate.rs +++ b/sqlx-macros/src/migrate.rs @@ -3,6 +3,7 @@ use quote::{quote, ToTokens, TokenStreamExt}; use sha2::{Digest, Sha384}; use sqlx_core::migrate::MigrationType; use std::fs; +use std::path::Path; use syn::LitStr; pub struct QuotedMigrationType(MigrationType); @@ -56,8 +57,20 @@ impl ToTokens for QuotedMigration { } // mostly copied from sqlx-core/src/migrate/source.rs -pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result { - let path = crate::common::resolve_path(&dir.value(), dir.span())?; +pub(crate) fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result { + expand_migrator_from_dir(&dir.value(), dir.span()) +} + +pub(crate) fn expand_migrator_from_dir( + dir: &str, + err_span: proc_macro2::Span, +) -> crate::Result { + let path = crate::common::resolve_path(dir, err_span)?; + + expand_migrator(&path) +} + +pub(crate) fn expand_migrator(path: &Path) -> crate::Result { let mut migrations = Vec::new(); for entry in fs::read_dir(&path)? { diff --git a/sqlx-macros/src/test_attr.rs b/sqlx-macros/src/test_attr.rs new file mode 100644 index 0000000000..149a225e51 --- /dev/null +++ b/sqlx-macros/src/test_attr.rs @@ -0,0 +1,217 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::LitStr; + +struct Args { + fixtures: Vec, + migrations: MigrationsOpt, +} + +enum MigrationsOpt { + InferredPath, + ExplicitPath(LitStr), + ExplicitMigrator(syn::Path), + Disabled, +} + +pub fn expand(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result { + if input.sig.inputs.is_empty() { + if !args.is_empty() { + if cfg!(feature = "migrate") { + return Err(syn::Error::new_spanned( + args.first().unwrap(), + "control attributes are not allowed unless \ + the `migrate` feature is enabled and \ + automatic test DB management is used; see docs", + ) + .into()); + } + + return Err(syn::Error::new_spanned( + args.first().unwrap(), + "control attributes are not allowed unless \ + automatic test DB management is used; see docs", + ) + .into()); + } + + return Ok(expand_simple(input)); + } + + #[cfg(feature = "migrate")] + return expand_advanced(args, input); + + #[cfg(not(feature = "migrate"))] + return Err(syn::Error::new_spanned(input, "`migrate` feature required").into()); +} + +fn expand_simple(input: syn::ItemFn) -> TokenStream { + let ret = &input.sig.output; + let name = &input.sig.ident; + let body = &input.block; + let attrs = &input.attrs; + + quote! { + #[test] + #(#attrs)* + fn #name() #ret { + ::sqlx::test_block_on(async { #body }) + } + } +} + +#[cfg(feature = "migrate")] +fn expand_advanced(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result { + let ret = &input.sig.output; + let name = &input.sig.ident; + let inputs = &input.sig.inputs; + let body = &input.block; + let attrs = &input.attrs; + + let args = parse_args(args)?; + + let fn_arg_types = inputs.iter().map(|_| quote! { _ }); + + let fixtures = args.fixtures.into_iter().map(|fixture| { + let path = format!("fixtures/{}.sql", fixture.value()); + + quote! { + ::sqlx::testing::TestFixture { + path: #path, + contents: include_str!(#path), + } + } + }); + + let migrations = match args.migrations { + MigrationsOpt::ExplicitPath(path) => { + let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?; + quote! { args.migrator(&#migrator); } + } + MigrationsOpt::InferredPath if !inputs.is_empty() => { + let migrations_path = crate::common::resolve_path("./migrations", Span::call_site())?; + + if migrations_path.is_dir() { + let migrator = crate::migrate::expand_migrator(&migrations_path)?; + quote! { args.migrator(&#migrator); } + } else { + quote! {} + } + } + MigrationsOpt::ExplicitMigrator(path) => { + quote! { args.migrator(&#path); } + } + _ => quote! {}, + }; + + Ok(quote! { + #[test] + #(#attrs)* + fn #name() #ret { + async fn inner(#inputs) #ret { + #body + } + + let mut args = ::sqlx::testing::TestArgs::new(concat!(module_path!(), "::", stringify!(#name))); + + #migrations + + args.fixtures(&[#(#fixtures),*]); + + // We need to give a coercion site or else we get "unimplemented trait" errors. + let f: fn(#(#fn_arg_types),*) -> _ = inner; + + ::sqlx::testing::TestFn::run_test(f, args) + } + }) +} + +#[cfg(feature = "migrate")] +fn parse_args(attr_args: syn::AttributeArgs) -> syn::Result { + let mut fixtures = vec![]; + let mut migrations = MigrationsOpt::InferredPath; + + for arg in attr_args { + match arg { + syn::NestedMeta::Meta(syn::Meta::List(list)) if list.path.is_ident("fixtures") => { + if !fixtures.is_empty() { + return Err(syn::Error::new_spanned(list, "duplicate `fixtures` arg")); + } + + for nested in list.nested { + match nested { + syn::NestedMeta::Lit(syn::Lit::Str(litstr)) => fixtures.push(litstr), + other => { + return Err(syn::Error::new_spanned(other, "expected string literal")) + } + } + } + } + syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) + if namevalue.path.is_ident("migrations") => + { + if !matches!(migrations, MigrationsOpt::InferredPath) { + return Err(syn::Error::new_spanned( + namevalue, + "cannot have more than one `migrations` or `migrator` arg", + )); + } + + migrations = match namevalue.lit { + syn::Lit::Bool(litbool) => { + if !litbool.value { + // migrations = false + MigrationsOpt::Disabled + } else { + // migrations = true + return Err(syn::Error::new_spanned( + litbool, + "`migrations = true` is redundant", + )); + } + } + // migrations = "" + syn::Lit::Str(litstr) => MigrationsOpt::ExplicitPath(litstr), + _ => { + return Err(syn::Error::new_spanned( + namevalue, + "expected string or `false`", + )) + } + }; + } + syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) + if namevalue.path.is_ident("migrator") => + { + if !matches!(migrations, MigrationsOpt::InferredPath) { + return Err(syn::Error::new_spanned( + namevalue, + "cannot have more than one `migrations` or `migrator` arg", + )); + } + + migrations = match namevalue.lit { + // migrator = "" + syn::Lit::Str(litstr) => MigrationsOpt::ExplicitMigrator(litstr.parse()?), + _ => { + return Err(syn::Error::new_spanned( + namevalue, + "expected string", + )) + } + }; + } + other => { + return Err(syn::Error::new_spanned( + other, + "expected `fixtures(\"\", ...)` or `migrations = \"\" | false` or `migrator = \"\"`", + )) + } + } + } + + Ok(Args { + fixtures, + migrations, + }) +} diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index 7df8139407..a0aac5b8ea 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -1,6 +1,10 @@ +//! Core runtime support for SQLx. **Semver-exempt**, not for general use. + #[cfg(not(any( + feature = "runtime-actix-native-tls", feature = "runtime-async-std-native-tls", feature = "runtime-tokio-native-tls", + feature = "runtime-actix-rustls", feature = "runtime-async-std-rustls", feature = "runtime-tokio-rustls", )))] @@ -11,6 +15,8 @@ compile_error!( ); #[cfg(any( + all(feature = "_rt-actix", feature = "_rt-async-std"), + all(feature = "_rt-actix", feature = "_rt-tokio"), all(feature = "_rt-async-std", feature = "_rt-tokio"), all(feature = "_tls-native-tls", feature = "_tls-rustls"), ))] @@ -20,116 +26,24 @@ compile_error!( 'runtime-tokio-rustls'] can be enabled" ); -#[cfg(all(feature = "_tls-native-tls"))] -pub use native_tls; - -// -// Tokio -// - -#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))] -pub use tokio::{ - self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, - time::sleep, time::timeout, -}; +#[cfg(feature = "_rt-async-std")] +mod rt_async_std; -#[cfg(all(unix, feature = "_rt-tokio", not(feature = "_rt-async-std")))] -pub use tokio::net::UnixStream; +#[cfg(any(feature = "_rt-tokio", feature = "_rt-actix"))] +mod rt_tokio; -#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))] -pub use tokio_runtime::{block_on, enter_runtime}; - -#[cfg(feature = "_rt-tokio")] -mod tokio_runtime { - use once_cell::sync::Lazy; - use tokio::runtime::{self, Runtime}; - - // lazily initialize a global runtime once for multiple invocations of the macros - static RUNTIME: Lazy = Lazy::new(|| { - runtime::Builder::new_multi_thread() - .enable_io() - .enable_time() - .build() - .expect("failed to initialize Tokio runtime") - }); - - pub fn block_on(future: F) -> F::Output { - RUNTIME.block_on(future) - } - - pub fn enter_runtime(f: F) -> R - where - F: FnOnce() -> R, - { - let _rt = RUNTIME.enter(); - f() - } -} - -#[cfg(all( - feature = "_tls-native-tls", - feature = "_rt-tokio", - not(any(feature = "_tls-rustls", feature = "_rt-async-std")), -))] -pub use tokio_native_tls::{TlsConnector, TlsStream}; - -#[cfg(all( - feature = "_tls-rustls", - feature = "_rt-tokio", - not(any(feature = "_tls-native-tls", feature = "_rt-async-std")), -))] -pub use tokio_rustls::{client::TlsStream, TlsConnector}; - -#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))] -#[macro_export] -macro_rules! blocking { - ($($expr:tt)*) => { - $crate::tokio::task::spawn_blocking(move || { $($expr)* }) - .await.expect("Blocking task failed to complete.") - }; -} +#[cfg(all(feature = "_tls-native-tls"))] +pub use native_tls; // -// async-std +// Actix *OR* Tokio // -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::{ - self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, - io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, - net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, -}; - -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -#[macro_export] -macro_rules! blocking { - ($($expr:tt)*) => { - $crate::async_std::task::spawn_blocking(move || { $($expr)* }).await - }; -} - -#[cfg(all(unix, feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::os::unix::net::UnixStream; - -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::task::block_on; - -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub fn enter_runtime(f: F) -> R -where - F: FnOnce() -> R, -{ - // no-op for async-std - f() -} - -#[cfg(all(feature = "async-native-tls", not(feature = "tokio-native-tls")))] -pub use async_native_tls::{TlsConnector, TlsStream}; +#[cfg(all(any(feature = "_rt-tokio", feature = "_rt-actix"),))] +pub use rt_tokio::*; #[cfg(all( - feature = "_tls-rustls", feature = "_rt-async-std", - not(any(feature = "_tls-native-tls", feature = "_rt-tokio")), + not(any(feature = "_rt-tokio", feature = "_rt-actix")) ))] -pub use futures_rustls::{client::TlsStream, TlsConnector}; +pub use rt_async_std::*; diff --git a/sqlx-rt/src/rt_async_std.rs b/sqlx-rt/src/rt_async_std.rs new file mode 100644 index 0000000000..aeca8541ab --- /dev/null +++ b/sqlx-rt/src/rt_async_std.rs @@ -0,0 +1,24 @@ +pub use async_std::{ + self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, + io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, + net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, +}; + +#[cfg(unix)] +pub use async_std::os::unix::net::UnixStream; + +#[cfg(all(feature = "_tls-native-tls", not(feature = "_tls-rustls")))] +pub use async_native_tls::{TlsConnector, TlsStream}; + +#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] +pub use futures_rustls::{client::TlsStream, TlsConnector}; + +pub use async_std::task::{block_on, block_on as test_block_on}; + +pub fn enter_runtime(f: F) -> R +where + F: FnOnce() -> R, +{ + // no-op for async-std + f() +} diff --git a/sqlx-rt/src/rt_tokio.rs b/sqlx-rt/src/rt_tokio.rs new file mode 100644 index 0000000000..b1d3bc8149 --- /dev/null +++ b/sqlx-rt/src/rt_tokio.rs @@ -0,0 +1,47 @@ +pub use tokio::{ + self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, + net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, + time::sleep, time::timeout, +}; + +#[cfg(unix)] +pub use tokio::net::UnixStream; + +use once_cell::sync::Lazy; +use tokio::runtime::{self, Runtime}; + +#[cfg(all(feature = "_tls-native-tls", not(feature = "_tls-rustls")))] +pub use tokio_native_tls::{TlsConnector, TlsStream}; + +#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] +pub use tokio_rustls::{client::TlsStream, TlsConnector}; + +// lazily initialize a global runtime once for multiple invocations of the macros +static RUNTIME: Lazy = Lazy::new(|| { + runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .build() + .expect("failed to initialize Tokio runtime") +}); + +pub fn block_on(future: F) -> F::Output { + RUNTIME.block_on(future) +} + +pub fn enter_runtime(f: F) -> R +where + F: FnOnce() -> R, +{ + let _rt = RUNTIME.enter(); + f() +} + +pub fn test_block_on(future: F) -> F::Output { + // For tests, we want a single runtime per thread for isolation. + runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to initialize Tokio test runtime") + .block_on(future) +} diff --git a/src/lib.rs b/src/lib.rs index a017f323dc..e6487d1c10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,6 +71,18 @@ pub extern crate sqlx_macros; #[doc(hidden)] pub use sqlx_macros::{FromRow, Type}; +// We can't do our normal facade approach with an attribute, but thankfully we can now +// have docs out-of-line quite easily. +#[doc = include_str!("macros/test.md")] +pub use sqlx_macros::test; + +#[doc(hidden)] +#[cfg(feature = "migrate")] +pub use sqlx_core::testing; + +#[doc(hidden)] +pub use sqlx_core::test_block_on; + #[cfg(feature = "macros")] mod macros; diff --git a/src/macros.rs b/src/macros/mod.rs similarity index 100% rename from src/macros.rs rename to src/macros/mod.rs diff --git a/src/macros/test.md b/src/macros/test.md new file mode 100644 index 0000000000..a1ec1cad61 --- /dev/null +++ b/src/macros/test.md @@ -0,0 +1,218 @@ +Mark an `async fn` as a test with SQLx support. + +The test will automatically be executed in the async runtime according to the chosen +`runtime-{async-std, tokio}-{native-tls, rustls}` feature. + +By default, this behaves identically to `#[tokio::test]`1 or `#[async_std::test]`: + +```rust,norun +# // Note if reading these examples directly in `test.md`: +# // lines prefixed with `#` are not meant to be shown; +# // they are supporting code to help the examples to compile successfully. +# #[cfg(feature = "_rt-tokio")] +#[sqlx::test] +async fn test_async_fn() { + tokio::task::yield_now().await; +} +``` + +However, several advanced features are also supported as shown in the next section. + +1`#[sqlx::test]` does not recognize any of the control arguments supported by `#[tokio::test]` +as that would have complicated the implementation. If your use case requires any of those, feel free to open an issue. + +### Automatic Test Database Management (requires `migrate` feature) + +`#[sqlx::test]` can automatically create test databases for you and provide live connections to your test. + +For every annotated function, a new test database is created so tests can run against a live database +but are isolated from each other. + +This feature is activated by changing the signature of your test function. The following signatures are supported: + +* `async fn(Pool) -> Ret` + * the `Pool`s used by all running tests share a single connection limit to avoid exceeding the server's limit. +* `async fn(PoolConnection) -> Ret` + * `PoolConnection`, etc. +* `async fn(PoolOptions, impl ConnectOptions) -> Ret` + * Where `impl ConnectOptions` is, e.g, `PgConnectOptions`, `MySqlConnectOptions`, etc. + * If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`), + you can use this signature. + +Where `DB` is a supported `Database` type and `Ret` is `()` or `Result<_, _>`. + +##### Supported Databases + +Most of these will require you to set `DATABASE_URL` as an environment variable +or in a `.env` file like `sqlx::query!()` _et al_, to give the test driver a superuser connection with which +to manage test databases. + + +| Database | Requires `DATABASE_URL` | +| --- | --- | +| Postgres | Yes | +| MySQL | Yes | +| SQLite | No2 | + +Test databases are automatically cleaned up as tests succeed, but failed tests will leave their databases in-place +to facilitate debugging. Note that to simplify the implementation, panics are _always_ considered to be failures, +even for `#[should_panic]` tests. + +If you have `sqlx-cli` installed, you can run `sqlx test-db cleanup` to delete all test databases. +Old test databases will also be deleted the next time a test binary using `#[sqlx::test]` is run. + +```rust,no_run +# #[cfg(all(feature = "migrate", feature = "postgres"))] +# mod example { +use sqlx::PgPool; + +#[sqlx::test] +async fn basic_test(pool: PgPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + sqlx::query("SELECT * FROM foo") + .fetch_one(&mut conn) + .await?; + + assert_eq!(foo.get::("bar"), "foobar!"); + + Ok(()) +} +# } +``` + +2 SQLite defaults to `target/sqlx/test-dbs/.sqlite` where `` is the path of the test function +converted to a filesystem path (`::` replaced with `/`). + +### Automatic Migrations (requires `migrate` feature) + +To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a +`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's +`Cargo.toml` resides). + +You can override the resolved path relative to `CARGO_MANIFEST_DIR` in the attribute (global overrides are not currently +supported): + +```rust,ignore +# #[cfg(all(feature = "migrate", feature = "postgres"))] +# mod example { +use sqlx::PgPool; + +#[sqlx::test(migrations = "foo_migrations")] +async fn basic_test(pool: PgPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + sqlx::query("SELECT * FROM foo") + .fetch_one(&mut conn) + .await?; + + assert_eq!(foo.get::("bar"), "foobar!"); + + Ok(()) +} +# } +``` + +Or if you're already embedding migrations in your main crate, you can reference them directly: + +`foo_crate/lib.rs` +```rust,ignore +pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("foo_migrations"); +``` + +`foo_crate/tests/foo_test.rs` +```rust,no_run +# #[cfg(all(feature = "migrate", feature = "postgres"))] +# mod example { +use sqlx::PgPool; + +# // This is standing in for the main crate since doc examples don't support multiple crates. +# mod foo_crate { +# use std::borrow::Cow; +# static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate::Migrator { +# migrations: Cow::Borrowed(&[]), +# ignore_missing: false, +# }; +# } + +// You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. +#[sqlx::test(migrator = "foo_crate::MIGRATOR")] +async fn basic_test(pool: PgPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + sqlx::query("SELECT * FROM foo") + .fetch_one(&mut conn) + .await?; + + assert_eq!(foo.get::("bar"), "foobar!"); + + Ok(()) +} +# } +``` + +Or disable migrations processing entirely: + +```rust,no_run +# #[cfg(all(feature = "migrate", feature = "postgres"))] +# mod example { +use sqlx::PgPool; + +#[sqlx::test(migrations = false)] +async fn basic_test(pool: PgPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + conn.execute("CREATE TABLE foo(bar text)").await?; + + sqlx::query("SELECT * FROM foo") + .fetch_one(&mut conn) + .await?; + + assert_eq!(foo.get::("bar"), "foobar!"); + + Ok(()) +} +# } +``` + +### Automatic Fixture Application (requires `migrate` feature) + +Since tests are isolated from each other but may require data to already exist in the database to keep from growing +exponentially in complexity, `#[sqlx::test]` also supports applying test fixtures, which are SQL scripts that function +similarly to migrations but are solely intended to insert test data and be arbitrarily composable. + +Imagine a basic social app that has users, posts and comments. To test the comment routes, you'd want +the database to already have users and posts in it so the comments tests don't have to duplicate that work. + +You can pass a list of fixture names to the attribute like so, and they will be applied in the given order3: + +```rust,no_run +# #[cfg(all(feature = "migrate", feature = "postgres"))] +# mod example { +# struct App {} +# fn create_app(pool: PgPool) -> App { App {} } +use sqlx::PgPool; +use serde_json::json; + +#[sqlx::test(fixtures("users", "posts"))] +async fn test_create_comment(pool: PgPool) -> sqlx::Result<()> { + // See examples/postgres/social-axum-with-tests for a more in-depth example. + let mut app = create_app(pool); + + let comment = test_request( + &mut app, "POST", "/v1/comment", json! { "postId": "1234" } + ).await?; + + assert_eq!(comment["postId"], "1234"); + + Ok(()) +} +# } +``` + +Fixtures are resolved relative to the current file as `./fixtures/{name}.sql`. + +3Ordering for test fixtures is entirely up to the application, and each test may choose which fixtures to +apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped +in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key +requirements are always satisfied, or else you might get errors. diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index d67c1e1dd3..bd8ce46297 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: "3" services: # - # MySQL 8.x, 5.7.x, 5.6.x + # MySQL 8.x, 5.7.x # https://www.mysql.com/support/supportedplatforms/database.html # @@ -29,20 +29,8 @@ services: MYSQL_ROOT_HOST: '%' MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - - mysql_5_6: - image: mysql:5.6 - volumes: - - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" - ports: - - 3306 - environment: - MYSQL_ROOT_HOST: '%' - MYSQL_ROOT_PASSWORD: password - MYSQL_DATABASE: sqlx - # - # MariaDB 10.6, 10.5, 10.4, 10.3, 10.2 + # MariaDB 10.6, 10.5, 10.4, 10.3 # https://mariadb.org/about/#maintenance-policy # @@ -86,18 +74,8 @@ services: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: sqlx - mariadb_10_2: - image: mariadb:10.2 - volumes: - - "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql" - ports: - - 3306 - environment: - MYSQL_ROOT_PASSWORD: password - MYSQL_DATABASE: sqlx - # - # PostgreSQL 14.x, 13.x, 12.x, 11.x 10.x, 9.6.x + # PostgreSQL 14.x, 13.x, 12.x, 11.x 10.x # https://www.postgresql.org/support/versioning/ # @@ -195,25 +173,6 @@ services: - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" command: > -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key - - postgres_9_6: - build: - context: . - dockerfile: postgres/Dockerfile - args: - VERSION: 9.6 - ports: - - 5432 - environment: - POSTGRES_DB: sqlx - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password - POSTGRES_HOST_AUTH_METHOD: md5 - volumes: - - "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql" - command: > - -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key - # # Microsoft SQL Server (MSSQL) # https://hub.docker.com/_/microsoft-mssql-server diff --git a/tests/mysql/fixtures/comments.sql b/tests/mysql/fixtures/comments.sql new file mode 100644 index 0000000000..7255de5e25 --- /dev/null +++ b/tests/mysql/fixtures/comments.sql @@ -0,0 +1,16 @@ +insert into comment(comment_id, post_id, user_id, content, created_at) +values (1, + 1, + 2, + 'lol bet ur still bad, 1v1 me', + timestamp(now(), '-0:50:00')), + (2, + 1, + 1, + 'you''re on!', + timestamp(now(), '-0:45:00')), + (3, + 2, + 1, + 'lol you''re just mad you lost :P', + timestamp(now(), '-0:15:00')); diff --git a/tests/mysql/fixtures/posts.sql b/tests/mysql/fixtures/posts.sql new file mode 100644 index 0000000000..d692f3a1bd --- /dev/null +++ b/tests/mysql/fixtures/posts.sql @@ -0,0 +1,9 @@ +insert into post(post_id, user_id, content, created_at) +values (1, + 1, + 'This new computer is lightning-fast!', + timestamp(now(), '-1:00:00')), + (2, + 2, + '@alice is a haxxor :(', + timestamp(now(), '-0:30:00')); diff --git a/tests/mysql/fixtures/users.sql b/tests/mysql/fixtures/users.sql new file mode 100644 index 0000000000..9c4813c027 --- /dev/null +++ b/tests/mysql/fixtures/users.sql @@ -0,0 +1,2 @@ +insert into user(user_id, username) +values (1, 'alice'), (2, 'bob'); diff --git a/tests/mysql/migrations/1_user.sql b/tests/mysql/migrations/1_user.sql new file mode 100644 index 0000000000..0fc2b61d88 --- /dev/null +++ b/tests/mysql/migrations/1_user.sql @@ -0,0 +1,7 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key auto_increment, + -- indexed text values have to have a max length + username varchar(16) unique not null +); diff --git a/tests/mysql/migrations/2_post.sql b/tests/mysql/migrations/2_post.sql new file mode 100644 index 0000000000..3863f3bc11 --- /dev/null +++ b/tests/mysql/migrations/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key auto_increment, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default current_timestamp +); + +create index post_created_at on post (created_at desc); diff --git a/tests/mysql/migrations/3_comment.sql b/tests/mysql/migrations/3_comment.sql new file mode 100644 index 0000000000..c0fe7ea235 --- /dev/null +++ b/tests/mysql/migrations/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references user (user_id), + content text not null, + created_at datetime default current_timestamp +); + +create index comment_created_at on comment (created_at desc); diff --git a/tests/mysql/test-attr.rs b/tests/mysql/test-attr.rs new file mode 100644 index 0000000000..5b96609cba --- /dev/null +++ b/tests/mysql/test-attr.rs @@ -0,0 +1,96 @@ +// The no-arg variant is covered by other tests already. + +use sqlx::{MySqlPool, Row}; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/mysql/migrations"); + +#[sqlx::test] +async fn it_gets_a_pool(pool: MySqlPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + let db_name: String = sqlx::query_scalar("select database()") + .fetch_one(&mut conn) + .await?; + + assert!( + db_name.starts_with("_sqlx_test_database_"), + "db_name: {:?}", + db_name + ); + + Ok(()) +} + +// This should apply migrations and then `fixtures/users.sql` +#[sqlx::test(migrations = "tests/mysql/migrations", fixtures("users"))] +async fn it_gets_users(pool: MySqlPool) -> sqlx::Result<()> { + let usernames: Vec = + sqlx::query_scalar(r#"SELECT username FROM user ORDER BY username"#) + .fetch_all(&pool) + .await?; + + assert_eq!(usernames, ["alice", "bob"]); + + let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)") + .fetch_one(&pool) + .await?; + + assert!(!post_exists); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +#[sqlx::test(migrations = "tests/mysql/migrations", fixtures("users", "posts"))] +async fn it_gets_posts(pool: MySqlPool) -> sqlx::Result<()> { + let post_contents: Vec = + sqlx::query_scalar("SELECT content FROM post ORDER BY created_at") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_contents, + [ + "This new computer is lightning-fast!", + "@alice is a haxxor :(" + ] + ); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +// Try `migrator` +#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] +async fn it_gets_comments(pool: MySqlPool) -> sqlx::Result<()> { + let post_1_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at") + .bind(&1) + .fetch_all(&pool) + .await?; + + assert_eq!( + post_1_comments, + ["lol bet ur still bad, 1v1 me", "you're on!"] + ); + + let post_2_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at") + .bind(&2) + .fetch_all(&pool) + .await?; + + assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); + + Ok(()) +} diff --git a/tests/postgres/fixtures/comments.sql b/tests/postgres/fixtures/comments.sql new file mode 100644 index 0000000000..a982bd98d7 --- /dev/null +++ b/tests/postgres/fixtures/comments.sql @@ -0,0 +1,16 @@ +insert into comment(comment_id, post_id, user_id, content, created_at) +values ('fbbbb7dc-dc6f-4649-b663-8d3636035164', + '252c1d98-a9b0-4f18-8298-e59058bdfe16', + '297923c5-a83c-4052-bab0-030887154e52', + 'lol bet ur still bad, 1v1 me', + now() + '50 minutes ago'::interval), + ('cb7612a2-cff4-4e3e-a768-055f01f25dc4', + '252c1d98-a9b0-4f18-8298-e59058bdfe16', + '297923c5-a83c-4052-bab0-030887154e52', + 'you''re on!', + now() + '45 minutes ago'::interval), + ('f2164fcc-a770-4f52-8714-d9cc6a1c89cf', + '844265f7-2472-4689-9a2e-b21f40dbf401', + '297923c5-a83c-4052-bab0-030887154e52', + 'lol you''re just mad you lost :P', + now() + '15 minutes ago'::interval); diff --git a/tests/postgres/fixtures/posts.sql b/tests/postgres/fixtures/posts.sql new file mode 100644 index 0000000000..b563ec0839 --- /dev/null +++ b/tests/postgres/fixtures/posts.sql @@ -0,0 +1,14 @@ +insert into post(post_id, user_id, content, created_at) +values + ( + '252c1d98-a9b0-4f18-8298-e59058bdfe16', + '6592b7c0-b531-4613-ace5-94246b7ce0c3', + 'This new computer is lightning-fast!', + now() + '1 hour ago'::interval + ), + ( + '844265f7-2472-4689-9a2e-b21f40dbf401', + '6592b7c0-b531-4613-ace5-94246b7ce0c3', + '@alice is a haxxor :(', + now() + '30 minutes ago'::interval + ); diff --git a/tests/postgres/fixtures/users.sql b/tests/postgres/fixtures/users.sql new file mode 100644 index 0000000000..571fb829ed --- /dev/null +++ b/tests/postgres/fixtures/users.sql @@ -0,0 +1,2 @@ +insert into "user"(user_id, username) +values ('6592b7c0-b531-4613-ace5-94246b7ce0c3', 'alice'), ('297923c5-a83c-4052-bab0-030887154e52', 'bob'); diff --git a/tests/postgres/migrations/0_setup.sql b/tests/postgres/migrations/0_setup.sql new file mode 100644 index 0000000000..b0138f0c31 --- /dev/null +++ b/tests/postgres/migrations/0_setup.sql @@ -0,0 +1,2 @@ +-- `gen_random_uuid()` wasn't added until Postgres 13 +create extension if not exists "uuid-ossp"; diff --git a/tests/postgres/migrations/1_user.sql b/tests/postgres/migrations/1_user.sql new file mode 100644 index 0000000000..4ea11dd5a7 --- /dev/null +++ b/tests/postgres/migrations/1_user.sql @@ -0,0 +1,5 @@ +create table "user" +( + user_id uuid primary key default uuid_generate_v1mc(), + username text unique not null +); diff --git a/tests/postgres/migrations/2_post.sql b/tests/postgres/migrations/2_post.sql new file mode 100644 index 0000000000..ff929746ce --- /dev/null +++ b/tests/postgres/migrations/2_post.sql @@ -0,0 +1,8 @@ +create table post ( + post_id uuid primary key default uuid_generate_v1mc(), + user_id uuid not null references "user"(user_id), + content text not null, + created_at timestamptz default now() +); + +create index on post(created_at desc); diff --git a/tests/postgres/migrations/3_comment.sql b/tests/postgres/migrations/3_comment.sql new file mode 100644 index 0000000000..f841e4f666 --- /dev/null +++ b/tests/postgres/migrations/3_comment.sql @@ -0,0 +1,9 @@ +create table comment ( + comment_id uuid primary key default uuid_generate_v1mc(), + post_id uuid not null references post(post_id), + user_id uuid not null references "user"(user_id), + content text not null, + created_at timestamptz not null default now() +); + +create index on comment(created_at desc); diff --git a/tests/postgres/test-attr.rs b/tests/postgres/test-attr.rs new file mode 100644 index 0000000000..1ebf0ff989 --- /dev/null +++ b/tests/postgres/test-attr.rs @@ -0,0 +1,94 @@ +// The no-arg variant is covered by other tests already. + +use sqlx::PgPool; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/postgres/migrations"); + +#[sqlx::test] +async fn it_gets_a_pool(pool: PgPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + let db_name: String = sqlx::query_scalar("SELECT current_database()") + .fetch_one(&mut conn) + .await?; + + assert!(db_name.starts_with("_sqlx_test"), "dbname: {db_name:?}"); + + Ok(()) +} + +// This should apply migrations and then `fixtures/users.sql` +#[sqlx::test(migrations = "tests/postgres/migrations", fixtures("users"))] +async fn it_gets_users(pool: PgPool) -> sqlx::Result<()> { + let usernames: Vec = + sqlx::query_scalar(r#"SELECT username FROM "user" ORDER BY username"#) + .fetch_all(&pool) + .await?; + + assert_eq!(usernames, ["alice", "bob"]); + + let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)") + .fetch_one(&pool) + .await?; + + assert!(!post_exists); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +#[sqlx::test(migrations = "tests/postgres/migrations", fixtures("users", "posts"))] +async fn it_gets_posts(pool: PgPool) -> sqlx::Result<()> { + let post_contents: Vec = + sqlx::query_scalar("SELECT content FROM post ORDER BY created_at") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_contents, + [ + "This new computer is lightning-fast!", + "@alice is a haxxor :(" + ] + ); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +// Try `migrator` +#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] +async fn it_gets_comments(pool: PgPool) -> sqlx::Result<()> { + let post_1_comments: Vec = sqlx::query_scalar( + "SELECT content FROM comment WHERE post_id = $1::uuid ORDER BY created_at", + ) + .bind(&"252c1d98-a9b0-4f18-8298-e59058bdfe16") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_1_comments, + ["lol bet ur still bad, 1v1 me", "you're on!"] + ); + + let post_2_comments: Vec = sqlx::query_scalar( + "SELECT content FROM comment WHERE post_id = $1::uuid ORDER BY created_at", + ) + .bind(&"844265f7-2472-4689-9a2e-b21f40dbf401") + .fetch_all(&pool) + .await?; + + assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); + + Ok(()) +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c6e6a4db3c..a05fdd0130 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -547,7 +547,7 @@ test_prepared_type!(money_vec>(Postgres, "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], )); -// FIXME: needed to disable `ltree` tests in Postgres 9.6 +// FIXME: needed to disable `ltree` tests in version that don't have a binary format for it // but `PgLTree` should just fall back to text format #[cfg(postgres_14)] test_type!(ltree(Postgres, @@ -555,7 +555,7 @@ test_type!(ltree(Postgres, "'Alpha.Beta.Delta.Gamma'::ltree" == sqlx::postgres::types::PgLTree::from_iter(["Alpha", "Beta", "Delta", "Gamma"]).unwrap(), )); -// FIXME: needed to disable `ltree` tests in Postgres 9.6 +// FIXME: needed to disable `ltree` tests in version that don't have a binary format for it // but `PgLTree` should just fall back to text format #[cfg(postgres_14)] test_type!(ltree_vec>(Postgres, diff --git a/tests/sqlite/fixtures/comments.sql b/tests/sqlite/fixtures/comments.sql new file mode 100644 index 0000000000..d6b2188128 --- /dev/null +++ b/tests/sqlite/fixtures/comments.sql @@ -0,0 +1,16 @@ +insert into comment(comment_id, post_id, user_id, content, created_at) +values (1, + 1, + 2, + 'lol bet ur still bad, 1v1 me', + datetime('now', '-50 minutes')), + (2, + 1, + 1, + 'you''re on!', + datetime('now', '-45 minutes')), + (3, + 2, + 1, + 'lol you''re just mad you lost :P', + datetime('now', '-15 minutes')); diff --git a/tests/sqlite/fixtures/posts.sql b/tests/sqlite/fixtures/posts.sql new file mode 100644 index 0000000000..e48975f84e --- /dev/null +++ b/tests/sqlite/fixtures/posts.sql @@ -0,0 +1,9 @@ +insert into post(post_id, user_id, content, created_at) +values (1, + 1, + 'This new computer is lightning-fast!', + datetime('now', '-1 hour')), + (2, + 2, + '@alice is a haxxor :(', + datetime('now', '-30 minutes')); diff --git a/tests/sqlite/fixtures/users.sql b/tests/sqlite/fixtures/users.sql new file mode 100644 index 0000000000..627f7d9b3c --- /dev/null +++ b/tests/sqlite/fixtures/users.sql @@ -0,0 +1,2 @@ +insert into "user"(user_id, username) +values (1, 'alice'), (2, 'bob'); diff --git a/tests/sqlite/migrations/1_user.sql b/tests/sqlite/migrations/1_user.sql new file mode 100644 index 0000000000..100b750f19 --- /dev/null +++ b/tests/sqlite/migrations/1_user.sql @@ -0,0 +1,6 @@ +create table user +( + -- integer primary keys are the most efficient in SQLite + user_id integer primary key, + username text unique not null +); diff --git a/tests/sqlite/migrations/2_post.sql b/tests/sqlite/migrations/2_post.sql new file mode 100644 index 0000000000..74d2460596 --- /dev/null +++ b/tests/sqlite/migrations/2_post.sql @@ -0,0 +1,10 @@ +create table post +( + post_id integer primary key, + user_id integer not null references user (user_id), + content text not null, + -- Defaults have to be wrapped in parenthesis + created_at datetime default (datetime('now')) +); + +create index post_created_at on post (created_at desc); diff --git a/tests/sqlite/migrations/3_comment.sql b/tests/sqlite/migrations/3_comment.sql new file mode 100644 index 0000000000..a98b2628fc --- /dev/null +++ b/tests/sqlite/migrations/3_comment.sql @@ -0,0 +1,10 @@ +create table comment +( + comment_id integer primary key, + post_id integer not null references post (post_id), + user_id integer not null references "user" (user_id), + content text not null, + created_at datetime default (datetime('now')) +); + +create index comment_created_at on comment (created_at desc); diff --git a/tests/sqlite/test-attr.rs b/tests/sqlite/test-attr.rs new file mode 100644 index 0000000000..dcc5a4d756 --- /dev/null +++ b/tests/sqlite/test-attr.rs @@ -0,0 +1,99 @@ +// The no-arg variant is covered by other tests already. + +use sqlx::{Row, SqlitePool}; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/sqlite/migrations"); + +#[sqlx::test] +async fn it_gets_a_pool(pool: SqlitePool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + // https://www.sqlite.org/pragma.html#pragma_database_list + let db = sqlx::query("PRAGMA database_list") + .fetch_one(&mut conn) + .await?; + + let db_name = db.get::(2); + + assert!( + db_name.ends_with("target/sqlx/test-dbs/sqlite_test_attr/it_gets_a_pool.sqlite"), + "db_name: {:?}", + db_name + ); + + Ok(()) +} + +// This should apply migrations and then `fixtures/users.sql` +#[sqlx::test(migrations = "tests/sqlite/migrations", fixtures("users"))] +async fn it_gets_users(pool: SqlitePool) -> sqlx::Result<()> { + let usernames: Vec = + sqlx::query_scalar(r#"SELECT username FROM "user" ORDER BY username"#) + .fetch_all(&pool) + .await?; + + assert_eq!(usernames, ["alice", "bob"]); + + let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)") + .fetch_one(&pool) + .await?; + + assert!(!post_exists); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +#[sqlx::test(migrations = "tests/sqlite/migrations", fixtures("users", "posts"))] +async fn it_gets_posts(pool: SqlitePool) -> sqlx::Result<()> { + let post_contents: Vec = + sqlx::query_scalar("SELECT content FROM post ORDER BY created_at") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_contents, + [ + "This new computer is lightning-fast!", + "@alice is a haxxor :(" + ] + ); + + let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)") + .fetch_one(&pool) + .await?; + + assert!(!comment_exists); + + Ok(()) +} + +// Try `migrator` +#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] +async fn it_gets_comments(pool: SqlitePool) -> sqlx::Result<()> { + let post_1_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at") + .bind(&1) + .fetch_all(&pool) + .await?; + + assert_eq!( + post_1_comments, + ["lol bet ur still bad, 1v1 me", "you're on!"] + ); + + let post_2_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at") + .bind(&2) + .fetch_all(&pool) + .await?; + + assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); + + Ok(()) +} diff --git a/tests/x.py b/tests/x.py index 33e5ffce9e..6b8785d83f 100755 --- a/tests/x.py +++ b/tests/x.py @@ -130,7 +130,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data # postgres # - for version in ["14", "13", "12", "11", "10", "9_6"]: + for version in ["14", "13", "12", "11", "10"]: run( f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}", comment=f"test postgres {version}", @@ -139,7 +139,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data ) ## +ssl - for version in ["14", "13", "12", "11", "10", "9_6"]: + for version in ["14", "13", "12", "11", "10"]: run( f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}", comment=f"test postgres {version} ssl", @@ -152,7 +152,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data # mysql # - for version in ["8", "5_7", "5_6"]: + for version in ["8", "5_7"]: run( f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-{tls}", comment=f"test mysql {version}",