Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pss sign/verify when key length is multiple of 8 + 1 bits. #263

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/pss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ pub(crate) fn verify<PK: PublicKey>(

let em_bits = pub_key.n().bits() - 1;
let em_len = (em_bits + 7) / 8;
let mut em = pub_key.raw_encryption_primitive(sig, em_len)?;
let key_len = pub_key.size();
let mut em = pub_key.raw_encryption_primitive(sig, key_len)?;

emsa_pss_verify(hashed, &mut em, em_bits, None, digest)
emsa_pss_verify(hashed, &mut em[key_len - em_len ..], em_bits, None, digest)
}

pub(crate) fn verify_digest<PK, D>(pub_key: &PK, hashed: &[u8], sig: &[u8]) -> Result<()>
Expand All @@ -212,9 +213,10 @@ where

let em_bits = pub_key.n().bits() - 1;
let em_len = (em_bits + 7) / 8;
let mut em = pub_key.raw_encryption_primitive(sig, em_len)?;
let key_len = pub_key.size();
let mut em = pub_key.raw_encryption_primitive(sig, key_len)?;

emsa_pss_verify_digest::<D>(hashed, &mut em, em_bits, None)
emsa_pss_verify_digest::<D>(hashed, &mut em[key_len - em_len ..], em_bits, None)
}

/// SignPSS calculates the signature of hashed using RSASSA-PSS.
Expand Down Expand Up @@ -257,7 +259,9 @@ fn generate_salt<T: CryptoRngCore + ?Sized, SK: PrivateKey>(
salt_len: Option<usize>,
digest_size: usize,
) -> Vec<u8> {
let salt_len = salt_len.unwrap_or_else(|| priv_key.size() - 2 - digest_size);
let em_bits = priv_key.n().bits() - 1;
let em_len = (em_bits + 7) / 8;
let salt_len = salt_len.unwrap_or_else(|| em_len - 2 - digest_size);

let mut salt = vec![0; salt_len];
rng.fill_bytes(&mut salt[..]);
Expand Down Expand Up @@ -481,7 +485,7 @@ fn emsa_pss_verify_pre<'a>(
// 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
if db[0] & (0xFF << /*uint*/(8 - (8 * em_len - em_bits))) != 0 {
if db[0] & (0xFF_u8.checked_shl(8 - (8 * em_len - em_bits) as u32).unwrap_or(0)) != 0 {
return Err(Error::Verification);
}

Expand Down Expand Up @@ -1344,4 +1348,21 @@ mod test {
.expect("failed to verify");
}
}

#[test]
// Tests the corner case where the key is multiple of 8 + 1 bits long
fn test_sign_and_verify_2049bit_key() {
let plaintext = "Hello\n";
let rng = ChaCha8Rng::from_seed([42; 32]);
let priv_key = RsaPrivateKey::new(&mut rng.clone(), 2049).unwrap();

let digest = Sha1::digest(plaintext.as_bytes()).to_vec();
let sig = priv_key
.sign_with_rng(&mut rng.clone(), Pss::new::<Sha1>(), &digest)
.expect("failed to sign");

priv_key
.verify(Pss::new::<Sha1>(), &digest, &sig)
.expect("failed to verify");
}
}