diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 66238af37..0501fa6e9 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -65,6 +65,8 @@ jobs: run: cargo build --profile ci --package janus_messages --no-default-features - name: Build janus_core run: cargo build --profile ci --package janus_core + - name: Test janus_client OHTTP + run: cargo test --package janus_client --features ohttp # Note: keep Build & Test steps consecutive, and match flags other than `--no-run`. - name: Build run: cargo test --profile ci --locked --all-targets --no-run @@ -132,7 +134,7 @@ jobs: - name: Clippy (all features) run: cargo clippy --profile ci --workspace --all-targets --all-features - name: Document - run: cargo doc --profile ci --workspace + run: cargo doc --profile ci --workspace --all-features - name: cargo-deny uses: EmbarkStudios/cargo-deny-action@v1.6.3 with: diff --git a/Cargo.lock b/Cargo.lock index 1224c12c0..65bbd7870 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aead" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" +dependencies = [ + "generic-array", + "rand_core 0.6.4", +] + [[package]] name = "aead" version = "0.5.2" @@ -37,6 +47,18 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aes" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" +dependencies = [ + "cfg-if", + "cipher 0.3.0", + "cpufeatures 0.2.12", + "opaque-debug", +] + [[package]] name = "aes" version = "0.8.4" @@ -44,8 +66,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", - "cipher", - "cpufeatures", + "cipher 0.4.4", + "cpufeatures 0.2.12", +] + +[[package]] +name = "aes-gcm" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc3be92e19a7ef47457b8e6f90707e12b6ac5d20c6f3866584fa3be0787d839f" +dependencies = [ + "aead 0.4.3", + "aes 0.7.5", + "cipher 0.3.0", + "ctr 0.7.0", + "ghash 0.4.4", + "subtle", ] [[package]] @@ -54,11 +90,11 @@ version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" dependencies = [ - "aead", - "aes", - "cipher", - "ctr", - "ghash", + "aead 0.5.2", + "aes 0.8.4", + "cipher 0.4.4", + "ctr 0.9.2", + "ghash 0.5.1", "subtle", ] @@ -602,6 +638,16 @@ dependencies = [ "smallvec", ] +[[package]] +name = "bhttp" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef06386f8f092c3419e153a657396e53cafbb901de445a5c54d96ab2ff8c7b2" +dependencies = [ + "thiserror", + "url", +] + [[package]] name = "bindgen" version = "0.69.4" @@ -652,6 +698,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -782,6 +837,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fee7ad89dc1128635074c268ee661f90c3f7e83d9fd12910608c36b47d6c3412" +dependencies = [ + "cfg-if", + "cipher 0.3.0", + "cpufeatures 0.1.5", + "zeroize", +] + [[package]] name = "chacha20" version = "0.9.1" @@ -789,8 +856,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", - "cipher", - "cpufeatures", + "cipher 0.4.4", + "cpufeatures 0.2.12", +] + +[[package]] +name = "chacha20poly1305" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1580317203210c517b6d44794abfbe600698276db18127e37ad3e69bf5e848e5" +dependencies = [ + "aead 0.4.3", + "chacha20 0.7.1", + "cipher 0.3.0", + "poly1305 0.7.2", + "zeroize", ] [[package]] @@ -799,10 +879,10 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ - "aead", - "chacha20", - "cipher", - "poly1305", + "aead 0.5.2", + "chacha20 0.9.1", + "cipher 0.4.4", + "poly1305 0.8.0", "zeroize", ] @@ -827,6 +907,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d18b093eba54c9aaa1e3784d4361eb2ba944cf7d0a932a830132238f483e8d8" +[[package]] +name = "cipher" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +dependencies = [ + "generic-array", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1001,6 +1090,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "cpufeatures" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" +dependencies = [ + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -1106,13 +1204,32 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-mac" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25fab6889090c8133f3deb8f73ba3c65a7f456f66436fc012a1b1e272b1e103e" +dependencies = [ + "generic-array", + "subtle", +] + +[[package]] +name = "ctr" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a232f92a03f37dd7d7dd2adc67166c77e9cd88de5b019b9a9eecfaeaf7bfd481" +dependencies = [ + "cipher 0.3.0", +] + [[package]] name = "ctr" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ - "cipher", + "cipher 0.4.4", ] [[package]] @@ -1277,7 +1394,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", "crypto-common", "subtle", @@ -1382,7 +1499,7 @@ dependencies = [ "ff", "generic-array", "group", - "hkdf", + "hkdf 0.12.4", "rand_core 0.6.4", "sec1 0.3.0", "subtle", @@ -1832,6 +1949,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "ghash" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" +dependencies = [ + "opaque-debug", + "polyval 0.5.3", +] + [[package]] name = "ghash" version = "0.5.1" @@ -1839,7 +1966,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" dependencies = [ "opaque-debug", - "polyval", + "polyval 0.6.2", ] [[package]] @@ -2000,13 +2127,33 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" +[[package]] +name = "hkdf" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01706d578d5c281058480e673ae4086a9f4710d8df1ad80a5b03e39ece5f886b" +dependencies = [ + "digest 0.9.0", + "hmac 0.11.0", +] + [[package]] name = "hkdf" version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" dependencies = [ - "hmac", + "hmac 0.12.1", +] + +[[package]] +name = "hmac" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" +dependencies = [ + "crypto-mac", + "digest 0.9.0", ] [[package]] @@ -2033,17 +2180,17 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf39e5461bfdc6ad0fbc97067519fcaf96a7a2e67f24cc0eb8a1e7c0c45af792" dependencies = [ - "aead", - "aes-gcm", + "aead 0.5.2", + "aes-gcm 0.10.3", "byteorder", - "chacha20poly1305", + "chacha20poly1305 0.10.1", "digest 0.10.7", "generic-array", - "hkdf", - "hmac", + "hkdf 0.12.4", + "hmac 0.12.1", "p256", "rand_core 0.6.4", - "sha2", + "sha2 0.10.8", "subtle", "x25519-dalek", "zeroize", @@ -2613,6 +2760,7 @@ version = "0.7.12" dependencies = [ "assert_matches", "backoff", + "bhttp", "derivative", "hex-literal", "http 1.1.0", @@ -2620,6 +2768,7 @@ dependencies = [ "janus_core", "janus_messages 0.7.12", "mockito", + "ohttp", "prio", "rand", "reqwest", @@ -2924,7 +3073,7 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.12", ] [[package]] @@ -3346,6 +3495,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "ohttp" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578cb11a3fb5c85697ed8bb850d5ad1cbf819d3eea05c2b253cf1d240fbb10c5" +dependencies = [ + "aead 0.4.3", + "aes-gcm 0.9.2", + "byteorder", + "chacha20poly1305 0.8.0", + "hex", + "hkdf 0.11.0", + "hpke", + "lazy_static", + "log", + "rand", + "serde", + "serde_derive", + "sha2 0.9.9", + "thiserror", + "toml", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -3635,7 +3807,7 @@ checksum = "934cd7631c050f4674352a6e835d5f6711ffbfb9345c2fc0107155ac495ae293" dependencies = [ "once_cell", "pest", - "sha2", + "sha2 0.10.8", ] [[package]] @@ -3757,15 +3929,38 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "poly1305" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" +dependencies = [ + "cpufeatures 0.2.12", + "opaque-debug", + "universal-hash 0.4.0", +] + [[package]] name = "poly1305" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.12", "opaque-debug", - "universal-hash", + "universal-hash 0.5.1", +] + +[[package]] +name = "polyval" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.12", + "opaque-debug", + "universal-hash 0.4.0", ] [[package]] @@ -3775,9 +3970,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.12", "opaque-debug", - "universal-hash", + "universal-hash 0.5.1", ] [[package]] @@ -3811,11 +4006,11 @@ dependencies = [ "byteorder", "bytes", "fallible-iterator", - "hmac", + "hmac 0.12.1", "md-5", "memchr", "rand", - "sha2", + "sha2 0.10.8", "stringprep", ] @@ -3874,15 +4069,15 @@ version = "0.16.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d81f366140fb9dfdaf6f3be9b937bf598e5dbf8aafb9dbd182b1f3be196bc2" dependencies = [ - "aes", + "aes 0.8.4", "bitvec", "byteorder", - "ctr", + "ctr 0.9.2", "fiat-crypto", "fixed", "getrandom", "hex", - "hmac", + "hmac 0.12.1", "num-bigint", "num-integer", "num-iter", @@ -3893,7 +4088,7 @@ dependencies = [ "rayon", "serde", "serde_json", - "sha2", + "sha2 0.10.8", "sha3", "subtle", "thiserror", @@ -4742,10 +4937,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.12", "digest 0.10.7", ] +[[package]] +name = "sha2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures 0.2.12", + "digest 0.9.0", + "opaque-debug", +] + [[package]] name = "sha2" version = "0.10.8" @@ -4753,7 +4961,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.12", "digest 0.10.7", ] @@ -5018,7 +5226,7 @@ dependencies = [ "rustls-pemfile 1.0.4", "serde", "serde_json", - "sha2", + "sha2 0.10.8", "smallvec", "sqlformat", "thiserror", @@ -5057,7 +5265,7 @@ dependencies = [ "quote", "serde", "serde_json", - "sha2", + "sha2 0.10.8", "sqlx-core", "sqlx-mysql", "sqlx-postgres", @@ -5089,8 +5297,8 @@ dependencies = [ "futures-util", "generic-array", "hex", - "hkdf", - "hmac", + "hkdf 0.12.4", + "hmac 0.12.1", "itoa", "log", "md-5", @@ -5101,7 +5309,7 @@ dependencies = [ "rsa", "serde", "sha1", - "sha2", + "sha2 0.10.8", "smallvec", "sqlx-core", "stringprep", @@ -5128,8 +5336,8 @@ dependencies = [ "futures-io", "futures-util", "hex", - "hkdf", - "hmac", + "hkdf 0.12.4", + "hmac 0.12.1", "home", "itoa", "log", @@ -5139,7 +5347,7 @@ dependencies = [ "rand", "serde", "serde_json", - "sha2", + "sha2 0.10.8", "smallvec", "sqlx-core", "stringprep", @@ -5537,6 +5745,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + [[package]] name = "toml_datetime" version = "0.6.5" @@ -6136,6 +6353,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "universal-hash" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8326b2c654932e3e4f9196e69d08fdf7cfd718e1dc6f66b347e6024a0c961402" +dependencies = [ + "generic-array", + "subtle", +] + [[package]] name = "universal-hash" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 9e929a80c..a61116902 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ assert_matches = "1" async-trait = "0.1" backoff = "0.4.0" base64 = "0.22.1" +bhttp = "0.5.1" bytes = "1" cfg-if = "1.0.0" # Disable default features to disable compatibility with the old `time` crate, and we also don't @@ -62,6 +63,7 @@ k8s-openapi = { version = "0.21.0", features = ["v1_26"] } # keep this version kube = { version = "0.90.0", default-features = false, features = ["client", "rustls-tls"] } mockito = "1.4.0" num_enum = "0.7.2" +ohttp = { version = "0.5.1", default-features = false } opentelemetry = { version = "0.22", features = ["metrics"] } opentelemetry-otlp = "0.15" opentelemetry-prometheus = "0.15" diff --git a/client/Cargo.toml b/client/Cargo.toml index 8307884d3..f9779ec64 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -9,13 +9,22 @@ repository.workspace = true rust-version.workspace = true version.workspace = true +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[features] +ohttp = ["dep:ohttp", "dep:bhttp"] + [dependencies] backoff = { workspace = true, features = ["tokio"] } +bhttp = { workspace = true, features = ["bhttp", "http"], optional = true } derivative.workspace = true http.workspace = true itertools.workspace = true janus_core.workspace = true janus_messages.workspace = true +ohttp = { workspace = true, default-features = false, features = ["client", "rust-hpke"], optional = true } prio.workspace = true rand.workspace = true reqwest = { workspace = true, features = ["json"] } @@ -29,5 +38,7 @@ assert_matches.workspace = true hex-literal = { workspace = true } janus_core = { workspace = true, features = ["test-util"] } mockito = { workspace = true } +ohttp = { workspace = true, default-features = true } +tokio.workspace = true tracing-log = { workspace = true } tracing-subscriber = { workspace = true, features = ["std", "env-filter", "fmt"] } diff --git a/client/src/lib.rs b/client/src/lib.rs index c470df2e2..b59a2d4cb 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -38,8 +38,12 @@ //! ``` use backoff::ExponentialBackoff; +#[cfg(feature = "ohttp")] +use bhttp::{ControlData, Message, Mode}; use derivative::Derivative; -use http::header::CONTENT_TYPE; +#[cfg(feature = "ohttp")] +use http::{header::ACCEPT, HeaderValue}; +use http::{header::CONTENT_TYPE, StatusCode}; use itertools::Itertools; use janus_core::{ hpke::{self, is_hpke_config_supported, HpkeApplicationInfo, Label}, @@ -52,13 +56,16 @@ use janus_messages::{ Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, ReportMetadata, Role, TaskId, Time, }; +#[cfg(feature = "ohttp")] +use ohttp::{ClientRequest, KeyConfig}; use prio::{ codec::{Decode, Encode}, vdaf, }; use rand::random; +#[cfg(feature = "ohttp")] +use std::io::Cursor; use std::{convert::Infallible, fmt::Debug, time::SystemTimeError}; -use tokio::try_join; use url::Url; #[cfg(test)] @@ -84,6 +91,15 @@ pub enum Error { UnexpectedServerResponse(&'static str), #[error("time conversion error: {0}")] TimeConversion(#[from] SystemTimeError), + #[cfg(feature = "ohttp")] + #[error("OHTTP error: {0}")] + Ohttp(#[from] ohttp::Error), + #[cfg(feature = "ohttp")] + #[error("BHTTP error: {0}")] + Bhttp(#[from] bhttp::Error), + #[cfg(feature = "ohttp")] + #[error("No supported key configurations advertised by OHTTP gateway")] + OhttpNoSupportedKeyConfigs(Box>), } impl From for Error { @@ -92,6 +108,15 @@ impl From for Error { } } +impl From> for Error { + fn from(result: Result) -> Self { + match result { + Ok(http_error_response) => Error::Http(Box::new(http_error_response)), + Err(error) => error.into(), + } + } +} + static CLIENT_USER_AGENT: &str = concat!( env!("CARGO_PKG_NAME"), "/", @@ -100,6 +125,13 @@ static CLIENT_USER_AGENT: &str = concat!( "client" ); +#[cfg(feature = "ohttp")] +const OHTTP_KEYS_MEDIA_TYPE: &str = "application/ohttp-keys"; +#[cfg(feature = "ohttp")] +const OHTTP_REQUEST_MEDIA_TYPE: &str = "message/ohttp-req"; +#[cfg(feature = "ohttp")] +const OHTTP_RESPONSE_MEDIA_TYPE: &str = "message/ohttp-res"; + /// The DAP client's view of task parameters. #[derive(Clone, Derivative)] #[derivative(Debug)] @@ -166,21 +198,22 @@ impl ClientParameters { /// provided [`ClientParameters`]. #[tracing::instrument(err)] async fn aggregator_hpke_config( + hpke_config: Option, client_parameters: &ClientParameters, aggregator_role: &Role, http_client: &reqwest::Client, ) -> Result { + if let Some(hpke_config) = hpke_config { + return Ok(hpke_config); + } + let mut request_url = client_parameters.hpke_config_endpoint(aggregator_role)?; request_url.set_query(Some(&format!("task_id={}", client_parameters.task_id))); let hpke_config_response = retry_http_request( client_parameters.http_request_retry_parameters.clone(), || async { http_client.get(request_url.clone()).send().await }, ) - .await - .map_err(|err| match err { - Ok(http_error_response) => Error::Http(Box::new(http_error_response)), - Err(error) => error.into(), - })?; + .await?; let status = hpke_config_response.status(); if !status.is_success() { return Err(Error::Http(Box::new(HttpErrorResponse::from(status)))); @@ -211,6 +244,44 @@ async fn aggregator_hpke_config( Err(first_error.unwrap().into()) } +/// Fetches OHTTP HPKE key configurations for the provided OHTTP config. +#[tracing::instrument(err)] +#[cfg(feature = "ohttp")] +async fn ohttp_key_configs( + http_request_retry_parameters: ExponentialBackoff, + ohttp_config: &OhttpConfig, + http_client: &reqwest::Client, +) -> Result, Error> { + // TODO(#3159): store/fetch OHTTP key configs in a cache-control aware persistent cache. + let keys_response = retry_http_request(http_request_retry_parameters, || async { + http_client + .get(ohttp_config.key_configs.clone()) + .header(ACCEPT, OHTTP_KEYS_MEDIA_TYPE) + .send() + .await + }) + .await?; + + if !keys_response.status().is_success() { + return Err(Error::Http(Box::new(HttpErrorResponse::from( + keys_response.status(), + )))); + } + + if keys_response + .headers() + .get(CONTENT_TYPE) + .map(HeaderValue::as_bytes) + != Some(OHTTP_KEYS_MEDIA_TYPE.as_bytes()) + { + return Err(Error::UnexpectedServerResponse( + "content type wrong for OHTTP keys", + )); + } + + Ok(KeyConfig::decode_list(keys_response.body().as_ref())?) +} + /// Construct a [`reqwest::Client`] suitable for use in a DAP [`Client`]. pub fn default_http_client() -> Result { Ok(reqwest::Client::builder() @@ -222,10 +293,29 @@ pub fn default_http_client() -> Result { .build()?) } +/// Configuration for using Oblivious HTTP (RFC 9458). +#[derive(Clone, Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))] +#[cfg(feature = "ohttp")] +pub struct OhttpConfig { + /// Endpoint from which OHTTP gateway key configurations may be fetched. The key configurations + /// must be in the format specified by [RFC 9458, section 3][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc9458#name-key-configuration + pub key_configs: Url, + + /// The OHTTP relay which will relay encapsulated messages to the gateway. + pub relay: Url, +} + /// Builder for configuring a [`Client`]. pub struct ClientBuilder> { parameters: ClientParameters, vdaf: V, + leader_hpke_config: Option, + helper_hpke_config: Option, + #[cfg(feature = "ohttp")] + ohttp_config: Option, http_client: Option, } @@ -246,6 +336,10 @@ impl> ClientBuilder { time_precision, ), vdaf, + leader_hpke_config: None, + helper_hpke_config: None, + #[cfg(feature = "ohttp")] + ohttp_config: None, http_client: None, } } @@ -258,11 +352,38 @@ impl> ClientBuilder { } else { default_http_client()? }; - let (leader_hpke_config, helper_hpke_config) = try_join!( - aggregator_hpke_config(&self.parameters, &Role::Leader, &http_client), - aggregator_hpke_config(&self.parameters, &Role::Helper, &http_client) + // TODO(#3159): store/fetch HPKE configurations in a cache-control aware persistent cache + let (leader_hpke_config, helper_hpke_config) = tokio::try_join!( + aggregator_hpke_config( + self.leader_hpke_config, + &self.parameters, + &Role::Leader, + &http_client + ), + aggregator_hpke_config( + self.helper_hpke_config, + &self.parameters, + &Role::Helper, + &http_client + ), )?; + + #[cfg(feature = "ohttp")] + let ohttp_config = if let Some(ohttp_config) = self.ohttp_config { + let key_configs = ohttp_key_configs( + self.parameters.http_request_retry_parameters.clone(), + &ohttp_config, + &http_client, + ) + .await?; + Some((ohttp_config, key_configs)) + } else { + None + }; + Ok(Client { + #[cfg(feature = "ohttp")] + ohttp_config, parameters: self.parameters, vdaf: self.vdaf, http_client, @@ -273,6 +394,14 @@ impl> ClientBuilder { /// Finalize construction of a [`Client`], and provide aggregator HPKE configurations through an /// out-of-band mechanism. + /// + /// # Notes + /// + /// This method is not compatible with OHTTP . Use [`ClientBuilder::with_ohttp_config`] and then + /// [`ClientBuilder::build`] to provide OHTTP configuration. + #[deprecated( + note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead" + )] pub fn build_with_hpke_configs( self, leader_hpke_config: HpkeConfig, @@ -286,6 +415,8 @@ impl> ClientBuilder { Ok(Client { parameters: self.parameters, vdaf: self.vdaf, + #[cfg(feature = "ohttp")] + ohttp_config: None, http_client, leader_hpke_config, helper_hpke_config, @@ -303,6 +434,61 @@ impl> ClientBuilder { self.parameters.http_request_retry_parameters = http_request_retry_parameters; self } + + /// Set the leader HPKE configuration to be used, preventing the client from fetching it from + /// the aggregator over HTTPS. + pub fn with_leader_hpke_config(mut self, hpke_config: HpkeConfig) -> Self { + self.leader_hpke_config = Some(hpke_config); + self + } + + /// Set the helper HPKE configuration to be used, preventing the client from fetching it from + /// the aggregator over HTTPS. + pub fn with_helper_hpke_config(mut self, hpke_config: HpkeConfig) -> Self { + self.helper_hpke_config = Some(hpke_config); + self + } + + /// Set the OHTTP configuration to be used when uploading reports, but not when fetching DAP + /// HPKE configurations. + /// + /// # Examples + /// + /// ```no_run + /// # use url::Url; + /// # use prio::vdaf::prio3::Prio3Count; + /// # use janus_messages::{Duration, TaskId}; + /// # use rand::random; + /// # use std::str::FromStr; + /// + /// #[tokio::main] + /// async fn main() { + /// let task_id = random(); + /// + /// let client = janus_client::Client::builder( + /// task_id, + /// Url::parse("https://leader.example.com/").unwrap(), + /// Url::parse("https://helper.example.com/").unwrap(), + /// Duration::from_seconds(1), + /// Prio3Count::new_count(2).unwrap(), + /// ) + /// .with_ohttp_config(janus_client::OhttpConfig { + /// key_configs: Url::parse("https://ohttp-keys.example.com").unwrap(), + /// relay: Url::parse("https://ohttp-relay.example.com").unwrap(), + /// }) + /// .build() + /// .await + /// .unwrap(); + /// + /// client.upload(&true).await.unwrap(); + /// } + /// ``` + #[cfg(feature = "ohttp")] + #[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))] + pub fn with_ohttp_config(mut self, ohttp_config: OhttpConfig) -> Self { + self.ohttp_config = Some(ohttp_config); + self + } } /// A DAP client. @@ -310,6 +496,8 @@ impl> ClientBuilder { pub struct Client> { parameters: ClientParameters, vdaf: V, + #[cfg(feature = "ohttp")] + ohttp_config: Option<(OhttpConfig, Vec)>, http_client: reqwest::Client, leader_hpke_config: HpkeConfig, helper_hpke_config: HpkeConfig, @@ -337,6 +525,14 @@ impl> Client { /// Construct a new client, and provide the aggregator HPKE configurations through an /// out-of-band means. + /// + /// # Notes + /// + /// This method is not compatible with OHTTP. Use [`ClientBuilder::with_ohttp_config`] and then + /// [`ClientBuilder::build`] to provide OHTTP configuration. + #[deprecated( + note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead" + )] pub fn with_hpke_configs( task_id: TaskId, leader_aggregator_endpoint: Url, @@ -346,6 +542,7 @@ impl> Client { leader_hpke_config: HpkeConfig, helper_hpke_config: HpkeConfig, ) -> Result { + #[allow(deprecated)] ClientBuilder::new( task_id, leader_aggregator_endpoint, @@ -476,28 +673,127 @@ impl> Client { let upload_endpoint = self .parameters .reports_resource_uri(&self.parameters.task_id)?; - let upload_response = retry_http_request( + + #[cfg(feature = "ohttp")] + let upload_status = self.upload_with_ohttp(&upload_endpoint, &report).await?; + #[cfg(not(feature = "ohttp"))] + let upload_status = self.put_report(&upload_endpoint, &report).await?; + + if !upload_status.is_success() { + return Err(Error::Http(Box::new(HttpErrorResponse::from( + upload_status, + )))); + } + + Ok(()) + } + + async fn put_report( + &self, + upload_endpoint: &Url, + request_body: &[u8], + ) -> Result { + Ok(retry_http_request( self.parameters.http_request_retry_parameters.clone(), || async { self.http_client .put(upload_endpoint.clone()) .header(CONTENT_TYPE, Report::MEDIA_TYPE) - .body(report.clone()) + .body(request_body.to_vec()) .send() .await }, ) - .await - .map_err(|err| match err { - Ok(http_error_response) => Error::Http(Box::new(http_error_response)), - Err(error) => error.into(), - })?; + .await? + .status()) + } + + /// Send a DAP upload request via OHTTP, if the client is configured to use it, or directly if + /// not. This is only intended for DAP uploads and so does not handle response bodies. + #[cfg(feature = "ohttp")] + #[tracing::instrument(skip(self), err)] + async fn upload_with_ohttp( + &self, + upload_endpoint: &Url, + request_body: &[u8], + ) -> Result { + let (ohttp_config, key_configs) = + if let Some((ohttp_config, key_configs)) = &self.ohttp_config { + (ohttp_config, key_configs) + } else { + return self.put_report(upload_endpoint, request_body).await; + }; + + // Construct a Message representing the upload request... + let mut message = Message::request( + "PUT".into(), + upload_endpoint.scheme().into(), + upload_endpoint.authority().into(), + upload_endpoint.path().into(), + ); + message.put_header(CONTENT_TYPE.as_str(), Report::MEDIA_TYPE); + message.write_content(request_body); + + // ...get the BHTTP encoding of the message... + let mut request_buf = Vec::new(); + message.write_bhttp(Mode::KnownLength, &mut request_buf)?; + + // ...and OHTTP encapsulate it to the gateway key config. + let ohttp_request = key_configs + .iter() + .cloned() + .find_map(|mut key_config| ClientRequest::from_config(&mut key_config).ok()) + .ok_or_else(|| Error::OhttpNoSupportedKeyConfigs(Box::new(key_configs.to_vec())))?; + + let (encapsulated_request, ohttp_response) = ohttp_request.encapsulate(&request_buf)?; + + let relay_response = retry_http_request( + self.parameters.http_request_retry_parameters.clone(), + || async { + self.http_client + .post(ohttp_config.relay.clone()) + .header(CONTENT_TYPE, OHTTP_REQUEST_MEDIA_TYPE) + .header(ACCEPT, OHTTP_RESPONSE_MEDIA_TYPE) + .body(encapsulated_request.clone()) + .send() + .await + }, + ) + .await?; + + // Check whether request to the OHTTP relay was successful, and if so, decapsulate that + // response to get the DAP aggregator's response. + if !relay_response.status().is_success() { + return Err(Error::Http(Box::new(HttpErrorResponse::from( + relay_response.status(), + )))); + } - let status = upload_response.status(); - if !status.is_success() { - return Err(Error::Http(Box::new(HttpErrorResponse::from(status)))); + if relay_response + .headers() + .get(CONTENT_TYPE) + .map(HeaderValue::as_bytes) + != Some(OHTTP_RESPONSE_MEDIA_TYPE.as_bytes()) + { + return Err(Error::UnexpectedServerResponse( + "content type wrong for OHTTP response", + )); } - Ok(()) + let decapsulated_response = ohttp_response.decapsulate(relay_response.body().as_ref())?; + let message = Message::read_bhttp(&mut Cursor::new(&decapsulated_response))?; + let status = if let ControlData::Response(status) = message.control() { + StatusCode::from_u16(*status).map_err(|_| { + Error::UnexpectedServerResponse( + "status in decapsulated response is not valid HTTP status", + ) + })? + } else { + return Err(Error::UnexpectedServerResponse( + "decapsulated response control data is not a response", + )); + }; + + Ok(status) } } diff --git a/client/src/tests/mod.rs b/client/src/tests/mod.rs index 935b1e8ce..e47972433 100644 --- a/client/src/tests/mod.rs +++ b/client/src/tests/mod.rs @@ -15,7 +15,10 @@ use prio::{ use rand::random; use url::Url; -fn setup_client>(server: &mockito::Server, vdaf: V) -> Client { +#[cfg(feature = "ohttp")] +mod ohttp; + +async fn setup_client>(server: &mockito::Server, vdaf: V) -> Client { let server_url = Url::parse(&server.url()).unwrap(); Client::builder( random(), @@ -25,10 +28,10 @@ fn setup_client>(server: &mockito::Server, vdaf: V) -> Clien vdaf, ) .with_backoff(test_http_request_exponential_backoff()) - .build_with_hpke_configs( - generate_test_hpke_config_and_private_key().config().clone(), - generate_test_hpke_config_and_private_key().config().clone(), - ) + .with_leader_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .with_helper_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .build() + .await .unwrap() } @@ -55,7 +58,7 @@ fn aggregator_endpoints_end_in_slash() { async fn upload_prio3_count() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; let mocked_upload = server .mock( @@ -78,7 +81,7 @@ async fn upload_prio3_invalid_measurement() { install_test_trace_subscriber(); let server = mockito::Server::new_async().await; let vdaf = Prio3::new_sum(2, 16).unwrap(); - let client = setup_client(&server, vdaf); + let client = setup_client(&server, vdaf).await; // 65536 is too big for a 16 bit sum and will be rejected by the VDAF. // Make sure we get the right error variant but otherwise we aren't @@ -90,7 +93,7 @@ async fn upload_prio3_invalid_measurement() { async fn upload_prio3_http_status_code() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; let mocked_upload = server .mock( @@ -117,7 +120,7 @@ async fn upload_prio3_http_status_code() { async fn upload_problem_details() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; let mocked_upload = server .mock( @@ -165,21 +168,21 @@ async fn upload_bad_time_precision() { Duration::from_seconds(0), Prio3::new_count(2).unwrap(), ) - .build_with_hpke_configs( - generate_test_hpke_config_and_private_key().config().clone(), - generate_test_hpke_config_and_private_key().config().clone(), - ) + .with_leader_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .with_helper_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .build() + .await .unwrap(); let result = client.upload(&true).await; assert_matches!(result, Err(Error::InvalidParameter(_))); } -#[test] -fn report_timestamp() { +#[tokio::test] +async fn report_timestamp() { install_test_trace_subscriber(); - let server = mockito::Server::new(); + let server = mockito::Server::new_async().await; let vdaf = Prio3::new_count(2).unwrap(); - let mut client = setup_client(&server, vdaf); + let mut client = setup_client(&server, vdaf).await; client.parameters.time_precision = Duration::from_seconds(100); assert_eq!( @@ -238,9 +241,21 @@ async fn aggregator_hpke() { .create_async() .await; - let got_hpke_config = aggregator_hpke_config(&client_parameters, &Role::Leader, http_client) - .await - .unwrap(); + let got_hpke_config = + aggregator_hpke_config(None, &client_parameters, &Role::Leader, http_client) + .await + .unwrap(); + assert_eq!(&got_hpke_config, keypair.config()); + + // Fetching HPKE config again should not hit the mock server + let got_hpke_config = aggregator_hpke_config( + Some(got_hpke_config), + &client_parameters, + &Role::Leader, + http_client, + ) + .await + .unwrap(); assert_eq!(&got_hpke_config, keypair.config()); mock.assert_async().await; @@ -294,9 +309,10 @@ async fn unsupported_hpke_algorithms() { .create_async() .await; - let got_hpke_config = aggregator_hpke_config(&client_parameters, &Role::Leader, http_client) - .await - .unwrap(); + let got_hpke_config = + aggregator_hpke_config(None, &client_parameters, &Role::Leader, http_client) + .await + .unwrap(); assert_eq!(got_hpke_config, good_hpke_config); mock.assert_async().await; diff --git a/client/src/tests/ohttp.rs b/client/src/tests/ohttp.rs new file mode 100644 index 000000000..f36657822 --- /dev/null +++ b/client/src/tests/ohttp.rs @@ -0,0 +1,329 @@ +use std::io::Cursor; + +use crate::{ + Client, Error, OhttpConfig, OHTTP_KEYS_MEDIA_TYPE, OHTTP_REQUEST_MEDIA_TYPE, + OHTTP_RESPONSE_MEDIA_TYPE, +}; +use assert_matches::assert_matches; +use bhttp::{Message, Mode}; +use http::header::{ACCEPT, CONTENT_TYPE}; +use janus_core::{ + hpke::test_util::generate_test_hpke_config_and_private_key, http::HttpErrorResponse, + retries::test_util::test_http_request_exponential_backoff, + test_util::install_test_trace_subscriber, +}; +use janus_messages::{Duration, Report}; +use ohttp::{ + hpke::{Aead, Kdf}, + KeyConfig, SymmetricSuite, +}; +use prio::{ + codec::Decode, + vdaf::prio3::{Prio3, Prio3Count}, +}; +use rand::random; +use url::Url; + +async fn build_client(server: &mockito::ServerGuard) -> Result, Error> { + let task_id = random(); + let server_url = Url::parse(&server.url()).unwrap(); + let keys_endpoint = Url::parse(format!("{}/ohttp-keys", server.url()).as_str()).unwrap(); + let relay = Url::parse(format!("{}/relay", server.url()).as_str()).unwrap(); + + Client::builder( + task_id, + server_url.clone(), + server_url.clone(), + Duration::from_seconds(1), + Prio3::new_count(2).unwrap(), + ) + .with_backoff(test_http_request_exponential_backoff()) + .with_leader_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .with_helper_hpke_config(generate_test_hpke_config_and_private_key().config().clone()) + .with_ohttp_config(OhttpConfig { + key_configs: keys_endpoint, + relay, + }) + .build() + .await +} + +async fn mocked_ohttp_keys(server: &mut mockito::ServerGuard) -> (mockito::Mock, ohttp::Server) { + let key_config = KeyConfig::new( + 0, + ohttp::hpke::Kem::X25519Sha256, + vec![SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)], + ) + .unwrap(); + let encoded_key_config = KeyConfig::encode_list(&[&key_config]).unwrap(); + + let mocked_ohttp_keys = server + .mock("GET", "/ohttp-keys") + .match_header(ACCEPT.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .with_body(encoded_key_config) + .create_async() + .await; + + (mocked_ohttp_keys, ohttp::Server::new(key_config).unwrap()) +} + +#[tokio::test] +async fn successful_upload() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let (mocked_ohttp_keys, ohttp_server) = mocked_ohttp_keys(&mut server).await; + let client = build_client(&server).await.unwrap(); + + let mocked_ohttp_upload = server + .mock("POST", "/relay") + .match_header(ACCEPT.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .match_header(CONTENT_TYPE.as_str(), OHTTP_REQUEST_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .with_body_from_request(move |request| { + let encapsulated_req = request.body().unwrap(); + let (decapsulated_req, server_response) = + ohttp_server.decapsulate(encapsulated_req.as_ref()).unwrap(); + let bin_request = Message::read_bhttp(&mut Cursor::new(&decapsulated_req[..])).unwrap(); + + // Check that encapsulated request is a correct DAP upload + assert_eq!( + bin_request + .control() + .method() + .map(|a| String::from_utf8(a.to_vec()).unwrap()), + Some("PUT".to_string()), + ); + assert_eq!( + bin_request + .control() + .scheme() + .map(|a| String::from_utf8(a.to_vec()).unwrap()), + Some("http".to_string()), + ); + assert_eq!( + bin_request + .control() + .authority() + .map(|a| String::from_utf8(a.to_vec()).unwrap()), + Some(Url::parse(&server.url()).unwrap().authority().to_string()), + ); + assert_eq!( + bin_request + .control() + .path() + .map(|a| String::from_utf8(a.to_vec()).unwrap()), + Some(format!("/tasks/{}/reports", client.parameters.task_id)), + ); + + assert_eq!( + bin_request + .header() + .get(b"content-type") + .map(|a| String::from_utf8(a.to_vec()).unwrap()), + Some(Report::MEDIA_TYPE.to_string()), + ); + + Report::get_decoded(bin_request.content()).unwrap(); + + // Construct a 200 OK response to the encapsulated request, then encapsulate that + let mut response = Vec::new(); + Message::response(200) + .write_bhttp(Mode::KnownLength, &mut response) + .unwrap(); + server_response.encapsulate(&response).unwrap() + }) + .create_async() + .await; + + client.upload(&true).await.unwrap(); + + mocked_ohttp_keys.assert_async().await; + mocked_ohttp_upload.assert_async().await; +} + +#[tokio::test] +async fn ohttp_keyconfigs_http_error() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let mocked_ohttp_keys = server + .mock("GET", "/ohttp-keys") + .match_header(ACCEPT.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .expect(1) + .with_status(400) + .with_header(CONTENT_TYPE.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .create_async() + .await; + + let error = build_client(&server).await.unwrap_err(); + assert_matches!(error, Error::Http(_)); + + mocked_ohttp_keys.assert_async().await; +} + +#[tokio::test] +async fn ohttp_keyconfigs_malformed_response_body() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let mocked_ohttp_keys = server + .mock("GET", "/ohttp-keys") + .match_header(ACCEPT.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), OHTTP_KEYS_MEDIA_TYPE) + // Not a valid KeyConfigList + .with_body(vec![0, 1, 2, 3]) + .create_async() + .await; + + let error = build_client(&server).await.unwrap_err(); + + assert_matches!(error, Error::Ohttp(_)); + + mocked_ohttp_keys.assert_async().await; +} + +#[tokio::test] +async fn ohttp_keyconfigs_wrong_content_type() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let key_config = KeyConfig::new( + 0, + ohttp::hpke::Kem::X25519Sha256, + vec![SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)], + ) + .unwrap(); + let encoded_key_config = KeyConfig::encode_list(&[&key_config]).unwrap(); + + let mocked_ohttp_keys = server + .mock("GET", "/ohttp-keys") + .match_header(ACCEPT.as_str(), OHTTP_KEYS_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), "application/wrong-type") + .with_body(encoded_key_config) + .create_async() + .await; + + let error = build_client(&server).await.unwrap_err(); + + assert_matches!(error, Error::UnexpectedServerResponse(_)); + + mocked_ohttp_keys.assert_async().await; +} + +#[tokio::test] +async fn http_client_error_from_relay() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let (mocked_ohttp_keys, _) = mocked_ohttp_keys(&mut server).await; + let client = build_client(&server).await.unwrap(); + + let mocked_ohttp_upload = server + .mock("POST", "/relay") + .match_header(ACCEPT.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .match_header(CONTENT_TYPE.as_str(), OHTTP_REQUEST_MEDIA_TYPE) + .expect(1) + .with_status(400) + .create_async() + .await; + + let error = client.upload(&true).await.unwrap_err(); + + assert_matches!(error, Error::Http(boxed) => { + assert_matches!(boxed.as_ref(), HttpErrorResponse {..}); + }); + + mocked_ohttp_keys.assert_async().await; + mocked_ohttp_upload.assert_async().await; +} + +#[tokio::test] +async fn http_client_error_from_target() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let (mocked_ohttp_keys, ohttp_server) = mocked_ohttp_keys(&mut server).await; + let client = build_client(&server).await.unwrap(); + + let mocked_ohttp_upload = server + .mock("POST", "/relay") + .match_header(ACCEPT.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .match_header(CONTENT_TYPE.as_str(), OHTTP_REQUEST_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .with_body_from_request(move |request| { + let encapsulated_req = request.body().unwrap(); + let (_, server_response) = ohttp_server.decapsulate(encapsulated_req.as_ref()).unwrap(); + + // Construct a 400 Client Error response to the encapsulated request, then encapsulate that + let mut response = Vec::new(); + Message::response(400) + .write_bhttp(Mode::KnownLength, &mut response) + .unwrap(); + server_response.encapsulate(&response).unwrap() + }) + .create_async() + .await; + + let error = client.upload(&true).await.unwrap_err(); + + assert_matches!(error, Error::Http(boxed) => { + assert_matches!(boxed.as_ref(), HttpErrorResponse {..}); + }); + + mocked_ohttp_keys.assert_async().await; + mocked_ohttp_upload.assert_async().await; +} + +#[tokio::test] +async fn encapsulated_server_message_is_http_request() { + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + + let (mocked_ohttp_keys, ohttp_server) = mocked_ohttp_keys(&mut server).await; + let client = build_client(&server).await.unwrap(); + + let mocked_ohttp_upload = server + .mock("POST", "/relay") + .match_header(ACCEPT.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .match_header(CONTENT_TYPE.as_str(), OHTTP_REQUEST_MEDIA_TYPE) + .expect(1) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), OHTTP_RESPONSE_MEDIA_TYPE) + .with_body_from_request(move |request| { + let encapsulated_req = request.body().unwrap(); + let (_, server_response) = ohttp_server.decapsulate(encapsulated_req.as_ref()).unwrap(); + + // Construct an encapsulated *request* + let mut response = Vec::new(); + Message::request( + b"GET".to_vec(), + b"http".to_vec(), + b"example.com".to_vec(), + b"/something".to_vec(), + ) + .write_bhttp(Mode::KnownLength, &mut response) + .unwrap(); + server_response.encapsulate(&response).unwrap() + }) + .create_async() + .await; + + let error = client.upload(&true).await.unwrap_err(); + + assert_matches!(error, Error::UnexpectedServerResponse(_)); + + mocked_ohttp_keys.assert_async().await; + mocked_ohttp_upload.assert_async().await; +}