Skip to content

Commit a2bafc0

Browse files
apply CR for pss implementation
1 parent f3e99b4 commit a2bafc0

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

src/pss.rs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,17 @@ fn sign_pss_with_salt<T: Rng, SK: PrivateKey>(
5757
digest: &mut dyn DynDigest,
5858
) -> Result<Vec<u8>> {
5959
let em_bits = priv_key.n().bits() - 1;
60-
let mut em = vec![0; (em_bits + 7) / 8];
61-
emsa_pss_encode(&mut em, hashed, em_bits, salt, digest)?;
60+
let em = emsa_pss_encode(hashed, em_bits, salt, digest)?;
6261

6362
priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size())
6463
}
6564

6665
fn emsa_pss_encode(
67-
em: &mut [u8],
6866
m_hash: &[u8],
6967
em_bits: usize,
7068
salt: &[u8],
7169
hash: &mut dyn DynDigest,
72-
) -> Result<()> {
70+
) -> Result<Vec<u8>> {
7371
// See [1], section 9.1.1
7472
let h_len = hash.output_size();
7573
let s_len = salt.len();
@@ -90,11 +88,9 @@ fn emsa_pss_encode(
9088
return Err(Error::Internal);
9189
}
9290

93-
if em.len() != em_len {
94-
return Err(Error::Internal);
95-
}
91+
let mut em = vec![0; em_len];
9692

97-
let (db, h) = em.split_at_mut(em_len - s_len - h_len - 2 + 1 + s_len);
93+
let (db, h) = em.split_at_mut(em_len - h_len - 1);
9894
let h = &mut h[..(em_len - 1) - db.len()];
9995

10096
// 4. Generate a random octet string salt of length s_len; if s_len = 0,
@@ -136,7 +132,7 @@ fn emsa_pss_encode(
136132
// 12. Let EM = maskedDB || H || 0xbc.
137133
em[em_len - 1] = 0xBC;
138134

139-
return Ok(());
135+
Ok(em)
140136
}
141137

142138
fn emsa_pss_verify(
@@ -158,7 +154,7 @@ fn emsa_pss_verify(
158154

159155
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
160156
let em_len = em.len(); //(em_bits + 7) / 8;
161-
if em_len < h_len + 2 {
157+
if em_len < h_len + s_len.unwrap_or_default() + 2 {
162158
return Err(Error::Verification);
163159
}
164160

@@ -171,7 +167,7 @@ fn emsa_pss_verify(
171167
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
172168
// let H be the next hLen octets.
173169
let (db, h) = em.split_at_mut(em_len - h_len - 1);
174-
let h = &mut h[..(em_len - 1) - (em_len - h_len - 1)];
170+
let h = &mut h[..h_len];
175171

176172
// 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
177173
// maskedDB are not all equal to zero, output "inconsistent" and
@@ -204,14 +200,11 @@ fn emsa_pss_verify(
204200
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
205201
// position is "position 1") does not have hexadecimal value 0x01,
206202
// output "inconsistent" and stop.
207-
for e in &db[..em_len - h_len - s_len - 2] {
208-
if *e != 0x00 {
209-
return Err(Error::Verification);
210-
}
211-
}
212-
if db[em_len - h_len - s_len - 2] != 0x01 {
203+
let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
204+
if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 {
213205
return Err(Error::Verification);
214206
}
207+
215208
s_len
216209
}
217210
};
@@ -233,7 +226,7 @@ fn emsa_pss_verify(
233226
let h0 = hash.finalize_reset();
234227

235228
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
236-
if Into::<bool>::into(h0.ct_eq(h)) {
229+
if h0.ct_eq(h).into() {
237230
Ok(())
238231
} else {
239232
Err(Error::Verification)

0 commit comments

Comments
 (0)