diff --git a/.circleci/config.yml b/.circleci/config.yml index 1a67d4e18c..15915bbb1e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -48,13 +48,13 @@ commands: - run: name: Rust Clippy MySQL command: | - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings rust-clippy-spanner: steps: - run: name: Rust Clippy Spanner command: | - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings cargo-build: steps: - run: diff --git a/Cargo.lock b/Cargo.lock index f2f438d108..e6bbf70c9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1068,8 +1068,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1439,6 +1441,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4" +dependencies = [ + "base64", + "js-sys", + "ring", + "serde 1.0.195", + "serde_json", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -3032,14 +3047,23 @@ name = "tokenserver-auth" version = "0.14.4" dependencies = [ "async-trait", + "base64", "dyn-clone", "futures 0.3.30", + "hex", + "hkdf", + "hmac", + "jsonwebtoken", "mockito", "pyo3", "reqwest", + "ring", "serde 1.0.195", "serde_json", + "sha2", + "slog-scope", "syncserver-common", + "thiserror", "tokenserver-common", "tokenserver-settings", "tokio", @@ -3051,6 +3075,7 @@ version = "0.14.4" dependencies = [ "actix-web", "backtrace", + "jsonwebtoken", "serde 1.0.195", "serde_json", "syncserver-common", @@ -3086,6 +3111,7 @@ dependencies = [ name = "tokenserver-settings" version = "0.14.4" dependencies = [ + "jsonwebtoken", "serde 1.0.195", "tokenserver-common", ] diff --git a/Cargo.toml b/Cargo.toml index fd875a0709..44107ceeec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,10 @@ docopt = "1.1" env_logger = "0.10" futures = { version = "0.3", features = ["compat"] } hex = "0.4" +hkdf = "0.12" +hmac = "0.12" http = "0.2" +jsonwebtoken = { version = "9.2", default-features = false } lazy_static = "1.4" protobuf = "=2.25.2" # pin to 2.25.2 to prevent side updating rand = "0.8" @@ -62,6 +65,7 @@ slog-scope = "4.3" slog-stdlog = "4.1" slog-term = "2.6" tokio = "1" +thiserror = "1.0.26" [profile.release] # Enables line numbers in Sentry reporting diff --git a/Dockerfile b/Dockerfile index 8f71ce7f34..2abdc5696a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,7 +19,7 @@ RUN \ apt-get -q install -y --no-install-recommends libmysqlclient-dev cmake COPY --from=planner /app/recipe.json recipe.json -RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --recipe-path recipe.json +RUN cargo chef cook --release --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --recipe-path recipe.json FROM chef as builder ARG DATABASE_BACKEND=spanner @@ -46,7 +46,7 @@ ENV PATH=$PATH:/root/.cargo/bin RUN \ cargo --version && \ rustc --version && \ - cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --locked --root /app && \ + cargo install --path ./syncserver --no-default-features --features=syncstorage-db/$DATABASE_BACKEND --features=py_verifier --locked --root /app && \ if [ "$DATABASE_BACKEND" = "spanner" ] ; then cargo install --path ./syncstorage-spanner --locked --root /app --bin purge_ttl ; fi FROM docker.io/library/debian:bullseye-slim @@ -56,11 +56,12 @@ COPY --from=builder /app/requirements.txt /app # have to set this env var to prevent the cryptography package from building # with Rust. See this link for more information: # https://pythonshowcase.com/question/problem-installing-cryptography-on-raspberry-pi +ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1 RUN \ apt-get -q update && apt-get -qy install wget -ENV CRYPTOGRAPHY_DONT_BUILD_RUST=1 + RUN \ groupadd --gid 10001 app && \ useradd --uid 10001 --gid 10001 --home /app --create-home app && \ diff --git a/Makefile b/Makefile index a0540f6b53..73b95c7171 100644 --- a/Makefile +++ b/Makefile @@ -15,11 +15,11 @@ PYTHON_SITE_PACKGES = $(shell $(SRC_ROOT)/venv/bin/python -c "from distutils.sys clippy_mysql: # Matches what's run in circleci - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- -D warnings clippy_spanner: # Matches what's run in circleci - cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner -- -D warnings + cargo clippy --workspace --all-targets --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- -D warnings clean: cargo clean @@ -47,14 +47,15 @@ python: python3 -m venv venv venv/bin/python -m pip install -r requirements.txt + run_mysql: python PATH="./venv/bin:$(PATH)" \ # See https://github.com/PyO3/pyo3/issues/1741 for discussion re: why we need to set the # below env var PYTHONPATH=$(PYTHON_SITE_PACKGES) \ - RUST_LOG=debug \ + RUST_LOG=debug \ RUST_BACKTRACE=full \ - cargo run --no-default-features --features=syncstorage-db/mysql -- --config config/local.toml + cargo run --no-default-features --features=syncstorage-db/mysql --features=py_verifier -- --config config/local.toml run_spanner: python GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) \ @@ -65,7 +66,7 @@ run_spanner: python PATH="./venv/bin:$(PATH)" \ RUST_LOG=debug \ RUST_BACKTRACE=full \ - cargo run --no-default-features --features=syncstorage-db/spanner -- --config config/local.toml + cargo run --no-default-features --features=syncstorage-db/spanner --features=py_verifier -- --config config/local.toml test: SYNC_SYNCSTORAGE__DATABASE_URL=mysql://sample_user:sample_password@localhost/syncstorage_rs \ diff --git a/syncserver-common/Cargo.toml b/syncserver-common/Cargo.toml index d3a3f44355..1f5f507c24 100644 --- a/syncserver-common/Cargo.toml +++ b/syncserver-common/Cargo.toml @@ -14,5 +14,5 @@ serde_json.workspace = true slog.workspace = true slog-scope.workspace = true actix-web.workspace = true +hkdf.workspace = true -hkdf = "0.12" diff --git a/syncserver-db-common/Cargo.toml b/syncserver-db-common/Cargo.toml index c9a1783994..2b43ccbd18 100644 --- a/syncserver-db-common/Cargo.toml +++ b/syncserver-db-common/Cargo.toml @@ -9,9 +9,9 @@ edition.workspace=true backtrace.workspace=true futures.workspace=true http.workspace=true +thiserror.workspace=true deadpool = { git = "https://github.com/mozilla-services/deadpool", tag = "deadpool-v0.7.0" } diesel = { version = "1.4", features = ["mysql", "r2d2"] } diesel_migrations = { version = "1.4.0", features = ["mysql"] } syncserver-common = { path = "../syncserver-common" } -thiserror = "1.0.26" diff --git a/syncserver/Cargo.toml b/syncserver/Cargo.toml index 76f191d897..28623de282 100644 --- a/syncserver/Cargo.toml +++ b/syncserver/Cargo.toml @@ -32,6 +32,8 @@ slog-mozlog-json.workspace = true slog-scope.workspace = true slog-stdlog.workspace = true slog-term.workspace = true +hmac.workspace = true +thiserror.workspace = true actix-http = "3" actix-rt = "2" @@ -40,7 +42,6 @@ async-trait = "0.1.40" dyn-clone = "1.0.4" hostname = "0.3.1" hawk = "5.0" -hmac = "0.12" mime = "0.3" reqwest = { workspace = true, features = [ "json", @@ -53,8 +54,7 @@ syncserver-settings = { path = "../syncserver-settings" } syncstorage-db = { path = "../syncstorage-db" } syncstorage-settings = { path = "../syncstorage-settings" } time = "^0.3" -thiserror = "1.0.26" -tokenserver-auth = { path = "../tokenserver-auth" } +tokenserver-auth = { path = "../tokenserver-auth", default-features = false} tokenserver-common = { path = "../tokenserver-common" } tokenserver-db = { path = "../tokenserver-db" } tokenserver-settings = { path = "../tokenserver-settings" } @@ -65,7 +65,8 @@ validator_derive = "0.16" woothee = "0.13" [features] -default = ["mysql"] +default = ["mysql", "py_verifier"] no_auth = [] +py_verifier = ["tokenserver-auth/py"] mysql = ["syncstorage-db/mysql"] spanner = ["syncstorage-db/spanner"] diff --git a/syncserver/src/tokenserver/extractors.rs b/syncserver/src/tokenserver/extractors.rs index e3fe816371..a250812aff 100644 --- a/syncserver/src/tokenserver/extractors.rs +++ b/syncserver/src/tokenserver/extractors.rs @@ -457,7 +457,8 @@ impl FromRequest for AuthData { let mut tags = HashMap::default(); tags.insert("token_type".to_owned(), "BrowserID".to_owned()); metrics.start_timer("token_verification", Some(tags)); - let verify_output = state.browserid_verifier.verify(assertion).await?; + let verify_output = + state.browserid_verifier.verify(assertion, &metrics).await?; // For requests using BrowserID, the client state is embedded in the // X-Client-State header, and the generation and keys_changed_at are extracted @@ -487,7 +488,7 @@ impl FromRequest for AuthData { let mut tags = HashMap::default(); tags.insert("token_type".to_owned(), "OAuth".to_owned()); metrics.start_timer("token_verification", Some(tags)); - let verify_output = state.oauth_verifier.verify(token).await?; + let verify_output = state.oauth_verifier.verify(token, &metrics).await?; // For requests using OAuth, the keys_changed_at and client state are embedded // in the X-KeyID header. diff --git a/syncserver/src/tokenserver/mod.rs b/syncserver/src/tokenserver/mod.rs index f7dd727d46..d7d5dc55d8 100644 --- a/syncserver/src/tokenserver/mod.rs +++ b/syncserver/src/tokenserver/mod.rs @@ -9,6 +9,8 @@ use serde::{ Serialize, }; use syncserver_common::{BlockingThreadpool, Metrics}; +#[cfg(not(feature = "py_verifier"))] +use tokenserver_auth::JWTVerifierImpl; use tokenserver_auth::{browserid, oauth, VerifyToken}; use tokenserver_common::NodeType; use tokenserver_db::{params, DbPool, TokenserverPool}; @@ -40,6 +42,32 @@ impl ServerState { metrics: Arc, blocking_threadpool: Arc, ) -> Result { + #[cfg(not(feature = "py_verifier"))] + let oauth_verifier = { + let mut jwk_verifiers: Vec = Vec::new(); + if let Some(primary) = &settings.fxa_oauth_primary_jwk { + jwk_verifiers.push( + primary + .clone() + .try_into() + .expect("Invalid primary key, should either be fixed or removed"), + ) + } + if let Some(secondary) = &settings.fxa_oauth_secondary_jwk { + jwk_verifiers.push( + secondary + .clone() + .try_into() + .expect("Invalid secondary key, should either be fixed or removed"), + ); + } + Box::new( + oauth::Verifier::new(settings, jwk_verifiers) + .expect("failed to create Tokenserver OAuth verifier"), + ) + }; + + #[cfg(feature = "py_verifier")] let oauth_verifier = Box::new( oauth::Verifier::new(settings, blocking_threadpool.clone()) .expect("failed to create Tokenserver OAuth verifier"), diff --git a/syncstorage-db-common/Cargo.toml b/syncstorage-db-common/Cargo.toml index 68655bcf78..d55380c1d6 100644 --- a/syncstorage-db-common/Cargo.toml +++ b/syncstorage-db-common/Cargo.toml @@ -13,10 +13,10 @@ lazy_static.workspace=true http.workspace=true serde.workspace=true serde_json.workspace=true +thiserror.workspace=true async-trait = "0.1.40" diesel = { version = "1.4", features = ["mysql", "r2d2"] } diesel_migrations = { version = "1.4.0", features = ["mysql"] } syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } -thiserror = "1.0.26" diff --git a/syncstorage-mysql/Cargo.toml b/syncstorage-mysql/Cargo.toml index 80f372c827..297b83ed6f 100644 --- a/syncstorage-mysql/Cargo.toml +++ b/syncstorage-mysql/Cargo.toml @@ -11,6 +11,7 @@ base64.workspace=true futures.workspace=true http.workspace=true slog-scope.workspace=true +thiserror.workspace=true async-trait = "0.1.40" diesel = { version = "1.4", features = ["mysql", "r2d2"] } @@ -20,7 +21,6 @@ syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } syncstorage-db-common = { path = "../syncstorage-db-common" } syncstorage-settings = { path = "../syncstorage-settings" } -thiserror = "1.0.26" url = "2.1" [dev-dependencies] diff --git a/syncstorage-spanner/Cargo.toml b/syncstorage-spanner/Cargo.toml index c4babd6964..212efcbe7f 100644 --- a/syncstorage-spanner/Cargo.toml +++ b/syncstorage-spanner/Cargo.toml @@ -12,6 +12,7 @@ env_logger.workspace = true futures.workspace = true http.workspace = true slog-scope.workspace = true +thiserror.workspace = true async-trait = "0.1.40" google-cloud-rust-raw = { version = "0.16.1", features = ["spanner"] } @@ -30,7 +31,6 @@ syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } syncstorage-db-common = { path = "../syncstorage-db-common" } syncstorage-settings = { path = "../syncstorage-settings" } -thiserror = "1.0.26" tokio = { workspace = true, features = [ "macros", "sync", diff --git a/tokenserver-auth/Cargo.toml b/tokenserver-auth/Cargo.toml index 731a613d55..5286fede6f 100644 --- a/tokenserver-auth/Cargo.toml +++ b/tokenserver-auth/Cargo.toml @@ -11,15 +11,30 @@ edition.workspace = true futures.workspace = true serde.workspace = true serde_json.workspace = true +hex.workspace = true +hkdf.workspace = true +hmac.workspace = true +jsonwebtoken.workspace = true +base64.workspace = true +sha2.workspace = true +thiserror.workspace = true +slog-scope.workspace = true async-trait = "0.1.40" dyn-clone = "1.0.4" -pyo3 = { version = "0.20", features = ["auto-initialize"] } reqwest = { workspace = true, features = ["json", "rustls-tls"] } +ring = "0.17" syncserver-common = { path = "../syncserver-common" } tokenserver-common = { path = "../tokenserver-common" } tokenserver-settings = { path = "../tokenserver-settings" } tokio = { workspace = true } +pyo3 = { version = "0.20", features = ["auto-initialize"], optional = true} + [dev-dependencies] mockito = "0.30.0" +tokio = { workspace = true, features = ["macros"]} + +[features] +default = ["py"] +py = ["pyo3"] diff --git a/tokenserver-auth/src/browserid.rs b/tokenserver-auth/src/browserid.rs index 2757cf421f..9b3711c7ae 100644 --- a/tokenserver-auth/src/browserid.rs +++ b/tokenserver-auth/src/browserid.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use reqwest::{Client as ReqwestClient, StatusCode}; use serde::{de::Deserializer, Deserialize, Serialize}; +use syncserver_common::Metrics; use tokenserver_common::{ErrorLocation, TokenType, TokenserverError}; use tokenserver_settings::Settings; @@ -52,7 +53,11 @@ impl VerifyToken for Verifier { /// Verifies a BrowserID assertion. Returns `VerifyOutput` for valid assertions and a /// `TokenserverError` for invalid assertions. - async fn verify(&self, assertion: String) -> Result { + async fn verify( + &self, + assertion: String, + _metrics: &Metrics, + ) -> Result { let response = self .request_client .post(&self.fxa_verifier_url) @@ -313,7 +318,10 @@ mod tests { }) .unwrap(); - let result = verifier.verify("test".to_owned()).await.unwrap(); + let result = verifier + .verify("test".to_owned(), &Default::default()) + .await + .unwrap(); mock.assert(); let expected_result = VerifyOutput { @@ -345,7 +353,10 @@ mod tests { .with_header("content-type", "application/json") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -363,7 +374,10 @@ mod tests { .with_body("

Server Error

") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -381,7 +395,10 @@ mod tests { .with_body("{\"status\": \"error\"}") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -399,7 +416,10 @@ mod tests { .with_body("{\"status\": \"potato\"}") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -417,7 +437,10 @@ mod tests { .with_body("{\"status\": \"failure\", \"reason\": \"something broke\"}") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -434,7 +457,10 @@ mod tests { .with_body("{\"status\": \"failure\"}") .create(); - let error = verifier.verify(assertion.to_owned()).await.unwrap_err(); + let error = verifier + .verify(assertion.to_owned(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -481,7 +507,10 @@ mod tests { { let mock = mock("login.persona.org"); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); assert_eq!(expected_error, error); @@ -489,7 +518,10 @@ mod tests { { let mock = mock(ISSUER); - let result = verifier.verify(assertion.clone()).await.unwrap(); + let result = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap(); let expected_result = VerifyOutput { device_id: None, email: "test@example.com".to_owned(), @@ -503,7 +535,10 @@ mod tests { { let mock = mock("accounts.firefox.org"); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); assert_eq!(expected_error, error); @@ -511,7 +546,10 @@ mod tests { { let mock = mock("http://accounts.firefox.com"); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); assert_eq!(expected_error, error); @@ -519,7 +557,10 @@ mod tests { { let mock = mock("accounts.firefox.co"); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); assert_eq!(expected_error, error); @@ -536,7 +577,10 @@ mod tests { .with_header("content-type", "application/json") .with_body(body.to_string()) .create(); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -558,7 +602,10 @@ mod tests { .with_header("content-type", "application/json") .with_body(body.to_string()) .create(); - let error = verifier.verify(assertion.clone()).await.unwrap_err(); + let error = verifier + .verify(assertion.clone(), &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { @@ -579,7 +626,10 @@ mod tests { .with_header("content-type", "application/json") .with_body(body.to_string()) .create(); - let error = verifier.verify(assertion).await.unwrap_err(); + let error = verifier + .verify(assertion, &Default::default()) + .await + .unwrap_err(); mock.assert(); let expected_error = TokenserverError { diff --git a/tokenserver-auth/src/crypto.rs b/tokenserver-auth/src/crypto.rs new file mode 100644 index 0000000000..4e3a882468 --- /dev/null +++ b/tokenserver-auth/src/crypto.rs @@ -0,0 +1,172 @@ +use hkdf::Hkdf; +use hmac::{Hmac, Mac}; +use jsonwebtoken::{errors::ErrorKind, jwk::Jwk, Algorithm, DecodingKey, Validation}; +use ring::rand::{SecureRandom, SystemRandom}; +use serde::de::DeserializeOwned; +use sha2::Sha256; +use tokenserver_common::TokenserverError; +pub const SHA256_OUTPUT_LEN: usize = 32; +/// A triat representing all the required cryptographic operations by the token server +pub trait Crypto { + type Error; + /// HKDF key derivation + /// + /// This expands `info` into a 32 byte value using `secret` and the optional `salt`. + /// Salt is normally specified, except when this function is called in [syncserver-settings::Secrets::new] or when deriving + /// a key to be used to sign the tokenserver tokens, so both syncserver and tokenserver can + /// sign and validate the signatures + fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result, Self::Error>; + + /// HMAC signiture + /// + /// Signs the `payload` using HMAC given the `key` + fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result, Self::Error>; + + /// Verify an HMAC signature on a payload given a shared key + fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error>; + + /// Generates random bytes using a cryptographic random number generator + /// and fills `output` with those bytes + fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error>; +} + +/// An implementation for the needed cryptographic using +/// the hmac crate for hmac and hkdf crate for hkdf +/// it uses ring for the random number generation +pub struct CryptoImpl {} + +impl Crypto for CryptoImpl { + type Error = TokenserverError; + fn hkdf(&self, secret: &str, salt: Option<&[u8]>, info: &[u8]) -> Result, Self::Error> { + let hk = Hkdf::::new(salt, secret.as_bytes()); + let mut okm = [0u8; SHA256_OUTPUT_LEN]; + hk.expand(info, &mut okm) + .map_err(|_| TokenserverError::internal_error())?; + Ok(okm.to_vec()) + } + + fn hmac_sign(&self, key: &[u8], payload: &[u8]) -> Result, Self::Error> { + let mut mac: Hmac = + Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?; + mac.update(payload); + Ok(mac.finalize().into_bytes().to_vec()) + } + + fn hmac_verify(&self, key: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), Self::Error> { + let mut mac: Hmac = + Hmac::new_from_slice(key).map_err(|_| TokenserverError::internal_error())?; + mac.update(payload); + mac.verify_slice(signature) + .map_err(|_| TokenserverError::internal_error())?; + Ok(()) + } + + fn rand_bytes(&self, output: &mut [u8]) -> Result<(), Self::Error> { + let rng = SystemRandom::new(); + rng.fill(output) + .map_err(|_| TokenserverError::internal_error())?; + Ok(()) + } +} + +/// OAuthVerifyError captures the errors possible while verifing an OAuth JWT access token +#[derive(Debug, thiserror::Error)] +pub enum OAuthVerifyError { + #[error("The signature has expired")] + ExpiredSignature, + #[error("Untrusted token")] + TrustError, + #[error("Invalid Key")] + InvalidKey, + #[error("Error decoding JWT")] + DecodingError, + #[error("The key was well formatted, but the signature was invalid")] + InvalidSignature, +} + +impl OAuthVerifyError { + pub fn metric_label(&self) -> &'static str { + match self { + Self::ExpiredSignature => "oauth.error.expired_signature", + Self::TrustError => "oauth.error.trust_error", + Self::InvalidKey => "oauth.error.invalid_key", + Self::InvalidSignature => "oauth.error.invalid_signature", + Self::DecodingError => "oauth.error.decoding_error", + } + } + + pub fn is_reportable_err(&self) -> bool { + matches!(self, Self::InvalidKey | Self::DecodingError) + } +} + +impl From for OAuthVerifyError { + fn from(value: jsonwebtoken::errors::Error) -> Self { + match value.kind() { + ErrorKind::InvalidKeyFormat => OAuthVerifyError::InvalidKey, + ErrorKind::InvalidSignature => OAuthVerifyError::InvalidSignature, + ErrorKind::ExpiredSignature => OAuthVerifyError::ExpiredSignature, + _ => OAuthVerifyError::DecodingError, + } + } +} + +/// A trait representing a JSON Web Token verifier +pub trait JWTVerifier: TryFrom + Sync + Send + Clone { + type Key: DeserializeOwned; + + fn verify(&self, token: &str) -> Result; +} + +/// An implementation of the JWT verifier using the jsonwebtoken crate +#[derive(Clone)] +pub struct JWTVerifierImpl { + key: DecodingKey, + validation: Validation, +} + +impl JWTVerifier for JWTVerifierImpl { + type Key = Jwk; + + fn verify(&self, token: &str) -> Result { + let token_data = jsonwebtoken::decode::(token, &self.key, &self.validation)?; + token_data + .header + .typ + .ok_or(OAuthVerifyError::TrustError) + .and_then(|typ| { + // Ref https://tools.ietf.org/html/rfc7515#section-4.1.9 the `typ` header + // is lowercase and has an implicit default `application/` prefix. + let typ = if !typ.contains('/') { + format!("application/{}", typ) + } else { + typ + }; + if typ.to_lowercase() != "application/at+jwt" { + return Err(OAuthVerifyError::TrustError); + } + Ok(typ) + })?; + Ok(token_data.claims) + } +} + +impl TryFrom for JWTVerifierImpl { + type Error = OAuthVerifyError; + fn try_from(value: Jwk) -> Result { + let decoding_key = + DecodingKey::from_jwk(&value).map_err(|_| OAuthVerifyError::InvalidKey)?; + let mut validation = Validation::new(Algorithm::RS256); + // The FxA OAuth ecosystem currently doesn't make good use of aud, and + // instead relies on scope for restricting which services can accept + // which tokens. So there's no value in checking it here, and in fact if + // we check it here, it fails because the right audience isn't being + // requested. + validation.validate_aud = false; + + Ok(Self { + key: decoding_key, + validation, + }) + } +} diff --git a/tokenserver-auth/src/lib.rs b/tokenserver-auth/src/lib.rs index 9cc2fce82a..f020cd313a 100644 --- a/tokenserver-auth/src/lib.rs +++ b/tokenserver-auth/src/lib.rs @@ -1,17 +1,21 @@ pub mod browserid; + +#[cfg(not(feature = "py"))] +mod crypto; + +#[cfg(not(feature = "py"))] +pub use crypto::{JWTVerifier, JWTVerifierImpl}; pub mod oauth; +mod token; +use syncserver_common::Metrics; +pub use token::Tokenlib; use std::fmt; use async_trait::async_trait; use dyn_clone::{self, DynClone}; -use pyo3::{ - prelude::{IntoPy, PyErr, PyModule, PyObject, Python}, - types::IntoPyDict, -}; use serde::{Deserialize, Serialize}; use tokenserver_common::TokenserverError; - /// Represents the origin of the token used by Sync clients to access their data. #[derive(Clone, Copy, Default, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] #[serde(rename_all = "lowercase")] @@ -33,7 +37,7 @@ impl fmt::Display for TokenserverOrigin { } /// The plaintext needed to build a token. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] pub struct MakeTokenPlaintext { pub node: String, pub fxa_kid: String, @@ -45,69 +49,6 @@ pub struct MakeTokenPlaintext { pub tokenserver_origin: TokenserverOrigin, } -impl IntoPy for MakeTokenPlaintext { - fn into_py(self, py: Python<'_>) -> PyObject { - let dict = [ - ("node", self.node), - ("fxa_kid", self.fxa_kid), - ("fxa_uid", self.fxa_uid), - ("hashed_device_id", self.hashed_device_id), - ("hashed_fxa_uid", self.hashed_fxa_uid), - ("tokenserver_origin", self.tokenserver_origin.to_string()), - ] - .into_py_dict(py); - - // These need to be set separately since they aren't strings, and - // Rust doesn't support heterogeneous arrays - dict.set_item("expires", self.expires).unwrap(); - dict.set_item("uid", self.uid).unwrap(); - - dict.into() - } -} - -/// An adapter to the tokenlib Python library. -pub struct Tokenlib; - -impl Tokenlib { - /// Builds the token and derived secret to be returned by Tokenserver. - pub fn get_token_and_derived_secret( - plaintext: MakeTokenPlaintext, - shared_secret: &str, - ) -> Result<(String, String), TokenserverError> { - Python::with_gil(|py| { - // `import tokenlib` - let module = PyModule::import(py, "tokenlib").map_err(|e| { - e.print_and_set_sys_last_vars(py); - e - })?; - // `kwargs = { 'secret': shared_secret }` - let kwargs = [("secret", shared_secret)].into_py_dict(py); - // `token = tokenlib.make_token(plaintext, **kwargs)` - let token = module - .getattr("make_token")? - .call((plaintext,), Some(kwargs)) - .map_err(|e| { - e.print_and_set_sys_last_vars(py); - e - }) - .and_then(|x| x.extract())?; - // `derived_secret = tokenlib.get_derived_secret(token, **kwargs)` - let derived_secret = module - .getattr("get_derived_secret")? - .call((&token,), Some(kwargs)) - .map_err(|e| { - e.print_and_set_sys_last_vars(py); - e - }) - .and_then(|x| x.extract())?; - // `return (token, derived_secret)` - Ok((token, derived_secret)) - }) - .map_err(pyerr_to_tokenserver_error) - } -} - /// Implementers of this trait can be used to verify tokens for Tokenserver. #[async_trait] pub trait VerifyToken: DynClone + Sync + Send { @@ -115,7 +56,11 @@ pub trait VerifyToken: DynClone + Sync + Send { /// Verifies the given token. This function is async because token verification often involves /// making a request to a remote server. - async fn verify(&self, token: String) -> Result; + async fn verify( + &self, + token: String, + metrics: &Metrics, + ) -> Result; } dyn_clone::clone_trait_object!( VerifyToken); @@ -131,16 +76,9 @@ pub struct MockVerifier { impl VerifyToken for MockVerifier { type Output = T; - async fn verify(&self, _token: String) -> Result { + async fn verify(&self, _token: String, _metrics: &Metrics) -> Result { self.valid .then(|| self.verify_output.clone()) .ok_or_else(|| TokenserverError::invalid_credentials("Unauthorized".to_owned())) } } - -fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError { - TokenserverError { - context: e.to_string(), - ..TokenserverError::internal_error() - } -} diff --git a/tokenserver-auth/src/oauth.rs b/tokenserver-auth/src/oauth.rs index fb61eb5faa..fd91fbbed3 100644 --- a/tokenserver-auth/src/oauth.rs +++ b/tokenserver-auth/src/oauth.rs @@ -1,18 +1,15 @@ -use async_trait::async_trait; -use pyo3::{ - prelude::{Py, PyAny, PyErr, PyModule, Python}, - types::{IntoPyDict, PyString}, -}; use serde::{Deserialize, Serialize}; -use serde_json; -use syncserver_common::BlockingThreadpool; -use tokenserver_common::TokenserverError; -use tokenserver_settings::{Jwk, Settings}; -use tokio::time; -use super::VerifyToken; +#[cfg(not(feature = "py"))] +mod native; +#[cfg(feature = "py")] +mod py; -use std::{sync::Arc, time::Duration}; +#[cfg(feature = "py")] +pub type Verifier = py::Verifier; + +#[cfg(not(feature = "py"))] +pub type Verifier = native::Verifier; /// The information extracted from a valid OAuth token. #[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] @@ -21,148 +18,3 @@ pub struct VerifyOutput { pub fxa_uid: String, pub generation: Option, } - -/// The verifier used to verify OAuth tokens. -#[derive(Clone)] -pub struct Verifier { - // Note that we do not need to use an Arc here, since Py is already a reference-counted - // pointer - inner: Py, - timeout: u64, - blocking_threadpool: Arc, -} - -impl Verifier { - const FILENAME: &'static str = "verify.py"; - - pub fn new( - settings: &Settings, - blocking_threadpool: Arc, - ) -> Result { - let inner: Py = Python::with_gil::<_, Result, PyErr>>(|py| { - let code = include_str!("verify.py"); - let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME)?; - let kwargs = { - let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py); - let parse_jwk = |jwk: &Jwk| { - let dict = [ - ("kty", &jwk.kty), - ("alg", &jwk.alg), - ("kid", &jwk.kid), - ("use", &jwk.use_of_key), - ("n", &jwk.n), - ("e", &jwk.e), - ] - .into_py_dict(py); - dict.set_item("fxa-createdAt", jwk.fxa_created_at).unwrap(); - - dict - }; - - let jwks = match ( - &settings.fxa_oauth_primary_jwk, - &settings.fxa_oauth_secondary_jwk, - ) { - (Some(primary_jwk), Some(secondary_jwk)) => { - Some(vec![parse_jwk(primary_jwk), parse_jwk(secondary_jwk)]) - } - (Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)]), - (None, None) => None, - }; - dict.set_item("jwks", jwks).unwrap(); - dict - }; - let object: Py = module - .getattr("FxaOAuthClient")? - .call((), Some(kwargs)) - .map_err(|e| { - e.print_and_set_sys_last_vars(py); - e - })? - .into(); - - Ok(object) - }) - .map_err(super::pyerr_to_tokenserver_error)?; - - Ok(Self { - inner, - timeout: settings.fxa_oauth_request_timeout, - blocking_threadpool, - }) - } -} - -#[async_trait] -impl VerifyToken for Verifier { - type Output = VerifyOutput; - - /// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError` - /// for invalid tokens. - async fn verify(&self, token: String) -> Result { - // We don't want to move `self` into the body of the closure here because we'd need to - // clone it. Cloning it is only necessary if we need to verify the token remotely via FxA, - // since that would require passing `self` to a separate thread. Passing &Self to a closure - // gives us the flexibility to clone only when necessary. - let verify_inner = |verifier: &Self| { - let maybe_verify_output_string = Python::with_gil(|py| { - let client = verifier.inner.as_ref(py); - // `client.verify_token(token)` - let result: &PyAny = client - .getattr("verify_token")? - .call((token,), None) - .map_err(|e| { - e.print_and_set_sys_last_vars(py); - e - })?; - - if result.is_none() { - Ok(None) - } else { - let verify_output_python_string = result.downcast::()?; - verify_output_python_string.extract::().map(Some) - } - }) - .map_err(|e| TokenserverError { - context: format!("pyo3 error in OAuth verifier: {}", e), - ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) - })?; - - match maybe_verify_output_string { - Some(verify_output_string) => { - serde_json::from_str::(&verify_output_string).map_err(|e| { - TokenserverError { - context: format!("Invalid OAuth verify output: {}", e), - ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) - } - }) - } - None => Err(TokenserverError { - context: "Invalid OAuth token".to_owned(), - ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) - }), - } - }; - - let verifier = self.clone(); - - // If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key - // type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To - // improve performance, we make the request on a thread in a threadpool specifically - // used for blocking operations. The JWK should _always_ be cached in production to - // maximize performance. - let fut = self - .blocking_threadpool - .spawn(move || verify_inner(&verifier)); - - // The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here - // by timing out the future if the verification process blocks this thread for longer - // than the specified number of seconds. - time::timeout(Duration::from_secs(self.timeout), fut) - .await - .map_err(|_| TokenserverError { - context: "OAuth verification timeout".to_owned(), - ..TokenserverError::resource_unavailable() - })? - } -} diff --git a/tokenserver-auth/src/oauth/native.rs b/tokenserver-auth/src/oauth/native.rs new file mode 100644 index 0000000000..67ec88aad1 --- /dev/null +++ b/tokenserver-auth/src/oauth/native.rs @@ -0,0 +1,557 @@ +use super::VerifyOutput; +pub use crate::crypto::JWTVerifier; +use crate::crypto::OAuthVerifyError; +use crate::VerifyToken; +use async_trait::async_trait; +use reqwest::Url; +use serde::{Deserialize, Serialize}; +use std::{borrow::Cow, time::Duration}; +use syncserver_common::Metrics; +use tokenserver_common::TokenserverError; +use tokenserver_settings::Settings; + +const SYNC_SCOPE: &str = "https://identity.mozilla.com/apps/oldsync"; + +#[derive(Serialize, Deserialize, Debug)] +struct TokenClaims { + #[serde(rename = "sub")] + user: String, + scope: String, + #[serde(rename = "fxa-generation")] + generation: Option, +} + +impl TokenClaims { + fn validate(self) -> Result { + if !self.scope.split(',').any(|scope| scope == SYNC_SCOPE) { + return Err(TokenserverError::invalid_credentials( + "Unauthorized".to_string(), + )); + } + Ok(self.into()) + } +} + +impl From for VerifyOutput { + fn from(value: TokenClaims) -> Self { + Self { + fxa_uid: value.user, + generation: value.generation, + } + } +} + +/// The verifier used to verify OAuth tokens. +#[derive(Clone)] +pub struct Verifier { + verify_url: Url, + jwks_url: Url, + jwk_verifiers: Vec, + http_client: reqwest::Client, +} + +impl Verifier +where + J: JWTVerifier, +{ + pub fn new(settings: &Settings, jwk_verifiers: Vec) -> Result { + let base_url = Url::parse(&settings.fxa_oauth_server_url) + .map_err(|_| TokenserverError::internal_error())?; + let verify_url = base_url + .join("v1/verify") + .map_err(|_| TokenserverError::internal_error())?; + let jwks_url = base_url + .join("v1/jwks") + .map_err(|_| TokenserverError::internal_error())?; + let http_client = reqwest::Client::builder() + .timeout(Duration::from_secs(settings.fxa_oauth_request_timeout)) + .use_rustls_tls() + .build() + .map_err(|_| TokenserverError::internal_error())?; + + Ok(Self { + verify_url, + jwks_url, + jwk_verifiers, + http_client, + }) + } + + async fn remote_verify_token(&self, token: &str) -> Result { + #[derive(Serialize)] + struct VerifyRequest<'a> { + token: &'a str, + } + + #[derive(Serialize, Deserialize)] + struct VerifyResponse { + user: String, + scope: Vec, + generation: Option, + } + + impl From for TokenClaims { + fn from(value: VerifyResponse) -> Self { + Self { + user: value.user, + scope: value.scope.join(","), + generation: value.generation, + } + } + } + + Ok(self + .http_client + .post(self.verify_url.clone()) + .json(&VerifyRequest { token }) + .send() + .await + .map_err(unauthorized_err_with_ctx) + .and_then(|res| { + if !res.status().is_success() { + Err(unauthorized_err_with_ctx(format!( + "Got verify status code: {}", + res.status() + ))) + } else { + Ok(res) + } + })? + .json::() + .await + .map_err(unauthorized_err_with_ctx)? + .into()) + } + + async fn get_remote_jwks(&self) -> Result, TokenserverError> { + #[derive(Deserialize)] + struct KeysResponse { + keys: Vec, + } + self.http_client + .get(self.jwks_url.clone()) + .send() + .await + .map_err(internal_err_with_ctx)? + .json::>() + .await + .map_err(internal_err_with_ctx)? + .keys + .into_iter() + .map(|key| key.try_into().map_err(internal_err_with_ctx)) + .collect() + } + + fn verify_jwt_locally( + &self, + verifiers: &[Cow<'_, J>], + token: &str, + ) -> Result { + if verifiers.is_empty() { + return Err(OAuthVerifyError::InvalidKey); + } + + verifiers + .iter() + .find_map(|verifier| { + match verifier.verify::(token) { + // If it's an invalid signature, it means our key was well formatted, + // but the signature was incorrect. Lets try another key if we have any + Err(OAuthVerifyError::InvalidSignature) => None, + res => Some(res), + } + }) + // If there is nothing, it means all of our keys were well formatted, but none of them + // were able to verify the signature, lets erturn a TrustError + .ok_or(OAuthVerifyError::TrustError)? + } +} + +#[async_trait] +impl VerifyToken for Verifier +where + J: JWTVerifier, +{ + type Output = VerifyOutput; + + /// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError` + /// for invalid tokens. + /// + /// The verifier will first attempt to verify the token using FxA's public keys, which were + /// provided as environment variables. + /// + /// If FxA's public keys were not supplied, then the verifier will query FxA's /v1/jwks + /// endpoint to get the latest public keys. + /// + /// If verifying the tokens fails because the keys are + /// invalid, or because the keys were valid but the tokens have changed their structure, then + /// the verifier will fallback to hitting fxa's /v1/verify endpoint to verify instead. All + /// other failures will be recorded as invalid credentials and will returns a generic "Unauthorized" message + /// to the user + async fn verify( + &self, + token: String, + metrics: &Metrics, + ) -> Result { + let mut verifiers = self + .jwk_verifiers + .iter() + .map(Cow::Borrowed) + .collect::>(); + if self.jwk_verifiers.is_empty() { + verifiers = self + .get_remote_jwks() + .await + .unwrap_or_else(|e| { + slog_scope::warn!("Error requesting remote jwks: {}", e); + vec![] + }) + .into_iter() + .map(Cow::Owned) + .collect(); + } + + let claims = match self.verify_jwt_locally(&verifiers, &token) { + Ok(res) => res, + Err(e) => { + if e.is_reportable_err() { + metrics.incr(e.metric_label()) + } + match e { + OAuthVerifyError::DecodingError | OAuthVerifyError::InvalidKey => { + self.remote_verify_token(&token).await? + } + e => return Err(unauthorized_err_with_ctx(e)), + } + } + }; + claims.validate() + } +} + +fn unauthorized_err_with_ctx(err: E) -> TokenserverError { + TokenserverError { + context: err.to_string(), + ..TokenserverError::invalid_credentials("Unauthorized".to_string()) + } +} + +fn internal_err_with_ctx(err: E) -> TokenserverError { + TokenserverError { + context: err.to_string(), + ..TokenserverError::internal_error() + } +} + +#[cfg(test)] +mod tests { + use crate::crypto::{JWTVerifierImpl, OAuthVerifyError}; + use serde_json::json; + + use super::*; + #[derive(Deserialize)] + struct MockJWK {} + + macro_rules! mock_jwk_verifier { + ($im:expr) => { + mock_jwk_verifier!(_token, $im); + }; + ($token:ident, $im:expr) => { + #[derive(Clone, Debug)] + struct MockJWTVerifier {} + impl TryFrom for MockJWTVerifier { + type Error = OAuthVerifyError; + fn try_from(_value: MockJWK) -> Result { + Ok(Self {}) + } + } + + impl JWTVerifier for MockJWTVerifier { + type Key = MockJWK; + fn verify( + &self, + $token: &str, + ) -> Result { + $im + } + } + }; + } + + #[tokio::test] + async fn test_no_keys_in_verifier_fallsback_to_fxa() -> Result<(), TokenserverError> { + let mock_jwks = mockito::mock("GET", "/v1/jwks").with_status(500).create(); + + let body = json!({ + "user": "fxa_id", + "scope": [SYNC_SCOPE], + "generation": 123 + }); + let mock_verify = mockito::mock("POST", "/v1/verify") + .with_header("content-type", "application/json") + .with_status(200) + .with_body(body.to_string()) + .create(); + + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Default::default() + }; + let verifer: Verifier = Verifier::new(&settings, vec![])?; + let res = verifer + .verify("a token fxa will validate".to_string(), &Default::default()) + .await?; + mock_jwks.expect(1); + mock_verify.expect(1); + assert_eq!(res.generation.unwrap(), 123); + assert_eq!(res.fxa_uid, "fxa_id"); + Ok(()) + } + + #[tokio::test] + async fn test_expired_signature_fails() -> Result<(), TokenserverError> { + let mock = mockito::mock("POST", "/v1/verify").create(); + mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature)); + + let jwk_verifiers = vec![MockJWTVerifier {}]; + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers)?; + + let err = verifier + .verify("An expired token".to_string(), &Default::default()) + .await + .unwrap_err(); + // We also make sure we didn't try to hit the server + mock.expect(0); + assert_eq!(err.status, "invalid-credentials"); + assert_eq!(err.http_status, 401); + assert_eq!(err.description, "Unauthorized"); + + Ok(()) + } + + #[tokio::test] + async fn test_verifier_attempts_all_keys_if_invalid_signature() -> Result<(), TokenserverError> + { + let mock = mockito::mock("POST", "/v1/verify").create(); + #[derive(Debug, Clone)] + struct MockJWTVerifier { + id: u8, + } + + impl TryFrom for MockJWTVerifier { + type Error = OAuthVerifyError; + fn try_from(_value: MockJWK) -> Result { + Ok(Self { id: 0 }) + } + } + + impl JWTVerifier for MockJWTVerifier { + type Key = MockJWK; + fn verify( + &self, + token: &str, + ) -> Result { + if self.id == 0 { + Err(OAuthVerifyError::InvalidSignature) + } else { + Ok(serde_json::from_str(token).unwrap()) + } + } + } + + let jwk_verifiers = vec![MockJWTVerifier { id: 0 }, MockJWTVerifier { id: 1 }]; + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + + let token_claims = TokenClaims { + user: "fxa_id".to_string(), + scope: SYNC_SCOPE.to_string(), + generation: Some(124), + }; + + let res = verifier + .verify( + serde_json::to_string(&token_claims).unwrap(), + &Default::default(), + ) + .await?; + assert_eq!(res.fxa_uid, "fxa_id"); + assert_eq!(res.generation.unwrap(), 124); + mock.expect(0); // We shouldn't have hit the server + Ok(()) + } + + #[tokio::test] + async fn test_verifier_all_signature_failures_fails() -> Result<(), TokenserverError> { + let mock_verify = mockito::mock("POST", "/v1/verify").create(); + mock_jwk_verifier!(Err(OAuthVerifyError::InvalidSignature)); + + let jwk_verifiers = vec![MockJWTVerifier {}, MockJWTVerifier {}]; + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + let err = verifier + .verify( + "a token with an invalid signature".to_string(), + &Default::default(), + ) + .await + .unwrap_err(); + assert_eq!(err.status, "invalid-credentials"); + assert_eq!(err.http_status, 401); + assert_eq!(err.description, "Unauthorized"); + + mock_verify.expect(0); + Ok(()) + } + + #[tokio::test] + async fn test_verifier_fallsback_if_decode_error() -> Result<(), TokenserverError> { + let body = json!({ + "user": "fxa_id", + "scope": [SYNC_SCOPE], + "generation": 123 + }); + let mock_verify = mockito::mock("POST", "/v1/verify") + .with_header("content-type", "application/json") + .with_status(200) + .with_body(body.to_string()) + .create(); + + mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError)); + + let jwk_verifiers = vec![MockJWTVerifier {}]; + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + + let res = verifier + .verify( + "invalid token that can't be decoded".to_string(), + &Default::default(), + ) + .await?; + assert_eq!(res.fxa_uid, "fxa_id"); + assert_eq!(res.generation.unwrap(), 123); + mock_verify.expect(1); // We would have have hit the server + Ok(()) + } + + #[tokio::test] + async fn test_no_sync_scope_fails() -> Result<(), TokenserverError> { + let token_claims = TokenClaims { + user: "fxa_id".to_string(), + scope: "some other scope".to_string(), + generation: Some(124), + }; + mock_jwk_verifier!(token, Ok(serde_json::from_str(token).unwrap())); + let jwk_verifiers = vec![MockJWTVerifier {}]; + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + let err = verifier + .verify( + serde_json::to_string(&token_claims).unwrap(), + &Default::default(), + ) + .await + .unwrap_err(); + assert_eq!(err.status, "invalid-credentials"); + assert_eq!(err.http_status, 401); + assert_eq!(err.description, "Unauthorized"); + + Ok(()) + } + + #[tokio::test] + async fn test_fxa_rejects_token_no_matter_the_body() -> Result<(), TokenserverError> { + let body = json!({ + "user": "fxa_id", + "scope": [SYNC_SCOPE], + "generation": 123 + }); + let mock_verify = mockito::mock("POST", "/v1/verify") + .with_header("content-type", "application/json") + .with_status(401) + // Even though the body is fine, if FxA returns a none-200, we automatically + // return a credential error + .with_body(body.to_string()) + .create(); + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + + mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError)); + let jwk_verifiers = vec![]; + + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + + let err = verifier + .verify( + "A token that we will ask FxA about".to_string(), + &Default::default(), + ) + .await + .unwrap_err(); + assert_eq!(err.status, "invalid-credentials"); + assert_eq!(err.http_status, 401); + assert_eq!(err.description, "Unauthorized"); + mock_verify.expect(1); + + Ok(()) + } + + #[tokio::test] + async fn test_fxa_accepts_token_but_bad_body() -> Result<(), TokenserverError> { + let body = json!({ + "bad_key": "foo", + "scope": [SYNC_SCOPE], + "bad_genreation": 123 + }); + let mock_verify = mockito::mock("POST", "/v1/verify") + .with_header("content-type", "application/json") + .with_status(200) + // Even though the body is valid json, it doesn't match our expectation so we'll error + // out + .with_body(body.to_string()) + .create(); + let settings = Settings { + fxa_oauth_server_url: mockito::server_url(), + ..Settings::default() + }; + + mock_jwk_verifier!(Err(OAuthVerifyError::DecodingError)); + let jwk_verifiers = vec![]; + + let verifier: Verifier = Verifier::new(&settings, jwk_verifiers).unwrap(); + + let err = verifier + .verify( + "A token that we will ask FxA about".to_string(), + &Default::default(), + ) + .await + .unwrap_err(); + assert_eq!(err.status, "invalid-credentials"); + assert_eq!(err.http_status, 401); + assert_eq!(err.description, "Unauthorized"); + mock_verify.expect(1); + + Ok(()) + } +} diff --git a/tokenserver-auth/src/oauth/py.rs b/tokenserver-auth/src/oauth/py.rs new file mode 100644 index 0000000000..5e3d816a9f --- /dev/null +++ b/tokenserver-auth/src/oauth/py.rs @@ -0,0 +1,195 @@ +use async_trait::async_trait; +use jsonwebtoken::jwk::{AlgorithmParameters, Jwk, PublicKeyUse, RSAKeyParameters}; +use pyo3::{ + prelude::{Py, PyAny, PyErr, PyModule, Python}, + types::{IntoPyDict, PyString}, +}; +use serde_json; +use syncserver_common::{BlockingThreadpool, Metrics}; +use tokenserver_common::TokenserverError; +use tokenserver_settings::Settings; +use tokio::time; + +use super::VerifyOutput; +use crate::VerifyToken; + +use std::{sync::Arc, time::Duration}; + +/// The verifier used to verify OAuth tokens. +#[derive(Clone)] +pub struct Verifier { + // Note that we do not need to use an Arc here, since Py is already a reference-counted + // pointer + inner: Py, + timeout: u64, + blocking_threadpool: Arc, +} + +impl Verifier { + const FILENAME: &'static str = "verify.py"; + + pub fn new( + settings: &Settings, + blocking_threadpool: Arc, + ) -> Result { + let inner: Py = Python::with_gil::<_, Result, TokenserverError>>(|py| { + let code = include_str!("verify.py"); + let module = PyModule::from_code(py, code, Self::FILENAME, Self::FILENAME) + .map_err(pyerr_to_tokenserver_error)?; + let kwargs = { + let dict = [("server_url", &settings.fxa_oauth_server_url)].into_py_dict(py); + let parse_jwk = |jwk: &Jwk| { + let (n, e) = match &jwk.algorithm { + AlgorithmParameters::RSA(RSAKeyParameters { key_type: _, n, e }) => (n, e), + _ => return Err(TokenserverError::internal_error()), + }; + let alg = jwk + .common + .key_algorithm + .ok_or_else(TokenserverError::internal_error)? + .to_string(); + let kid = jwk + .common + .key_id + .as_ref() + .ok_or_else(TokenserverError::internal_error)?; + if !matches!( + jwk.common + .public_key_use + .as_ref() + .ok_or_else(TokenserverError::internal_error)?, + PublicKeyUse::Signature + ) { + return Err(TokenserverError::internal_error()); + } + + let dict = [ + ("kty", "RSA"), + ("alg", &alg), + ("kid", kid), + ("use", "sig"), + ("n", &n), + ("e", e), + ] + .into_py_dict(py); + Ok(dict) + }; + + let jwks = match ( + &settings.fxa_oauth_primary_jwk, + &settings.fxa_oauth_secondary_jwk, + ) { + (Some(primary_jwk), Some(secondary_jwk)) => { + Some(vec![parse_jwk(primary_jwk)?, parse_jwk(secondary_jwk)?]) + } + (Some(jwk), None) | (None, Some(jwk)) => Some(vec![parse_jwk(jwk)?]), + (None, None) => None, + }; + dict.set_item("jwks", jwks).unwrap(); + dict + }; + let object: Py = module + .getattr("FxaOAuthClient") + .map_err(pyerr_to_tokenserver_error)? + .call((), Some(kwargs)) + .map_err(|e| { + e.print_and_set_sys_last_vars(py); + pyerr_to_tokenserver_error(e) + })? + .into(); + + Ok(object) + })?; + + Ok(Self { + inner, + timeout: settings.fxa_oauth_request_timeout, + blocking_threadpool, + }) + } +} + +#[async_trait] +impl VerifyToken for Verifier { + type Output = VerifyOutput; + + /// Verifies an OAuth token. Returns `VerifyOutput` for valid tokens and a `TokenserverError` + /// for invalid tokens. + async fn verify( + &self, + token: String, + _metrics: &Metrics, + ) -> Result { + // We don't want to move `self` into the body of the closure here because we'd need to + // clone it. Cloning it is only necessary if we need to verify the token remotely via FxA, + // since that would require passing `self` to a separate thread. Passing &Self to a closure + // gives us the flexibility to clone only when necessary. + let verify_inner = |verifier: &Self| { + let maybe_verify_output_string = Python::with_gil(|py| { + let client = verifier.inner.as_ref(py); + // `client.verify_token(token)` + let result: &PyAny = client + .getattr("verify_token")? + .call((token,), None) + .map_err(|e| { + e.print_and_set_sys_last_vars(py); + e + })?; + + if result.is_none() { + Ok(None) + } else { + let verify_output_python_string = result.downcast::()?; + verify_output_python_string.extract::().map(Some) + } + }) + .map_err(|e| TokenserverError { + context: format!("pyo3 error in OAuth verifier: {}", e), + ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) + })?; + + match maybe_verify_output_string { + Some(verify_output_string) => { + serde_json::from_str::(&verify_output_string).map_err(|e| { + TokenserverError { + context: format!("Invalid OAuth verify output: {}", e), + ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) + } + }) + } + None => Err(TokenserverError { + context: "Invalid OAuth token".to_owned(), + ..TokenserverError::invalid_credentials("Unauthorized".to_owned()) + }), + } + }; + + let verifier = self.clone(); + + // If the JWK is not cached or if the token is not a JWT/wasn't signed by a known key + // type, PyFxA will make a request to FxA to retrieve it, blocking this thread. To + // improve performance, we make the request on a thread in a threadpool specifically + // used for blocking operations. The JWK should _always_ be cached in production to + // maximize performance. + let fut = self + .blocking_threadpool + .spawn(move || verify_inner(&verifier)); + + // The PyFxA OAuth client does not offer a way to set a request timeout, so we set one here + // by timing out the future if the verification process blocks this thread for longer + // than the specified number of seconds. + time::timeout(Duration::from_secs(self.timeout), fut) + .await + .map_err(|_| TokenserverError { + context: "OAuth verification timeout".to_owned(), + ..TokenserverError::resource_unavailable() + })? + } +} + +fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError { + TokenserverError { + context: e.to_string(), + ..TokenserverError::internal_error() + } +} diff --git a/tokenserver-auth/src/verify.py b/tokenserver-auth/src/oauth/verify.py similarity index 100% rename from tokenserver-auth/src/verify.py rename to tokenserver-auth/src/oauth/verify.py diff --git a/tokenserver-auth/src/token.rs b/tokenserver-auth/src/token.rs new file mode 100644 index 0000000000..0d5877379f --- /dev/null +++ b/tokenserver-auth/src/token.rs @@ -0,0 +1,11 @@ +#[cfg(not(feature = "py"))] +mod native; + +#[cfg(feature = "py")] +mod py; + +#[cfg(feature = "py")] +pub type Tokenlib = py::PyTokenlib; + +#[cfg(not(feature = "py"))] +pub type Tokenlib = native::Tokenlib; diff --git a/tokenserver-auth/src/token/native.rs b/tokenserver-auth/src/token/native.rs new file mode 100644 index 0000000000..496255c754 --- /dev/null +++ b/tokenserver-auth/src/token/native.rs @@ -0,0 +1,113 @@ +use crate::{ + crypto::{Crypto, CryptoImpl}, + MakeTokenPlaintext, +}; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use tokenserver_common::TokenserverError; +// Those two constants were pulled directly from +// https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L43-L45 +// We could change them, but we'd want to make sure that we also change them syncstorage, however +// that would cause temporary auth issues for anyone with an old pre-new-value token +const HKDF_SIGNING_INFO: &[u8] = b"services.mozilla.com/tokenlib/v1/signing"; +const HKDF_INFO_DERIVE: &[u8] = b"services.mozilla.com/tokenlib/v1/derive/"; + +pub struct Tokenlib {} + +#[derive(Debug, Serialize, Deserialize)] +struct Token<'a> { + #[serde(flatten)] + plaintext: MakeTokenPlaintext, + salt: &'a str, +} + +impl Tokenlib { + pub fn get_token_and_derived_secret( + plaintext: MakeTokenPlaintext, + shared_secret: &str, + ) -> Result<(String, String), TokenserverError> { + // First we make the token itself, the code blow was ported from: + // https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L96-L97 + let crypto_lib = CryptoImpl {}; + let mut salt_bytes = [0u8; 3]; + crypto_lib.rand_bytes(&mut salt_bytes)?; + let salt = hex::encode(salt_bytes); + let token_str = serde_json::to_string(&Token { + plaintext, + salt: &salt, + }) + .map_err(|_| TokenserverError::internal_error())?; + let hmac_key = crypto_lib.hkdf(shared_secret, None, HKDF_SIGNING_INFO)?; + let signature = crypto_lib.hmac_sign(&hmac_key, token_str.as_bytes())?; + let mut token_bytes = Vec::with_capacity(token_str.len() + signature.len()); + token_bytes.extend_from_slice(token_str.as_bytes()); + token_bytes.extend_from_slice(&signature); + let token = base64::engine::general_purpose::URL_SAFE.encode(token_bytes); + // Now that we finialized the token, lets generate our per token secret + // The code below was ported from: + // https://github.com/mozilla-services/tokenlib/blob/91ec9e2c922e55306eddba1394590a88f3b10602/tokenlib/__init__.py#L158-L159 + let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + token.as_bytes().len()); + info.extend_from_slice(HKDF_INFO_DERIVE); + info.extend_from_slice(token.as_bytes()); + + let per_token_secret = crypto_lib.hkdf(shared_secret, Some(salt.as_bytes()), &info)?; + let per_token_secret = base64::engine::general_purpose::URL_SAFE.encode(per_token_secret); + Ok((token, per_token_secret)) + } +} + +#[cfg(test)] +mod tests { + use crate::{crypto::SHA256_OUTPUT_LEN, TokenserverOrigin}; + + use super::*; + + #[test] + fn test_generate_valid_token_and_per_token_secret() -> Result<(), TokenserverError> { + // First we verify that the token we generated has a valid + // and correct HMAC signature if signed using the same key + let plaintext = MakeTokenPlaintext { + node: "https://www.example.com".to_string(), + fxa_kid: "kid".to_string(), + fxa_uid: "user uid".to_string(), + hashed_fxa_uid: "hased uid".to_string(), + hashed_device_id: "hashed device id".to_string(), + expires: 1031, + uid: 13, + tokenserver_origin: TokenserverOrigin::Rust, + }; + let secret = "foobar"; + let crypto_impl = CryptoImpl {}; + let hmac_key = crypto_impl.hkdf(secret, None, HKDF_SIGNING_INFO).unwrap(); + let (b64_token, per_token_secret) = + Tokenlib::get_token_and_derived_secret(plaintext.clone(), secret).unwrap(); + let token = base64::engine::general_purpose::URL_SAFE + .decode(&b64_token) + .unwrap(); + let token_size = token.len(); + let signature = &token[token_size - SHA256_OUTPUT_LEN..]; + let payload = &token[..token_size - SHA256_OUTPUT_LEN]; + crypto_impl + .hmac_verify(&hmac_key, payload, signature) + .unwrap(); + // Then we verify that the payload value we signed, is a valid + // Token represented by our Token struct, and has exactly the same + // plain_text values + let token_data = serde_json::from_slice::>(payload).unwrap(); + assert_eq!(token_data.plaintext, plaintext); + // Finally, we verify that the same per_token_secret can be derived given the payload + // and the shared secret + let mut info = Vec::with_capacity(HKDF_INFO_DERIVE.len() + b64_token.as_bytes().len()); + info.extend_from_slice(HKDF_INFO_DERIVE); + info.extend_from_slice(b64_token.as_bytes()); + + let expected_per_token_secret = + crypto_impl.hkdf(secret, Some(token_data.salt.as_bytes()), &info)?; + let expected_per_token_secret = + base64::engine::general_purpose::URL_SAFE.encode(expected_per_token_secret); + + assert_eq!(expected_per_token_secret, per_token_secret); + + Ok(()) + } +} diff --git a/tokenserver-auth/src/token/py.rs b/tokenserver-auth/src/token/py.rs new file mode 100644 index 0000000000..ccc91fde30 --- /dev/null +++ b/tokenserver-auth/src/token/py.rs @@ -0,0 +1,71 @@ +use crate::{MakeTokenPlaintext, TokenserverError}; +use pyo3::{ + prelude::{IntoPy, PyErr, PyModule, PyObject, Python}, + types::IntoPyDict, +}; + +pub struct PyTokenlib {} +impl IntoPy for MakeTokenPlaintext { + fn into_py(self, py: Python<'_>) -> PyObject { + let dict = [ + ("node", self.node), + ("fxa_kid", self.fxa_kid), + ("fxa_uid", self.fxa_uid), + ("hashed_device_id", self.hashed_device_id), + ("hashed_fxa_uid", self.hashed_fxa_uid), + ("tokenserver_origin", self.tokenserver_origin.to_string()), + ] + .into_py_dict(py); + + // These need to be set separately since they aren't strings, and + // Rust doesn't support heterogeneous arrays + dict.set_item("expires", self.expires).unwrap(); + dict.set_item("uid", self.uid).unwrap(); + + dict.into() + } +} +impl PyTokenlib { + pub fn get_token_and_derived_secret( + plaintext: MakeTokenPlaintext, + shared_secret: &str, + ) -> Result<(String, String), TokenserverError> { + Python::with_gil(|py| { + // `import tokenlib` + let module = PyModule::import(py, "tokenlib").map_err(|e| { + e.print_and_set_sys_last_vars(py); + e + })?; + // `kwargs = { 'secret': shared_secret }` + let kwargs = [("secret", shared_secret)].into_py_dict(py); + // `token = tokenlib.make_token(plaintext, **kwargs)` + let token = module + .getattr("make_token")? + .call((plaintext,), Some(kwargs)) + .map_err(|e| { + e.print_and_set_sys_last_vars(py); + e + }) + .and_then(|x| x.extract())?; + // `derived_secret = tokenlib.get_derived_secret(token, **kwargs)` + let derived_secret = module + .getattr("get_derived_secret")? + .call((&token,), Some(kwargs)) + .map_err(|e| { + e.print_and_set_sys_last_vars(py); + e + }) + .and_then(|x| x.extract())?; + // `return (token, derived_secret)` + Ok((token, derived_secret)) + }) + .map_err(pyerr_to_tokenserver_error) + } +} + +fn pyerr_to_tokenserver_error(e: PyErr) -> TokenserverError { + TokenserverError { + context: e.to_string(), + ..TokenserverError::internal_error() + } +} diff --git a/tokenserver-common/Cargo.toml b/tokenserver-common/Cargo.toml index 8d2b67b927..85066f4e0e 100644 --- a/tokenserver-common/Cargo.toml +++ b/tokenserver-common/Cargo.toml @@ -10,6 +10,7 @@ actix-web.workspace = true backtrace.workspace = true serde.workspace = true serde_json.workspace = true +jsonwebtoken.workspace = true +thiserror.workspace = true syncserver-common = { path = "../syncserver-common" } -thiserror = "1.0.26" diff --git a/tokenserver-db/Cargo.toml b/tokenserver-db/Cargo.toml index 43ff525dbf..98d9fe3804 100644 --- a/tokenserver-db/Cargo.toml +++ b/tokenserver-db/Cargo.toml @@ -13,6 +13,7 @@ serde.workspace = true serde_derive.workspace = true serde_json.workspace = true slog-scope.workspace = true +thiserror.workspace = true async-trait = "0.1.40" diesel = { version = "1.4", features = ["mysql", "r2d2"] } @@ -20,7 +21,6 @@ diesel_logger = "0.1.1" diesel_migrations = { version = "1.4.0", features = ["mysql"] } syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } -thiserror = "1.0.26" tokenserver-common = { path = "../tokenserver-common" } tokenserver-settings = { path = "../tokenserver-settings" } tokio = { workspace = true, features = ["macros", "sync"] } diff --git a/tokenserver-settings/Cargo.toml b/tokenserver-settings/Cargo.toml index 52b3eb335a..dc9b13e047 100644 --- a/tokenserver-settings/Cargo.toml +++ b/tokenserver-settings/Cargo.toml @@ -7,5 +7,6 @@ edition.workspace=true [dependencies] serde.workspace=true +jsonwebtoken.workspace=true tokenserver-common = { path = "../tokenserver-common" } diff --git a/tokenserver-settings/src/lib.rs b/tokenserver-settings/src/lib.rs index ab25475009..7b8ab41fd5 100644 --- a/tokenserver-settings/src/lib.rs +++ b/tokenserver-settings/src/lib.rs @@ -1,3 +1,4 @@ +use jsonwebtoken::jwk::Jwk; use serde::Deserialize; use tokenserver_common::NodeType; @@ -69,18 +70,6 @@ pub struct Settings { pub token_duration: u64, } -#[derive(Clone, Debug, Deserialize)] -pub struct Jwk { - pub kty: String, - pub alg: String, - pub kid: String, - pub fxa_created_at: u64, - #[serde(rename = "use")] - pub use_of_key: String, - pub n: String, - pub e: String, -} - impl Default for Settings { fn default() -> Settings { Settings { diff --git a/tools/integration_tests/tokenserver/test_e2e.py b/tools/integration_tests/tokenserver/test_e2e.py index 90899e20c1..85eab417cc 100644 --- a/tools/integration_tests/tokenserver/test_e2e.py +++ b/tools/integration_tests/tokenserver/test_e2e.py @@ -215,14 +215,14 @@ def test_valid_oauth_request(self): raw = urlsafe_b64decode(res.json['id']) payload = raw[:-32] signature = raw[-32:] - payload_dict = json.loads(payload.decode('utf-8')) + payload_str = payload.decode('utf-8') + payload_dict = json.loads(payload_str) # The `id` payload should include a field indicating the origin of the # token self.assertEqual(payload_dict['tokenserver_origin'], 'rust') signing_secret = self.TOKEN_SIGNING_SECRET - expected_token = tokenlib.make_token(payload_dict, - secret=signing_secret) - expected_signature = urlsafe_b64decode(expected_token)[-32:] + tm = tokenlib.TokenManager(secret=signing_secret) + expected_signature = tm._get_signature(payload_str.encode('utf8')) # Using the #compare_digest method here is not strictly necessary, as # this is not a security-sensitive situation, but it's good practice self.assertTrue(hmac.compare_digest(expected_signature, signature)) @@ -271,12 +271,11 @@ def test_valid_browserid_request(self): raw = urlsafe_b64decode(res.json['id']) payload = raw[:-32] signature = raw[-32:] - payload_dict = json.loads(payload.decode('utf-8')) + payload_str = payload.decode('utf-8') signing_secret = self.TOKEN_SIGNING_SECRET - expected_token = tokenlib.make_token(payload_dict, - secret=signing_secret) - expected_signature = urlsafe_b64decode(expected_token)[-32:] + tm = tokenlib.TokenManager(secret=signing_secret) + expected_signature = tm._get_signature(payload_str.encode('utf8')) # Using the #compare_digest method here is not strictly necessary, as # this is not a security-sensitive situation, but it's good practice self.assertTrue(hmac.compare_digest(expected_signature, signature))