Skip to content

Commit 347957a

Browse files
Oppenentropidelic
andauthored
feat: turn Poseidon parameters into trait (#706)
* feat: turn Poseidon parameters into trait Rather than having the parameters as internal state, we consider them to be compile time constants, and implement them via stateless structures with associated constants, which in turn implement the `PermutationParameters` trait. We use a seal trait for the actual code to avoid overriding methods by mistake. This is the `Poseidon` trait. It's automatically implemented for all implementors of `PermutationParameters`. * Format files --------- Co-authored-by: Mariano Nicolini <[email protected]>
1 parent f5208ff commit 347957a

File tree

8 files changed

+613
-833
lines changed

8 files changed

+613
-833
lines changed

crypto/src/hash/poseidon/mod.rs

+234
Original file line numberDiff line numberDiff line change
@@ -1 +1,235 @@
1+
use lambdaworks_math::field::element::FieldElement as FE;
2+
3+
pub mod parameters;
14
pub mod starknet;
5+
6+
use parameters::PermutationParameters;
7+
8+
mod private {
9+
use super::*;
10+
11+
pub trait Sealed {}
12+
13+
impl<P: PermutationParameters> Sealed for P {}
14+
}
15+
16+
pub trait Poseidon: PermutationParameters + self::private::Sealed {
17+
fn hades_permutation(state: &mut [FE<Self::F>]);
18+
fn full_round(state: &mut [FE<Self::F>], round_number: usize);
19+
fn partial_round(state: &mut [FE<Self::F>], round_number: usize);
20+
fn mix(state: &mut [FE<Self::F>]);
21+
fn hash(x: &FE<Self::F>, y: &FE<Self::F>) -> FE<Self::F>;
22+
fn hash_single(x: &FE<Self::F>) -> FE<Self::F>;
23+
fn hash_many(inputs: &[FE<Self::F>]) -> FE<Self::F>;
24+
}
25+
26+
impl<P: PermutationParameters> Poseidon for P {
27+
fn hades_permutation(state: &mut [FE<Self::F>]) {
28+
let mut round_number = 0;
29+
for _ in 0..P::N_FULL_ROUNDS / 2 {
30+
Self::full_round(state, round_number);
31+
round_number += 1;
32+
}
33+
for _ in 0..P::N_PARTIAL_ROUNDS {
34+
Self::partial_round(state, round_number);
35+
round_number += 1;
36+
}
37+
for _ in 0..P::N_FULL_ROUNDS / 2 {
38+
Self::full_round(state, round_number);
39+
round_number += 1;
40+
}
41+
}
42+
43+
fn full_round(state: &mut [FE<Self::F>], round_number: usize) {
44+
for (i, value) in state.iter_mut().enumerate() {
45+
*value = &(*value) + &P::ROUND_CONSTANTS[round_number * P::N_ROUND_CONSTANTS_COLS + i];
46+
*value = value.pow(P::ALPHA);
47+
}
48+
Self::mix(state);
49+
}
50+
51+
fn partial_round(state: &mut [FE<Self::F>], round_number: usize) {
52+
for (i, value) in state.iter_mut().enumerate() {
53+
*value = &(*value) + &P::ROUND_CONSTANTS[round_number * P::N_ROUND_CONSTANTS_COLS + i];
54+
}
55+
56+
state[P::STATE_SIZE - 1] = state[P::STATE_SIZE - 1].pow(P::ALPHA);
57+
58+
Self::mix(state);
59+
}
60+
61+
fn mix(state: &mut [FE<Self::F>]) {
62+
let mut new_state: Vec<FE<Self::F>> = Vec::with_capacity(P::STATE_SIZE);
63+
for i in 0..P::STATE_SIZE {
64+
let mut new_e = FE::zero();
65+
for (j, current_state) in state.iter().enumerate() {
66+
let mut mij = P::MDS_MATRIX[i * P::N_MDS_MATRIX_COLS + j].clone();
67+
mij = mij * current_state;
68+
new_e += mij;
69+
}
70+
new_state.push(new_e);
71+
}
72+
state.clone_from_slice(&new_state[0..P::STATE_SIZE]);
73+
}
74+
75+
fn hash(x: &FE<Self::F>, y: &FE<Self::F>) -> FE<Self::F> {
76+
let mut state: Vec<FE<Self::F>> = vec![x.clone(), y.clone(), FE::from(2)];
77+
Self::hades_permutation(&mut state);
78+
let x = &state[0];
79+
x.clone()
80+
}
81+
82+
fn hash_single(x: &FE<Self::F>) -> FE<Self::F> {
83+
let mut state: Vec<FE<Self::F>> = vec![x.clone(), FE::zero(), FE::from(1)];
84+
Self::hades_permutation(&mut state);
85+
let x = &state[0];
86+
x.clone()
87+
}
88+
89+
fn hash_many(inputs: &[FE<Self::F>]) -> FE<Self::F> {
90+
let r = P::RATE; // chunk size
91+
let m = P::STATE_SIZE; // state size
92+
93+
// Pad input with 1 followed by 0's (if necessary).
94+
let mut values = inputs.to_owned();
95+
values.push(FE::from(1));
96+
values.resize(((values.len() + r - 1) / r) * r, FE::zero());
97+
98+
assert!(values.len() % r == 0);
99+
let mut state: Vec<FE<Self::F>> = vec![FE::zero(); m];
100+
101+
// Process each block
102+
for block in values.chunks(r) {
103+
let mut block_state: Vec<FE<Self::F>> =
104+
state[0..r].iter().zip(block).map(|(s, b)| s + b).collect();
105+
block_state.extend_from_slice(&state[r..]);
106+
107+
Self::hades_permutation(&mut block_state);
108+
state = block_state;
109+
}
110+
111+
state[0].clone()
112+
}
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use super::*;
118+
use crate::hash::poseidon::starknet::PoseidonCairoStark252;
119+
use lambdaworks_math::field::{
120+
element::FieldElement, fields::fft_friendly::stark_252_prime_field::Stark252PrimeField,
121+
};
122+
123+
#[test]
124+
fn test_hades_permutation() {
125+
// Initialize a state to test. The exact contents will depend on your specific use case.
126+
let mut state: Vec<FieldElement<Stark252PrimeField>> = vec![
127+
FieldElement::<Stark252PrimeField>::from_hex("0x9").unwrap(),
128+
FieldElement::<Stark252PrimeField>::from_hex("0xb").unwrap(),
129+
FieldElement::<Stark252PrimeField>::from_hex("0x2").unwrap(),
130+
];
131+
132+
PoseidonCairoStark252::hades_permutation(&mut state);
133+
134+
// Compare the result to the expected output. You will need to know the expected output for your specific test case.
135+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
136+
"0x510f3a3faf4084e3b1e95fd44c30746271b48723f7ea9c8be6a9b6b5408e7e6",
137+
)
138+
.unwrap();
139+
let expected_state1 = FieldElement::<Stark252PrimeField>::from_hex(
140+
"0x4f511749bd4101266904288021211333fb0a514cb15381af087462fa46e6bd9",
141+
)
142+
.unwrap();
143+
let expected_state2 = FieldElement::<Stark252PrimeField>::from_hex(
144+
"0x186f6dd1a6e79cb1b66d505574c349272cd35c07c223351a0990410798bb9d8",
145+
)
146+
.unwrap();
147+
148+
assert_eq!(state[0], expected_state0);
149+
assert_eq!(state[1], expected_state1);
150+
assert_eq!(state[2], expected_state2);
151+
}
152+
#[test]
153+
fn test_hash() {
154+
let x = FieldElement::<Stark252PrimeField>::from_hex("0x123456").unwrap();
155+
let y = FieldElement::<Stark252PrimeField>::from_hex("0x789101").unwrap();
156+
157+
let z = PoseidonCairoStark252::hash(&x, &y);
158+
159+
// Compare the result to the expected output. You will need to know the expected output for your specific test case.
160+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
161+
"0x2fb6e1e8838d4b850877944f0a13340dd5810f01f5d4361c54b22b4abda3248",
162+
)
163+
.unwrap();
164+
165+
assert_eq!(z, expected_state0);
166+
}
167+
168+
#[test]
169+
fn test_hash_single() {
170+
let x = FieldElement::<Stark252PrimeField>::from_hex("0x9").unwrap();
171+
172+
let z = PoseidonCairoStark252::hash_single(&x);
173+
174+
// Compare the result to the expected output. You will need to know the expected output for your specific test case.
175+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
176+
"0x3bb3b91c714cb47003947f36dadc98326176963c434cd0a10320b8146c948b3",
177+
)
178+
.unwrap();
179+
180+
assert_eq!(z, expected_state0);
181+
}
182+
183+
#[test]
184+
fn test_hash_many() {
185+
let a = FieldElement::<Stark252PrimeField>::from_hex("0x1").unwrap();
186+
let b = FieldElement::<Stark252PrimeField>::from_hex("0x2").unwrap();
187+
let c = FieldElement::<Stark252PrimeField>::from_hex("0x3").unwrap();
188+
let d = FieldElement::<Stark252PrimeField>::from_hex("0x4").unwrap();
189+
let e = FieldElement::<Stark252PrimeField>::from_hex("0x5").unwrap();
190+
let f = FieldElement::<Stark252PrimeField>::from_hex("0x6").unwrap();
191+
192+
let ins = vec![a, b, c, d, e, f];
193+
let z = PoseidonCairoStark252::hash_many(&ins);
194+
195+
// Compare the result to the expected output. You will need to know the expected output for your specific test case.
196+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
197+
"0xf50993f0797e4cc05734a47daeb214fde2d444ef6619a7c1f7c8e0924feb0b",
198+
)
199+
.unwrap();
200+
assert_eq!(z, expected_state0);
201+
202+
let ins = vec![a];
203+
let z = PoseidonCairoStark252::hash_many(&ins);
204+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
205+
"0x579e8877c7755365d5ec1ec7d3a94a457eff5d1f40482bbe9729c064cdead2",
206+
)
207+
.unwrap();
208+
assert_eq!(z, expected_state0);
209+
210+
let ins = vec![a, b];
211+
let z = PoseidonCairoStark252::hash_many(&ins);
212+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
213+
"0x371cb6995ea5e7effcd2e174de264b5b407027a75a231a70c2c8d196107f0e7",
214+
)
215+
.unwrap();
216+
assert_eq!(z, expected_state0);
217+
218+
let ins = vec![a, b, c];
219+
let z = PoseidonCairoStark252::hash_many(&ins);
220+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
221+
"0x2f0d8840bcf3bc629598d8a6cc80cb7c0d9e52d93dab244bbf9cd0dca0ad082",
222+
)
223+
.unwrap();
224+
assert_eq!(z, expected_state0);
225+
226+
let ins = vec![a, b, c, d];
227+
let z = PoseidonCairoStark252::hash_many(&ins);
228+
let expected_state0 = FieldElement::<Stark252PrimeField>::from_hex(
229+
"0x26e3ad8b876e02bc8a4fc43dad40a8f81a6384083cabffa190bcf40d512ae1d",
230+
)
231+
.unwrap();
232+
233+
assert_eq!(z, expected_state0);
234+
}
235+
}
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use lambdaworks_math::field::{element::FieldElement as FE, traits::IsPrimeField};
2+
3+
/// Parameters for Poseidon
4+
/// MDS constants and rounds constants are stored as references to slices
5+
/// representing matrices of `N_MDS_MATRIX_ROWS * N_MDS_MATRIX_COLS` and
6+
/// `N_ROUND_CONSTANTS_ROWS * N_ROUND_CONSTANTS_COLS` respectively.
7+
/// We use this representation rather than an array because we can't use the
8+
/// associated constants for dimension, requiring many generic parameters
9+
/// otherwise.
10+
pub trait PermutationParameters {
11+
type F: IsPrimeField + 'static;
12+
13+
const RATE: usize;
14+
const CAPACITY: usize;
15+
const ALPHA: u32;
16+
const N_FULL_ROUNDS: usize;
17+
const N_PARTIAL_ROUNDS: usize;
18+
const STATE_SIZE: usize = Self::RATE + Self::CAPACITY;
19+
20+
const MDS_MATRIX: &'static [FE<Self::F>];
21+
const N_MDS_MATRIX_ROWS: usize;
22+
const N_MDS_MATRIX_COLS: usize;
23+
24+
const ROUND_CONSTANTS: &'static [FE<Self::F>];
25+
const N_ROUND_CONSTANTS_ROWS: usize;
26+
const N_ROUND_CONSTANTS_COLS: usize;
27+
}

0 commit comments

Comments
 (0)