Skip to content

Commit d354308

Browse files
committed
Add new sign_pss/verify_pss
1 parent 7319d9d commit d354308

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

src/key.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ use rand::Rng;
99
#[cfg(feature = "serde1")]
1010
use serde::{Deserialize, Serialize};
1111
use zeroize::{Zeroize, ZeroizeOnDrop};
12+
use digest::Digest;
1213

1314
use crate::algorithms::generate_multi_prime_key;
1415
use crate::errors::{Error, Result as RsaResult};
15-
use crate::hash::{Hash, Hashes};
16-
use crate::padding::PaddingScheme;
16+
use crate::hash::Hash;
1717
use crate::pkcs1v15;
1818
use crate::pss;
1919

@@ -115,7 +115,7 @@ pub(crate) struct CRTValue {
115115

116116
impl From<RSAPrivateKey> for RSAPublicKey {
117117
fn from(mut private_key: RSAPrivateKey) -> Self {
118-
let mut broken_key = RSAPublicKey {
118+
let broken_key = RSAPublicKey {
119119
// Fast, no-allocation creation of a biguint.
120120
n: BigUint::new_native(Default::default()),
121121
e: BigUint::new_native(Default::default())
@@ -169,6 +169,20 @@ impl RSAPublicKey {
169169
) -> RsaResult<()> {
170170
pkcs1v15::verify(self, hash, hashed, sig)
171171
}
172+
173+
/// Verify that the given signature is valid using the PSS padding scheme.
174+
///
175+
/// The first parameter should be a pre-hashed message, using D as the
176+
/// hashing scheme.
177+
///
178+
/// The salt length is auto-detected.
179+
pub fn verify_pss<D: Digest>(
180+
&self,
181+
hashed: &[u8],
182+
sig: &[u8]
183+
) -> RsaResult<()> {
184+
pss::verify::<D>(self, hashed, sig)
185+
}
172186
}
173187

174188
impl RSAPrivateKey {
@@ -333,6 +347,25 @@ impl RSAPrivateKey {
333347
) -> RsaResult<Vec<u8>> {
334348
pkcs1v15::sign(Some(rng), self, hash, digest)
335349
}
350+
351+
/// Sign the given pre-hashed message using the PSS padding scheme. The
352+
/// message should be hashed using the Digest algorithm passed as a generic
353+
/// argument.
354+
///
355+
/// RNG is used for PSS salt generation, and if `blind` is true, it will
356+
/// also be used to blind the RSA encryption.
357+
///
358+
/// The length of the salt can be controlled with the salt_len parameter. If
359+
/// it is None, then it will be calculated to be as large as possible.
360+
pub fn sign_pss<D: Digest, R: Rng>(
361+
&self,
362+
rng: &mut R,
363+
digest: &[u8],
364+
salt_len: Option<usize>,
365+
blind: bool
366+
) -> RsaResult<Vec<u8>> {
367+
pss::sign::<R, D>(rng, self, digest, salt_len, blind)
368+
}
336369
}
337370

338371
/// Check that the public key is well formed and has an exponent within acceptable bounds.

src/pss.rs

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
use crate::pkcs1v15::copy_with_left_pad;
22
use crate::internals;
3-
use crate::hash::{Hashes, Hash};
4-
use crate::key::{RSAPrivateKey, PublicKey};
3+
use crate::key::{RSAPrivateKey, RSAPublicKey};
54
use crate::errors::{Error, Result};
65

76
use alloc::vec::Vec;
87
use num_bigint::BigUint;
98
use subtle::ConstantTimeEq;
10-
use sha2::{Digest, Sha256};
11-
use sha1::Sha1;
9+
use digest::Digest;
1210
use rand::Rng;
1311

14-
pub fn verify<K: PublicKey>(
15-
pub_key: &K,
16-
hash: &Hashes,
12+
pub fn verify<H: Digest>(
13+
pub_key: &RSAPublicKey,
1714
hashed: &[u8],
1815
sig: &[u8]) -> Result<()>
1916
{
@@ -33,60 +30,51 @@ pub fn verify<K: PublicKey>(
3330
let mut em = vec![0; em_len];
3431
copy_with_left_pad(&mut em, &m);
3532

36-
match hash {
37-
Hashes::SHA1 => {
38-
emsa_pss_verify(hashed, &mut em, em_bits, None, Sha1::new())
39-
},
40-
Hashes::SHA2_256 => {
41-
emsa_pss_verify(hashed, &mut em, em_bits, None, Sha256::new())
42-
},
43-
_ => unimplemented!()
44-
}
33+
emsa_pss_verify::<H>(hashed, &mut em, em_bits, None)
4534
}
4635

4736

4837
/// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
4938
/// Note that hashed must be the result of hashing the input message using the
5039
/// given hash function. The opts argument may be nil, in which case sensible
5140
/// defaults are used.
52-
pub fn sign<T: Rng>(rng: &mut T, priv_key: &RSAPrivateKey, hash: &Hashes, hashed: &[u8], salt_len: Option<usize>) -> Result<Vec<u8>> {
41+
pub fn sign<T: Rng, H: Digest>(rng: &mut T, priv_key: &RSAPrivateKey, hashed: &[u8], salt_len: Option<usize>, blind: bool) -> Result<Vec<u8>> {
5342
let salt_len = salt_len.unwrap_or_else(|| {
54-
(priv_key.n().bits() + 7) / 8 - 2 - hash.size()
43+
(priv_key.n().bits() + 7) / 8 - 2 - H::output_size()
5544
});
5645

5746
let mut salt = vec![0; salt_len];
5847
rng.fill(&mut salt[..]);
5948

60-
return sign_pss_with_salt(rng, priv_key, hash, hashed, &salt)
49+
return sign_pss_with_salt::<_, H>(rng, priv_key, hashed, &salt, blind)
6150
}
6251

6352

6453
// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
6554
// Note that hashed must be the result of hashing the input message using the
6655
// given hash function. salt is a random sequence of bytes whose length will be
6756
// later used to verify the signature.
68-
fn sign_pss_with_salt<T: Rng>(rng: &mut T, priv_key: &RSAPrivateKey, hash: &Hashes, hashed: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
57+
fn sign_pss_with_salt<T: Rng, H: Digest>(rng: &mut T, priv_key: &RSAPrivateKey, hashed: &[u8], salt: &[u8], blind: bool) -> Result<Vec<u8>> {
6958
let n_bits = priv_key.n().bits();
7059
let mut em = vec![0; ((n_bits - 1) + 7) / 8];
71-
match hash {
72-
Hashes::SHA1 => {
73-
emsa_pss_encode(&mut em, hashed, n_bits - 1, salt, Sha1::new())?;
74-
},
75-
Hashes::SHA2_256 => {
76-
emsa_pss_encode(&mut em, hashed, n_bits - 1, salt, Sha256::new())?;
77-
},
78-
_ => unimplemented!()
79-
}
60+
emsa_pss_encode::<H>(&mut em, hashed, n_bits - 1, salt)?;
8061

81-
let mut m = BigUint::from_bytes_be(&em);
82-
let mut c = internals::decrypt_and_check(Some(rng), priv_key, &m)?.to_bytes_be();
62+
let m = BigUint::from_bytes_be(&em);
63+
64+
let blind_rng = if blind {
65+
Some(rng)
66+
} else {
67+
None
68+
};
69+
70+
let c = internals::decrypt_and_check(blind_rng, priv_key, &m)?.to_bytes_be();
8371

8472
let mut s = vec![0; (n_bits + 7) / 8];
8573
copy_with_left_pad(&mut s, &c);
8674
return Ok(s)
8775
}
8876

89-
fn emsa_pss_encode<H: Digest>(em: &mut [u8], m_hash: &[u8], em_bits: usize, salt: &[u8], mut hash: H) -> Result<()> {
77+
fn emsa_pss_encode<H: Digest>(em: &mut [u8], m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result<()> {
9078
// See [1], section 9.1.1
9179
let h_len = H::output_size();
9280
let s_len = salt.len();
@@ -125,11 +113,13 @@ fn emsa_pss_encode<H: Digest>(em: &mut [u8], m_hash: &[u8], em_bits: usize, salt
125113
//
126114
// 6. Let H = Hash(M'), an octet string of length h_len.
127115
let prefix = [0u8; 8];
116+
let mut hash = H::new();
117+
128118
hash.input(&prefix);
129119
hash.input(m_hash);
130120
hash.input(salt);
131121

132-
let hashed = hash.result_reset();
122+
let hashed = hash.result();
133123
h.copy_from_slice(&hashed);
134124

135125
// 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2
@@ -143,7 +133,7 @@ fn emsa_pss_encode<H: Digest>(em: &mut [u8], m_hash: &[u8], em_bits: usize, salt
143133
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
144134
//
145135
// 10. Let maskedDB = DB \xor dbMask.
146-
mgf1_xor(db, &mut hash, &h);
136+
mgf1_xor(db, &mut H::new(), &h);
147137

148138
// 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
149139
// maskedDB to zero.
@@ -155,7 +145,7 @@ fn emsa_pss_encode<H: Digest>(em: &mut [u8], m_hash: &[u8], em_bits: usize, salt
155145
return Ok(())
156146
}
157147

158-
fn emsa_pss_verify<H: Digest>(m_hash: &[u8], em: &mut [u8], em_bits: usize, s_len: Option<usize>, mut hash: H) -> Result<()> {
148+
fn emsa_pss_verify<H: Digest>(m_hash: &[u8], em: &mut [u8], em_bits: usize, s_len: Option<usize>) -> Result<()> {
159149
// 1. If the length of M is greater than the input limitation for the
160150
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
161151
// and stop.
@@ -193,7 +183,7 @@ fn emsa_pss_verify<H: Digest>(m_hash: &[u8], em: &mut [u8], em_bits: usize, s_le
193183
// 7. Let dbMask = MGF(H, em_len - h_len - 1)
194184
//
195185
// 8. Let DB = maskedDB \xor dbMask
196-
mgf1_xor(db, &mut hash, &*h);
186+
mgf1_xor(db, &mut H::new(), &*h);
197187

198188

199189
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
@@ -237,6 +227,7 @@ fn emsa_pss_verify<H: Digest>(m_hash: &[u8], em: &mut [u8], em_bits: usize, s_le
237227
// 13. Let H' = Hash(M'), an octet string of length hLen.
238228
let prefix = [0u8; 8];
239229

230+
let mut hash = H::new();
240231
hash.input(prefix);
241232
hash.input(m_hash);
242233
hash.input(salt);
@@ -283,6 +274,8 @@ fn inc_counter(counter: &mut [u8]) {
283274
}
284275

285276
/// Mask generation function
277+
///
278+
/// Will reset the Digest before returning.
286279
fn mgf1_xor<T: Digest>(out: &mut [u8], digest: &mut T, seed: &[u8]) {
287280
let mut counter = vec![0u8; 4];
288281
let mut i = 0;
@@ -310,8 +303,7 @@ fn mgf1_xor<T: Digest>(out: &mut [u8], digest: &mut T, seed: &[u8]) {
310303

311304
#[cfg(test)]
312305
mod test {
313-
use crate::{PaddingScheme, RSAPrivateKey, RSAPublicKey, PublicKey};
314-
use crate::hash::Hashes;
306+
use crate::{RSAPrivateKey, RSAPublicKey};
315307

316308
use num_bigint::BigUint;
317309
use num_traits::{FromPrimitive, Num};
@@ -355,7 +347,7 @@ mod test {
355347
let sig = hex::decode(test[1]).unwrap();
356348

357349
pub_key
358-
.verify(PaddingScheme::PSS, Some(&Hashes::SHA1), &digest, &sig)
350+
.verify_pss::<Sha1>(&digest, &sig)
359351
.expect("failed to verify");
360352
}
361353
}
@@ -369,11 +361,11 @@ mod test {
369361
for test in &tests {
370362
let digest = Sha1::digest(test.as_bytes()).to_vec();
371363
let sig = priv_key
372-
.sign_blinded(&mut thread_rng(), PaddingScheme::PSS, Some(&Hashes::SHA1), &digest)
364+
.sign_pss::<Sha1, _>(&mut thread_rng(), &digest, None, true)
373365
.expect("failed to sign");
374366

375367
priv_key
376-
.verify(PaddingScheme::PSS, Some(&Hashes::SHA1), &digest, &sig)
368+
.verify_pss::<Sha1>(&digest, &sig)
377369
.expect("failed to verify");
378370
}
379371
}

0 commit comments

Comments
 (0)