diff --git a/src/pss.rs b/src/pss.rs index f58580fb..6b0ccbf5 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -196,9 +196,10 @@ pub(crate) fn verify( let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; - let mut em = pub_key.raw_encryption_primitive(sig, em_len)?; + let key_len = pub_key.size(); + let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify(hashed, &mut em, em_bits, None, digest) + emsa_pss_verify(hashed, &mut em[key_len - em_len ..], em_bits, None, digest) } pub(crate) fn verify_digest(pub_key: &PK, hashed: &[u8], sig: &[u8]) -> Result<()> @@ -212,9 +213,10 @@ where let em_bits = pub_key.n().bits() - 1; let em_len = (em_bits + 7) / 8; - let mut em = pub_key.raw_encryption_primitive(sig, em_len)?; + let key_len = pub_key.size(); + let mut em = pub_key.raw_encryption_primitive(sig, key_len)?; - emsa_pss_verify_digest::(hashed, &mut em, em_bits, None) + emsa_pss_verify_digest::(hashed, &mut em[key_len - em_len ..], em_bits, None) } /// SignPSS calculates the signature of hashed using RSASSA-PSS. @@ -257,7 +259,9 @@ fn generate_salt( salt_len: Option, digest_size: usize, ) -> Vec { - let salt_len = salt_len.unwrap_or_else(|| priv_key.size() - 2 - digest_size); + let em_bits = priv_key.n().bits() - 1; + let em_len = (em_bits + 7) / 8; + let salt_len = salt_len.unwrap_or_else(|| em_len - 2 - digest_size); let mut salt = vec![0; salt_len]; rng.fill_bytes(&mut salt[..]); @@ -481,7 +485,7 @@ fn emsa_pss_verify_pre<'a>( // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in // maskedDB are not all equal to zero, output "inconsistent" and // stop. - if db[0] & (0xFF << /*uint*/(8 - (8 * em_len - em_bits))) != 0 { + if db[0] & (0xFF_u8.checked_shl(8 - (8 * em_len - em_bits) as u32).unwrap_or(0)) != 0 { return Err(Error::Verification); } @@ -1344,4 +1348,21 @@ mod test { .expect("failed to verify"); } } + + #[test] + // Tests the corner case where the key is multiple of 8 + 1 bits long + fn test_sign_and_verify_2049bit_key() { + let plaintext = "Hello\n"; + let rng = ChaCha8Rng::from_seed([42; 32]); + let priv_key = RsaPrivateKey::new(&mut rng.clone(), 2049).unwrap(); + + let digest = Sha1::digest(plaintext.as_bytes()).to_vec(); + let sig = priv_key + .sign_with_rng(&mut rng.clone(), Pss::new::(), &digest) + .expect("failed to sign"); + + priv_key + .verify(Pss::new::(), &digest, &sig) + .expect("failed to verify"); + } }